diff --git a/.clang-format b/.clang-format index c931e8f068..c6488cb358 100644 --- a/.clang-format +++ b/.clang-format @@ -52,7 +52,7 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 2 Cpp11BracedListStyle: true -DerivePointerAlignment: true +DerivePointerAlignment: false DisableFormat: false ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true @@ -94,7 +94,7 @@ PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left +PointerAlignment: Right RawStringFormats: - Language: Cpp Delimiters: diff --git a/CMakeLists.txt b/CMakeLists.txt index 46804c8dde..7dceca7ad7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,6 @@ cmake_minimum_required(VERSION 3.14) project (MindSpore) - include(${CMAKE_SOURCE_DIR}/cmake/options.cmake) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/") if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") diff --git a/README.md b/README.md index e465f8e3e1..3de87d3fec 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ Check out how MindSpore Open Governance [works](https://gitee.com/mindspore/comm - [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers. - IRC channel at `#mindspore` (only for meeting minutes logging purpose) -- Video Conferencing: meet.jit.si +- Video Conferencing: https://meet.jit.si - Mailing-list: https://mailweb.mindspore.cn/postorius/lists ## Contributing diff --git a/build.bat b/build.bat index 76d7f19262..ddb2e8affe 100644 --- a/build.bat +++ b/build.bat @@ -31,6 +31,7 @@ cd %CD%/mindspore cmake -DCMAKE_BUILD_TYPE=Release -DENABLE_CPU=ON -DENABLE_MINDDATA=ON -DUSE_GLOG=ON -G "CodeBlocks - MinGW Makefiles" ../.. IF NOT %errorlevel% == 0 ( + echo "cmake fail." goto run_fail ) @@ -40,6 +41,7 @@ IF "%1%" == "" ( cmake --build . --target package -- -j%1% ) IF NOT %errorlevel% == 0 ( + echo "build fail." goto run_fail ) @@ -49,6 +51,6 @@ goto run_eof :run_fail cd %BASEPATH% - echo "build fail." + set errorlevel=1 :run_eof diff --git a/build.sh b/build.sh index 7550d76c8f..b48014ed93 100755 --- a/build.sh +++ b/build.sh @@ -23,30 +23,30 @@ export BUILD_PATH="${BASEPATH}/build/" usage() { echo "Usage:" - echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge|cpu] [-m infer|train] \\" - echo " [-a on|off] [-g on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" + echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" + echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K]" echo "" echo "Options:" echo " -d Debug mode" echo " -r Release mode, default mode" echo " -v Display build command" - echo " -c Enable code coverage switch, default off" - echo " -t Run testcases switch, default on" + echo " -c Enable code coverage, default off" + echo " -t Run testcases, default on" echo " -g Use glog to output log, default on" echo " -h Print usage" echo " -b Select other backend, available: \\" - echo " ge:graph engine, cpu" - echo " -m Select mode, available: infer, train, default is infer " + echo " ge:graph engine" + echo " -m Select graph engine backend mode, available: infer, train, default is infer" echo " -a Enable ASAN, default off" - echo " -p Enable pipeline profile, default off" + echo " -p Enable pipeline profile, print to stdout, default off" + echo " -R Enable pipeline profile, record to json, default off" echo " -i Enable increment building, default off" echo " -L Enable load ANF-IR as input of 'infer', default off" - echo " -R Enable the time_line record, default off" echo " -j[n] Set the threads when building (Default: -j8)" echo " -e Use gpu, d or cpu" echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" - echo " -Q Enable dump end to end, default off" + echo " -Q Enable dump memory, default off" echo " -D Enable dumping of function graph ir, default on" echo " -z Compile dataset & mindrecord, default on" echo " -M Enable MPI and NCCL for GPU training, default on" diff --git a/cmake/dependency_graphengine.cmake b/cmake/dependency_graphengine.cmake index 2a90cc1458..533f9f8246 100644 --- a/cmake/dependency_graphengine.cmake +++ b/cmake/dependency_graphengine.cmake @@ -64,7 +64,7 @@ set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # force __FILE__ to show relative path of file, from source directory -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") add_subdirectory(${GE_SOURCE_DIR}/src/common/graph) if(ENABLE_D) add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) diff --git a/cmake/external_libs/tvm_gpu.cmake b/cmake/external_libs/tvm_gpu.cmake index 2edec52ee1..834e2d159d 100644 --- a/cmake/external_libs/tvm_gpu.cmake +++ b/cmake/external_libs/tvm_gpu.cmake @@ -1,16 +1,15 @@ -set(incubator_tvm_gpu_CFLAGS "-pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") -set(incubator_tvm_gpu_CXXFLAGS "-std=c++11 -pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") -set(USE_CUDA "ON") +set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") +set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") mindspore_add_pkg(incubator_tvm_gpu VER 0.6.0 LIBS tvm URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz MD5 9cbbd32545a776023acabbba270449fe + CUSTOM_CMAKE ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/ SUBMODULES ${dlpack_DIRPATH} ${dmlc-core_DIRPATH} ${rang_DIRPATH} SOURCEMODULES topi/python/topi python/tvm PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/find_library.patch - ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch - ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch - CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON) -include_directories(${incubator_tvm_gpu_INC}) -add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm) + ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch + ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch + CMAKE_OPTION " ") +add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm) \ No newline at end of file diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 501522a44b..f0a5dc594c 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -205,7 +205,7 @@ set(MS_FIND_NO_DEFAULT_PATH ${MS_FIND_NO_DEFAULT_PATH} PARENT_SCOPE) function(mindspore_add_pkg pkg_name ) set(options ) - set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH) + set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH CUSTOM_CMAKE) set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES SUBMODULES SOURCEMODULES) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) @@ -281,10 +281,6 @@ function(mindspore_add_pkg pkg_name ) file(GLOB ${pkg_name}_INSTALL_SUBMODULE ${_SUBMODULE_FILE}/*) file(COPY ${${pkg_name}_INSTALL_SUBMODULE} DESTINATION ${${pkg_name}_SOURCE_DIR}/3rdparty/${_SUBMODENAME}) endforeach (_SUBMODULE_FILE) - foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) - file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) - file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) - endforeach (_SUBMODULE_FILE) else() set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) endif () @@ -304,12 +300,20 @@ function(mindspore_add_pkg pkg_name ) message(FATAL_ERROR "Failed patch: ${_LF_PATCH_FILE}") endif() endforeach(_PATCH_FILE) - + foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) + file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) + file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) + endforeach (_SUBMODULE_FILE) file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) if(NOT ${pkg_name}_LOCK_RET EQUAL "0") message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") endif() + if (PKG_CUSTOM_CMAKE) + file(GLOB ${pkg_name}_cmake ${PKG_CUSTOM_CMAKE}/CMakeLists.txt) + file(COPY ${${pkg_name}_cmake} DESTINATION ${${pkg_name}_SOURCE_DIR}) + endif () + if(${pkg_name}_SOURCE_DIR) if (PKG_HEAD_ONLY) file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) diff --git a/example/alexnet_cifar10/README.md b/example/alexnet_cifar10/README.md new file mode 100644 index 0000000000..0efd3ca1bf --- /dev/null +++ b/example/alexnet_cifar10/README.md @@ -0,0 +1,58 @@ +# AlexNet Example + +## Description + +Training AlexNet with CIFAR-10 dataset in MindSpore. + +This is the simple tutorial for training AlexNet in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the CIFAR-10 dataset at . The directory structure is as follows: + +``` +├─cifar-10-batches-bin +│ +└─cifar-10-verify-bin +``` + +## Running the example + +```python +# train AlexNet, hyperparameter setting in config.py +python train.py --data_path cifar-10-batches-bin +``` + +You can get loss with each step similar to this: + +```bash +epoch: 1 step: 1, loss is 2.2791853 +... +epoch: 1 step: 1536, loss is 1.9366643 +epoch: 1 step: 1537, loss is 1.6983616 +epoch: 1 step: 1538, loss is 1.0221305 +... +``` + +Then, test AlexNet according to network model +```python +# test AlexNet, 1 epoch training accuracy is up to 51.1%; 10 epoch training accuracy is up to 81.2% +python eval.py --data_path cifar-10-verify-bin --mode test --ckpt_path checkpoint_alexnet-1_1562.ckpt +``` + +## Note +There are some optional arguments: + +```bash +-h, --help show this help message and exit +--device_target {Ascend,GPU} + device where the code will be implemented (default: Ascend) +--data_path DATA_PATH + path where the dataset is saved +--dataset_sink_mode DATASET_SINK_MODE + dataset_sink_mode is False or True +``` + +You can run ```python train.py -h``` or ```python eval.py -h``` to get more information. diff --git a/example/convert_to_mindrecord/README.md b/example/convert_to_mindrecord/README.md new file mode 100644 index 0000000000..8d3b25e311 --- /dev/null +++ b/example/convert_to_mindrecord/README.md @@ -0,0 +1,46 @@ +# MindRecord generating guidelines + + + +- [MindRecord generating guidelines](#mindrecord-generating-guidelines) + - [Create work space](#create-work-space) + - [Implement data generator](#implement-data-generator) + - [Run data generator](#run-data-generator) + + + +## Create work space + +Assume the dataset name is 'xyz' +* Create work space from template + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + cp -r template xyz + ``` + +## Implement data generator + +Edit dictionary data generator +* Edit file + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + vi xyz/mr_api.py + ``` + + Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented +- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. +- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() + + +Tricky for parallel run +- For imagenet, one directory can be a task. +- For TFRecord with multiple files, each file can be a task. +- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) + + +## Run data generator +* run python script + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + python writer.py --mindrecord_script imagenet [...] + ``` diff --git a/example/convert_to_mindrecord/imagenet/__init__.py b/example/convert_to_mindrecord/imagenet/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/imagenet/mr_api.py b/example/convert_to_mindrecord/imagenet/mr_api.py new file mode 100644 index 0000000000..e569b489b5 --- /dev/null +++ b/example/convert_to_mindrecord/imagenet/mr_api.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import os +import pickle + +######## mindrecord_schema begin ########## +mindrecord_schema = {"label": {"type": "int64"}, + "data": {"type": "bytes"}, + "file_name": {"type": "string"}} +######## mindrecord_schema end ########## + +######## Frozen code begin ########## +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: + ARG_LIST = pickle.load(mindrecord_argument_file_handle) +######## Frozen code end ########## + +parser = argparse.ArgumentParser(description='Mind record imagenet example') +parser.add_argument('--label_file', type=str, default="", help='label file') +parser.add_argument('--image_dir', type=str, default="", help='images directory') + +######## Frozen code begin ########## +args = parser.parse_args(ARG_LIST) +print(args) +######## Frozen code end ########## + + +def _user_defined_private_func(): + """ + Internal function for tasks list + + Return: + tasks list + """ + if not os.path.exists(args.label_file): + raise IOError("map file {} not exists".format(args.label_file)) + + label_dict = {} + with open(args.label_file) as file_handle: + line = file_handle.readline() + while line: + labels = line.split(" ") + label_dict[labels[1]] = labels[0] + line = file_handle.readline() + # get all the dir which are n02087046, n02094114, n02109525 + dir_paths = {} + for item in label_dict: + real_path = os.path.join(args.image_dir, label_dict[item]) + if not os.path.isdir(real_path): + print("{} dir is not exist".format(real_path)) + continue + dir_paths[item] = real_path + + if not dir_paths: + print("not valid image dir in {}".format(args.image_dir)) + return {}, {} + + dir_list = [] + for label in dir_paths: + dir_list.append(label) + return dir_list, dir_paths + + +dir_list_global, dir_paths_global = _user_defined_private_func() + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return len(dir_list_global) + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + + # get the filename, label and image binary as a dict + label = dir_list_global[task_id] + for item in os.listdir(dir_paths_global[label]): + file_name = os.path.join(dir_paths_global[label], item) + if not item.endswith("JPEG") and not item.endswith( + "jpg") and not item.endswith("jpeg"): + print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) + continue + data = {} + data["file_name"] = str(file_name) + data["label"] = int(label) + + # get the image data + image_file = open(file_name, "rb") + image_bytes = image_file.read() + image_file.close() + data["data"] = image_bytes + yield data diff --git a/example/convert_to_mindrecord/run_imagenet.sh b/example/convert_to_mindrecord/run_imagenet.sh new file mode 100644 index 0000000000..11f5dcff75 --- /dev/null +++ b/example/convert_to_mindrecord/run_imagenet.sh @@ -0,0 +1,8 @@ +#!/bin/bash +rm /tmp/imagenet/mr/* + +python writer.py --mindrecord_script imagenet \ +--mindrecord_file "/tmp/imagenet/mr/m" \ +--mindrecord_partitions 16 \ +--label_file "/tmp/imagenet/label.txt" \ +--image_dir "/tmp/imagenet/jpeg" diff --git a/example/convert_to_mindrecord/run_template.sh b/example/convert_to_mindrecord/run_template.sh new file mode 100644 index 0000000000..a4c5142c00 --- /dev/null +++ b/example/convert_to_mindrecord/run_template.sh @@ -0,0 +1,6 @@ +#!/bin/bash +rm /tmp/template/* + +python writer.py --mindrecord_script template \ +--mindrecord_file "/tmp/template/m" \ +--mindrecord_partitions 4 diff --git a/example/convert_to_mindrecord/template/__init__.py b/example/convert_to_mindrecord/template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/template/mr_api.py b/example/convert_to_mindrecord/template/mr_api.py new file mode 100644 index 0000000000..3f7d7dddf0 --- /dev/null +++ b/example/convert_to_mindrecord/template/mr_api.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import pickle + +# ## Parse argument + +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line + ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line +parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line + +# ## Your arguments below +# parser.add_argument(...) + +args = parser.parse_args(ARG_LIST) # Do NOT change this line +print(args) # Do NOT change this line + + +# ## Default mindrecord vars. Comment them unless default value has to be changed. +# mindrecord_index_fields = ['label'] +# mindrecord_header_size = 1 << 24 +# mindrecord_page_size = 1 << 25 + + +# define global vars here if necessary + + +# ####### Your code below ########## +mindrecord_schema = {"label": {"type": "int32"}} + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return 1 + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + print("task is {}".format(task_id)) + for i in range(256): + data = {} + data['label'] = i + yield data diff --git a/example/convert_to_mindrecord/writer.py b/example/convert_to_mindrecord/writer.py new file mode 100644 index 0000000000..d34f1fb1c7 --- /dev/null +++ b/example/convert_to_mindrecord/writer.py @@ -0,0 +1,152 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +######################## write mindrecord example ######################## +Write mindrecord by data dictionary: +python writer.py --mindrecord_script /YourScriptPath ... +""" +import argparse +import os +import pickle +import time +from importlib import import_module +from multiprocessing import Pool + +from mindspore.mindrecord import FileWriter + + +def _exec_task(task_id, parallel_writer=True): + """ + Execute task with specified task id + """ + print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) + imagenet_iter = mindrecord_dict_data(task_id) + batch_size = 2048 + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data_list.append(imagenet_iter.__next__()) + transform_count += 1 + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + except StopIteration: + if data_list: + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Mind record writer') + parser.add_argument('--mindrecord_script', type=str, default="template", + help='path where script is saved') + + parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", + help='written file name prefix') + + parser.add_argument('--mindrecord_partitions', type=int, default=1, + help='number of written files') + + parser.add_argument('--mindrecord_workers', type=int, default=8, + help='number of parallel workers') + + args = parser.parse_known_args() + + args, other_args = parser.parse_known_args() + + print(args) + print(other_args) + + with open('mr_argument.pickle', 'wb') as file_handle: + pickle.dump(other_args, file_handle) + + try: + mr_api = import_module(args.mindrecord_script + '.mr_api') + except ModuleNotFoundError: + raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) + + num_tasks = mr_api.mindrecord_task_number() + + print("Write mindrecord ...") + + mindrecord_dict_data = mr_api.mindrecord_dict_data + + # get number of files + writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) + + start_time = time.time() + + # set the header size + try: + header_size = mr_api.mindrecord_header_size + writer.set_header_size(header_size) + except AttributeError: + print("Default header size: {}".format(1 << 24)) + + # set the page size + try: + page_size = mr_api.mindrecord_page_size + writer.set_page_size(page_size) + except AttributeError: + print("Default page size: {}".format(1 << 25)) + + # get schema + try: + mindrecord_schema = mr_api.mindrecord_schema + except AttributeError: + raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") + + # create the schema + writer.add_schema(mindrecord_schema, "mindrecord_schema") + + # add the index + try: + index_fields = mr_api.mindrecord_index_fields + writer.add_index(index_fields) + except AttributeError: + print("Default index fields: all simple fields are indexes.") + + writer.open_and_set_header() + + task_list = list(range(num_tasks)) + + # set number of workers + num_workers = args.mindrecord_workers + + if num_tasks < 1: + num_tasks = 1 + + if num_workers > num_tasks: + num_workers = num_tasks + + if os.name == 'nt': + for window_task_id in task_list: + _exec_task(window_task_id, False) + elif num_tasks > 1: + with Pool(num_workers) as p: + p.map(_exec_task, task_list) + else: + _exec_task(0, False) + + ret = writer.commit() + + os.remove("{}".format("mr_argument.pickle")) + + end_time = time.time() + print("--------------------------------------------") + print("END. Total time: {}".format(end_time - start_time)) + print("--------------------------------------------") diff --git a/example/lenet_mnist/README.md b/example/lenet_mnist/README.md new file mode 100644 index 0000000000..fea92883c6 --- /dev/null +++ b/example/lenet_mnist/README.md @@ -0,0 +1,63 @@ +# LeNet Example + +## Description + +Training LeNet with MNIST dataset in MindSpore. + +This is the simple and basic tutorial for constructing a network in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the MNIST dataset at . The directory structure is as follows: + +``` +└─MNIST_Data + ├─test + │ t10k-images.idx3-ubyte + │ t10k-labels.idx1-ubyte + │ + └─train + train-images.idx3-ubyte + train-labels.idx1-ubyte +``` + +## Running the example + +```python +# train LeNet, hyperparameter setting in config.py +python train.py --data_path MNIST_Data +``` + +You can get loss with each step similar to this: + +```bash +epoch: 1 step: 1, loss is 2.3040335 +... +epoch: 1 step: 1739, loss is 0.06952668 +epoch: 1 step: 1740, loss is 0.05038793 +epoch: 1 step: 1741, loss is 0.05018193 +... +``` + +Then, test LeNet according to network model +```python +# test LeNet, after 1 epoch training, the accuracy is up to 96.5% +python eval.py --data_path MNIST_Data --mode test --ckpt_path checkpoint_lenet-1_1875.ckpt +``` + +## Note +There are some optional arguments: + +```bash +-h, --help show this help message and exit +--device_target {Ascend,GPU,CPU} + device where the code will be implemented (default: Ascend) +--data_path DATA_PATH + path where the dataset is saved +--dataset_sink_mode DATASET_SINK_MODE + dataset_sink_mode is False or True +``` + +You can run ```python train.py -h``` or ```python eval.py -h``` to get more information. diff --git a/graphengine b/graphengine index 70bb745b45..43a715bc46 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 70bb745b459ff9a0e7fc1008d15fe4b510f03da7 +Subproject commit 43a715bc461fd70b7837051a2f47f0a1b19c5859 diff --git a/mindspore/_akg/gpu/__init__.py b/mindspore/_akg/gpu/__init__.py index 2ac6d1adb1..f9db48c634 100644 --- a/mindspore/_akg/gpu/__init__.py +++ b/mindspore/_akg/gpu/__init__.py @@ -26,7 +26,12 @@ from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad from .mean import SimpleMean, gpu_schedule_SimpleMean from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad from .mul import Mul, gpu_schedule_Mul -from .hsigmoid import Hsigmoid, gpu_schedule_Hsigmoid -from .hsigmoid_grad import HsigmoidGrad, gpu_schedule_HsigmoidGrad -from .hswish import Hswish, gpu_schedule_Hswish -from .hswish_grad import HswishGrad, gpu_schedule_HswishGrad +from .hsigmoid import HSigmoid, gpu_schedule_HSigmoid +from .hsigmoid_grad import HSigmoidGrad, gpu_schedule_HSigmoidGrad +from .hswish import HSwish, gpu_schedule_HSwish +from .hswish_grad import HSwishGrad, gpu_schedule_HSwishGrad +from .logical_or import LogicalOr, gpu_schedule_LogicalOr +from .logical_not import LogicalNot, gpu_schedule_LogicalNot +from .logical_and import LogicalAnd, gpu_schedule_LogicalAnd +from .sub import Sub, gpu_schedule_Sub +from .less_equal import LessEqual, gpu_schedule_LessEqual diff --git a/mindspore/_akg/gpu/hsigmoid.py b/mindspore/_akg/gpu/hsigmoid.py index b9d5ea74c9..b313c2fd5a 100644 --- a/mindspore/_akg/gpu/hsigmoid.py +++ b/mindspore/_akg/gpu/hsigmoid.py @@ -33,9 +33,9 @@ def topi_nn_hsigmoid(x): (x(*i) + 3) / 6))) -def Hsigmoid(x): +def HSigmoid(x): """ - Hsigmoid + HSigmoid Args: x: @@ -45,9 +45,9 @@ def Hsigmoid(x): return topi_nn_hsigmoid(x) -def gpu_schedule_Hsigmoid(outs): +def gpu_schedule_HSigmoid(outs): """ - gpu schedule Hsigmoid + gpu schedule HSigmoid Args: outs: diff --git a/mindspore/_akg/gpu/hsigmoid_grad.py b/mindspore/_akg/gpu/hsigmoid_grad.py index d3e7ac6345..bdde4ed3ca 100644 --- a/mindspore/_akg/gpu/hsigmoid_grad.py +++ b/mindspore/_akg/gpu/hsigmoid_grad.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Hsigmoid grad""" +"""HSigmoid grad""" import _akg.topi as topi import _akg.tvm as tvm -def HsigmoidGrad(y_grad, x): +def HSigmoidGrad(y_grad, x): """ - HsigmoidGrad + HSigmoidGrad Args: y_grad: x: @@ -32,7 +32,7 @@ def HsigmoidGrad(y_grad, x): y_grad(*i) / 6))) -def gpu_schedule_HsigmoidGrad(outs): +def gpu_schedule_HSigmoidGrad(outs): """ gpu schedule ReLU6Grad Args: diff --git a/mindspore/_akg/gpu/hswish.py b/mindspore/_akg/gpu/hswish.py index 904c38c2a2..44fcf10918 100644 --- a/mindspore/_akg/gpu/hswish.py +++ b/mindspore/_akg/gpu/hswish.py @@ -33,9 +33,9 @@ def topi_nn_hswish(x): x(*i) * (x(*i) + 3) / 6))) -def Hswish(x): +def HSwish(x): """ - Hswish + HSwish Args: x: @@ -45,9 +45,9 @@ def Hswish(x): return topi_nn_hswish(x) -def gpu_schedule_Hswish(outs): +def gpu_schedule_HSwish(outs): """ - gpu schedule Hswish + gpu schedule HSwish Args: outs: diff --git a/mindspore/_akg/gpu/hswish_grad.py b/mindspore/_akg/gpu/hswish_grad.py index 5b38f07c84..cadbf0f663 100644 --- a/mindspore/_akg/gpu/hswish_grad.py +++ b/mindspore/_akg/gpu/hswish_grad.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""HswishGrad""" +"""HSwishGrad""" import _akg.topi as topi import _akg.tvm as tvm -def HswishGrad(y_grad, x): +def HSwishGrad(y_grad, x): """ - HswishGrad + HSwishGrad Args: y_grad: x: @@ -34,9 +34,9 @@ def HswishGrad(y_grad, x): return res6 -def gpu_schedule_HswishGrad(outs): +def gpu_schedule_HSwishGrad(outs): """ - gpu schedule HswishGrad + gpu schedule HSwishGrad Args: outs: diff --git a/mindspore/_akg/gpu/less_equal.py b/mindspore/_akg/gpu/less_equal.py new file mode 100644 index 0000000000..c58346e929 --- /dev/null +++ b/mindspore/_akg/gpu/less_equal.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""less_equal""" +import _akg.tvm +from _akg.ops.math import less_equal +from _akg.topi.generic import schedule_elemwise + +def LessEqual(x, y): + """LessEqual.""" + return less_equal.less_equal(x, y) + + +def gpu_schedule_LessEqual(outs): + """ + GPU schedule for LessEqual. + + Args: + outs (tvm.tensor.Tensor): Outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/gpu/logical_and.py b/mindspore/_akg/gpu/logical_and.py new file mode 100644 index 0000000000..6453901458 --- /dev/null +++ b/mindspore/_akg/gpu/logical_and.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""logical_and""" +import _akg.tvm +from _akg.ops.math import logical_and +from _akg.topi.generic import schedule_elemwise + +def LogicalAnd(x, y): + """LogicalAnd.""" + return logical_and.logical_and(x, y) + + +def gpu_schedule_LogicalAnd(outs): + """ + GPU schedule for LogicalAnd. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/gpu/logical_not.py b/mindspore/_akg/gpu/logical_not.py new file mode 100644 index 0000000000..0a38107187 --- /dev/null +++ b/mindspore/_akg/gpu/logical_not.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""logical_not""" +import _akg.tvm +from _akg.ops.math import logical_not +from _akg.topi.generic import schedule_elemwise + +def LogicalNot(x): + """LogicalNot.""" + return logical_not.logical_not(x) + + +def gpu_schedule_LogicalNot(outs): + """ + GPU schedule for LogicalNot. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/gpu/logical_or.py b/mindspore/_akg/gpu/logical_or.py new file mode 100644 index 0000000000..1bd49bedbc --- /dev/null +++ b/mindspore/_akg/gpu/logical_or.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""logical_or""" +import _akg.tvm +from _akg.ops.math import logical_or +from _akg.topi.generic import schedule_elemwise + +def LogicalOr(x, y): + """LogicalOr.""" + return logical_or.logical_or(x, y) + + +def gpu_schedule_LogicalOr(outs): + """ + GPU schedule for LogicalOr. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/gpu/sub.py b/mindspore/_akg/gpu/sub.py new file mode 100644 index 0000000000..611e4228fd --- /dev/null +++ b/mindspore/_akg/gpu/sub.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""sub""" +import _akg.tvm +from _akg.ops.math import sub +from _akg.topi.generic import schedule_elemwise + +def Sub(x, y): + """Sub.""" + return sub.sub(x, y) + + +def gpu_schedule_Sub(outs): + """ + GPU schedule for Sub. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/ops/math/less_equal.py b/mindspore/_akg/ops/math/less_equal.py new file mode 100644 index 0000000000..5a566fbbca --- /dev/null +++ b/mindspore/_akg/ops/math/less_equal.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: lessequal""" +import _akg.tvm +import _akg.topi +from _akg.utils.dsl_create import produce_shapes +from _akg.utils import validation_check as vc_util + + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) +def less_equal(input1, input2): + """ + Check whether input1 lessequals to input2. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + input2 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. If input1 lessequal to input2 return True, else return False. + """ + shape1 = [x.value for x in input1.shape] + shape2 = [x.value for x in input2.shape] + vc_util.check_shape(shape1) + vc_util.check_shape(shape2) + + shape1, shape2, shape = produce_shapes(shape1, shape2) + + vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) + dtype = input1.dtype + + # get lessequal compute + t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") + f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") + + input1_bro = _akg.topi.broadcast_to(input1, shape) + input2_bro = _akg.topi.broadcast_to(input2, shape) + c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] <= input2_bro[indice], + t_value[indice], f_value[indice]), name="C") + res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") + + return res diff --git a/mindspore/_akg/ops/math/logical_and.py b/mindspore/_akg/ops/math/logical_and.py new file mode 100644 index 0000000000..480d4e1741 --- /dev/null +++ b/mindspore/_akg/ops/math/logical_and.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: logical_and""" +import _akg.tvm +import _akg.topi +from _akg.utils import validation_check as vc_util + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) +def logical_and(input1, input2): + """ + Compute logical_and of input1 and input2. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + input2 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. LogicalAnd of input1 and input2. + """ + + vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) + + shape1 = [x.value for x in input1.shape] + shape2 = [x.value for x in input2.shape] + vc_util.check_shape(shape1) + vc_util.check_shape(shape2) + + res = _akg.topi.logical_and(input1, input2) + return res diff --git a/mindspore/_akg/ops/math/logical_not.py b/mindspore/_akg/ops/math/logical_not.py new file mode 100644 index 0000000000..9befe7e816 --- /dev/null +++ b/mindspore/_akg/ops/math/logical_not.py @@ -0,0 +1,32 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: logical_not""" +import _akg.tvm +import _akg.topi +from _akg.utils import validation_check as vc_util + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor) +def logical_not(input1): + """ + Compute logical_not of input1. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. + """ + res = _akg.topi.logical_not(input1) + return res diff --git a/mindspore/_akg/ops/math/logical_or.py b/mindspore/_akg/ops/math/logical_or.py new file mode 100644 index 0000000000..8fb0b80567 --- /dev/null +++ b/mindspore/_akg/ops/math/logical_or.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: logical_or""" +import _akg.tvm +import _akg.topi +from _akg.utils import validation_check as vc_util + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) +def logical_or(input1, input2): + """ + Compute logical_or of input1 and input2. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + input2 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. LogicalOr of input1 and input2. + """ + + vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) + + shape1 = [x.value for x in input1.shape] + shape2 = [x.value for x in input2.shape] + vc_util.check_shape(shape1) + vc_util.check_shape(shape2) + + res = _akg.topi.logical_or(input1, input2) + return res diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index dc2f71cc18..707ca748b4 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -14,10 +14,12 @@ # ============================================================================ """Check parameters.""" import re +import inspect +import math from enum import Enum -from functools import reduce +from functools import reduce, wraps from itertools import repeat -from collections import Iterable +from collections.abc import Iterable import numpy as np from mindspore import log as logger @@ -98,7 +100,7 @@ class Validator: """validator for checking input parameters""" @staticmethod - def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): + def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): """ Method for judging relation between two int values or list/tuple made up of ints. @@ -108,18 +110,29 @@ class Validator: rel_fn = Rel.get_fns(rel) if not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') - msg_prefix = f'For {prim_name} the' if prim_name else "The" - raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') @staticmethod def check_integer(arg_name, arg_value, value, rel, prim_name): """Integer value judgment.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) + excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) - raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' - f' but got {arg_value}.') + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' + f' with type `{type(arg_value).__name__}`.') + return arg_value + + @staticmethod + def check_number(arg_name, arg_value, value, rel, prim_name): + """Integer value judgment.""" + rel_fn = Rel.get_fns(rel) + if not rel_fn(arg_value, value): + rel_str = Rel.get_strs(rel).format(value) + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') return arg_value @staticmethod @@ -127,15 +140,53 @@ class Validator: """Method for checking whether an int value is in some range.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) + excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) - raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' - f' but got {arg_value}.') + raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' + f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') return arg_value + @staticmethod + def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): + """Method for checking whether a numeric value is in some range.""" + rel_fn = Rel.get_fns(rel) + if not rel_fn(arg_value, lower_limit, upper_limit): + rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') + return arg_value + + @staticmethod + def check_string(arg_name, arg_value, valid_values, prim_name): + """Checks whether a string is in some value list""" + if isinstance(arg_value, str) and arg_value in valid_values: + return arg_value + if len(valid_values) == 1: + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' + f' but got {arg_value}.') + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' + f' but got {arg_value}.') + + @staticmethod + def check_pad_value_by_mode(pad_mode, padding, prim_name): + """Validates value of padding according to pad_mode""" + if pad_mode != 'pad' and padding != 0: + raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") + return padding + + @staticmethod + def check_float_positive(arg_name, arg_value, prim_name): + """Float type judgment.""" + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + if isinstance(arg_value, float): + if arg_value > 0: + return arg_value + raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") + raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") + @staticmethod def check_subclass(arg_name, type_, template_type, prim_name): - """Check whether some type is sublcass of another type""" + """Checks whether some type is subclass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -144,32 +195,51 @@ class Validator: f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') @staticmethod - def check_tensor_type_same(args, valid_values, prim_name): - """check whether the element types of input tensors are the same.""" + def check_const_input(arg_name, arg_value, prim_name): + """Checks valid value.""" + if arg_value is None: + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') + + @staticmethod + def check_type_same(args, valid_values, prim_name): + """Checks whether the types of inputs are the same.""" def _check_tensor_type(arg): arg_key, arg_val = arg - Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) - elem_type = arg_val.element_type() + elem_type = arg_val if not elem_type in valid_values: - raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' - f' but `{arg_key}` is {elem_type}.') + type_names = [] + for t in valid_values: + type_names.append(str(t)) + types_info = '[' + ", ".join(type_names) + ']' + raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' + f' but got {elem_type}.') return (arg_key, elem_type) def _check_types_same(arg1, arg2): arg1_name, arg1_type = arg1 arg2_name, arg2_type = arg2 if arg1_type != arg2_type: - raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,' - f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') + raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' + f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') return arg1 elem_types = map(_check_tensor_type, args.items()) reduce(_check_types_same, elem_types) + @staticmethod + def check_tensor_type_same(args, valid_values, prim_name): + """Checks whether the element types of input tensors are the same.""" + tensor_types = [mstype.tensor_type(t) for t in valid_values] + Validator.check_type_same(args, tensor_types, prim_name) @staticmethod - def check_scalar_or_tensor_type_same(args, valid_values, prim_name): - """check whether the types of inputs are the same. if the input args are tensors, check their element types""" + def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): + """ + Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. + + If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. + """ + def _check_argument_type(arg): arg_key, arg_val = arg if isinstance(arg_val, type(mstype.tensor)): @@ -182,16 +252,19 @@ class Validator: def _check_types_same(arg1, arg2): arg1_name, arg1_type = arg1 arg2_name, arg2_type = arg2 - excp_flag = False + except_flag = False if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): arg1_type = arg1_type.element_type() arg2_type = arg2_type.element_type() elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): pass + elif allow_mix: + arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type + arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type else: - excp_flag = True + except_flag = True - if excp_flag or arg1_type != arg2_type: + if except_flag or arg1_type != arg2_type: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') return arg1 @@ -199,13 +272,15 @@ class Validator: @staticmethod def check_value_type(arg_name, arg_value, valid_types, prim_name): - """Check whether a values is instance of some types.""" + """Checks whether a value is instance of some types.""" + valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + def raise_error_msg(): """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) - raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' - f'{"one of " if num_types > 1 else ""}' + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and @@ -216,6 +291,34 @@ class Validator: return arg_value raise_error_msg() + @staticmethod + def check_type_name(arg_name, arg_type, valid_types, prim_name): + """Checks whether a type in some specified types""" + valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + + def get_typename(t): + return t.__name__ if hasattr(t, '__name__') else str(t) + + if arg_type in valid_types: + return arg_type + type_names = [get_typename(t) for t in valid_types] + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + if len(valid_types) == 1: + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' + f' but got {get_typename(arg_type)}.') + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' + f' but got {get_typename(arg_type)}.') + + @staticmethod + def check_float_legal_value(arg_name, arg_value, prim_name): + """Checks whether a legal value of float type""" + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + if isinstance(arg_value, float): + if math.isinf(arg_value) or math.isnan(arg_value): + raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.") + return arg_value + raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") + class ParamValidator: """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" @@ -268,9 +371,9 @@ class ParamValidator: @staticmethod def check_isinstance(arg_name, arg_value, classes): - """Check arg isintance of classes""" + """Check arg isinstance of classes""" if not isinstance(arg_value, classes): - raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.') + raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') return arg_value @staticmethod @@ -284,7 +387,7 @@ class ParamValidator: @staticmethod def check_subclass(arg_name, type_, template_type, with_type_of=True): - """Check whether some type is sublcass of another type""" + """Check whether some type is subclass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -302,9 +405,9 @@ class ParamValidator: @staticmethod def check_bool(arg_name, arg_value): - """Check arg isintance of bool""" + """Check arg isinstance of bool""" if not isinstance(arg_value, bool): - raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.') + raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') return arg_value @staticmethod @@ -671,3 +774,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII): if re.match(reg, target, flag) is None: raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) return True + + +def args_type_check(*type_args, **type_kwargs): + """Check whether input data type is correct.""" + + def type_check(func): + sig = inspect.signature(func) + bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal bound_types + bound_values = sig.bind(*args, **kwargs) + argument_dict = bound_values.arguments + if "kwargs" in bound_types: + bound_types = bound_types["kwargs"] + if "kwargs" in argument_dict: + argument_dict = argument_dict["kwargs"] + for name, value in argument_dict.items(): + if name in bound_types: + if value is not None and not isinstance(value, bound_types[name]): + raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) + return func(*args, **kwargs) + + return wrapper + + return type_check diff --git a/mindspore/_extends/__init__.py b/mindspore/_extends/__init__.py index 5eabfcd97c..91e1192e7e 100644 --- a/mindspore/_extends/__init__.py +++ b/mindspore/_extends/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ -Extension functions. +Extension functions. Python functions that will be called in the c++ parts of MindSpore. """ diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index 9fb357597e..7178cd2634 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -83,6 +83,7 @@ convert_object_map = { T.mul: multitype_ops.mul, T.truediv: multitype_ops.div, T.getitem: multitype_ops.getitem, + T.setitem: multitype_ops.setitem, T.floordiv: multitype_ops.floordiv, T.mod: multitype_ops.mod, T.pow: multitype_ops.pow_, @@ -113,12 +114,12 @@ convert_object_map = { T.map: C.HyperMap(), T.partial: F.partial, T.zip: C.zip_operation, + T.print: F.print_, # custom define operation T.iter: M.ms_iter, T.next: M.ms_next, T.hasnext: M.hasnext, - T.setitem: M.setitem, T.make_tuple: F.make_tuple, T.make_dict: F.make_dict, diff --git a/mindspore/_extends/parse/trope.py b/mindspore/_extends/parse/trope.py index 9f8f67fba5..7b40adcd16 100644 --- a/mindspore/_extends/parse/trope.py +++ b/mindspore/_extends/parse/trope.py @@ -27,7 +27,7 @@ from operator import ( # noqa # support system function call from builtins import ( # noqa - bool, getattr, setattr, len, iter, next, pow, range, map, zip + bool, getattr, setattr, len, iter, next, pow, range, map, zip, print ) # support functools @@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', 'matmul', 'getitem', 'setitem', 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', - 'partial', + 'partial', 'print', 'exp', 'log', 'sin', 'cos', 'tan'] diff --git a/mindspore/_extends/pynative_helper.py b/mindspore/_extends/pynative_helper.py deleted file mode 100644 index 0b93ab926b..0000000000 --- a/mindspore/_extends/pynative_helper.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Pynative mode help module.""" -from inspect import signature -from functools import wraps - - -def args_type_check(*type_args, **type_kwargs): - """Check whether input data type is correct.""" - - def type_check(func): - sig = signature(func) - bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments - - @wraps(func) - def wrapper(*args, **kwargs): - nonlocal bound_types - bound_values = sig.bind(*args, **kwargs) - argument_dict = bound_values.arguments - if "kwargs" in bound_types: - bound_types = bound_types["kwargs"] - if "kwargs" in argument_dict: - argument_dict = argument_dict["kwargs"] - for name, value in argument_dict.items(): - if name in bound_types: - if value is not None and not isinstance(value, bound_types[name]): - raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) - return func(*args, **kwargs) - - return wrapper - - return type_check diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 8c33b9051c..eb33de1c4b 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -394,10 +394,6 @@ if(USE_GLOG) target_link_libraries(_c_expression PRIVATE mindspore::glog) endif() -if(ENABLE_GPU) - target_link_libraries(_c_expression PRIVATE mindspore::tvm) -endif() - if(ENABLE_DUMP_PROTO) target_link_libraries(_c_expression PRIVATE mindspore::protobuf) endif() diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index b4e02c8fe6..1174be1f48 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -103,17 +103,39 @@ const std::map, DataTypeTransMode> mode_map{ template void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(EXCEPTION) << "Invalid src or dst data size."; + } for (size_t idx = 0; idx != data_size; idx++) { SrcT src_data = static_cast(args.data)[idx]; static_cast(dst)[idx] = static_cast(src_data); } } +template +void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) { + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(EXCEPTION) << "Invalid src or dst data size."; + } + auto src_data = static_cast(args.data); + auto half_data = static_cast(dst); + for (size_t i = 0; i < data_size; i++) { + half_data[i] = Eigen::half(src_data[i]); + } +} + bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) { switch (mode) { case FROM_FLOAT_TO_FLOAT16: device::FloatToHalf(dst, args.data, data_size); break; + case FROM_INT32_TO_FLOAT16: + TransDataSrc2Fp16(args, dst, data_size); + break; case FROM_FLOAT16_TO_FLOAT: device::HalfToFloat(dst, args.data, data_size); break; @@ -372,27 +394,27 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { } bool TransDataType(const TypeIdArgs &args, void *result) { - MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " - << TypeIdLabel(args.device_data_type); + MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type); MS_EXCEPTION_IF_NULL(result); - std::pair type_info(args.host_data_type, args.device_data_type); + std::pair type_info(args.src_type, args.dst_type); auto iter = mode_map.find(type_info); if (iter == mode_map.end()) { - MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type) - << ", dst_type:" << TypeIdLabel(args.device_data_type); + MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type) + << ", dst_type:" << TypeIdLabel(args.dst_type); return false; } auto trans_mode = iter->second; - auto type_size = TypeIdSize(args.device_data_type); - if (type_size < 1) { - MS_LOG(ERROR) << "Invalid host data type."; + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (src_id < 1 || dst_id < 1) { + MS_LOG(ERROR) << "Invalid src or dst data type."; return false; } - if (args.host_shape_size < 1) { - MS_LOG(ERROR) << "Invalid host data size."; + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(ERROR) << "Invalid src or dst data size."; return false; } - if (!CastKernel(args, result, args.host_shape_size, trans_mode)) { + if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) { MS_LOG(ERROR) << "Failed to trans datatype.."; return false; } diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 054fa89a06..e6e81ed359 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -31,9 +31,12 @@ namespace mindspore { namespace trans { struct TypeIdArgs { const void *data; - size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d - TypeId host_data_type; - TypeId device_data_type; + size_t src_size; + size_t dst_size; + TypeId src_type; + TypeId dst_type; + size_t src_shape_size; + size_t dst_shape_size; }; struct FormatArgs { diff --git a/mindspore/ccsrc/common/utils.cc b/mindspore/ccsrc/common/utils.cc index 328a059113..7109c121e5 100644 --- a/mindspore/ccsrc/common/utils.cc +++ b/mindspore/ccsrc/common/utils.cc @@ -23,7 +23,7 @@ namespace common { const int CACHED_STR_NUM = 1 << 8; const int CACHED_STR_MASK = CACHED_STR_NUM - 1; std::vector STR_HOLDER(CACHED_STR_NUM); -const char* SafeCStr(const std::string&& str) { +const char *SafeCStr(const std::string &&str) { static std::atomic index{0}; uint32_t cur_index = index++; cur_index = cur_index & CACHED_STR_MASK; diff --git a/mindspore/ccsrc/common/utils.h b/mindspore/ccsrc/common/utils.h index 7cee933ac8..8f6e8f7c0c 100644 --- a/mindspore/ccsrc/common/utils.h +++ b/mindspore/ccsrc/common/utils.h @@ -21,16 +21,16 @@ #include #define DISABLE_COPY_AND_ASSIGN(ClassType) \ - ClassType(const ClassType&) = delete; \ - ClassType& operator=(const ClassType&) = delete; + ClassType(const ClassType &) = delete; \ + ClassType &operator=(const ClassType &) = delete; namespace mindspore { namespace common { -inline const char* SafeCStr(const std::string& str) { return str.c_str(); } -const char* SafeCStr(const std::string&& str); +inline const char *SafeCStr(const std::string &str) { return str.c_str(); } +const char *SafeCStr(const std::string &&str); -static inline std::string GetEnv(const std::string& envvar) { - const char* value = ::getenv(envvar.c_str()); +static inline std::string GetEnv(const std::string &envvar) { + const char *value = ::getenv(envvar.c_str()); if (value == nullptr) { return std::string(); diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index 0bc4065ac9..8e9b2664dc 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -12,10 +12,10 @@ endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--image-base -Wl,0x10000000") -endif() ############################# Options ################################ +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + add_definitions(-D _CRT_RAND_S) +endif () if (ENABLE_GPUQUE) add_definitions(-D ENABLE_GPUQUE) message(STATUS "GPU queue is enabled") diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 5f61c86f06..c3dfeafe48 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -28,10 +28,11 @@ #include "dataset/engine/datasetops/source/manifest_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" - #include "dataset/util/random.h" #include "dataset/util/status.h" #include "utils/log_adapter.h" @@ -45,7 +46,9 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kShuffle, &DEPipeline::ParseShuffleOp}, {kMindrecord, &DEPipeline::ParseMindRecordOp}, {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, {kBatch, &DEPipeline::ParseBatchOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, {kRepeat, &DEPipeline::ParseRepeatOp}, {kSkip, &DEPipeline::ParseSkipOp}, {kZip, &DEPipeline::ParseZipOp}, @@ -61,7 +64,8 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kVoc, &DEPipeline::ParseVOCOp}, {kCifar10, &DEPipeline::ParseCifar10Op}, {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}}; + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { try { @@ -501,6 +505,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * return Status::OK(); } +Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + + if (args["predicate"].is_none()) { + RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "predicate") { + py::handle op = args["predicate"]; + if (!py::isinstance(op)) { + RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); + } + py::function predicate_func = op.cast(); + (void)builder->SetPredicateFunc(std::move(predicate_func)); + } else if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)builder->SetInColNames(in_col_names); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr) { if (args["count"].is_none()) { std::string err_msg = "Error: count is invalid or not set."; @@ -589,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr return Status::OK(); } +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + // Right now barrier should only take num_rows_per_buffer = 1 + // The reason for this is because having it otherwise can lead to blocking issues + // See barrier_op.h for more details + (void)builder->SetRowsPerBuffer(1); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "condition_name") { + (void)builder->SetConditionName(ToString(value)); + } else if (key == "condition_func") { + (void)builder->SetConditionFunc(value.cast()); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr) { int32_t prefetch_size = 0; if (args.contains("prefetch_size")) { @@ -670,8 +733,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr * return Status::OK(); } -DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); } - Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr) { // Required arguments std::shared_ptr builder = std::make_shared(); @@ -985,5 +1046,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr) { + // Required arguments + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + (void)builder->SetTextFilesList(ToStringVector(args["dataset_files"])); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + (void)builder->SetNumDevices(ToInt(value)); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 6ff7bb091c..7f9c6c459a 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -40,6 +40,7 @@ enum OpName { kShuffle, kMindrecord, kBatch, + kBarrier, kCache, kRepeat, kSkip, @@ -58,7 +59,8 @@ enum OpName { kVoc, kCifar10, kCifar100, - kCelebA + kCelebA, + kTextFile }; // The C++ binder class that we expose to the python script. @@ -106,12 +108,16 @@ class DEPipeline { Status ParseMapOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseFilterOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr); Status ParseSkipOp(const py::dict &args, std::shared_ptr *ptr); Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr); Status ParseRenameOp(const py::dict &args, std::shared_ptr *ptr); @@ -120,8 +126,6 @@ class DEPipeline { Status ParseZipOp(const py::dict &args, std::shared_ptr *ptr); - DsOpPtr ParseFilterOp(const py::dict &args) const; - Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr); Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr); @@ -148,6 +152,8 @@ class DEPipeline { Status ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseTextFileOp(const py::dict &args, std::shared_ptr *ptr); + private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 076f2ecc36..ea2e8352da 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -24,7 +24,6 @@ #endif #include "dataset/kernels/image/cut_out_op.h" #include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/distort_bounding_box_crop_op.h" #include "dataset/kernels/image/hwc_to_chw_op.h" #include "dataset/kernels/image/image_utils.h" #include "dataset/kernels/image/normalize_op.h" @@ -40,6 +39,7 @@ #include "dataset/kernels/image/rescale_op.h" #include "dataset/kernels/image/resize_bilinear_op.h" #include "dataset/kernels/image/resize_op.h" +#include "dataset/kernels/image/uniform_aug_op.h" #include "dataset/kernels/data/type_cast_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h" @@ -53,11 +53,14 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "dataset/engine/datasetops/source/sampler/python_sampler.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/jagged_connector.h" +#include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/kernels/data/to_float16_op.h" #include "dataset/util/random.h" #include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_sample.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -150,9 +153,14 @@ void bindDatasetOps(py::module *m) { }); (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", [](const std::string &path) { + .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { int64_t count = 0; - THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count)); + std::shared_ptr op; + if (py::hasattr(sampler, "_create_for_minddataset")) { + auto create = sampler.attr("_create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); return count; }); @@ -176,6 +184,17 @@ void bindDatasetOps(py::module *m) { THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count)); return count; }); + + (void)py::class_>(*m, "TextFileOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); + return count; + }); } void bindTensor(py::module *m) { (void)py::class_(*m, "GlobalContext") @@ -251,6 +270,10 @@ void bindTensorOps1(py::module *m) { .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); + (void)py::class_>( + *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") + .def(py::init(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps); + (void)py::class_>( *m, "ResizeBilinearOp", "Tensor operation to resize an image using " @@ -345,18 +368,6 @@ void bindTensorOps3(py::module *m) { } void bindTensorOps4(py::module *m) { - (void)py::class_>( - *m, "DistortBoundingBoxCropOp", - "Tensor operator to crop an image randomly as long as the cropped image has sufficient " - "overlap with any one bounding box associated with original image" - "Takes aspect ratio of the generated crop box, the intersection ratio of crop box and bounding box," - "crop ratio lower and upper bounds" - "Optional parameters: number of attempts for crop, number of attempts of crop box generation") - .def(py::init(), py::arg("aspect_ratio"), py::arg("intersect_ratio"), - py::arg("crop_ratio_lower_bound"), py::arg("crop_ratio_upper_bound"), - py::arg("max_attempts") = DistortBoundingBoxCropOp::kDefMaxAttempts, - py::arg("box_gen_attempts") = DistortBoundingBoxCropOp::kDefBoxGenAttempts); - (void)py::class_>( *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") .def(py::init(), py::arg("data_type")) @@ -415,16 +426,30 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "SequentialSampler") .def(py::init<>()); + (void)py::class_>(*m, "SubsetRandomSampler") .def(py::init>(), py::arg("indices")); (void)py::class_>( *m, "MindrecordSubsetRandomSampler") .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + (void)py::class_>( + *m, "MindrecordPkSampler") + .def(py::init([](int64_t kVal, bool shuffle) { + if (shuffle == true) { + return std::make_shared("label", kVal, std::numeric_limits::max(), + GetSeed()); + } else { + return std::make_shared("label", kVal); + } + })); (void)py::class_>(*m, "WeightedRandomSampler") .def(py::init, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), py::arg("replacement")); + + (void)py::class_>(*m, "PythonSampler") + .def(py::init(), py::arg("pySampler")); } void bindInfoObjects(py::module *m) { @@ -443,6 +468,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("STORAGE", OpName::kStorage) .value("SHUFFLE", OpName::kShuffle) .value("BATCH", OpName::kBatch) + .value("BARRIER", OpName::kBarrier) .value("MINDRECORD", OpName::kMindrecord) .value("CACHE", OpName::kCache) .value("REPEAT", OpName::kRepeat) @@ -463,7 +489,8 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("VOC", OpName::kVoc) .value("CIFAR10", OpName::kCifar10) .value("CIFAR100", OpName::kCifar100) - .value("CELEBA", OpName::kCelebA); + .value("CELEBA", OpName::kCelebA) + .value("TEXTFILE", OpName::kTextFile); (void)py::enum_(m, "InterpolationMode", py::arithmetic()) .value("DE_INTER_LINEAR", InterpolationMode::kLinear) diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index b865c54260..40de887aea 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -25,12 +25,14 @@ #include "dataset/core/tensor_shape.h" #include "dataset/engine/data_schema.h" #include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/barrier_op.h" #include "dataset/engine/datasetops/batch_op.h" #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/device_queue_op.h" #include "dataset/engine/datasetops/map_op.h" #include "dataset/engine/datasetops/project_op.h" #include "dataset/engine/datasetops/rename_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "dataset/engine/datasetops/repeat_op.h" #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/datasetops/shuffle_op.h" diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index a566d51f5c..3f41f27726 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) c DS_ASSERT(data_); switch (type_.value()) { - CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t); + CASE_PRINT_HEX(DataType::DE_BOOL, bool); CASE_PRINT_HEX(DataType::DE_INT8, int8_t); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 655a739ada..9e8272d513 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT dataset_op.cc parallel_op.cc pipeline_op.cc + barrier_op.cc batch_op.cc device_queue_op.cc map_op.cc @@ -14,5 +15,6 @@ add_library(engine-datasetops OBJECT take_op.cc shuffle_op.cc zip_op.cc + filter_op.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc new file mode 100644 index 0000000000..b0ea7dbd07 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/barrier_op.h" +#include +#include "dataset/core/constants.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/db_connector.h" +#include "dataset/core/config_manager.h" +#include "dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +BarrierOp::Builder::Builder() { + // Some arguments to the BarrierOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the BarrierOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } + +Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, + builder_condition_func_); + return Status::OK(); +} + +// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions +BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func) + : PipelineOp(op_connector_size), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + clean_up_(false), + eof_(false), + condition_name_(condition_name), + condition_function_(condition_func) {} + +// destructor +BarrierOp::~BarrierOp() {} + +// Entry point for Barrier, called by launch() +Status BarrierOp::operator()() { + // The children_num_ parameter needs to be put here + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // create child iterator, right now this barrier is a pipeline operator + int32_t worker_id = 0; + int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Loop until eof is true + while (!eof_) { + // Create new table to put the new tensor rows + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + + // we have to output new buffer with possibly different buffer size, possibly one row + while (!clean_up_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + curr_buffer->set_column_name_map(col_name_id_map_); + MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (clean_up_) { + MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; + // Send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + // 5 handle eof + // propagate eof here. + MS_LOG(INFO) << "Barrier operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status BarrierOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; + clean_up_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row = {}; + // use iterator to get next row and invoke pyfunc wait + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + // This epoch is empty + return Status::OK(); + } + // Pack this first row into our tensor table + // first row we also have to check if we should block + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + // At this point we have 1 row produced, we take the old column map id and use it in the new table + // Initializing col_name_id_map_ from the first data buffer. + col_name_id_map_ = child_iterator_->col_name_id_map(); + // the update code below shouldn't do anything bad if the column name already exists. + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status BarrierOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); + } + TensorRow new_row = {}; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// function executes a py_func and blocks until condition becomes true. +Status BarrierOp::blockCond() { + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // we have condition name, however the flexibility is in python today + try { + // Invoke python function + py::object ret_py_obj = condition_function_(); + // Process the return value + if (!py::isinstance(ret_py_obj)) { + return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +// fetches next Barrier buffer row +Status BarrierOp::getNextTensorRow(TensorRow *new_row) { + // iterate over all iterators and generate a row + RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row->empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; + clean_up_ = true; + // If we picked up an eof here, then we are completely done. + if ((child_iterator_)->eof_handled()) { + MS_LOG(INFO) << "Barrier operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } + return Status::OK(); +} + +// A function that prints info about the Operator +void BarrierOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first + PipelineOp::Print(out, show_all); + out << "\nBarrierOp:\n" + << "\nCondition " << condition_name_ << "\n\n"; +} + +// overwrite function and handle eof +Status BarrierOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status BarrierOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h new file mode 100644 index 0000000000..8be55fba7e --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ + +#include +#include +#include +#include +#include +#include "dataset/core/tensor.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/pipeline_op.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has +// been received. This signal is given from python layer. The current barrier design respects the +// rows per buffer design and will only output a buffer with rows once it has received rows per buffer +// signals from python. + +class BarrierOp : public PipelineOp { + public: + // The nested builder class inside of the BarrierOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param const std::string & condition_name + // @return Builder setter method returns reference to the builder. + Builder &SetConditionName(const std::string &condition_name) { + builder_condition_name_ = condition_name; + return *this; + } + + // Setter method. + // @param py::function condition_func - blocking condition function + // @return Builder setter method returns reference to the builder. + Builder &SetConditionFunc(py::function condition_func) { + builder_condition_func_ = condition_func; + return *this; + } + + // The builder "build" method creates the BarrierOp dataset Operator. + // @return shared_ptr to the new BarrierOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::string builder_condition_name_; + py::function builder_condition_func_; + + Status SanityCheck() const; + }; + + // Constructor for BarrierOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + // @param condition_name - the condition name associated with this operator + // @param condition_func - the blocking condition check per row + // @note - currently rows_per_buffer should = 1 for barrier. + // The reason for this is having other values would complicate how the pipeline behaves with other operators + // One example of such case is having batch after barrier. Batch would be waiting for data and having + // rows per buffer in this case can result in hanging + BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func); + + // Destructor + ~BarrierOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Barrier + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { + bo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Handles preprocessing of the main loop, used when starting new epoch + // @param table - a table of tensors to be moved into a buffer + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table - a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Gets next tensor row and sets control signals + Status getNextTensorRow(TensorRow *new_row); + + // This function runs the wait function on condition + Status blockCond(); + + private: + // clean up variable to return imcomplete buffer + bool clean_up_; + // end of file state, we stop reading data and shut down + bool eof_; + // rows per buffer + int32_t rows_per_buffer_; + // buffer_id + int32_t buffer_id_; + // local variable to keep track of the buffer information + std::unordered_map col_name_id_map_; + // iterator to pull new rows, we only have one child + std::unique_ptr child_iterator_; + // condition name, to support multiple barriers + std::string condition_name_; + // Function pointer of blocking function + py::function condition_function_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc new file mode 100644 index 0000000000..ce312ce3d9 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/filter_op.h" +#include +#include +#include +#include +#include +#include "dataset/core/config_manager.h" +#include "dataset/core/constants.h" +#include "dataset/core/global_context.h" +#include "dataset/core/tensor.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/db_connector.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { + +Status FilterOp::Builder::SanityCheck() { + std::string err; + err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; + err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); +} + +FilterOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status FilterOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, + builder_predicate_func_); + return Status::OK(); +} + +FilterOp::FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func) + : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} + +Status FilterOp::operator()() { + // The operator class just starts off threads by calling the tree_ function. + RETURN_UNEXPECTED_IF_NULL(tree_); + // Synchronize with TaskManager. + TaskManager::FindMe()->Post(); + filter_queues_.Init(num_workers_, oc_queue_size_); + RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK(Collector()); + return Status::OK(); +} + +Status FilterOp::EofReceived(int32_t) { return Status::OK(); } + +Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } + +// Validating if each of the input_columns exists in the DataBuffer. +Status FilterOp::ValidateInColumns(const std::unordered_map &col_name_id_map, + const std::vector *input_columns) { + for (const auto &inCol : *input_columns) { + bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +// A print method typically used for debugging. +void FilterOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first. + ParallelOp::Print(out, show_all); + + // Then display our own stuff. + out << "\nFilterOp:"; + out << "\n Input column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } +} + +Status FilterOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + std::unique_ptr in_buffer; + bool worker_stop = false; + while (worker_stop == false) { + // Getting a databuffer to work on. + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); + if (in_buffer->eoe()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); + continue; + } else if (in_buffer->eof()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); + worker_stop = true; + continue; + } + + RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_)); + + // if the databuffer was all filtered, it is marked as kFilterEmpty. + // if the databuffer was partially filtered, it is marked as kFilterPartial. + // if the databuffer was not filtered, it is marked as kFilterFull. + int32_t num_rows = in_buffer->NumRows(); + std::unique_ptr new_tensor_table; + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table)); + + if (new_tensor_table->empty()) { + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); + } else if (new_tensor_table->size() == num_rows) { + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); + } else { // kFilterPartial + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); + } + } + return Status::OK(); +} + +Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out) { + *out = std::make_unique(); + int32_t num_rows = in_buffer->NumRows(); + for (int32_t i = 0; i < num_rows; i++) { + TensorRow to_process; + TensorRow cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + if (in_columns_.empty() == true) { + MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; + to_process = cur_row; + } else { + std::unordered_map col_map = in_buffer->column_name_map(); + (void)std::transform( + in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process), + [&cur_row, &col_map](const auto &it) -> std::shared_ptr { return cur_row[col_map[it]]; }); + } + bool predicate = true; + RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); + if (predicate) { + (*out)->push_back(std::move(cur_row)); + } + } + return Status::OK(); +} + +// if the filtered DataBuffer is written directly to out_connector_, +// the thread fetching data will block in a queue. +// Collector function will reorder the DataBuffer in order. +// for example in two work queues: +// int filter_queues_: +// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) +// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) +// after reorder in out_connector_: +// queue1: DB(data2) DB(data4) DB(eof) +// queue2: DB(eoe) DB(eoe) +Status FilterOp::Collector() { + bool collector_stop = false; + uint64_t task_id_cnt = 0; + uint64_t out_id_cnt = 0; + std::pair, filterCtrl> in_pair; + while (collector_stop == false) { + uint32_t w_id = task_id_cnt % num_workers_; + RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); + if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || + in_pair.second == filterCtrl::kFilterEoe) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + out_id_cnt++; + task_id_cnt++; + } else if (in_pair.second == filterCtrl::kFilterEof) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + collector_stop = true; + } else { // kFilterEmpty + task_id_cnt++; + } + } + return Status::OK(); +} + +// Private function for checking the column legality. +Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns) { + int32_t num_rows = in_buf->NumRows(); + int32_t num_cols = in_buf->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); + } + std::unordered_map col_name_id_map = in_buf->column_name_map(); + // Check if there is invalid column name in the inColumns. + RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns)); + return Status::OK(); +} + +Status FilterOp::CheckInput(const TensorRow &input) const { + for (auto &item : input) { + if (item == nullptr) { + RETURN_STATUS_UNEXPECTED("input is null."); + } + } + return Status::OK(); +} + +Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { + RETURN_IF_NOT_OK(CheckInput(input)); + // Acquire Python GIL. + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + // Transform input tensor vector into numpy array vector. + py::tuple input_args(input.size()); + for (size_t i = 0; i < input.size(); i++) { + py::array new_data; + RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); + input_args[i] = new_data; + } + // Invoke python function. + py::object ret_py_obj = predicate_func_(*input_args); + *out_predicate = ret_py_obj.cast(); + } catch (const py::error_already_set &e) { + std::stringstream ss; + ss << e.what() << std::endl; + ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; + return Status(StatusCode::kPyFuncException, ss.str()); + } + return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h new file mode 100644 index 0000000000..92312e0843 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h @@ -0,0 +1,181 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ + +#include +#include +#include +#include +#include +#include +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/queue.h" + +namespace mindspore { +namespace dataset { + +class FilterOp : public ParallelOp { + public: + // The nested builder class inside of the FilterOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args. + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPredicateFunc(py::function func) { + builder_predicate_func_ = std::move(func); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + builder_op_connector_size_ = connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new FilterOp object. + // @return Status. + Status Build(std::shared_ptr *ptr); + + private: + // Sanity check for builder class args. + // @return Status - The error code return. + Status SanityCheck(); + std::vector build_in_col_names_; + py::function builder_predicate_func_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + }; + + enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; + + // Constructor of FilterOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names,when it is empty the predicate will be + // applied all columns in the dataset. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + // @param predicate_func python callable which returns a boolean value. + FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func); + + // Destructor + ~FilterOp() = default; + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status The error code return + Status operator()() override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EofReceived(int32_t) override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging. + // @param out The output stream to write output to. + // @param show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + private: + // predicate_func python callable which returns a boolean value. + py::function predicate_func_; + + // Variable to store the column name that will feed to predicate function. + std::vector in_columns_; + + // Internal queue for filter. + QueueList, filterCtrl>> filter_queues_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of FilterOp, getting the data from previous Op, validating user specified column names, + // applying predicate to each of the data, filter the data when predicate result is false. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return. + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Filter the data by predicate function . + // @param in_buffer input data buffer. + // @param to_proess_indices Indices of columns to be processed. + // @param out data buffer that are filtered by predicate. + // @return Status The error code return. + Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out); + + // Collector databuffer. + // @return Status The error code return. + Status Collector(); + + // @param input tensor vector. + // @return Status - The error code return. + Status CheckInput(const TensorRow &input) const; + + // Invoke python func. + // @param input tensor vector. + // @param the result of predicate. + // @return Status - The error code return. + Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); + + // Private function for validating if each of the user specified input column names + // exist in the DataBuffer. + // @param col_name_id_map The column name to index mapping obtained from DataBuffer. + // @param input_columns The vector of input column names used in the current thread. + // @return Status The error code return. + Status ValidateInColumns(const std::unordered_map &col_name_id_map, + const std::vector *input_columns); + + // Private function for checking the column legality + // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory + // and is not shared with other threads. + // @param[out] to_process_indices Indices of columns that will feed to predicate. + // @param input_columns The vector of input column names used in the current thread. + Status CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns); +}; + +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc index 3f8d70b606..b6d603bac9 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc @@ -65,9 +65,6 @@ MapOp::MapOp(const std::vector &in_col_names, const std::vectorGetNextBuffer(&buff, 0)); is_eof = buff->eof(); RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); -#if defined(_WIN32) || defined(_WIN64) - if (is_eof) { - eof_worker_id_ = que_id; - for (int32_t id = 0; id < num_workers_; id++) { - if (id != eof_worker_id_) { - auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(local_queues_[id]->Add(std::move(eof_buffer))); - } - } - } -#endif que_id = (que_id + 1) % num_workers_; } } @@ -173,14 +159,6 @@ Status MapOp::WorkerEntry(int32_t worker_id) { continue; } else if (in_buffer->eof()) { // Calling base class EofReceived to forward eof buffer. -#if defined(_WIN32) || defined(_Win64) - if (perf_mode_) { - if (eof_worker_id_ == worker_id) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - break; - } -#endif RETURN_IF_NOT_OK(EofReceived(worker_id)); break; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h index 5e16bc3fed..4c9d27f9c7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h @@ -193,10 +193,6 @@ class MapOp : public ParallelOp { // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. bool perf_mode_; -#if defined(_WIN32) || defined(_WIN64) - // EOF worker id is only work on Performance mode, to record the worker id of queue which gets EOF - int32_t eof_worker_id_; -#endif // Private function for worker/thread to loop continuously. It comprises the main // logic of MapOp: getting the data from previous Op, validating user specified column names, // applying a list of TensorOps to each of the data, process the results and then diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc index bdf39b6a39..422c38f2f2 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc @@ -13,6 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#if defined(_WIN32) || defined(_WIN64) +#include +#endif #include #include #include @@ -86,7 +89,9 @@ Status ShuffleOp::SelfReset() { rng_ = std::mt19937_64(shuffle_seed_); } else { #if defined(_WIN32) || defined(_WIN64) - std::random_device random_device; + unsigned int number; + rand_s(&number); + std::mt19937 random_device{static_cast(number)}; #else std::random_device random_device("/dev/urandom"); #endif diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc index 90c160b5bf..d851f2c699 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc @@ -67,9 +67,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work } std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + // Drop first max_skips_ rows while (skip_count_ < max_skips_) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); if (buf->eoe() || buf->eof()) { break; } @@ -77,31 +78,24 @@ Status SkipOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work // Consider the rows of buffer more than 1 TensorRow drop_row; int row_num = buf->NumRows(); - for (int i = 0; i < row_num; i++) { + int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; + skip_count_ += drop_num; + for (int i = 0; i < drop_num; i++) { RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); - if (++skip_count_ == max_skips_) { - break; - } } - } - - // If buffer is none or the rows of buffer is 0, - // then get a buffer from child. - if (!buf || buf->NumRows() == 0) { - if (buf && buf->eof()) { - *p_buffer = std::move(buf); - return Status::OK(); + if (buf->NumRows() == 0) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); } - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); } - // Handling eoe and eof - if (buf->eoe() || buf->eof()) { + // Handling eoe + if (buf->eoe()) { RETURN_IF_NOT_OK(EoeReceived(worker_id)); - if (state_ == OpState::kDeOpIdle) { - *p_buffer = std::move(buf); - return Status::OK(); - } + } + + // Handling eof + if (buf->eof()) { + RETURN_IF_NOT_OK(EofReceived(worker_id)); } *p_buffer = std::move(buf); @@ -125,7 +119,7 @@ Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is a // Base-class override for handling cases when an eof is received. Status SkipOp::EofReceived(int32_t worker_id) { - MS_LOG(INFO) << "Skip operator EOF received, do nothing now."; + MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt index a7c0dfd725..8801205f6c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt @@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT manifest_op.cc cifar_op.cc celeba_op.cc + text_file_op.cc ) add_dependencies(engine-datasetops-source mindspore::protobuf) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index fbb772af59..72dee6f2e6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { return Status::OK(); } -Status MindRecordOp::CountTotalRows(const std::string dataset_path, int64_t *count) { +Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, + int64_t *count) { std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, count); + MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); if (rc == MSRStatus::FAILED) { RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h index aca5c86c2c..899919e529 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h @@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp { int32_t num_rows() const { return num_rows_; } // Getter method - static Status CountTotalRows(const std::string dataset_path, int64_t *count); + static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, + int64_t *count); // Getter method int32_t rows_per_buffer() const { return rows_per_buffer_; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 5d55c8276a..b084e1c125 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(engine-datasetops-source-sampler OBJECT distributed_sampler.cc pk_sampler.cc + python_sampler.cc random_sampler.cc sampler.cc sequential_sampler.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc new file mode 100644 index 0000000000..1747040141 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/source/sampler/python_sampler.h" + +#include + +namespace mindspore { +namespace dataset { + +PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer) + : Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} + +Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { + if (need_to_reset_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + std::shared_ptr sample_ids; + { + py::gil_scoped_acquire gil_acquire; + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py::object py_ret = py_sampler_instance.attr("_get_indices")(); + py::array np_sample_ids = py_ret.cast(); + Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index"); + } + } + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + need_to_reset_ = true; + } + return Status::OK(); +} + +Status PythonSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +Status PythonSampler::Reset() { + CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); + need_to_reset_ = false; + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("reset")(); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h new file mode 100644 index 0000000000..b8734fee6a --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ + +#include +#include + +#include "dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class PythonSampler : public Sampler { + public: + // Constructor + // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PythonSampler(py::object py_sampler_instance, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~PythonSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status Reset() override; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextBuffer(std::unique_ptr *out_buffer) override; + + private: + bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() + + py::object py_sampler_instance; // The handle to the py_sampler python object +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 3c3f5f48e8..9fe752448a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { std::unique_ptr db; std::shared_ptr sample_ids; - // check samples_per_buffer is properly set and doesn't overflow - CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); - // A call to derived class to get sample ids wrapped inside a buffer RETURN_IF_NOT_OK(GetNextBuffer(&db)); // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index a3c4fe2256..6ed06b527f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) } Status SequentialSampler::InitSampler() { + num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; return Status::OK(); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc index 862edcf63a..7f081af2b7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc @@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const { std::ifstream in(schemaFile); nlohmann::json js; in >> js; - num_rows = js.value("numRows", 0); + if (js.find("numRows") == js.end()) { + num_rows = MAX_INTEGER_INT32; + } else { + num_rows = js.value("numRows", 0); + } if (num_rows == 0) { std::string err_msg = "Storage client has not properly done dataset " diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc new file mode 100644 index 0000000000..2b62616366 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -0,0 +1,459 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/core/config_manager.h" +#include "dataset/util/task_manager.h" +#include "dataset/util/wait_post.h" +#include "dataset/util/random.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +TextFileOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status TextFileOp::Builder::ValidateInputs() const { + std::string err_msg; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : ""; + err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +Status TextFileOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { + builder_num_workers_ = builder_text_files_list_.size(); + MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + + std::shared_ptr text_file_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, + std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, + builder_num_devices_, builder_device_id_); + RETURN_IF_NOT_OK(text_file_op->Init()); + *op = std::move(text_file_op); + + return Status::OK(); +} + +TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + std::unique_ptr schema, std::vector text_files_list, + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + device_id_(device_id), + num_devices_(num_device), + rows_per_buffer_(rows_per_buffer), + num_samples_(num_samples), + text_files_list_(std::move(text_files_list)), + shuffle_files_(shuffle_files), + data_schema_(std::move(schema)), + all_num_rows_(0), + num_rows_per_shard_(0), + filename_index_(std::make_unique()), + finished_reading_dataset_(false), + load_io_block_queue_(true), + load_jagged_connector_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status TextFileOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(text_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + col_name_map_[data_schema_->column(i).name()] = i; + } + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + return Status::OK(); +} + +Status TextFileOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(), + TensorShape(std::vector(1, line.size())), data_schema_->column(0).type(), + const_cast(reinterpret_cast(common::SafeCStr(line))))); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + cur_buffer->set_column_name_map(col_name_map_); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + cur_buffer->set_column_name_map(col_name_map_); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + + return Status::OK(); +} + +Status TextFileOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// Pops an element from a queue in io_block_queues +Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status TextFileOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +Status TextFileOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + auto file_it = filename_index_->Search(*it); + file_index.emplace_back(std::pair(file_it.value(), *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +Status TextFileOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +Status TextFileOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); + + // Read data from disk into buffers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + + io_block_queue_wait_post_.Register(tree_->AllTasks()); + NotifyToFillIOBlockQueue(); + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == 0 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + + return Status::OK(); +} + +int64_t TextFileOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + count++; + } + + return count; +} + +Status TextFileOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED("Number of rows can not be zero"); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +Status TextFileOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h new file mode 100644 index 0000000000..49f224ffc3 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h @@ -0,0 +1,263 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "dataset/util/status.h" +#include "dataset/util/auto_index.h" +#include "dataset/engine/data_schema.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/util/queue.h" +#include "dataset/util/wait_post.h" +#include "dataset/engine/jagged_connector.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; + +class TextFileOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetTextFilesList(const std::vector &files_list) { + builder_text_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_text_files_list_; + bool builder_shuffle_files_; + std::unique_ptr builder_schema_; + }; + + // Constructor of TextFileOp + // @note The builder class should be used to call this constructor. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param rows_per_buffer - number of rows that a full buffer will contain. + // @param total_num_rows - number of rows to read + // @param dataset_files_list - list of filepaths for the dataset files. + // @param data_schema - the data schema object. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param columns_to_load - the names of the columns to load data from. + // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param equal_rows_per_shard - whether or not to get equal rows for each process. + TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~TextFileOp() = default; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all text files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a text file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - text file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + int32_t device_id_; + int32_t num_devices_; + int64_t rows_per_buffer_; + int64_t num_samples_; + std::vector text_files_list_; + bool shuffle_files_; + std::unique_ptr data_schema_; + int64_t all_num_rows_; + int64_t num_rows_per_shard_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + QueueList> io_block_queues_; + WaitPost io_block_queue_wait_post_; + bool finished_reading_dataset_; + bool load_io_block_queue_; + bool load_jagged_connector_; + std::unordered_map col_name_map_; + std::unique_ptr jagged_buffer_connector_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 0764d7e0ad..6132f628d7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -42,6 +42,7 @@ #include "dataset/util/status.h" #include "dataset/util/task_manager.h" #include "dataset/util/wait_post.h" +#include "utils/system/crc32c.h" namespace mindspore { namespace dataset { @@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder() builder_data_schema_ = std::make_unique(); } +bool ValidateFirstRowCrc(const std::string &filename) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + return false; + } + + // read data + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // read crc from file + uint32_t masked_crc = 0; + (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); + + // generate crc from data + uint32_t generated_crc = + system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); + + return masked_crc == generated_crc; +} + Status TFReaderOp::Builder::ValidateInputs() const { std::string err_msg; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : ""; - if (!builder_equal_rows_per_shard_) { - err_msg += builder_dataset_files_list_.size() < static_cast(builder_num_devices_) - ? "No enough tf_file files provided\n" - : ""; + + if (builder_num_workers_ <= 0) { + err_msg += "Number of parallel workers is smaller or equal to 0\n"; + } + + if (!builder_equal_rows_per_shard_ && + builder_dataset_files_list_.size() < static_cast(builder_num_devices_)) { + err_msg += "Not enough tfrecord files provided\n"; + } + + if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { + err_msg += "Wrong sharding configs\n"; + } + + std::vector invalid_files(builder_dataset_files_list_.size()); + auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), + [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); + invalid_files.resize(std::distance(invalid_files.begin(), it)); + + if (!invalid_files.empty()) { + err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; + + std::string accumulated_filenames = std::accumulate( + invalid_files.begin(), invalid_files.end(), std::string(""), + [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); + err_msg += accumulated_filenames; } - err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); } @@ -119,6 +163,9 @@ Status TFReaderOp::Init() { if (total_rows_ == 0) { total_rows_ = data_schema_->num_rows(); } + if (total_rows_ < 0) { + RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); + } // Build the index with our files such that each file corresponds to a key id. RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); @@ -523,6 +570,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); rows_read++; } + // ignore crc footer (void)reader.ignore(static_cast(sizeof(int32_t))); rows_total++; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc index d9625b6c26..5d7df58153 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -67,7 +67,7 @@ Status TakeOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); if (take_count_ == max_takes_) { if (state_ == OpState::kDeOpRunning) { - MS_LOG(INFO) << "meet max count and push-back eoe buffer."; + MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer."; auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); *p_buffer = std::move(eoe_buffer); state_ = OpState::kDeOpIdle; @@ -80,11 +80,13 @@ Status TakeOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); } } - } else { - MS_LOG(INFO) << "meet max count and push-back eof buffer."; + } else if (state_ == OpState::kDeOpIdle) { + MS_LOG(DEBUG) << "Meet max count and push-back eof buffer."; auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); *p_buffer = std::move(eof_buffer); take_count_ = 0; + } else { + MS_LOG(WARNING) << "Invalid OpState: " << state_; } return Status::OK(); } @@ -116,7 +118,7 @@ Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptr new_tensor_table = std::make_unique(); while (take_count_ < max_takes_) { TensorRow new_row; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h index f14ecba733..04d8ab0121 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h @@ -34,7 +34,7 @@ class DataBuffer; class ZipOp : public PipelineOp { public: - // The nested builder class inside of the BatchOp is used to help manage all of + // The nested builder class inside of the ZipOp is used to help manage all of // the arguments for constructing it. Use the builder by setting each argument // with the provided set methods, and then finally call the build method to execute // the actual construction. @@ -76,8 +76,8 @@ class ZipOp : public PipelineOp { }; // Constructor for ZipOp - // @param rows_per_buffer number of rows in output buffer - // @param op_connector_size connector + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); // Destructor @@ -88,8 +88,8 @@ class ZipOp : public PipelineOp { Status EoeReceived(int32_t) override; // Print function for Zip - // @param out output stream to print to - // @param show_all if it should print everything + // @param out - output stream to print to + // @param show_all - if it should print everything void Print(std::ostream &out, bool show_all) const override; // Provide stream operator for displaying it @@ -113,14 +113,14 @@ class ZipOp : public PipelineOp { Status fillBuffer(TensorQTable *const table); // Special handle case where an empty row has been received from child iterator - // @note we need to drain eoe signals from all children connectors. - // @details when this function is called, then we encountered eoe at child iterator + // @note - we need to drain eoe signals from all children connectors. + // @details - when this function is called, then we encountered eoe at child iterator // we have to drain rows from other child iterators until we hit eoe from all other child iterators Status drainPipeline(); // Merges 1 row from each childIterator together - // @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty - // @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true + // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty + // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true // @details merge rows from iterator together. This is the main functionality for ZipOp // this function takes one row and fills it with tensors from rows fetched // from childIterators. diff --git a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt index 23a26d5214..43b68d8933 100644 --- a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt @@ -3,7 +3,6 @@ if (WIN32) center_crop_op.cc cut_out_op.cc decode_op.cc - distort_bounding_box_crop_op.cc hwc_to_chw_op.cc image_utils.cc normalize_op.cc @@ -19,6 +18,7 @@ if (WIN32) rescale_op.cc resize_bilinear_op.cc resize_op.cc + uniform_aug_op.cc ) else() add_library(kernels-image OBJECT @@ -26,7 +26,6 @@ else() change_mode_op.cc cut_out_op.cc decode_op.cc - distort_bounding_box_crop_op.cc hwc_to_chw_op.cc image_utils.cc normalize_op.cc @@ -42,5 +41,6 @@ else() rescale_op.cc resize_bilinear_op.cc resize_op.cc + uniform_aug_op.cc ) -endif() +endif() \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc index 9327d785fc..74d9df5d6b 100644 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc @@ -33,7 +33,8 @@ const uint8_t CutOutOp::kDefFillB = 0; // constructor CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) - : box_height_(box_height), + : rnd_(GetSeed()), + box_height_(box_height), box_width_(box_width), num_patches_(num_patches), random_color_(random_color), @@ -46,8 +47,8 @@ Status CutOutOp::Compute(const std::shared_ptr &input, std::shared_ptr inputCV = CVTensor::AsCVTensor(input); // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black - RETURN_IF_NOT_OK( - Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, fill_r_, fill_g_, fill_b_)); + RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, + fill_g_, fill_b_)); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h index 9a76572a54..2198f23e44 100644 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h @@ -62,6 +62,7 @@ class CutOutOp : public TensorOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; private: + std::mt19937 rnd_; int32_t box_height_; int32_t box_width_; int32_t num_patches_; diff --git a/mindspore/ccsrc/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/dataset/kernels/image/decode_op.h index 50d2d3cb68..6e7180958a 100644 --- a/mindspore/ccsrc/dataset/kernels/image/decode_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/decode_op.h @@ -34,11 +34,11 @@ class DecodeOp : public TensorOp { ~DecodeOp() = default; - Status Compute(const std::shared_ptr& input, std::shared_ptr* output) override; + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - void Print(std::ostream& out) const override { out << "DecodeOp"; } - Status OutputShape(const std::vector& inputs, std::vector& outputs) override; - Status OutputType(const std::vector& inputs, std::vector& outputs) override; + void Print(std::ostream &out) const override { out << "DecodeOp"; } + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; private: bool is_rgb_format_ = true; diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc deleted file mode 100644 index e7a8cc3496..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/distort_bounding_box_crop_op.h" -#include -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t DistortBoundingBoxCropOp::kDefMaxAttempts = 100; -const int32_t DistortBoundingBoxCropOp::kDefBoxGenAttempts = 10; - -DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, - float crop_ratio_ub, int32_t max_attempts, int32_t box_gen_attempts) - : max_attempts_(max_attempts), - box_gen_attempts_(box_gen_attempts), - aspect_ratio_(aspect_ratio), - intersect_ratio_(intersect_ratio), - crop_ratio_lb_(crop_ratio_lb), - crop_ratio_ub_(crop_ratio_ub) { - seed_ = GetSeed(); - rnd_.seed(seed_); -} - -Status DistortBoundingBoxCropOp::Compute(const std::vector>& input, - std::vector>* output) { - IO_CHECK_VECTOR(input, output); - if (input.size() != NumInput()) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); - - CHECK_FAIL_RETURN_UNEXPECTED(input[1]->shape().Size() >= 1, "The shape of the second tensor is abnormal"); - int64_t num_boxes = 0; - for (uint64_t i = 1; i < input.size(); i++) { - if (i == 1) num_boxes = input[i]->shape()[0]; - if (num_boxes != input[i]->shape()[0]) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Numbers of boxes do not match"); - - if (input[i]->type() != DataType::DE_FLOAT32) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "boxes' type is not DE_FLOAT21"); - } - - // assume input Tensor vector in the order of [img, bbox_y_min, bbox_y_max, bbox_x_min, bbox_x_max] - CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of the first tensor is abnormal"); - int h_in = input[0]->shape()[0]; - int w_in = input[0]->shape()[1]; - - std::vector bounding_boxes; - for (int64_t i = 0; i < num_boxes; ++i) { - // bbox coordinates are floats relative to the image width and height - float y_min, y_max, x_min, x_max; - RETURN_IF_NOT_OK(input[1]->GetItemAt(&y_min, {i})); - RETURN_IF_NOT_OK(input[2]->GetItemAt(&y_max, {i})); - RETURN_IF_NOT_OK(input[3]->GetItemAt(&x_min, {i})); - RETURN_IF_NOT_OK(input[4]->GetItemAt(&x_max, {i})); - bounding_boxes.emplace_back(static_cast(x_min * w_in), static_cast(y_min * h_in), - static_cast((x_max - x_min) * w_in), static_cast((y_max - y_min) * h_in)); - } - cv::Rect output_box; - bool should_crop = false; - - // go over iterations, if no satisfying box found we return the original image - for (int32_t t = 0; t < max_attempts_; ++t) { - // try to generate random box - RETURN_IF_NOT_OK(GenerateRandomCropBox(h_in, w_in, aspect_ratio_, crop_ratio_lb_, crop_ratio_ub_, - box_gen_attempts_, // int maxIter, should not be needed here - &output_box, seed_)); - RETURN_IF_NOT_OK(CheckOverlapConstraint(output_box, - bounding_boxes, // have to change, should take tensor or add bbox logic - intersect_ratio_, &should_crop)); - if (should_crop) { - // found a box to crop - break; - } - } - // essentially we have to check this again at the end to return original tensor - if (should_crop) { - std::shared_ptr out; - RETURN_IF_NOT_OK(Crop(input[0], &out, output_box.x, output_box.y, output_box.width, output_box.height)); - output->push_back(out); - } else { - output->push_back(input[0]); - } - return Status::OK(); -} - -Status DistortBoundingBoxCropOp::OutputShape(const std::vector& inputs, - std::vector& outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{-1, -1}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -Status DistortBoundingBoxCropOp::OutputType(const std::vector& inputs, std::vector& outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = inputs[0]; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h deleted file mode 100644 index 6d5dca99fb..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ - -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DistortBoundingBoxCropOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefMaxAttempts; - static const int32_t kDefBoxGenAttempts; - - // Constructor for DistortBoundingBoxCropOp - // @param max_attempts tries before the crop happens - // @param box_gen_attempts crop box generation attempts - // @param aspect_ratio aspect ratio of the generated crop box - // @param intersect_ratio area overlap ratio, condition for crop only if area over lap between the generated - // crop box has sufficient overlap with any 1 bounding box - // @param crop_ratio_lb the crop ratio lower bound - // @param crop_ratio_ub the crop ratio upper bound - // @param seed - DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, float crop_ratio_ub, - int32_t max_attempts = kDefMaxAttempts, int32_t box_gen_attempts = kDefBoxGenAttempts); - - ~DistortBoundingBoxCropOp() override = default; - - void Print(std::ostream& out) const override { - out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; - } - - Status Compute(const std::vector>& input, - std::vector>* output) override; - - uint32_t NumInput() override { return 5; } - Status OutputShape(const std::vector& inputs, std::vector& outputs) override; - Status OutputType(const std::vector& inputs, std::vector& outputs) override; - - private: - int32_t max_attempts_; - int32_t box_gen_attempts_; - float aspect_ratio_; - float intersect_ratio_; - float crop_ratio_lb_; - float crop_ratio_ub_; - std::mt19937 rnd_; - uint32_t seed_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/dataset/kernels/image/image_utils.cc index 63c9bb2641..e4570b876d 100644 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/dataset/kernels/image/image_utils.cc @@ -636,76 +636,10 @@ Status AdjustHue(const std::shared_ptr &input, std::shared_ptr * return Status::OK(); } -Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, - cv::Rect *crop_box, uint32_t seed) { - try { - std::mt19937 rnd; - rnd.seed(GetSeed()); - if (input_height <= 0 || input_width <= 0 || ratio <= 0.0 || lb <= 0.0 || lb > ub) { - RETURN_STATUS_UNEXPECTED("Invalid inputs GenerateRandomCropBox"); - } - std::uniform_real_distribution rd_crop_ratio(lb, ub); - float crop_ratio; - int crop_width, crop_height; - bool crop_success = false; - int64_t input_area = input_height * input_width; - for (auto i = 0; i < max_itr; i++) { - crop_ratio = rd_crop_ratio(rnd); - crop_width = static_cast(std::round(std::sqrt(input_area * static_cast(crop_ratio) / ratio))); - crop_height = static_cast(std::round(crop_width * ratio)); - if (crop_width <= input_width && crop_height <= input_height) { - crop_success = true; - break; - } - } - if (crop_success == false) { - ratio = static_cast(input_height) / input_width; - crop_ratio = rd_crop_ratio(rnd); - crop_width = static_cast(std::lround(std::sqrt(input_area * static_cast(crop_ratio) / ratio))); - crop_height = static_cast(std::lround(crop_width * ratio)); - crop_height = (crop_height > input_height) ? input_height : crop_height; - crop_width = (crop_width > input_width) ? input_width : crop_width; - } - std::uniform_int_distribution<> rd_x(0, input_width - crop_width); - std::uniform_int_distribution<> rd_y(0, input_height - crop_height); - *crop_box = cv::Rect(rd_x(rnd), rd_y(rnd), crop_width, crop_height); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("error in GenerateRandomCropBox."); - } -} - -Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector &bounding_boxes, - float min_intersect_ratio, bool *is_satisfied) { - try { - // not satisfied if the crop box contains no pixel - if (crop_box.area() < 1.0) { - *is_satisfied = false; - } - for (const auto &b_box : bounding_boxes) { - const float b_box_area = b_box.area(); - // not satisfied if the bounding box contains no pixel - if (b_box_area < 1.0) { - continue; - } - const float intersect_ratio = (crop_box & b_box).area() / b_box_area; - if (intersect_ratio >= min_intersect_ratio) { - *is_satisfied = true; - break; - } - } - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("error in CheckOverlapConstraint."); - } -} - Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r, uint8_t fill_g, - uint8_t fill_b) { + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, + uint8_t fill_g, uint8_t fill_b) { try { - std::mt19937 rnd; - rnd.seed(GetSeed()); std::shared_ptr input_cv = CVTensor::AsCVTensor(input); if (input_cv->mat().data == nullptr || (input_cv->Rank() != 3 && input_cv->shape()[2] != 3)) { RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); @@ -731,8 +665,8 @@ Status Erase(const std::shared_ptr &input, std::shared_ptr *outp // rows in cv mat refers to the height of the cropped box // we determine h_start and w_start using two different distributions as erasing is used by two different // image augmentations. The bounds are also different in each case. - int32_t h_start = (bounded) ? height_distribution_bound(rnd) : (height_distribution_unbound(rnd) - box_height); - int32_t w_start = (bounded) ? width_distribution_bound(rnd) : (width_distribution_unbound(rnd) - box_width); + int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); + int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; @@ -744,9 +678,9 @@ Status Erase(const std::shared_ptr &input, std::shared_ptr *outp for (int x = h_start; x < max_height; x++) { if (random_color) { // fill each box with a random value - input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(rnd)); - input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(rnd)); - input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(rnd)); + input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(*rnd)); } else { input_img.at(cv::Point(y, x))[0] = fill_r; input_img.at(cv::Point(y, x))[1] = fill_g; diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/dataset/kernels/image/image_utils.h index 51090fb9ea..394323974a 100644 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/dataset/kernels/image/image_utils.h @@ -196,12 +196,6 @@ Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &hue); -Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, - cv::Rect *crop_box, uint32_t seed = std::mt19937::default_seed); - -Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector &bounding_boxes, - float min_intersect_ratio, bool *is_satisfied); - // Masks out a random section from the image with set dimension // @param input: input Tensor // @param output: cutOut Tensor @@ -214,8 +208,8 @@ Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r = 0, - uint8_t fill_g = 0, uint8_t fill_b = 0); + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); // Pads the input image and puts the padded image in the output // @param input: input Tensor diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc index 3cf6065659..a3cf8cefb5 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc @@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ rnd_.seed(GetSeed()); } -Status RandomCropAndResizeOp::Compute(const std::shared_ptr& input, std::shared_ptr* output) { +Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); @@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr& input, std: (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); } -Status RandomCropAndResizeOp::OutputShape(const std::vector& inputs, std::vector& outputs) { +Status RandomCropAndResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); outputs.clear(); TensorShape out = TensorShape{target_height_, target_width_}; @@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector& inputs if (!outputs.empty()) return Status::OK(); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); } -Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { +Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { double scale, aspect; *crop_width = w_in; *crop_height = h_in; diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc new file mode 100644 index 0000000000..5725c10908 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "dataset/kernels/image/uniform_aug_op.h" +#include "dataset/kernels/py_func_op.h" +#include "dataset/util/random.h" + +namespace mindspore { +namespace dataset { +const int UniformAugOp::kDefNumOps = 2; + +UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) { + std::shared_ptr tensor_op; + // iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ + for (auto op : op_list) { + if (py::isinstance(op)) { + // python op + tensor_op = std::make_shared(op.cast()); + } else if (py::isinstance(op)) { + // C++ op + tensor_op = op.cast>(); + } + tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); + } + + rnd_.seed(GetSeed()); +} +// compute method to apply uniformly random selected augmentations from a list +Status UniformAugOp::Compute(const std::vector> &input, + std::vector> *output) { + IO_CHECK_VECTOR(input, output); + + // variables to generate random number to select ops from the list + std::vector random_indexes; + + // variables to copy the result to output if it is not already + std::vector> even_out; + std::vector> *even_out_ptr = &even_out; + int count = 1; + + // select random indexes for candidates to be applied + for (int i = 0; i < num_ops_; ++i) { + random_indexes.insert(random_indexes.end(), + std::uniform_int_distribution(0, tensor_op_list_.size() - 1)(rnd_)); + } + + for (auto it = random_indexes.begin(); it != random_indexes.end(); ++it) { + // Do NOT apply the op, if second random generator returned zero + if (std::uniform_int_distribution(0, 1)(rnd_)) { + continue; + } + std::shared_ptr tensor_op = tensor_op_list_[*it]; + + // apply python/C++ op + if (count == 1) { + (*tensor_op).Compute(input, output); + } else if (count % 2 == 0) { + (*tensor_op).Compute(*output, even_out_ptr); + } else { + (*tensor_op).Compute(even_out, output); + } + count++; + } + + // copy the result to output if it is not in output + if (count == 1) { + *output = input; + } else if ((count % 2 == 1)) { + (*output).swap(even_out); + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h new file mode 100644 index 0000000000..336bc8c859 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ +#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ + +#include +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/status.h" +#include "dataset/kernels/py_func_op.h" + +#include "pybind11/stl.h" + +namespace mindspore { +namespace dataset { +class UniformAugOp : public TensorOp { + public: + // Default number of Operations to be applied + static const int kDefNumOps; + + // Constructor for UniformAugOp + // @param list op_list: list of candidate python operations + // @param list num_ops: number of augemtation operations to applied + UniformAugOp(py::list op_list, int32_t num_ops); + + ~UniformAugOp() override = default; + + void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } + + // Overrides the base class compute function + // @return Status - The error code return + Status Compute(const std::vector> &input, + std::vector> *output) override; + + private: + int32_t num_ops_; + std::vector> tensor_op_list_; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ diff --git a/mindspore/ccsrc/dataset/util/random.cc b/mindspore/ccsrc/dataset/util/random.cc index 2a0762c920..43b3ee4afd 100644 --- a/mindspore/ccsrc/dataset/util/random.cc +++ b/mindspore/ccsrc/dataset/util/random.cc @@ -18,6 +18,9 @@ #include "dataset/util/random.h" +#if defined(_WIN32) || defined(_WIn64) +#include +#endif #include #include #include @@ -33,7 +36,9 @@ uint32_t GetSeed() { uint32_t seed = GlobalContext::config_manager()->seed(); if (seed == std::mt19937::default_seed) { #if defined(_WIN32) || defined(_WIN64) - std::random_device random_device; + unsigned int number; + rand_s(&number); + std::mt19937 random_device{static_cast(number)}; #else std::random_device random_device("/dev/urandom"); #endif diff --git a/mindspore/ccsrc/dataset/util/services.cc b/mindspore/ccsrc/dataset/util/services.cc index ea7b11014c..a2b3f734c2 100644 --- a/mindspore/ccsrc/dataset/util/services.cc +++ b/mindspore/ccsrc/dataset/util/services.cc @@ -18,6 +18,8 @@ #include #if !defined(_WIN32) && !defined(_WIN64) #include +#else +#include #endif #include #include @@ -49,7 +51,9 @@ int Services::GetLWP() { return syscall(SYS_gettid); } std::string Services::GetUniqueID() { const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; #if defined(_WIN32) || defined(_WIN64) - std::mt19937 gen{std::random_device{}()}; + unsigned int number; + rand_s(&number); + std::mt19937 gen{static_cast(number)}; #else std::mt19937 gen{std::random_device{"/dev/urandom"}()}; #endif diff --git a/mindspore/ccsrc/debug/anf_ir_dump.h b/mindspore/ccsrc/debug/anf_ir_dump.h index 5c4bc5eacd..a53888348d 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.h +++ b/mindspore/ccsrc/debug/anf_ir_dump.h @@ -22,7 +22,7 @@ namespace mindspore { constexpr char PARALLEL_STRATEGY[] = "strategy"; -void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); +void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); } // namespace mindspore diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 8e626d6f9a..6ebe3ad43f 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; // get MindSpore Intermediate Representation Path std::string GetMsIrPath(void) { std::string path; - const char* path_ptr = getenv("MS_IR_PATH"); + const char *path_ptr = getenv("MS_IR_PATH"); if (path_ptr != nullptr) { path = path_ptr; char real_path[PATH_MAX] = {0}; @@ -62,13 +62,13 @@ std::string GetMsIrPath(void) { return path; } -std::string dump_obj(const py::object& obj, const std::string& path) { +std::string dump_obj(const py::object &obj, const std::string &path) { py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); return py::str(name); } -py::object load_obj(const std::string& path) { +py::object load_obj(const std::string &path) { py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); return obj; @@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) { // ============================================= MindSpore IR Exporter ============================================= -std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { +std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) { abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast(nd->Shape()); TypePtr type = dyn_cast(nd->Type()); std::ostringstream oss; @@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { return oss.str(); } -std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const { +std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const { std::string pkl_path = GetMsIrPath(); // if not specified env 'MS_IR_PATH', do not create any files if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { @@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca return file_prefix + file_name; } -int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) { +int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { if (func_graph == nullptr || param == nullptr) { return -1; } @@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& // try to find index of parameter for SymbolicKeyInstance from all exported graphs // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different -int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { +int AnfExporter::GetParamIndexFromExported(const AnfNodePtr ¶m) { if (param == nullptr) { return -1; } int ret = -1; - for (const auto& item : exported) { + for (const auto &item : exported) { auto pram_iter = item.second.find(param); if (pram_iter != item.second.end()) { return pram_iter->second; @@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { return ret; } -std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) { +std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(node); return GetValueText(fg, node->value()); } -std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) { +std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) { auto py_funcs = mt_func_graph->GetPyFunctions(); if (py_funcs.empty()) { return ""; @@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap oss << "{"; bool is_first = true; - for (const auto& py_func : py_funcs) { + for (const auto &py_func : py_funcs) { if (is_first) { is_first = false; } else { @@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap * ├── GradOperation * └── TupleAdd */ -std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) { +std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) { if (meta_func_graph == nullptr) { return ""; } @@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_ return oss.str(); } -std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { +std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) { std::ostringstream oss; if (prim == nullptr) { return oss.str(); @@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { if (prim->isa()) { auto do_signature = dyn_cast(prim); - auto& func = do_signature->function(); + auto &func = do_signature->function(); if (func->isa()) { auto sig_prim = dyn_cast(func); oss << sig_prim->GetAttrsText(); @@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { return oss.str(); } -std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { +std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) { std::ostringstream oss; if (ns == nullptr) { return oss.str(); @@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { return oss.str(); } -std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, - const SymbolicKeyInstancePtr& sym_inst) { +std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, + const SymbolicKeyInstancePtr &sym_inst) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(sym_inst); AnfNodePtr sym_node = sym_inst->node(); @@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra return oss.str(); } -std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; // output ValueList, ValueTuple ValueSequeuePtr seq = dyn_cast(value); @@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V return oss.str(); } -std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; ValueDictionaryPtr dict = value->cast(); oss << "{"; bool first_flag = true; - for (const auto& elem : dict->value()) { + for (const auto &elem : dict->value()) { if (first_flag) { first_flag = false; } else { @@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value return oss.str(); } -std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { +std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) { std::ostringstream oss; if (check_integrity_) { @@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& return oss.str(); } -std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; bool is_null_ptr = (func_graph == nullptr || value == nullptr); if (is_null_ptr) { @@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu } // this function is used to output node in CNode's inputs -std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map) { +std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map) { std::ostringstream oss; if (func_graph == nullptr || node == nullptr) { return oss.str(); @@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An return oss.str(); } -void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector& parameters, - OrderedMap* param_map) { +void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector ¶meters, + OrderedMap *param_map) { bool first_flag = true; - for (const AnfNodePtr& param : parameters) { + for (const AnfNodePtr ¶m : parameters) { if (first_flag) { first_flag = false; ofs << " "; @@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vectorinputs(); + auto &inputs = node->inputs(); if (inputs.size() > 1) { ofs << " #("; for (size_t i = 1; i < inputs.size(); ++i) { @@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod ofs << " #scope: " << node->scope()->name(); } -void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector& nodes, - const FuncGraphPtr& func_graph) { +void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector &nodes, + const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } int idx = 1; std::map apply_map; - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector } auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); // non-return node if (node != func_graph->get_return()) { @@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector } } -void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) { +void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun ofs << "}\n"; } -void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { +void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt ofs.close(); } -void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector& graphs) { +void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector &graphs) { if (graphs.empty()) { return; } @@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector param_index = 1; - for (const auto& tagged_graph : graphs) { + for (const auto &tagged_graph : graphs) { tagged_cnodes_ = tagged_graph.second; ExportOneFuncGraph(ofs, tagged_graph.first); tagged_cnodes_.clear(); @@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector } #ifdef ENABLE_DUMP_IR -void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) { +void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap ChangeFileMode(filename, S_IRUSR); } -void ExportIR(const std::string& filename, const std::vector& graphs) { +void ExportIR(const std::string &filename, const std::vector &graphs) { AnfExporter exporter("", false); ChangeFileMode(filename, S_IRWXU); exporter.ExportFuncGraph(filename, graphs); @@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector& graph ChangeFileMode(filename, S_IRUSR); } #else -void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { +void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { << "please recompile source to enable it. See help of building script."; } -void ExportIR(const std::string& filename, const std::vector& graphs) { +void ExportIR(const std::string &filename, const std::vector &graphs) { static bool already_printed = false; if (already_printed) { return; @@ -732,7 +732,7 @@ enum Token : int { TOK_ERROR // file read error }; -std::map token_text = { +std::map token_text = { {TOK_INVALID, "invalid"}, // invalid token {TOK_LPARENTHESIS, "("}, // ( left parenthesis {TOK_RPARENTHESIS, ")"}, // ) right parenthesis @@ -761,14 +761,14 @@ std::map token_text = { class Lexer { public: // filename is checked in ImportIR; - explicit Lexer(const char* filename) : fin(filename) {} + explicit Lexer(const char *filename) : fin(filename) {} ~Lexer() { try { if (fin.is_open()) { fin.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file"; } catch (...) { std::string exName(abi::__cxa_current_exception_type()->name()); @@ -776,7 +776,7 @@ class Lexer { } } - bool IsSingleCharToken(char ch, Token* token_ptr) { + bool IsSingleCharToken(char ch, Token *token_ptr) { // clang-format off std::unordered_map char_to_token = { {'(', TOK_LPARENTHESIS}, @@ -806,7 +806,7 @@ class Lexer { Token GetNextToken() { #ifdef DEBUG Token token = GetNextTokenInner(); - const char* str = token_text[token]; + const char *str = token_text[token]; std::string text = (str == nullptr ? GetTokenText() : str); MS_LOG(DEBUG) << "------Parse token] " << text; return token; @@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE; class IrParser { public: - explicit IrParser(const char* filename) : lexer_(filename) {} + explicit IrParser(const char *filename) : lexer_(filename) {} ~IrParser() {} - py::object LoadObject(const std::string& file_name) const { + py::object LoadObject(const std::string &file_name) const { std::string pkl_path = GetMsIrPath(); py::object default_obj = load_obj(pkl_path + "/" + file_name); return default_obj; @@ -1087,7 +1087,7 @@ class IrParser { MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); } - Token ParseParent(FuncGraphPtr* const parent_ptr) { + Token ParseParent(FuncGraphPtr *const parent_ptr) { if (lexer_.GetNextToken() != TOK_IDENTIFIER) { return TOK_ERROR; } @@ -1168,7 +1168,7 @@ class IrParser { return func_graph; } - FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) { + FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) { Token tok = lexer_.SkipWhiteToken(); while (tok == TOK_VARIABLE) { if (ParseStatement(func_graph) == nullptr) { @@ -1264,56 +1264,56 @@ class IrParser { return func_graph; } - void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const { + void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const { if (ptr == nullptr) { return; } *ptr = dtype; } - void SetTupleType(TypePtr* ptr) { + void SetTupleType(TypePtr *ptr) { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetTupleType(TypePtr* ptr, const TypePtrList& elems) { + void SetTupleType(TypePtr *ptr, const TypePtrList &elems) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elems); } - void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector&) { + void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector &) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem_type); } - void SetListType(TypePtr* ptr) { + void SetListType(TypePtr *ptr) { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetListType(TypePtr* ptr, const TypePtrList& elems) { + void SetListType(TypePtr *ptr, const TypePtrList &elems) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elems); } - void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) { + void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem); } - void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const { + void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const { if (ptr == nullptr) { return; } @@ -1321,45 +1321,45 @@ class IrParser { } // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} - void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const { + void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {} - void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {} + void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {} + void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {} - void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { + void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { if (ptr == nullptr) { return; } // if one of elems is nullptr, just return - if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { + if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { return; } *ptr = std::make_shared(elems); } - void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector& shape) { + void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector &shape) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem_type, shape); } - void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { + void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { if (ptr == nullptr) { return; } - if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { + if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { return; } *ptr = std::make_shared(elems); } - void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) { + void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) { if (ptr == nullptr) { return; } @@ -1367,7 +1367,7 @@ class IrParser { } template - Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) { + Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) { if (tok != TOK_LBRACKET) { MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; return tok; @@ -1415,7 +1415,7 @@ class IrParser { } template - Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { + Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { if (tok != TOK_LPARENTHESIS) { if (ptr != nullptr) { SetBasicType(ptr, std::make_shared()); @@ -1454,7 +1454,7 @@ class IrParser { return lexer_.GetNextToken(); } - bool IsNumberType(const std::string& type, TypeId* typeid_ptr) { + bool IsNumberType(const std::string &type, TypeId *typeid_ptr) { // clang-format off static std::unordered_map basic_types = { {"Bool", kNumberTypeBool}, @@ -1486,7 +1486,7 @@ class IrParser { } template - void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) { + void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) { TypePtr dtype = nullptr; std::unordered_map type_map = { @@ -1519,7 +1519,7 @@ class IrParser { } template - Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) { + Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) { if (type == "NoneType") { SetBasicType(ptr, std::make_shared()); return lexer_.GetNextToken(); @@ -1541,7 +1541,7 @@ class IrParser { } template - Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { + Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { if (tok != TOK_IDENTIFIER) { return TOK_ERROR; } @@ -1588,11 +1588,11 @@ class IrParser { } } - Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) { + Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) { return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); } - Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { + Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { Token tok = ParseAttribute(func_graph, prim); while (tok == TOK_COMMA) { tok = ParseAttribute(func_graph, prim); @@ -1603,7 +1603,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { + Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { Token tok = lexer_.GetNextToken(); if (tok != TOK_IDENTIFIER) { return TOK_ERROR; @@ -1670,7 +1670,7 @@ class IrParser { return tok == TOK_RPARENTHESIS ? func_graph : nullptr; } - FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector* const inputs_ptr) { + FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector *const inputs_ptr) { Token tok = ParseArgument(func_graph, inputs_ptr); while (tok == TOK_COMMA) { tok = ParseArgument(func_graph, inputs_ptr); @@ -1681,9 +1681,9 @@ class IrParser { return func_graph; } - AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) { + AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string ¶m_name) { while (func_graph != nullptr) { - for (auto& ptr : func_graph->parameters()) { + for (auto &ptr : func_graph->parameters()) { MS_EXCEPTION_IF_NULL(ptr); ParameterPtr param = ptr->cast(); MS_EXCEPTION_IF_NULL(param); @@ -1701,12 +1701,12 @@ class IrParser { return nullptr; } - bool Match(const std::string& str, const std::string& pattern) const { + bool Match(const std::string &str, const std::string &pattern) const { return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; } template - Token ParseScalar(ValuePtr* const val_ptr) { + Token ParseScalar(ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_NUMBER) { return TOK_ERROR; } @@ -1725,7 +1725,7 @@ class IrParser { } template - Token ParseScalar(ValuePtr* const val_ptr, Token tok) { + Token ParseScalar(ValuePtr *const val_ptr, Token tok) { if (tok != TOK_LPARENTHESIS) { *val_ptr = std::make_shared(); return tok; @@ -1735,7 +1735,7 @@ class IrParser { } template - Token ParseScalar(ValuePtr* const val_ptr, Token tok) { + Token ParseScalar(ValuePtr *const val_ptr, Token tok) { if (tok != TOK_LPARENTHESIS) { *val_ptr = std::make_shared(nbits); return tok; @@ -1745,7 +1745,7 @@ class IrParser { } template - T StringToScalar(const std::string& text) { + T StringToScalar(const std::string &text) { std::stringstream ss; T value; ss << text; @@ -1753,7 +1753,7 @@ class IrParser { return value; } - Token ParseTensor(ValuePtr* const val_ptr) { + Token ParseTensor(ValuePtr *const val_ptr) { // parse type TypeId type; if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { @@ -1803,7 +1803,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParsePrimType(Token tok, PrimType* prim_type_ptr) { + Token ParsePrimType(Token tok, PrimType *prim_type_ptr) { if (tok != TOK_LBRACE) { return tok; } @@ -1830,7 +1830,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { + Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { if (tok != TOK_LPARENTHESIS) { return TOK_ERROR; } @@ -1855,7 +1855,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { + Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { if (tok != TOK_LBRACE) { return tok; } @@ -1868,7 +1868,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseBoolValue(const std::string& key, bool* val_ptr) { + Token ParseBoolValue(const std::string &key, bool *val_ptr) { if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { return TOK_ERROR; } @@ -1892,7 +1892,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) { + Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_LBRACE) { return TOK_ERROR; } @@ -1920,7 +1920,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) { + Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) { if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { return TOK_ERROR; } @@ -1951,7 +1951,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) { + Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_AT_FILE) { return TOK_ERROR; } @@ -1984,7 +1984,7 @@ class IrParser { return next; } - Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) { + Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) { if (Match(id, "MultitypeFuncGraph::")) { std::string name = id.substr(strlen("MultitypeFuncGraph::")); auto mt_func_graph = std::make_shared(name); @@ -2024,8 +2024,8 @@ class IrParser { } } - Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr, - AnfNodePtr* const node_ptr = nullptr) { + Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr, + AnfNodePtr *const node_ptr = nullptr) { if (id == "None") { *val_ptr = std::make_shared(); return lexer_.GetNextToken(); @@ -2075,9 +2075,9 @@ class IrParser { } } - Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid, - const std::vector& elems, const std::vector& nodes, - ValuePtr* const val_ptr, AnfNodePtr* node_ptr) { + Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid, + const std::vector &elems, const std::vector &nodes, + ValuePtr *const val_ptr, AnfNodePtr *node_ptr) { if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { if (node_is_valid && node_ptr != nullptr) { MS_EXCEPTION_IF_NULL(func_graph); @@ -2097,8 +2097,8 @@ class IrParser { } } - Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, - AnfNodePtr* node_ptr = nullptr) { + Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, + AnfNodePtr *node_ptr = nullptr) { Token left_tok = tok; std::vector elems; @@ -2138,7 +2138,7 @@ class IrParser { return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); } - Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) { + Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) { // tuple or list if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); @@ -2152,7 +2152,7 @@ class IrParser { return TOK_ERROR; } - Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr, + Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr, Token tok = TOK_INVALID) { if (tok == TOK_INVALID) { tok = lexer_.GetNextToken(); @@ -2193,7 +2193,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseArgument(const FuncGraphPtr& func_graph, std::vector* const inputs_ptr) { + Token ParseArgument(const FuncGraphPtr &func_graph, std::vector *const inputs_ptr) { Token tok = lexer_.GetNextToken(); if (tok == TOK_RPARENTHESIS) { return tok; @@ -2208,7 +2208,7 @@ class IrParser { return tok; } - const std::vector& GetFuncGraphs() const { return func_graphs_; } + const std::vector &GetFuncGraphs() const { return func_graphs_; } private: Lexer lexer_; @@ -2226,14 +2226,14 @@ class IrParser { std::map param_nodes_; // map parameter name to parameter }; -std::vector ImportIR(const std::string& filename) { +std::vector ImportIR(const std::string &filename) { IrParser parser(filename.c_str()); parser.ParseFile(); return parser.GetFuncGraphs(); } #ifdef ENABLE_DUMP_IR -void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { +void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { if (func_graph == nullptr) { MS_LOG(ERROR) << "Func graph is nullptr"; return; @@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { return; } char real_path[PATH_MAX] = {0}; - char* real_path_ret = nullptr; + char *real_path_ret = nullptr; #if defined(_WIN32) || defined(_WIN64) real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); #else @@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { ChangeFileMode(file_path, S_IRUSR); } #else -void DumpIRProto(const FuncGraphPtr&, const std::string&) { +void DumpIRProto(const FuncGraphPtr &, const std::string &) { static bool already_printed = false; if (already_printed) { return; diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 5342c1ab96..6c8601c4af 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -39,7 +39,7 @@ namespace mindspore { struct ParamPtrEqual { - bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { + bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const { const ParameterPtr param1 = dyn_cast(t1); const ParameterPtr param2 = dyn_cast(t2); @@ -52,7 +52,7 @@ struct ParamPtrEqual { }; struct ParamPtrHasher { - std::size_t operator()(AnfNodePtr const& param) const { + std::size_t operator()(AnfNodePtr const ¶m) const { const ParameterPtr parameter = dyn_cast(param); if (parameter == nullptr) { return 0; @@ -64,39 +64,39 @@ struct ParamPtrHasher { class AnfExporter { public: - explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) + explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false) : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { func_graph_set.clear(); exported.clear(); } virtual ~AnfExporter() {} - void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); - void ExportFuncGraph(const std::string& filename, const std::vector& graphs); + void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); + void ExportFuncGraph(const std::string &filename, const std::vector &graphs); protected: - virtual std::string GetNodeType(const AnfNodePtr& nd); - int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); - int GetParamIndexFromExported(const AnfNodePtr& param); - std::string DumpObject(const py::object& obj, const std::string& category) const; - std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); - std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); - std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); - std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetPrimitiveText(const PrimitivePtr& prim); - std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetNameSpaceText(const parse::NameSpacePtr& ns); - std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); - std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map); - void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); - void OutputParameters(std::ofstream& ofs, const std::vector& parameters, - OrderedMap* param_map); - - void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); - void OutputCNodes(std::ofstream& ofs, const std::vector& nodes, const FuncGraphPtr& func_graph); + virtual std::string GetNodeType(const AnfNodePtr &nd); + int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); + int GetParamIndexFromExported(const AnfNodePtr ¶m); + std::string DumpObject(const py::object &obj, const std::string &category) const; + std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); + std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); + std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); + std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetPrimitiveText(const PrimitivePtr &prim); + std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetNameSpaceText(const parse::NameSpacePtr &ns); + std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph); + std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map); + void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); + void OutputParameters(std::ofstream &ofs, const std::vector ¶meters, + OrderedMap *param_map); + + void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); + void OutputCNodes(std::ofstream &ofs, const std::vector &nodes, const FuncGraphPtr &func_graph); int param_index; OrderedSet func_graph_set{}; @@ -108,16 +108,16 @@ class AnfExporter { abstract::AnfNodeConfigPtr node_cfg_ = nullptr; }; -void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); -void ExportIR(const std::string& filename, const std::vector& graphs); +void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph); +void ExportIR(const std::string &filename, const std::vector &graphs); -std::vector ImportIR(const std::string& filename); +std::vector ImportIR(const std::string &filename); -std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); +std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); -void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); +void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); -std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index 3e8cbfba19..d3b92532fa 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -34,7 +34,7 @@ namespace draw { namespace { // Only for ValueNode -std::string ValueType(const ValueNodePtr& node) { +std::string ValueType(const ValueNodePtr &node) { if (node == nullptr) { return ""; } @@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) { return v->type_name(); } -std::string ReplaceSpecialChar(const std::string& str) { +std::string ReplaceSpecialChar(const std::string &str) { std::ostringstream oss; for (size_t i = 0; i < str.size(); i++) { if (str[i] == '<') { @@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) { } // namespace // API of debug utils -void DrawNodes(const std::vector& nodes, OrderedMap>* sub_graphs, +void DrawNodes(const std::vector &nodes, OrderedMap> *sub_graphs, bool is_user) { if (sub_graphs == nullptr) { return; } - for (auto& nd : nodes) { + for (auto &nd : nodes) { MS_EXCEPTION_IF_NULL(nd); auto sub_graph = nd->func_graph(); if (sub_graph != nullptr) { @@ -84,16 +84,16 @@ void DrawNodes(const std::vector& nodes, OrderedMap& nodes, - OrderedMap>* sub_graphs) { +void DrawValueNodes(const std::vector &nodes, + OrderedMap> *sub_graphs) { if (sub_graphs == nullptr) { return; } int dup_idx = 0; - for (auto& nd : nodes) { - for (auto& t : SuccIncoming(nd)) { + for (auto &nd : nodes) { + for (auto &t : SuccIncoming(nd)) { MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(nd); if (t->isa() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { @@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector& nodes, } } -void DrawEdges(const std::vector& nodes, const std::shared_ptr& digraph, bool is_user) { +void DrawEdges(const std::vector &nodes, const std::shared_ptr &digraph, bool is_user) { if (digraph == nullptr) { return; } @@ -120,11 +120,11 @@ void DrawEdges(const std::vector& nodes, const std::shared_ptrisa() || t->isa()) { if ((!is_user) || (i != 0)) { @@ -143,7 +143,7 @@ void DrawEdges(const std::vector& nodes, const std::shared_ptrSubGraph(gsub.first, gsub.second); } @@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use } #ifdef ENABLE_DUMP_IR -void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { +void Draw(const std::string &filename, const FuncGraphPtr &func_graph) { const std::string dot_suffix = ".dot"; std::string filename_with_suffix = (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; DrawByOpt(filename_with_suffix, func_graph, false); } -void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { +void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { DrawByOpt(filename, func_graph, true); } #else -void Draw(const std::string&, const FuncGraphPtr&) { +void Draw(const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) { << "please recompile source to enable it. See help of building script."; } -void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { +void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) { return "plaintext"; } -std::string Graphviz::Color(const AnfNodePtr& node) { +std::string Graphviz::Color(const AnfNodePtr &node) { if (node == nullptr) { return ""; } @@ -259,7 +259,7 @@ void BaseDigraph::Start() { buffer_ << "compound=true" << std::endl; } -void BaseDigraph::Head(const AnfNodePtr& node, int id) { +void BaseDigraph::Head(const AnfNodePtr &node, int id) { if (node == nullptr) { return; } @@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) { } } -void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { +void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) { if (node == nullptr) { return; } @@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { buffer_ << ":" << idx; } -void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { +void BaseDigraph::Tail(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -304,12 +304,12 @@ void BaseDigraph::End() { } } -void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { +void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { buffer_ << "parameters_" << key << "[shape=plaintext "; buffer_ << "label=<"; buffer_ << ""; int count = 0; - for (auto& parameter : key->parameters()) { + for (auto ¶meter : key->parameters()) { buffer_ << "
parameters
"; buffer_ << parameter->ToString(); auto py_p = dyn_cast(parameter)->default_param(); @@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { buffer_ << "
>,];"; } -void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr& gsub) { +void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr &gsub) { if (key == nullptr || gsub == nullptr) { return; } @@ -361,12 +361,12 @@ Digraph::~Digraph() { if (fout_.is_open()) { fout_.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file " << filename_; } } -static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { +static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) { size_t start_pos = 0; while ((start_pos = str.find(from, start_pos)) != std::string::npos) { (void)str.replace(start_pos, from.length(), to); @@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st return str; } -static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { +static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(graph_obj); graph_obj->buffer() << "label=<"; @@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { graph_obj->buffer() << ""; graph_obj->buffer() << "
"; int i = 0; - for (const auto& attr : attrs) { + for (const auto &attr : attrs) { if (i != 0) { graph_obj->buffer() << "
"; } @@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { graph_obj->buffer() << "
>,"; } -static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { +static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr) { return; } @@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { } } -static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { +static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr || node->size() == 0) { return; } @@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { } graph_obj->buffer() << ">"; int i = 0; - for (auto& attr : attrs) { + for (auto &attr : attrs) { if (i != 0) { graph_obj->buffer() << "
"; } @@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() { if (fout_.is_open()) { fout_.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "exception when closing file " << filename_; } } diff --git a/mindspore/ccsrc/debug/draw.h b/mindspore/ccsrc/debug/draw.h index 4781a6c231..7804c6e94a 100644 --- a/mindspore/ccsrc/debug/draw.h +++ b/mindspore/ccsrc/debug/draw.h @@ -31,9 +31,9 @@ namespace parse = mindspore::parse; class Graphviz { public: - Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} + Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {} - explicit Graphviz(const std::string& name) : name_(name) {} + explicit Graphviz(const std::string &name) : name_(name) {} virtual ~Graphviz() {} @@ -41,8 +41,8 @@ class Graphviz { virtual void End() {} virtual std::string Shape(AnfNodePtr node); - std::string Color(const AnfNodePtr& node); - std::ostringstream& buffer() { return buffer_; } + std::string Color(const AnfNodePtr &node); + std::ostringstream &buffer() { return buffer_; } std::ostringstream buffer_; protected: @@ -53,8 +53,8 @@ class Graphviz { class BaseDigraph : public Graphviz { public: - BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} - explicit BaseDigraph(const std::string& name) : Graphviz(name) {} + BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {} + explicit BaseDigraph(const std::string &name) : Graphviz(name) {} ~BaseDigraph() override = default; virtual void Node(AnfNodePtr node, int id = 0) = 0; @@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz { void Start() override; void End() override; virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); - void FuncGraphParameters(const FuncGraphPtr& key); - void SubGraph(const FuncGraphPtr& key, const std::shared_ptr& gsub); + void FuncGraphParameters(const FuncGraphPtr &key); + void SubGraph(const FuncGraphPtr &key, const std::shared_ptr &gsub); - const std::string& name() const { return name_; } + const std::string &name() const { return name_; } protected: - void Head(const AnfNodePtr& node, int id = 0); - void Tail(const AnfNodePtr& node, int idx, int id = 0); - void Tail(const FuncGraphPtr& func_graph); + void Head(const AnfNodePtr &node, int id = 0); + void Tail(const AnfNodePtr &node, int idx, int id = 0); + void Tail(const FuncGraphPtr &func_graph); }; class Digraph : public BaseDigraph { public: - Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} - explicit Digraph(const std::string& name) : BaseDigraph(name) {} + Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} + explicit Digraph(const std::string &name) : BaseDigraph(name) {} ~Digraph() override; void Node(AnfNodePtr node, int id = 0) override; @@ -86,8 +86,8 @@ class Digraph : public BaseDigraph { class ModelDigraph : public BaseDigraph { public: - ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} - explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} + ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} + explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {} ~ModelDigraph() override; std::string Shape(AnfNodePtr node) override; @@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph { }; // API to draw -void Draw(const std::string& filename, const FuncGraphPtr& func_graph); -void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); +void Draw(const std::string &filename, const FuncGraphPtr &func_graph); +void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); } // namespace draw } // namespace mindspore diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index a7a1e208a4..83ab1e4505 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -33,38 +33,38 @@ class ProtoExporter { ProtoExporter() {} ~ProtoExporter() {} - std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); + std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); private: void InitModelInfo(); - void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto); - std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map, - std::map* const_map_ptr); - void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto); - void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto); - void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto); - void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto); - void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto); - void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto); - - void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); - void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); - void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, - std::map* const_map_ptr); - void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* apply_map_ptr, - std::map* const_map_ptr, irpb::GraphProto* graph_proto); - void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, - const std::map& apply_map, std::map* const_map_ptr, - irpb::GraphProto* graph_proto); - void ExportValueNodes(const std::map& const_map, irpb::GraphProto* graph_proto); + void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto); + std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr); + void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto); + void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto); + void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto); + void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto); + void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto); + void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto); + + void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); + void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, + std::map *const_map_ptr); + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *apply_map_ptr, + std::map *const_map_ptr, irpb::GraphProto *graph_proto); + void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, std::map *const_map_ptr, + irpb::GraphProto *graph_proto); + void ExportValueNodes(const std::map &const_map, irpb::GraphProto *graph_proto); static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } irpb::ModelProto model_; }; -static irpb::DataType GetNumberDataType(const TypePtr& type) { +static irpb::DataType GetNumberDataType(const TypePtr &type) { switch (type->type_id()) { case kNumberTypeBool: return irpb::DT_BOOL; @@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) { } } -void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) { +void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { if (type_proto == nullptr) { return; } @@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s type_proto->set_data_type(irpb::DT_TENSOR); if (shape != nullptr && shape->isa()) { abstract::ShapePtr shape_info = dyn_cast(shape); - for (const auto& elem : shape_info->shape()) { + for (const auto &elem : shape_info->shape()) { type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); } } } else if (type->isa()) { TuplePtr tuple_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_TUPLE); - for (const auto& elem_type : tuple_type->elements()) { + for (const auto &elem_type : tuple_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { @@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s } else if (type->isa()) { ListPtr list_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_LIST); - for (const auto& elem_type : list_type->elements()) { + for (const auto &elem_type : list_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { @@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s } } -void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) { +void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) { if (node == nullptr || type_proto == nullptr) { return; } SetNodeOutputType(node->Type(), node->Shape(), type_proto); } -void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const StringImmPtr& value = dyn_cast(val); + const StringImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_STRING); value_proto->set_str_val(value->value()); } else if (val->isa()) { @@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value } else if (val->isa()) { tensor::TensorPtr tensor_ptr = dyn_cast(val); value_proto->set_dtype(irpb::DT_TENSOR); - irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val(); + irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); - for (auto& elem : tensor_ptr->shape()) { + for (auto &elem : tensor_ptr->shape()) { tensor_proto->add_dims(elem); } } else if (val->isa()) { value_proto->set_dtype(irpb::DT_TYPE); - irpb::TypeProto* type_proto = value_proto->mutable_type_val(); + irpb::TypeProto *type_proto = value_proto->mutable_type_val(); type_proto->set_data_type(irpb::DT_TENSOR); TypePtr elem_type = dyn_cast(val)->element(); type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); @@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value } } -void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const BoolImmPtr& value = dyn_cast(val); + const BoolImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_BOOL); value_proto->set_bool_val(value->value()); } else if (val->isa()) { - const Int8ImmPtr& value = dyn_cast(val); + const Int8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT8); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int16ImmPtr& value = dyn_cast(val); + const Int16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT16); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int32ImmPtr& value = dyn_cast(val); + const Int32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT32); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int64ImmPtr& value = dyn_cast(val); + const Int64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT64); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const UInt8ImmPtr& value = dyn_cast(val); + const UInt8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT8); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt16ImmPtr& value = dyn_cast(val); + const UInt16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT16); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt32ImmPtr& value = dyn_cast(val); + const UInt32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT32); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt64ImmPtr& value = dyn_cast(val); + const UInt64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT64); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const FP32ImmPtr& value = dyn_cast(val); + const FP32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_FLOAT32); value_proto->set_float_val(value->value()); } else if (val->isa()) { - const FP64ImmPtr& value = dyn_cast(val); + const FP64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_FLOAT64); value_proto->set_double_val(value->value()); } else { @@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val } } -void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const ValueTuplePtr& value = dyn_cast(val); + const ValueTuplePtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_TUPLE); - for (const auto& item : value->value()) { + for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } else if (val->isa()) { - const ValueListPtr& value = dyn_cast(val); + const ValueListPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_LIST); - for (const auto& item : value->value()) { + for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } } -void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } value_proto->set_dtype(irpb::DT_DICT); - for (const auto& item : val->value()) { - irpb::NamedValueProto* named_val = value_proto->add_dict_val(); + for (const auto &item : val->value()) { + irpb::NamedValueProto *named_val = value_proto->add_dict_val(); named_val->set_key(item.first); SetValueToProto(item.second, named_val->mutable_value()); } } -void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) { +void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) { if (node == nullptr || node_proto == nullptr) { return; } @@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); } - const PrimitivePtr& prim = GetValueNode(node); + const PrimitivePtr &prim = GetValueNode(node); node_proto->set_op_type(prim->name()); - for (const auto& attr : prim->attrs()) { - irpb::AttributeProto* attr_proto = node_proto->add_attribute(); + for (const auto &attr : prim->attrs()) { + irpb::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name(attr.first); SetValueToProto(attr.second, attr_proto->mutable_value()); } node_proto->set_scope(node->scope()->name()); } -std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node, - const std::map& apply_map, - std::map* const_map_ptr) { +std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr) { if (node == nullptr || const_map_ptr == nullptr) { return ""; } @@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; } -std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { +std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ""; } InitModelInfo(); - irpb::GraphProto* graph_proto = model_.mutable_graph(); + irpb::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_.SerializeAsString(); } -void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } @@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP ExportValueNodes(const_map, graph_proto); } -void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } std::vector parameters = func_graph->parameters(); - for (auto& param : parameters) { - irpb::ParameterProto* param_proto = graph_proto->add_parameters(); + for (auto ¶m : parameters) { + irpb::ParameterProto *param_proto = graph_proto->add_parameters(); param_proto->set_name(param->ToString()); SetNodeOutputType(param, param_proto->mutable_type()); @@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph } } -void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, - std::map* const_map_ptr) { +void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, + std::map *const_map_ptr) { if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { return; } // topo sort nodes std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::map apply_map; - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt } } -void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* apply_map_ptr, - std::map* const_map_ptr, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *apply_map_ptr, + std::map *const_map_ptr, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || graph_proto == nullptr) { return; @@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& auto apply_idx = apply_map_ptr->size() + 1; (*apply_map_ptr)[node] = apply_idx; - auto& inputs = node->inputs(); + auto &inputs = node->inputs(); if (inputs.size() < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } AnfNodePtr op = inputs[0]; - irpb::NodeProto* node_proto = graph_proto->add_node(); + irpb::NodeProto *node_proto = graph_proto->add_node(); // CNode/ConstGraph/Const/Parameter if (op->isa() || IsValueNode(op) || op->isa()) { @@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& // process OP inputs for (size_t i = 1; i < inputs.size(); ++i) { - irpb::InputProto* input_proto = node_proto->add_input(); + irpb::InputProto *input_proto = node_proto->add_input(); input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); input_proto->set_name(id); @@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& } } -void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, - const std::map& apply_map, - std::map* const_map_ptr, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, + std::map *const_map_ptr, irpb::GraphProto *graph_proto) { if (ret_node == nullptr || !ret_node->isa()) { MS_LOG(EXCEPTION) << "Graph return node is illegal"; } @@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } - irpb::OutputProto* output_proto = graph_proto->add_outputs(); + irpb::OutputProto *output_proto = graph_proto->add_outputs(); if (output_proto == nullptr) { MS_LOG(EXCEPTION) << "output_proto is nullptr"; } @@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const SetNodeOutputType(arg, output_proto->mutable_type()); } -static bool CompareValue(const std::pair& x, const std::pair& y) { +static bool CompareValue(const std::pair &x, const std::pair &y) { return x.second < y.second; } -void ProtoExporter::ExportValueNodes(const std::map& const_map, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportValueNodes(const std::map &const_map, irpb::GraphProto *graph_proto) { std::vector> nodes; (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), - [](const std::pair& item) { return item; }); + [](const std::pair &item) { return item; }); sort(nodes.begin(), nodes.end(), CompareValue); - for (auto& item : nodes) { + for (auto &item : nodes) { if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } - irpb::NamedValueProto* named_value = graph_proto->add_const_vals(); + irpb::NamedValueProto *named_value = graph_proto->add_const_vals(); MS_EXCEPTION_IF_NULL(named_value); named_value->set_key(GetConstNodeId(item.second)); SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); @@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map& const_m void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } -std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { +std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { ProtoExporter exporter; return exporter.GetFuncGraphProtoString(func_graph); } diff --git a/mindspore/ccsrc/debug/e2e_dump.cc b/mindspore/ccsrc/debug/e2e_dump.cc index fbe76cdc47..34d401191a 100644 --- a/mindspore/ccsrc/debug/e2e_dump.cc +++ b/mindspore/ccsrc/debug/e2e_dump.cc @@ -36,7 +36,7 @@ Dump::Dump() dump_iter_(0), cur_iter_(0) {} -bool Dump::IsKernelNeedDump(const std::string& kernel_name) { +bool Dump::IsKernelNeedDump(const std::string &kernel_name) { if (dump_mode_ == 0) { // Dump All Kernels mode return true; @@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) { return false; } -bool Dump::ParseDumpConfig(const std::string& dump_config_file) { +bool Dump::ParseDumpConfig(const std::string &dump_config_file) { std::ifstream jsonFile(dump_config_file); if (!jsonFile.is_open()) { MS_LOG(ERROR) << dump_config_file << " open failed."; @@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) { return true; } -bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { +bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) { if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || @@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { return true; } -bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { +bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { auto trans_flag = dumpSettings.at("trans_flag"); auto enable = dumpSettings.at("enable"); auto mode = dumpSettings.at("mode"); @@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { dump_path_ = path; dump_net_name_ = net_name; dump_iter_ = iteration; - for (const auto& kernel : kernels) { + for (const auto &kernel : kernels) { dump_kernels_.push_back(kernel); } return true; } bool Dump::SetDumpConfFromJsonFile() { - const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); + const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); if (config_path_str != nullptr) { MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; } else { @@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() { return ParseDumpConfig(dump_config_file); } -bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { +bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) { if (filename.empty() || data == nullptr || len == 0) { MS_LOG(ERROR) << "Incorrect parameter."; return false; @@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) MS_LOG(ERROR) << "Open file " << realpath << " fail."; return false; } - (void)fd.write(reinterpret_cast(data), SizeToLong(len)); + (void)fd.write(reinterpret_cast(data), SizeToLong(len)); fd.close(); return true; } -bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { +bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) { MS_EXCEPTION_IF_NULL(outpath); auto path_split_pos = inpath.find_last_of('/'); if (path_split_pos == std::string::npos) { @@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { return true; } -bool Dump::CreateNotExistDirs(const std::string& path) { +bool Dump::CreateNotExistDirs(const std::string &path) { std::shared_ptr fs = system::Env::GetFileSystem(); MS_EXCEPTION_IF_NULL(fs); char temp_path[PATH_MAX] = {0}; diff --git a/mindspore/ccsrc/debug/e2e_dump.h b/mindspore/ccsrc/debug/e2e_dump.h index 2410dfb09a..4c3e8308da 100644 --- a/mindspore/ccsrc/debug/e2e_dump.h +++ b/mindspore/ccsrc/debug/e2e_dump.h @@ -43,11 +43,11 @@ class Dump { uint32_t cur_iter() const { return cur_iter_; } - bool IsKernelNeedDump(const std::string& kernel_name); + bool IsKernelNeedDump(const std::string &kernel_name); bool SetDumpConfFromJsonFile(); - static bool DumpToFile(const std::string& filename, const void* data, size_t len); + static bool DumpToFile(const std::string &filename, const void *data, size_t len); protected: bool dump_enable_; @@ -59,14 +59,14 @@ class Dump { uint32_t cur_iter_; std::vector dump_kernels_; - static bool GetRealPath(const std::string& inpath, std::string* outpath); + static bool GetRealPath(const std::string &inpath, std::string *outpath); - static bool CreateNotExistDirs(const std::string& path); + static bool CreateNotExistDirs(const std::string &path); private: - bool ParseDumpConfig(const std::string& dump_config_file); - bool IsConfigExist(const nlohmann::json& dumpSettings); - bool IsConfigValid(const nlohmann::json& dumpSettings); + bool ParseDumpConfig(const std::string &dump_config_file); + bool IsConfigExist(const nlohmann::json &dumpSettings); + bool IsConfigValid(const nlohmann::json &dumpSettings); }; using DumpConfPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/debug/info.cc b/mindspore/ccsrc/debug/info.cc index 3c43bfa9b1..7903e554d9 100644 --- a/mindspore/ccsrc/debug/info.cc +++ b/mindspore/ccsrc/debug/info.cc @@ -23,7 +23,7 @@ #include "pipeline/parse/python_adapter.h" namespace mindspore { -std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { +std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { std::string temp_line = line; if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && tip != kSourceLineTipDiscard) { @@ -101,14 +101,14 @@ DebugInfo::DebugInfo() { name_ = ""; } -DebugInfo::DebugInfo(const std::string& name) { +DebugInfo::DebugInfo(const std::string &name) { InitValueFromContext(); unique_id_ = gen_unique_id(); debug_id_ = -1; name_ = name; } -DebugInfo::DebugInfo(const LocationPtr& loc) { +DebugInfo::DebugInfo(const LocationPtr &loc) { InitValueFromContext(); unique_id_ = gen_unique_id(); debug_id_ = -1; @@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() { } int64_t DebugInfo::unique_id_through_copy() const { - TraceInfoPtr trace_info = const_cast(this)->trace_info(); + TraceInfoPtr trace_info = const_cast(this)->trace_info(); if (trace_info != nullptr) { if (trace_info->isa() && trace_info->debug_info() != nullptr) { return trace_info->debug_info()->unique_id_through_copy(); @@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() { } return DebugInfo::location(); } -void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } +void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } TraceContextPtr TraceManager::CurrentContextInfo() { if (!TraceManager::trace_context_stack_.empty()) { @@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() { return nullptr; } -void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { +void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { TraceContextPtr context = std::make_shared(location); context->set_func_name(func_name); TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const LocationPtr& location) { +void TraceManager::DebugTrace(const LocationPtr &location) { TraceContextPtr context = std::make_shared(location); TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { +void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { if (trace_info == nullptr) { MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; } @@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { +void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { if (trace_info == nullptr) { MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; } diff --git a/mindspore/ccsrc/debug/info.h b/mindspore/ccsrc/debug/info.h index da641ab74b..a34d6e3df5 100644 --- a/mindspore/ccsrc/debug/info.h +++ b/mindspore/ccsrc/debug/info.h @@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou // Location class record the location in source code. class Location { public: - Location(const std::string& file_name, int line, int column, int line_end, int column_end) + Location(const std::string &file_name, int line, int column, int line_end, int column_end) : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} - Location(const Location& loc) + Location(const Location &loc) : file_name_(loc.file_name_), line_(loc.line_), column_(loc.column_), @@ -77,21 +77,21 @@ class TraceManager { TraceManager() = default; ~TraceManager() = default; static TraceContextPtr CurrentContextInfo(); - static void DebugTrace(const std::string& func_name, const LocationPtr& location); - static void DebugTrace(const LocationPtr& location); - static void DebugTrace(const TraceInfoPtr& trace_info); + static void DebugTrace(const std::string &func_name, const LocationPtr &location); + static void DebugTrace(const LocationPtr &location); + static void DebugTrace(const TraceInfoPtr &trace_info); // debug trace with a cloned trace info with debug_info - static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); + static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); static void EndTrace(); static std::stack trace_context_stack_; }; class TraceGuard { public: - explicit TraceGuard(const std::string func_name, const LocationPtr& location) { + explicit TraceGuard(const std::string func_name, const LocationPtr &location) { TraceManager::DebugTrace(func_name, location); } - explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } + explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } ~TraceGuard() { TraceManager::EndTrace(); } }; @@ -106,23 +106,23 @@ class TraceContext { public: ~TraceContext() = default; - explicit TraceContext(const LocationPtr& loc) { + explicit TraceContext(const LocationPtr &loc) { ProcessAttributeFromContext(); location_ = loc; } - explicit TraceContext(const std::string& func_name) { + explicit TraceContext(const std::string &func_name) { ProcessAttributeFromContext(); func_name_ = func_name; } - explicit TraceContext(const TraceInfoPtr& trace_info) { + explicit TraceContext(const TraceInfoPtr &trace_info) { ProcessAttributeFromContext(); trace_info_ = trace_info; } - void set_location(const LocationPtr& loc) { location_ = loc; } + void set_location(const LocationPtr &loc) { location_ = loc; } LocationPtr location() { return location_; } - void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } TraceInfoPtr trace_info() { return trace_info_; } - void set_func_name(const std::string& func_name) { func_name_ = func_name; } + void set_func_name(const std::string &func_name) { func_name_ = func_name; } std::string func_name() { return func_name_; } }; @@ -130,9 +130,9 @@ class DebugInfo : public Base { public: DebugInfo(); - explicit DebugInfo(const std::string& name); + explicit DebugInfo(const std::string &name); - explicit DebugInfo(const LocationPtr& loc); + explicit DebugInfo(const LocationPtr &loc); virtual ~DebugInfo() = default; MS_DECLARE_PARENT(DebugInfo, Base); @@ -141,12 +141,12 @@ class DebugInfo : public Base { int64_t unique_id_through_copy() const; std::string get_id() { return std::to_string(debug_id()); } - void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } TraceInfoPtr trace_info() { return trace_info_; } - void set_location(const LocationPtr& loc) { location_ = loc; } + void set_location(const LocationPtr &loc) { location_ = loc; } virtual LocationPtr location() { return location_; } std::string name() { return name_; } - void set_name(const std::string& name) { name_ = name; } + void set_name(const std::string &name) { name_ = name; } virtual std::string debug_name(); virtual std::string get_python_func_belonged() { return ""; } @@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo { py_func_belonged_ = context_info->func_name(); } } - explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { + explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { if (TraceManager::CurrentContextInfo() != nullptr) { auto context_info = TraceManager::CurrentContextInfo(); py_func_belonged_ = context_info->func_name(); @@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo { ~NodeDebugInfo() override = default; std::string debug_name() override; - void set_node(const std::shared_ptr& node) { node_ = AnfNodeWeakPtr(node); } + void set_node(const std::shared_ptr &node) { node_ = AnfNodeWeakPtr(node); } std::shared_ptr get_node() const { return node_.lock(); } - void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } + void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } std::string get_python_func_belonged() override { return py_func_belonged_; } AnfNodeWeakPtr node_; std::string py_func_belonged_; @@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo { } } - explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { + explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { if (TraceManager::CurrentContextInfo() != nullptr) { auto context_info = TraceManager::CurrentContextInfo(); py_func_name_ = context_info->func_name(); @@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo { std::string debug_name() override; LocationPtr location() override; LocationPtr deco_location() { return deco_loc_; } - void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } + void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } FuncGraphPtr get_graph() const { return func_graph_.lock(); } - void set_full_name(const std::string& name) { full_name_ = name; } + void set_full_name(const std::string &name) { full_name_ = name; } std::string get_full_name() { return full_name_; } - void set_deco_location(const LocationPtr& deco_list_loc); + void set_deco_location(const LocationPtr &deco_list_loc); std::string get_python_func_belonged() override { return py_func_name_; } FuncGraphWeakPtr func_graph_; LocationPtr deco_loc_; diff --git a/mindspore/ccsrc/debug/label.cc b/mindspore/ccsrc/debug/label.cc index f0e16e831e..d8c4986482 100644 --- a/mindspore/ccsrc/debug/label.cc +++ b/mindspore/ccsrc/debug/label.cc @@ -31,7 +31,7 @@ struct NameWithTrace { std::string name; std::vector trace_labels; }; -static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { +static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { switch (trace_label) { case TraceLabelType::kShortSymbol: return trace_info->symbol(); @@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t } } -NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { +NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { NameWithTrace trace_name; // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node auto temp_info = debug_info; @@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe return trace_name; } -std::string CombineTraceTypes(const std::string& root_name, const std::vector& trace_labels) { +std::string CombineTraceTypes(const std::string &root_name, const std::vector &trace_labels) { std::string tags = ""; - for (auto& itr : trace_labels) { + for (auto &itr : trace_labels) { std::string symbol = itr; tags = tags + symbol; } @@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { return debug_with_loc_vec; } -DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { +DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); if (debug_with_loc_vec.size() > 0) { return debug_with_loc_vec[0]; @@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { } } -std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { if (info == nullptr) { return ""; } @@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { // a trace info identifies a node transform, so we can trace the node transform through // a link of trace info and debug info -std::string GetInfoWithAction(const std::vector& info_vec, SourceLineTip tip) { +std::string GetInfoWithAction(const std::vector &info_vec, SourceLineTip tip) { if (info_vec.size() < 1) { return ""; } @@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector& info_vec, SourceL return traced_info; } -std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { +std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { if (info == nullptr) { return ""; } @@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { return ""; } -std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { std::ostringstream oss; if (info == nullptr) { return ""; @@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So return oss.str(); } -std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { +std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) { std::ostringstream oss; oss << "graph:" << graph->ToString() << " with args["; auto params = graph->parameters(); @@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas return oss.str(); } -void DumpInferStack(std::ostringstream& oss) { - auto& infer_stack = GetCurrenGraphInferStack(); +void DumpInferStack(std::ostringstream &oss) { + auto &infer_stack = GetCurrenGraphInferStack(); if (infer_stack.empty()) { return; } @@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) { } std::reverse(infer_vec.begin(), infer_vec.end()); int index = 0; - for (auto& item : infer_vec) { + for (auto &item : infer_vec) { auto graph_infer = std::dynamic_pointer_cast(item.first); if (graph_infer == nullptr) { MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; @@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) { } void TraceGraphInfer() { - auto& infer_stack = GetCurrenGraphInferStack(); + auto &infer_stack = GetCurrenGraphInferStack(); std::ostringstream oss; if (infer_stack.empty()) { return; @@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter { AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} ~AnalyzedFuncGraphExporter() override = default; - void ExportFuncGraph(const std::string& filename, const std::vector& node_cfgs); + void ExportFuncGraph(const std::string &filename, const std::vector &node_cfgs); private: - std::string GetNodeType(const AnfNodePtr& nd) override; + std::string GetNodeType(const AnfNodePtr &nd) override; }; std::unordered_map CalcTaggedFuncGraphs() { std::unordered_map tagged_func_graphs; - auto& list = GetCNodeDebugStack(); + auto &list = GetCNodeDebugStack(); for (size_t i = 0; i < list.size(); ++i) { auto node_cfg = list[i]; auto fg = node_cfg->context()->func_graph(); @@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() { exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); } -std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { +std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { if (node_cfg_ == nullptr) { return AnfExporter::GetNodeType(node); } @@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { return oss.str(); } -void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, - const std::vector& node_cfgs) { +void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, + const std::vector &node_cfgs) { if (node_cfgs.empty()) { MS_LOG(DEBUG) << "Node configs is empty"; return; @@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, auto tagged_func_graphs = CalcTaggedFuncGraphs(); // first output graph on the analysis stack - for (const auto& node_cfg : node_cfgs) { + for (const auto &node_cfg : node_cfgs) { auto fg = node_cfg->context()->func_graph(); // the graph is already output, skip it if (exported.find(fg) != exported.end()) { @@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, ofs.close(); } -void GetInferStackInfo(std::ostringstream& oss) { +void GetInferStackInfo(std::ostringstream &oss) { MS_LOG(INFO) << "Get graph analysis information begin"; auto stack = GetCNodeDebugStack(); if (stack.empty()) { @@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) { static std::stack> graph_infer_stack; // trace the cnode infer debug info static std::vector cnode_debug_stack{}; -void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { +void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { if (eval == nullptr) { MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; } @@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An } } -void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { +void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { if (eval == nullptr) { MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; } @@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { } } -void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } +void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } -std::vector& GetCNodeDebugStack() { return cnode_debug_stack; } +std::vector &GetCNodeDebugStack() { return cnode_debug_stack; } -std::stack>& GetCurrenGraphInferStack() { +std::stack> &GetCurrenGraphInferStack() { return graph_infer_stack; } void ClearTraceStack() { diff --git a/mindspore/ccsrc/debug/trace.h b/mindspore/ccsrc/debug/trace.h index 5fba86fddd..2704a80a35 100644 --- a/mindspore/ccsrc/debug/trace.h +++ b/mindspore/ccsrc/debug/trace.h @@ -31,19 +31,19 @@ namespace mindspore { namespace trace { -std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); -std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip = kSourceLineTipNextLine); -DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); +DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); void TraceGraphInfer(); -void GetInferStackInfo(std::ostringstream& oss); -void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); -void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); -void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); +void GetInferStackInfo(std::ostringstream &oss); +void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); +void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); +void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); void TraceInferCNodeLeave(); -std::vector& GetCNodeDebugStack(); -std::stack>& GetCurrenGraphInferStack(); -std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); +std::vector &GetCNodeDebugStack(); +std::stack> &GetCurrenGraphInferStack(); +std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); void ClearTraceStack(); } // namespace trace } // namespace mindspore diff --git a/mindspore/ccsrc/debug/trace_info.cc b/mindspore/ccsrc/debug/trace_info.cc index b01cd15010..19358e197a 100644 --- a/mindspore/ccsrc/debug/trace_info.cc +++ b/mindspore/ccsrc/debug/trace_info.cc @@ -23,7 +23,7 @@ #include "pipeline/parse/python_adapter.h" namespace mindspore { -std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { +std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { if (info == nullptr) { return ""; } diff --git a/mindspore/ccsrc/debug/trace_info.h b/mindspore/ccsrc/debug/trace_info.h index 16be9031e2..e7a8c83dad 100644 --- a/mindspore/ccsrc/debug/trace_info.h +++ b/mindspore/ccsrc/debug/trace_info.h @@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr; // namespace to support intermediate representation definition class TraceInfo : public Base { public: - TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) { + TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) { symbol_ = symbol; full_name_ = full_name; name_ = full_name_; debug_info_ = info; } - TraceInfo(const TraceInfo& info) + TraceInfo(const TraceInfo &info) : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} virtual ~TraceInfo() = default; MS_DECLARE_PARENT(TraceInfo, Base); @@ -55,8 +55,8 @@ class TraceInfo : public Base { virtual std::string full_name() { return full_name_; } virtual TraceInfoPtr clone() { return shared_from_base(); } virtual std::string action_name() { return ""; } - virtual std::string GetActionBetweenNode(const DebugInfoPtr& info); - void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; } + virtual std::string GetActionBetweenNode(const DebugInfoPtr &info); + void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; } DebugInfoPtr debug_info() { return debug_info_; } DebugInfoPtr DebugInfoHasLoc(); std::vector> GetSourceCodeDebugInfo(); @@ -70,7 +70,7 @@ class TraceInfo : public Base { class TracePhi : public TraceInfo { public: - explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {} + explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {} MS_DECLARE_PARENT(TracePhi, TraceInfo); ~TracePhi() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -78,8 +78,8 @@ class TracePhi : public TraceInfo { class TraceIfStmtTrueBranch : public TraceInfo { public: - TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default; - explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {} + TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default; + explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {} MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); ~TraceIfStmtTrueBranch() override = default; TraceInfoPtr clone() override { @@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo { class TraceIfStmtFalseBranch : public TraceInfo { public: - TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default; - explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {} + TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default; + explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {} MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); ~TraceIfStmtFalseBranch() override = default; TraceInfoPtr clone() override { @@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo { class TraceIfStmtAfterBranch : public TraceInfo { public: - explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {} + explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {} MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); ~TraceIfStmtAfterBranch() override = default; TraceInfoPtr clone() override { @@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo { class TraceIfExpTrueBranch : public TraceInfo { public: - explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {} + explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {} MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); ~TraceIfExpTrueBranch() override = default; TraceInfoPtr clone() override { @@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo { class TraceIfExpFalseBranch : public TraceInfo { public: - explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {} + explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {} MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); ~TraceIfExpFalseBranch() override = default; TraceInfoPtr clone() override { @@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo { class TraceCopy : public TraceInfo { public: TraceCopy() : TraceInfo(nullptr, "copy", "") {} - explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {} + explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {} MS_DECLARE_PARENT(TraceCopy, TraceInfo); ~TraceCopy() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo { class TraceIterator : public TraceInfo { public: - explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {} + explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {} MS_DECLARE_PARENT(TraceIterator, TraceInfo); ~TraceIterator() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo { class TraceWhileHeader : public TraceInfo { public: - explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {} + explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {} MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); ~TraceWhileHeader() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo { class TraceWhileBody : public TraceInfo { public: - explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {} + explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {} MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); ~TraceWhileBody() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo { class TraceWhileAfter : public TraceInfo { public: - explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {} + explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {} MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); ~TraceWhileAfter() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo { class TraceForHeader : public TraceInfo { public: - explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {} + explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {} MS_DECLARE_PARENT(TraceForHeader, TraceInfo); ~TraceForHeader() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo { class TraceForBody : public TraceInfo { public: - explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {} + explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {} MS_DECLARE_PARENT(TraceForBody, TraceInfo); ~TraceForBody() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo { class TraceForAfter : public TraceInfo { public: - explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {} + explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {} MS_DECLARE_PARENT(TraceForAfter, TraceInfo); ~TraceForAfter() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo { class TraceEquiv : public TraceInfo { public: - explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {} + explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} MS_DECLARE_PARENT(TraceEquiv, TraceInfo); ~TraceEquiv() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo { class TraceGradFpropApp : public TraceInfo { public: TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} - explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {} + explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {} MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); ~TraceGradFpropApp() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo { class TraceGradBpropApp : public TraceInfo { public: TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} - explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {} + explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {} MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); ~TraceGradBpropApp() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo { class TraceGradFprop : public TraceInfo { public: TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} - explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {} + explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {} MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); ~TraceGradFprop() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo { class TraceGradBprop : public TraceInfo { public: TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} - explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {} + explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {} MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); ~TraceGradBprop() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo { class TraceGradSens : public TraceInfo { public: TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} - explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {} + explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {} MS_DECLARE_PARENT(TraceGradSens, TraceInfo); ~TraceGradSens() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo { class TraceSpecialize : public TraceInfo { public: - explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } + explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); std::string name() override { return full_name_ + counter_; } std::string symbol() override { return counter_ + "_"; } @@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo { class TraceGradOperation : public TraceInfo { public: - explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {} + explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {} MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); ~TraceGradOperation() override = default; TraceInfoPtr clone() override { @@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo { class TraceForceBool : public TraceInfo { public: - explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {} + explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {} MS_DECLARE_PARENT(TraceForceBool, TraceInfo); ~TraceForceBool() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo { class TraceExpandJ : public TraceInfo { public: - explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {} + explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {} MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); ~TraceExpandJ() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo { class TraceGenMetaFuncGraph : public TraceInfo { public: - explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {} + explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {} MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); ~TraceGenMetaFuncGraph() override = default; TraceInfoPtr clone() override { @@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo { class TraceEvaluatorGenGraph : public TraceInfo { public: - explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {} + explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {} MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); ~TraceEvaluatorGenGraph() override = default; TraceInfoPtr clone() override { @@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo { class TraceResolve : public TraceInfo { public: - explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {} + explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {} MS_DECLARE_PARENT(TraceResolve, TraceInfo); ~TraceResolve() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo { class TraceTransform : public TraceInfo { public: TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } - explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") { + explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") { transform_name_ = transform_name; } @@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo { class TraceGenerateVarArg : public TraceInfo { public: - explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {} + explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {} MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); ~TraceGenerateVarArg() override = default; TraceInfoPtr clone() override { @@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo { class TraceGenerateKwArg : public TraceInfo { public: - explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {} + explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {} MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); ~TraceGenerateKwArg() override = default; TraceInfoPtr clone() override { @@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo { class TraceTrasformK : public TraceInfo { public: - explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {} + explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {} MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); ~TraceTrasformK() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo { class TracePartialTransform : public TraceInfo { public: - explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {} + explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {} MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); ~TracePartialTransform() override = default; TraceInfoPtr clone() override { @@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo { class TraceGetEnv : public TraceInfo { public: - explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {} + explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {} MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); ~TraceGetEnv() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo { class TraceDoSignature : public TraceInfo { public: - explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {} + explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {} MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); ~TraceDoSignature() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo { class TraceCombileLikeGraphs : public TraceInfo { public: TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} - explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {} + explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {} MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); ~TraceCombileLikeGraphs() override = default; TraceInfoPtr clone() override { diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index 79241df612..df49400341 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); } else { - auto shape_size = trans::ShapeSize(host_shape); + auto host_size = trans::ShapeSize(host_shape); auto host = std::vector(size_); SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; + const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { MS_LOG(ERROR) << "trans data type failed."; @@ -153,14 +153,15 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector(size_); sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; + auto host_size = trans::ShapeSize(host_shape); + auto device_size = trans::ShapeSize(device_shape); + const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, device_size, host_size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } } else { @@ -168,7 +169,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); } else { - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_}; + auto host_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size}; auto host_tmp = std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans data type failed."; + MS_LOG(ERROR) << "Trans data type failed."; return false; } SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); @@ -234,12 +235,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans datatype failed."; + MS_LOG(ERROR) << "Trans datatype failed."; return false; } const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, @@ -247,7 +249,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransFormat(format_args, dst_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); @@ -256,7 +258,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransFormat(format_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 935e694636..44cf3f8fa8 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -283,18 +283,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); // the streams' flag not HEAD_STREAM - std::vector wait_active_stream_list = assign_instance.GetWaitStreams(); - std::vector force_copy_stream_list = assign_instance.GetHcomStreams(); + std::vector wait_active_stream_list; + assign_instance.GetWaitStreams(&wait_active_stream_list); + auto force_copy_stream_list = assign_instance.hcom_streams(); MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() - << ", total event num:" << assign_instance.GetTotalEventNum() + << ", total event num:" << assign_instance.total_event_num() << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); std::vector> empty_list; std::shared_ptr model = std::make_shared( task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0); + 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc index 2c38e4290d..69c6dca576 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc +++ b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace device { namespace ascend { -size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { +size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { if (has_malloc_) { MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; } @@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { return size; } -bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) { +bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { MS_EXCEPTION_IF_NULL(addr); has_malloc_ = false; free_mem_size_ = total_mem_size_; @@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } -void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { +void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { MS_EXCEPTION_IF_NULL(device_mem_pool_base); device_mem_pool_base_ = device_mem_pool_base; } diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h index a02bd453b2..7fa3ebc23e 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h +++ b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h @@ -26,12 +26,12 @@ namespace ascend { class AscendMemoryPool : public DynamicMemPoolBestFit { public: ~AscendMemoryPool() override = default; - AscendMemoryPool(const AscendMemoryPool&) = delete; - AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; + AscendMemoryPool(const AscendMemoryPool &) = delete; + AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; - size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; - bool FreeDeviceMem(const DeviceMemPtr& addr) override; - void set_device_mem_pool_base(uint8_t* device_mem_pool_base); + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + void set_device_mem_pool_base(uint8_t *device_mem_pool_base); void set_device_mem_pool_size(uint64_t device_mem_pool_size) { device_mem_pool_size_ = device_mem_pool_size; free_mem_size_ = device_mem_pool_size_; @@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { size_t free_mem_size() override; size_t total_mem_size() override; - static AscendMemoryPool& GetInstance() { + static AscendMemoryPool &GetInstance() { static AscendMemoryPool instance; return instance; } @@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { private: AscendMemoryPool() = default; bool has_malloc_{false}; - uint8_t* device_mem_pool_base_{nullptr}; + uint8_t *device_mem_pool_base_{nullptr}; uint64_t device_mem_pool_size_{0}; size_t free_mem_size_{0}; size_t total_mem_size_{0}; diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index 8c4d1f4a8f..e2cf469cd8 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -25,8 +25,8 @@ #include "session/anf_runtime_algorithm.h" #include "device/kernel_adjust.h" #include "predict/generator/utils/ir_model_util.h" -#include "device/kernel_info.h" #include "pre_activate/common/helper.h" +#include "utils/utils.h" namespace mindspore { namespace device { @@ -54,6 +54,7 @@ void AscendStreamAssign::ResetNew() { inner_parallel_streams_.clear(); processed_parallel_streams_.clear(); hcom_stream_list_.clear(); + need_first_active_streams_.clear(); } void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) { @@ -200,13 +201,12 @@ void AscendStreamAssign::AssignAllNodesStream(const shared_ptr AscendStreamAssign::TransLogicToPhysic(const vector &logic_ids) { - vector physic_ids; +void AscendStreamAssign::TransLogicToPhysic(const vector &logic_ids, vector *physic_ids) { for (auto &id : logic_ids) { auto it = logic_to_physic_map_.find(id); if (it != logic_to_physic_map_.end()) { MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]"; - physic_ids.push_back(it->second); + (*physic_ids).push_back(it->second); } else { MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id"; } @@ -214,10 +214,9 @@ vector AscendStreamAssign::TransLogicToPhysic(const vector & auto it_independ = logic_to_independent_map_.find(id); if (it_independ != logic_to_independent_map_.end()) { MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]"; - physic_ids.push_back(it_independ->second); + (*physic_ids).push_back(it_independ->second); } } - return physic_ids; } void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { @@ -227,7 +226,8 @@ void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { MS_EXCEPTION_IF_NULL(primitive); vector active_logic_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); // out StreamAcitve active physic stream is not parallel now, if parallel, should deal here. - vector active_physic_ids = TransLogicToPhysic(active_logic_ids); + vector active_physic_ids; + TransLogicToPhysic(active_logic_ids, &active_physic_ids); ValuePtr active_physic_value = MakeValue>(active_physic_ids); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr); } @@ -242,7 +242,8 @@ void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CN MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id << "]"; vector logic_ids{true_logic_id}; - vector physic_ids = TransLogicToPhysic(logic_ids); + vector physic_ids; + TransLogicToPhysic(logic_ids, &physic_ids); if (physic_ids.empty()) { MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id"; } @@ -334,8 +335,8 @@ bool AscendStreamAssign::IsProcessedParallelStream(uint32_t stream_id) { return false; } -vector AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id) { - vector parallel_streams; +void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, + vector *parallel_streams) { for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { auto cur_parallel_streams = inner_parallel_streams_[i]; auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); @@ -347,17 +348,17 @@ vector AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, u << "is same with streamacvite stream id" << stream_acitve_id; continue; } - parallel_streams.emplace_back(cur_parallel_streams[j]); + (*parallel_streams).emplace_back(cur_parallel_streams[j]); } // record processed parallel streams - (void)std::copy(parallel_streams.begin(), parallel_streams.end(), + (void)std::copy((*parallel_streams).begin(), (*parallel_streams).end(), std::back_inserter(processed_parallel_streams_)); - return parallel_streams; + return; } } - return vector{cur_stream_id}; + (*parallel_streams).push_back(cur_stream_id); } void AscendStreamAssign::InsertActiveNew(const std::shared_ptr &graph_ptr) { @@ -379,30 +380,32 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr active_index_list = GetParallelStream(cur_stream_id, pre_stream_id); + std::vector active_index_list; + GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); - } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive" && - AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { + } + // inner_active is not a if/else relationship with the next if/else. such as:StreamActive(S7)-->StreamActive(S8) + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName && + AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { // 2)outter stream assign, update active op update_cnode_list.emplace_back(cur_cnode_ptr); UpdateStreamActive(cur_cnode_ptr); - } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamSwitch") { + } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { // 3)update switch op MS_LOG(INFO) << "Insert active op after switch"; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateSteamActiveOp(graph_ptr); + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); update_cnode_list.emplace_back(cur_cnode_ptr); update_cnode_list.emplace_back(active_ptr); UpdateStreamSwitch(cur_cnode_ptr, active_ptr); @@ -417,6 +420,37 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr &graph_ptr) { + MS_LOG(INFO) << "start"; + MS_EXCEPTION_IF_NULL(graph_ptr); + CNodePtr cur_cnode_ptr = nullptr; + // key:virutal event id, value:real event id + std::unordered_map event_id_map; + uint32_t event_id; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(primitive); + event_id = GetValue(primitive->GetAttr(kAttrEventId)); + // before stream assign, send/recv event_id assign from kFirstEventId + if (event_id < kFirstEventId) { + continue; + } + auto it = event_id_map.find(event_id); + if (it == event_id_map.end()) { + event_id_map.insert(std::make_pair(event_id, total_event_num_)); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(total_event_num_), cur_cnode_ptr); + total_event_num_++; + } else { + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(it->second), cur_cnode_ptr); + } + } + } +} + void AscendStreamAssign::UpdateStreamId(const shared_ptr &graph_ptr) { MS_LOG(INFO) << "start"; MS_EXCEPTION_IF_NULL(graph_ptr); @@ -427,7 +461,7 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr & MS_EXCEPTION_IF_NULL(cur_cnode_ptr); uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); if (cur_stream_id < kIndependFirstStreamId) { - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive") { + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName) { auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(primitive); vector active_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); @@ -471,6 +505,29 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr & MS_LOG(INFO) << "end"; } +void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr &graph_ptr) { + MS_EXCEPTION_IF_NULL(graph_ptr); + CNodePtr cur_cnode_ptr = nullptr; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(primitive); + auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); + if (value_ptr == nullptr) { + continue; + } + + auto need_active = GetValue(value_ptr); + if (need_active) { + auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + MS_LOG(INFO) << "stream id:" << stream_id << " is need actived at first"; + need_first_active_streams_.push_back(stream_id); + } + } +} + void AscendStreamAssign::AssignStreamNew(const shared_ptr &graph_ptr) { if (IsTaskSink()) { ResetNew(); @@ -480,13 +537,15 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr InsertSendRecvForHcomParallel(graph_ptr); InsertSendRecvForIndependent(graph_ptr); UpdateStreamId(graph_ptr); + UpdateEventId(graph_ptr); + GetNeedActiveStreams(graph_ptr); MS_LOG(INFO) << "after finish stream assign"; PrintGraphExeOrders(graph_ptr); // Get info for D Model - generator::IRModelUtil::GetInstance().set_event_num(GetTotalEventNum()); - generator::IRModelUtil::GetInstance().set_stream_num(GetTotalCommonStreamNum() + GetTotalIndependStreamNum()); + generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); + generator::IRModelUtil::GetInstance().set_stream_num(total_common_stream_num() + total_independ_stream_num()); // Init to 1,temporarily generator::IRModelUtil::GetInstance().set_batch_num(1); } @@ -495,7 +554,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id) { MS_EXCEPTION_IF_NULL(graph_ptr); - auto send_op = std::make_shared("Send"); + auto send_op = std::make_shared(kSendOpName); MS_EXCEPTION_IF_NULL(send_op); auto send_apply = std::make_shared(send_op); MS_EXCEPTION_IF_NULL(send_apply); @@ -505,7 +564,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr(); MS_EXCEPTION_IF_NULL(abstract_none); send_node_ptr->set_abstract(abstract_none); @@ -516,7 +575,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id) { MS_EXCEPTION_IF_NULL(graph_ptr); - auto recv_op = std::make_shared("Recv"); + auto recv_op = std::make_shared(kRecvOpName); MS_EXCEPTION_IF_NULL(recv_op); auto recv_apply = std::make_shared(recv_op); MS_EXCEPTION_IF_NULL(recv_apply); @@ -526,7 +585,7 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr(); MS_EXCEPTION_IF_NULL(abstract_none); @@ -605,7 +664,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { return false; } - if (AnfAlgo::GetCNodeName(node_ptr) == "GetNext") { + if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { MS_LOG(INFO) << "GetNext should not be independent node"; return false; } @@ -638,20 +697,23 @@ bool AscendStreamAssign::IsTaskSink() { } } -std::vector AscendStreamAssign::GetWaitStreams() { - vector wait_active_stream_list; +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { if (total_common_stream_num_ == 0) { MS_LOG(INFO) << "total_common_stream_num is zero"; - return wait_active_stream_list; + return; } // common stream:active first common stream MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]"; for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) { - MS_LOG(INFO) << "wait common stream id = " << i; - wait_active_stream_list.push_back(i); + auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); + if (it == need_first_active_streams_.end()) { + MS_LOG(INFO) << "wait common stream id = " << i; + (*wait_active_stream_list).push_back(i); + } } + // all independ stream id before first physical stream id should be actived auto it = logic_to_independent_map_.find(first_logic_id_); if (it != logic_to_independent_map_.end()) { uint32_t independent_id = it->second; @@ -675,16 +737,14 @@ std::vector AscendStreamAssign::GetWaitStreams() { if (i + total_common_stream_num_ <= max_before_physic) { continue; } - MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; - wait_active_stream_list.push_back(i + total_common_stream_num_); + // all wait streams should not in need_first_active_streams_ + auto iter = + std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i + total_common_stream_num_); + if (iter == need_first_active_streams_.end()) { + MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; + (*wait_active_stream_list).push_back(i + total_common_stream_num_); + } } - - return wait_active_stream_list; -} - -std::vector AscendStreamAssign::GetHcomStreams() { - MS_LOG(INFO) << "hcom total stream nums:" << hcom_stream_list_.size(); - return hcom_stream_list_; } uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; } @@ -695,7 +755,7 @@ void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr& graph_ptr); - void AssignAllNodesStream(const std::shared_ptr& graph_ptr); + void InsertActiveNew(const std::shared_ptr &graph_ptr); + void AssignAllNodesStream(const std::shared_ptr &graph_ptr); void ResetNew(); - void AssignStreamNew(const std::shared_ptr& graph_ptr); - bool IsIndependentNode(const CNodePtr& node_ptr); - const std::unordered_map GetIndependentMap() { return logic_to_independent_map_; } - const std::unordered_map GetPhysicMap() { return logic_to_physic_map_; } - std::vector GetWaitStreams(); - std::vector GetHcomStreams(); + void AssignStreamNew(const std::shared_ptr &graph_ptr); + bool IsIndependentNode(const CNodePtr &node_ptr); + const std::unordered_map &logic_to_independent_map() { return logic_to_independent_map_; } + const std::unordered_map &logic_to_physic_map() { return logic_to_physic_map_; } + const std::vector> &inner_parallel_streams() { return inner_parallel_streams_; } + void GetWaitStreams(vector *wait_active_stream_list); + const std::vector &hcom_streams() { return hcom_stream_list_; } + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, + uint32_t stream_id); + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, + uint32_t stream_id); private: AscendStreamAssign() = default; ~AscendStreamAssign() = default; - CNodePtr CreateSendApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, - uint32_t stream_id); - CNodePtr CreateRecvApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, - uint32_t stream_id); - vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, - const CNodePtr& node); + const CNodePtr &node); - bool IsHcom(const CNodePtr& apply_kernel); + bool IsHcom(const CNodePtr &apply_kernel); bool IsProcessed(uint32_t logic_id); - vector TransLogicToPhysic(const vector& logic_ids); - void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, - uint32_t* cur_stream_id); + void TransLogicToPhysic(const vector &logic_ids, vector *physic_ids); + void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, + uint32_t *cur_stream_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id); - void UpdateStreamActive(const CNodePtr& active_ptr); - void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr); + void UpdateStreamActive(const CNodePtr &active_ptr); + void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr); bool IsTaskSink(); - void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); - void UpdateStreamId(const std::shared_ptr& graph_ptr); - void PrintGraphExeOrders(const std::shared_ptr& graph_ptr); - void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); - uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); + void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id); + void UpdateStreamId(const std::shared_ptr &graph_ptr); + void UpdateEventId(const std::shared_ptr &graph_ptr); + void PrintGraphExeOrders(const std::shared_ptr &graph_ptr); + void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); + uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr); void SetCommonStreamNum(uint32_t cur_stream_id); - void FindAllReduceParallel(const std::shared_ptr& graph_ptr); + void FindAllReduceParallel(const std::shared_ptr &graph_ptr); bool IsProcessedParallelStream(uint32_t stream_id); - vector GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id); - void InsertSendRecvForIndependent(const std::shared_ptr& graph_ptr); - void InsertSendRecvForHcomParallel(const std::shared_ptr& graph_ptr); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); + void InsertSendRecvForIndependent(const std::shared_ptr &graph_ptr); + void InsertSendRecvForHcomParallel(const std::shared_ptr &graph_ptr); + void GetNeedActiveStreams(const std::shared_ptr &graph_ptr); uint32_t total_common_stream_num_{0}; uint32_t total_independ_stream_num_{0}; @@ -112,6 +112,7 @@ class AscendStreamAssign { std::vector> inner_parallel_streams_{}; std::vector processed_parallel_streams_{}; std::vector hcom_stream_list_{}; + std::vector need_first_active_streams_{}; // new policy end }; } // namespace ascend diff --git a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h index 668b54b78c..bf4977bf9a 100644 --- a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h +++ b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h @@ -28,14 +28,14 @@ namespace device { namespace ascend { class PluginImpl : public PluginIntf { public: - explicit PluginImpl(const std::string& module); + explicit PluginImpl(const std::string &module); ~PluginImpl() override = default; - int Init(const Reporter* reporter) override; + int Init(const Reporter *reporter) override; int UnInit() override; - static Reporter* GetPluginReporter() { return reporter_; } + static Reporter *GetPluginReporter() { return reporter_; } private: - static Reporter* reporter_; + static Reporter *reporter_; std::string module_; }; } // namespace ascend diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc index 3a1dc4689b..cbecb3030d 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc @@ -20,12 +20,12 @@ namespace mindspore { namespace device { namespace ascend { -PluginIntf* ProfilingEngineImpl::CreatePlugin() { +PluginIntf *ProfilingEngineImpl::CreatePlugin() { MS_LOG(INFO) << "Create Plugin."; return new (std::nothrow) PluginImpl("Framework"); } -int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) { +int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { if (plugin != nullptr) { delete plugin; } diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h index e8dbfc7087..c7cbc4b7dd 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h @@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf { ProfilingEngineImpl() = default; ~ProfilingEngineImpl() override = default; - PluginIntf* CreatePlugin() override; - int ReleasePlugin(PluginIntf* plugin) override; + PluginIntf *CreatePlugin() override; + int ReleasePlugin(PluginIntf *plugin) override; }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc index 29193e5cfa..c3f622ffee 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc @@ -35,7 +35,7 @@ using Json = nlohmann::json; namespace mindspore { namespace device { namespace ascend { -ProfilingManager& ProfilingManager::GetInstance() { +ProfilingManager &ProfilingManager::GetInstance() { static ProfilingManager inst; return inst; } @@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { } uint64_t ProfilingManager::GetJobId() const { - const char* job_id = std::getenv("JOB_ID"); + const char *job_id = std::getenv("JOB_ID"); return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); } -bool ProfilingManager::ReportProfilingData(const map& op_taskId_map) const { +bool ProfilingManager::ReportProfilingData(const map &op_taskId_map) const { if (!IsProfiling()) { MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; return false; @@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map& op_taskI MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); Msprof::Engine::ReporterData reporter_data = {}; - for (const auto& iter : op_taskId_map) { + for (const auto &iter : op_taskId_map) { auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; reporter_data.deviceId = UintToInt(device_id_); - reporter_data.data = (unsigned char*)(const_cast(data.c_str())); + reporter_data.data = (unsigned char *)(const_cast(data.c_str())); reporter_data.dataLen = data.size(); auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); if (ret != 0) { @@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map& op_taskI return true; } -static std::vector Split(const std::string& str, const char delim) { +static std::vector Split(const std::string &str, const char delim) { std::vector elems; if (str.empty()) { @@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { device_id_ = device_id; // exp: export PROFILING_MODE=true // export PROFILING_OPTIONS=training_trace - const char* prof_options_str = std::getenv("PROFILING_OPTIONS"); + const char *prof_options_str = std::getenv("PROFILING_OPTIONS"); // register Framework to profiling int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); if (result != 0) { @@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const { MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; return true; } - Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter(); + Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); if (reporter != nullptr) { MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); } diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc index aa71aa0566..7960a08938 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc @@ -39,13 +39,9 @@ ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNullexecution_order(); ProfilingTraceInfo profiling_trace; profiling_trace.trace_begin = GetTraceBegin(cnode_exec_order); - profiling_trace.trace_bp_end = GetTraceBpEnd(); + profiling_trace.trace_bp_end = GetTraceBpEnd(cnode_exec_order); profiling_trace.trace_netoutput = GetTraceNetoutput(cnode_exec_order); - MS_LOG(INFO) << "[profiling] trace_begin:" << profiling_trace.trace_begin - << " trace_bp_end:" << profiling_trace.trace_bp_end - << " trace_netoutput:" << profiling_trace.trace_netoutput; - for (uint32_t i = 1; i <= kMaxProfilingNodeNum; ++i) { std::string env_str = std::string(kCustomNode) + std::to_string(i); const char *node_full_name = std::getenv(env_str.c_str()); @@ -56,9 +52,25 @@ ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull &cnode_exec_order, + NotNull profiling_trace) { + for (const auto &node : cnode_exec_order) { + if (AnfAlgo::IsCommunicationOp(node)) { + MS_EXCEPTION_IF_NULL(node); + profiling_trace->trace_custom_node.insert(node->fullname_with_scope()); + MS_LOG(INFO) << "[profiling]Get hccl node:" << node->fullname_with_scope(); + } + } +} + std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exec_order) { const char *trace_begin = std::getenv(kFpStartNode); auto &first_cnode = cnode_exec_order.front(); @@ -66,9 +78,45 @@ std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exe return trace_begin == nullptr ? first_cnode->fullname_with_scope() : std::string(trace_begin); } -std::string ProfilingUtils::GetTraceBpEnd() { +std::string ProfilingUtils::GetTraceBpEnd(const std::vector &cnode_exec_order) { const char *trace_bp_end = std::getenv(kBpEndNode); - return trace_bp_end == nullptr ? "" : std::string(trace_bp_end); + + if (trace_bp_end != nullptr) { + return std::string(trace_bp_end); + } + std::string bp_end_str = ""; + // Contain hccl kernel + auto iter = cnode_exec_order.rbegin(); + while (iter != cnode_exec_order.rend()) { + if (AnfAlgo::IsCommunicationOp(*iter)) { + // store communication op input nodes' name + std::set ar_input_node_names; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); + auto input_node = input_node_with_index.first; + ar_input_node_names.insert(input_node->fullname_with_scope()); + } + // start from previous node + ++iter; + // find input names in previous node + while (iter != cnode_exec_order.rend()) { + if (ar_input_node_names.find((*iter)->fullname_with_scope()) != ar_input_node_names.end()) { + bp_end_str = (*iter)->fullname_with_scope(); + break; + } + ++iter; + } + break; + } + ++iter; + } + + if (bp_end_str.empty()) { + auto last_cnode = cnode_exec_order.back(); + MS_EXCEPTION_IF_NULL(last_cnode); + bp_end_str = last_cnode->fullname_with_scope(); + } + return bp_end_str; } std::string ProfilingUtils::GetTraceNetoutput(const std::vector &cnode_exec_order) { @@ -109,6 +157,7 @@ void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node NotNull graph_ptr, NotNull *> kernel_list) { if (profiling_trace_info.trace_begin == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match FpStart:" << profiling_trace_info.trace_begin; auto job_id = ProfilingManager::GetInstance().GetJobId(); ProfilingContent job_profiling_context = {false, job_id, 0}; auto job_profiling_node = CreateProfilingCNodeWithStream(anf_node, job_profiling_context, graph_ptr); @@ -137,6 +186,7 @@ void ProfilingUtils::ProfilingCustomOp(const AnfNodePtr &anf_node, const Profili if (iter == profiling_trace_info.trace_custom_node.end()) { return; } + MS_LOG(INFO) << "Profiling Match CustomOp:" << anf_node->fullname_with_scope(); // custom op profiling job start from 3. ProfilingContent front_profiling_content = {false, 2 * custom_node_index_ + 1, 0}; CNodePtr front_node = CreateProfilingCNodeWithStream(anf_node, front_profiling_content, graph_ptr); @@ -153,6 +203,7 @@ void ProfilingUtils::ProfilingTraceBpEnd(const AnfNodePtr &anf_node, const Profi NotNull *> kernel_list) { MS_EXCEPTION_IF_NULL(anf_node); if (profiling_trace_info.trace_bp_end == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match BpEnd:" << profiling_trace_info.trace_bp_end; ProfilingContent bp_end_profiling_content = {false, kProfilingBpEndLogId, 0}; CNodePtr bp_end_node = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); kernel_list->emplace_back(bp_end_node); @@ -165,6 +216,7 @@ void ProfilingUtils::ProfilingTraceEnd(const AnfNodePtr &anf_node, const Profili MS_EXCEPTION_IF_NULL(anf_node); auto full_scope_name = anf_node->fullname_with_scope(); if (profiling_trace_info.trace_netoutput == full_scope_name) { + MS_LOG(INFO) << "Profiling Match IterEnd:" << profiling_trace_info.trace_netoutput; ProfilingContent bp_end_profiling_content = {true, kProfilingIterEndLogId, 0}; CNodePtr bp_kernel_ptr = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); kernel_list->emplace_back(bp_kernel_ptr); diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h index c59e856249..f9f08c9d3f 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h @@ -43,7 +43,7 @@ struct ProfilingTraceInfo { // 3. insert profiling_trace_bp_end. // 4. insert profiling_trace_net_output if profiling_trace_bp_end is not empty. - bool IsValid() const { return !(trace_begin.empty() || trace_bp_end.empty() || trace_netoutput.empty()); } + bool IsValid() const { return !(trace_begin.empty() || trace_netoutput.empty()); } }; struct ProfilingContent { @@ -109,8 +109,10 @@ class ProfilingUtils { static CNodePtr CreateProfilingCNodeWithStream(const AnfNodePtr &anf_node, const ProfilingContent &profiling_content, NotNull graph_ptr); static std::string GetTraceBegin(const std::vector &cnode_exec_order); - static std::string GetTraceBpEnd(); + static std::string GetTraceBpEnd(const std::vector &cnode_exec_order); static std::string GetTraceNetoutput(const std::vector &cnode_exec_order); + static void GetTraceHccl(const std::vector &cnode_exec_order, + NotNull profiling_trace); // graph id --> (kernel name list) static std::unordered_map> graph_kernel_name_; diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc index 5aba329e12..77a3345344 100644 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc +++ b/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc @@ -31,7 +31,9 @@ CPUKernelFactory &CPUKernelFactory::Get() { void CPUKernelFactory::Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) { if (kernel_creators_.find(kernel_name) == kernel_creators_.end()) { (void)kernel_creators_.emplace(kernel_name, kernel_creator); +#if !defined(_WIN32) && !defined(_WIN64) MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; +#endif } } diff --git a/mindspore/ccsrc/device/gpu/blocking_queue.h b/mindspore/ccsrc/device/gpu/blocking_queue.h index ccf481858f..a1594c21a9 100644 --- a/mindspore/ccsrc/device/gpu/blocking_queue.h +++ b/mindspore/ccsrc/device/gpu/blocking_queue.h @@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, class GpuQueue { public: - GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity); + GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity); virtual ~GpuQueue(); - void RegisterRelease(const std::function& func) { host_release_ = func; } + void RegisterRelease(const std::function &func) { host_release_ = func; } inline bool IsEmpty() const { return head_ == tail_; } inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } - BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size); - BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const; + BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size); + BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const; BlockQueueStatus_T Pop(); bool Destroy(); private: struct NodeInfo { std::unique_ptr event_; - void* host_feature_addr_; - void* host_label_addr_; + void *host_feature_addr_; + void *host_label_addr_; }; - void* buffer_; + void *buffer_; size_t head_; size_t tail_; size_t feature_size_; @@ -61,10 +61,10 @@ class GpuQueue { size_t capacity_; cudaStream_t stream_; std::unique_ptr node_info_; - std::function host_release_; + std::function host_release_; - GpuQueue(const GpuQueue&) = delete; - GpuQueue& operator=(const GpuQueue&) = delete; + GpuQueue(const GpuQueue &) = delete; + GpuQueue &operator=(const GpuQueue &) = delete; }; class BlockingQueue { @@ -72,11 +72,11 @@ class BlockingQueue { BlockingQueue() : queue_(nullptr) {} ~BlockingQueue() = default; - BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity); - void RegisterRelease(const std::function& func); - BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size, + BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity); + void RegisterRelease(const std::function &func); + BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size, unsigned int timeout_in_sec); - BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size); + BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size); BlockQueueStatus_T Pop(); bool Destroy(); diff --git a/mindspore/ccsrc/device/gpu/cuda_common.h b/mindspore/ccsrc/device/gpu/cuda_common.h index 5a5b6416ce..b79ba8bc28 100644 --- a/mindspore/ccsrc/device/gpu/cuda_common.h +++ b/mindspore/ccsrc/device/gpu/cuda_common.h @@ -56,7 +56,8 @@ class CudaCommon { #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() -#define MINIUM_SM 7 +#define MINIUM_SM 6 +#define RECOMMEND_SM 7 } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc index d212c56ae7..d7ab95bbe8 100644 --- a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc +++ b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc @@ -20,17 +20,17 @@ namespace mindspore { namespace device { namespace gpu { -CollectiveInitializer& CollectiveInitializer::instance() { +CollectiveInitializer &CollectiveInitializer::instance() { static CollectiveInitializer instance = {}; return instance; } bool CollectiveInitializer::collective_inited() const { return collective_inited_; } -const void* CollectiveInitializer::collective_handle() const { return collective_handle_; } +const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } void CollectiveInitializer::InitCollective() { - void* handle = dlopen("libgpu_collective.so", RTLD_LAZY); + void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); if (handle == nullptr) { MS_LOG(EXCEPTION) << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc index b25ba2906b..e505fdc218 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc @@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() { CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); } -bool GPUDeviceManager::CreateStream(DeviceStream* stream) { +bool GPUDeviceManager::CreateStream(DeviceStream *stream) { CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); gpu_streams_.emplace_back(*stream); return true; } -const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } +const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } @@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } -const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } +const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } -const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } +const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } -bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); } +bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } -bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const { +bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { return CudaDriver::CopyDeviceMemToHost(dst, src, size); } -bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const { +bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { return CudaDriver::CopyHostMemToDevice(dst, src, size); } } // namespace gpu diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/device/gpu/gpu_device_manager.h index 3b3d2aecb5..a546b999a4 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.h +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.h @@ -37,17 +37,17 @@ class GPUDeviceManager { uint32_t cur_device_id() const; bool is_device_id_init() const; - bool CreateStream(DeviceStream* stream); - bool SyncStream(const DeviceStream& stream) const; - const DeviceStream& default_stream() const; + bool CreateStream(DeviceStream *stream); + bool SyncStream(const DeviceStream &stream) const; + const DeviceStream &default_stream() const; - const cudnnHandle_t& GetCudnnHandle() const; - const cublasHandle_t& GetCublasHandle() const; + const cudnnHandle_t &GetCudnnHandle() const; + const cublasHandle_t &GetCublasHandle() const; - bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; - bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; + bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; + bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; - static GPUDeviceManager& GetInstance() { + static GPUDeviceManager &GetInstance() { static GPUDeviceManager instance; return instance; } @@ -55,8 +55,8 @@ class GPUDeviceManager { private: GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} ~GPUDeviceManager() = default; - GPUDeviceManager(const GPUDeviceManager&) = delete; - GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; + GPUDeviceManager(const GPUDeviceManager &) = delete; + GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; // default CUDA stream used for all the kernels. DeviceStream default_stream_{nullptr}; diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 11b8bdc162..5dd4facb25 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -111,7 +111,8 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(mem_manager_); mem_manager_->ResetDynamicMemory(); - AssignStaticMemory(graph); + AssignStaticMemoryInput(graph); + AssignStaticMemoryValueNode(graph); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); if (is_enable_dynamic_mem) { // Use the dynamic memory pool. @@ -181,7 +182,7 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto graph_id = graph->graph_id(); - // The inputs and outputs memory of communication kernel are special, so separate processing. + // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); auto &kernels = graph->execution_order(); @@ -229,15 +230,12 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); - auto device_ptr = device_address->ptr_; - if (device_ptr == nullptr) { - device_ptr = mem_manager_->MallocMemFromMemPool(output_sizes[i]); - MS_EXCEPTION_IF_NULL(device_ptr); - device_address->ptr_ = device_ptr; + if (device_address->ptr_ == nullptr) { + mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); } kernel::AddressPtr output = std::make_shared(); MS_EXCEPTION_IF_NULL(output); - output->addr = device_ptr; + output->addr = device_address->ptr_; output->size = output_sizes[i]; kernel_outputs->push_back(output); } @@ -267,7 +265,6 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph if (kernel_name == kAllReduceOpName) { AllocCommunicationOpInputDynamicRes(kernel); AllocCommunicationOpOutputDynamicRes(kernel); - return; } } } @@ -275,48 +272,30 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - // The reference count of communication kernel input is not 0. - if (communication_op_input_ref_count_ != 0) { - MS_LOG(ERROR) << "The reference count of communication kernel input is not 0."; - return; - } - - size_t total = 0; - std::vector> addr_size; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); // The inputs of communication kernel are not released. - if ((i == 0) && (device_address->ptr_ != nullptr)) { - MS_LOG(ERROR) << "The inputs of communication kernel are not released."; - return; + if (device_address->ptr_ != nullptr) { + MS_LOG(INFO) << "The inputs of communication kernel are not released."; + mem_manager_->FreeMemFromMemPool(device_address); } - auto output_size = device_address->size_; - total += output_size; - addr_size.emplace_back(device_address.get(), output_size); - } - - auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); - MS_EXCEPTION_IF_NULL(device_mem_ptr); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(device_mem_ptr); - communication_op_input_ref_count_++; - device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); + total_size += device_address->size_; + size_list.emplace_back(device_address->size_); + addr_list.emplace_back(device_address); } + mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - // The reference count of communication kernel output is not 0. - if (communication_op_output_ref_count_ != 0) { - MS_LOG(ERROR) << "The reference count of communication kernel output is not 0."; - return; - } - - size_t total = 0; - std::vector> addr_size; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); @@ -324,22 +303,15 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); // The outputs of communication kernel are not released. - if ((i == 0) && (device_address->ptr_ != nullptr)) { - MS_LOG(ERROR) << "The outputs of communication kernel are not released."; - return; + if (device_address->ptr_ != nullptr) { + MS_LOG(INFO) << "The outputs of communication kernel are not released."; + mem_manager_->FreeMemFromMemPool(device_address); } - total += output_sizes[i]; - addr_size.emplace_back(device_address.get(), output_sizes[i]); - } - - auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); - MS_EXCEPTION_IF_NULL(device_mem_ptr); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(device_mem_ptr); - communication_op_output_ref_count_++; - device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); + total_size += output_sizes[i]; + size_list.emplace_back(output_sizes[i]); + addr_list.emplace_back(device_address); } + mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, @@ -362,14 +334,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } kernel_ref_count_ptr->ref_count_dynamic_use_--; if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + mem_manager_->FreeMemFromMemPool(device_address); // Reset the reference count. kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; - bool is_communication_op = false; - FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); - if (!is_communication_op) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - mem_manager_->FreeMemFromMemPool(device_address); - } } } // Free the output of kernel, if output has no reference. @@ -393,40 +361,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } } } - -void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, - bool *is_communication_op) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - // The inputs memory of communication kernel is one piece memory, need release together. - if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) { - communication_op_input_ref_count_--; - if (communication_op_input_ref_count_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); - mem_manager_->FreeMemFromMemPool(device_address); - } - *is_communication_op = true; - return; - } - - auto cnode = kernel->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << cnode->inputs().size() - 1 - << "."; - } - auto input_node = cnode->input(input_idx + 1); - auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); - // The outputs memory of communication kernel is one piece memory, need release together. - if (AnfAlgo::GetCNodeName(kernel_input.first) == kAllReduceOpName) { - communication_op_output_ref_count_--; - if (communication_op_output_ref_count_ == 0) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); - mem_manager_->FreeMemFromMemPool(device_address); - } - *is_communication_op = true; - } -} } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h index e0eb2dc3f1..33d4b4be70 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h @@ -60,9 +60,6 @@ class GPUKernelRuntime : public KernelRuntime { void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, uint32_t graph_id); - void FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, bool *is_communication_op); - size_t communication_op_input_ref_count_{0}; - size_t communication_op_output_ref_count_{0}; std::unordered_map mem_reuse_util_map_; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc index cbd43645ab..3a1a53c600 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc +++ b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc @@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() { return true; } -bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) { +bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { auto alloc_size = AllocDeviceMem(size, addr); buffer_q_addr_ = *addr; // Buffer queue needs to ensure that the alloc_size and size is equal. return (alloc_size == size) ? true : false; } -size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { +size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { if (size == 0) { MS_LOG(EXCEPTION) << "The memory alloc size is 0."; } @@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { return alloc_size; } -bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); } +bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h index 0d2f0f8a39..36374bfaad 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h +++ b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h @@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit { ~GPUMemoryAllocator() override = default; bool Init(); bool Finalize(); - bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr); + bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); - size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; - bool FreeDeviceMem(const DeviceMemPtr& addr) override; + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; size_t free_mem_size() override; size_t total_mem_size() override; - static GPUMemoryAllocator& GetInstance() { + static GPUMemoryAllocator &GetInstance() { static GPUMemoryAllocator instance; return instance; } private: GPUMemoryAllocator() = default; - GPUMemoryAllocator(const GPUMemoryAllocator&) = delete; - GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete; + GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; + GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; // Used to track address of data buffer queue. DeviceMemPtr buffer_q_addr_{nullptr}; diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc index 8bb65963d8..6e81130b9c 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc @@ -29,6 +29,10 @@ void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); } +std::vector GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); +} + void GPUMemoryManager::MallocDeviceMemory() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h index cc5dac2a5e..c79fb9cc22 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h +++ b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#include #include "device/memory_manager.h" namespace mindspore { namespace device { @@ -30,6 +31,7 @@ class GPUMemoryManager : public MemoryManager { void *MallocMemFromMemPool(size_t size) override; void FreeMemFromMemPool(void *device_ptr) override; + std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); protected: uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc index 05ecf380d1..6ccb4c8cde 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc @@ -33,8 +33,8 @@ namespace gpu { using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; using mindspore::kernel::KernelBuildInfo; namespace { -bool CheckKernelInfo(const std::shared_ptr& alternative_kernel_info, - const std::shared_ptr& selected_kernel_info) { +bool CheckKernelInfo(const std::shared_ptr &alternative_kernel_info, + const std::shared_ptr &selected_kernel_info) { MS_EXCEPTION_IF_NULL(selected_kernel_info); MS_EXCEPTION_IF_NULL(alternative_kernel_info); size_t selected_input_num = selected_kernel_info->GetInputNum(); @@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr& alternative_kernel_ return true; } -std::string SupportedTypeList(const CNodePtr& kernel_node) { +std::string SupportedTypeList(const CNodePtr &kernel_node) { std::string supported_type_lists = kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); if (!supported_type_lists.empty()) { @@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) { return supported_type_lists; } -bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr& selected_kernel_info) { +bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr &selected_kernel_info) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(selected_kernel_info); std::vector> kernel_info_list; @@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr& alternative_kernel_info) { + [&](const std::shared_ptr &alternative_kernel_info) { return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); }); if (!match) { @@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptrinput(input_index + 1); @@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co } } // namespace -void SetKernelInfo(const CNodePtr& kernel_node) { +void SetKernelInfo(const CNodePtr &kernel_node) { std::vector inputs_format; std::vector inputs_type; std::shared_ptr builder = diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/device/gpu/kernel_info_setter.h index e3dc2241a9..b351f74fa3 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.h @@ -27,7 +27,7 @@ namespace mindspore { namespace device { namespace gpu { -void SetKernelInfo(const CNodePtr& apply_kernel_ptr); +void SetKernelInfo(const CNodePtr &apply_kernel_ptr); class KernelAttr { public: @@ -35,24 +35,24 @@ class KernelAttr { KernelAttr() : all_same_(false) {} ~KernelAttr() = default; - KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { + KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { input_type_.emplace_back(ms_type, format); return *this; } - KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { + KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { output_type_.emplace_back(ms_type, format); return *this; } - KernelAttr& AddAllSameAttr(const bool& all_same) { + KernelAttr &AddAllSameAttr(const bool &all_same) { all_same_ = all_same; return *this; } - const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; } - const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; } - const bool& GetAllSame() const { return all_same_; } + const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } + const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } + const bool &GetAllSame() const { return all_same_; } size_t GetInputSize() const { return input_type_.size(); } size_t GetOutputSize() const { return output_type_.size(); } diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index c1588d7d53..b557436db9 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -32,16 +32,8 @@ #include "utils/utils.h" #include "device/ascend/profiling/profiling_manager.h" #include "device/ascend/kernel_select_ascend.h" -#include "device/kernel_info.h" #include "runtime/base.h" - -constexpr auto kLoopCountParamName = "loop_count"; -constexpr auto kIterLoopParamName = "iter_loop"; -constexpr auto kZeroParamName = "zero"; -constexpr auto kOneParamName = "one"; -constexpr auto kStreamSwitch = "StreamSwitch"; -constexpr auto kStreamActive = "StreamActive"; -constexpr auto kAssignAdd = "AssignAdd"; +#include "device/ascend/ascend_stream_assign.h" namespace mindspore { namespace device { using device::ascend::ProfilingUtils; @@ -70,6 +62,63 @@ bool KernelAdjust::NeedInsertSwitch() { ConfigManager::GetInstance().iter_num() > 1); } +uint32_t KernelAdjust::FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + auto cnode_ptr_list = kernel_graph_ptr->execution_order(); + CNodePtr cur_cnode_ptr = nullptr; + uint32_t label = kInvalidDistincLabel; + for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { + label = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()); + break; + } + } + + return label; +} + +CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto send_op = std::make_shared(kSendOpName); + MS_EXCEPTION_IF_NULL(send_op); + auto send_apply = std::make_shared(send_op); + MS_EXCEPTION_IF_NULL(send_apply); + std::vector send_input_list = {send_apply}; + CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); + MS_EXCEPTION_IF_NULL(send_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + send_node_ptr->set_abstract(abstract_none); + return send_node_ptr; +} + +CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto recv_op = std::make_shared(kRecvOpName); + MS_EXCEPTION_IF_NULL(recv_op); + auto recv_apply = std::make_shared(recv_op); + MS_EXCEPTION_IF_NULL(recv_apply); + std::vector recv_input_list = {recv_apply}; + CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); + MS_EXCEPTION_IF_NULL(recv_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + recv_node_ptr->set_abstract(abstract_none); + return recv_node_ptr; +} + void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { if (!NeedInsertSwitch()) { return; @@ -93,21 +142,95 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr } } } + + auto orders = kernel_graph_ptr->execution_order(); + if (orders.empty()) { + MS_LOG(EXCEPTION) << "graph execution order is empty"; + } + uint32_t first_cnode_stream_label = AnfAlgo::GetStreamDistinctionLabel(orders[0].get()); + std::vector exec_order; - CNodePtr stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(stream_switch_app); - exec_order.push_back(stream_switch_app); + CNodePtr first_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(first_stream_switch_app); + AnfAlgo::SetStreamDistinctionLabel(kFirstStreamSwitchLabel, first_stream_switch_app.get()); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(kGetNextLabel), first_stream_switch_app); + + CNodePtr second_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(second_stream_switch_app); + AnfAlgo::SetStreamDistinctionLabel(kSecondStreamSwitchLabel, second_stream_switch_app.get()); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_cnode_stream_label), second_stream_switch_app); + // add attr "stream_need_active" + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), second_stream_switch_app); + + CNodePtr first_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(first_stream_active_app); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, first_stream_active_app.get()); + std::vector first_active_streams = {kFirstStreamSwitchLabel}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(first_active_streams), + first_stream_active_app); + + CNodePtr second_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(second_stream_active_app); + // specific deal for common ctrl stream policy + uint32_t first_common_stream_switch_label = FindFirstStreamSwitchLabel(kernel_graph_ptr); + if (first_common_stream_switch_label == kInvalidDistincLabel) { + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, second_stream_active_app.get()); + } else { + AnfAlgo::SetStreamDistinctionLabel(first_common_stream_switch_label, second_stream_active_app.get()); + } - CNodePtr stream_active_switch_app = CreateStreamActiveSwitchOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(stream_active_switch_app); + std::vector second_active_streams = {kSecondStreamSwitchLabel}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(second_active_streams), + second_stream_active_app); CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, assign_add_one.get()); + + CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); + AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, send.get()); + CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, recv.get()); + + // reorder graph orders + exec_order.push_back(first_stream_switch_app); + size_t i = 0; + for (; i < orders.size(); i++) { + auto node = orders[i]; + exec_order.push_back(node); + AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, exec_order[exec_order.size() - 1].get()); + if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { + break; + } + } + + exec_order.push_back(send); + exec_order.push_back(second_stream_switch_app); + exec_order.push_back(recv); exec_order.push_back(assign_add_one); - auto original_exec_order = kernel_graph_ptr->execution_order(); - (void)std::copy(original_exec_order.begin(), original_exec_order.end(), std::back_inserter(exec_order)); - exec_order.push_back(stream_active_switch_app); + std::vector memcpy_list; + std::vector before_list; + std::vector after_list; + bool first_memcpy_found = false; + CNodePtr cur_cnode = nullptr; + for (size_t idx = i + 1; idx < orders.size(); idx++) { + cur_cnode = orders[idx]; + if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { + memcpy_list.emplace_back(cur_cnode); + first_memcpy_found = true; + } else if (first_memcpy_found) { + after_list.emplace_back(cur_cnode); + } else { + before_list.emplace_back(cur_cnode); + } + } + + (void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order)); + (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); + exec_order.push_back(first_stream_active_app); + (void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order)); + exec_order.push_back(second_stream_active_app); kernel_graph_ptr->set_execution_order(exec_order); } @@ -167,7 +290,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr(); - auto stream_switch = std::make_shared(kStreamSwitch); + auto stream_switch = std::make_shared(kStreamSwitchOpName); std::vector inputs; inputs.push_back(NewValueNode(stream_switch)); inputs.push_back(switch_loop_input.at(kLoopCountParamName)); @@ -181,28 +304,19 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr(RT_LESS); ValuePtr cond = MakeValue(condition); AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); - // set attr:true branch graph id ,which is same to stream distinction label - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto first_stream = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_stream), stream_switch_app); // set attr:data_type int data_type = static_cast(RT_SWITCH_INT64); ValuePtr dt = MakeValue(data_type); AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); // set distinction label and graph id - AnfAlgo::SetGraphId(kInvalidGraphId - 1, stream_switch_app.get()); - AnfAlgo::SetStreamDistinctionLabel(kInvalidDistincLabel - 1, stream_switch_app.get()); return stream_switch_app; } -CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr) { +CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr) { kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActive); + auto stream_active_others = std::make_shared(kStreamActiveOpName); std::vector inputs; inputs.push_back(NewValueNode(stream_active_others)); MS_EXCEPTION_IF_NULL(kernel_graph_ptr); @@ -213,57 +327,6 @@ CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_switch = std::make_shared(kStreamActive); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_switch)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_switch_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_switch_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_switch_app.get()); - stream_active_switch_app->set_abstract(typeNone_abstract); - // set attr,which stream to active - std::vector active_index_value = {kInvalidDistincLabel - 1}; - auto value = MakeValue>(active_index_value); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, value, stream_active_switch_app); - // set the distinction label of stream active - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - // find the first switch's distinction label - for (auto node : kernel_graph_ptr->execution_order()) { - if (AnfAlgo::GetCNodeName(node) == "StreamSwitch") { - label = AnfAlgo::GetStreamDistinctionLabel(node.get()); - break; - } - } - AnfAlgo::SetStreamDistinctionLabel(label, stream_active_switch_app.get()); - return stream_active_switch_app; -} - -CNodePtr KernelAdjust::CreateStreamActiveOtherOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActive); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_others)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_others_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); - stream_active_others_app->set_abstract(typeNone_abstract); - // set attr - ValuePtr active_target = MakeValue(kValueTargetOther); - AnfAlgo::SetNodeAttr(kAttrActiveTarget, active_target, stream_active_others_app); - return stream_active_others_app; -} - CNodePtr KernelAdjust::CreateStreamAssignAddnOP( const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input) { @@ -273,7 +336,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); // AssignAdd - auto assign_add = std::make_shared(kAssignAdd); + auto assign_add = std::make_shared(kAssignAddOpName); std::vector inputs; inputs.push_back(NewValueNode(assign_add)); inputs.push_back(switch_loop_input.at(kLoopCountParamName)); @@ -290,70 +353,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); - // set the distinction label of assign add - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - AnfAlgo::SetStreamDistinctionLabel(label, assign_add_one.get()); return assign_add_one; } -void KernelAdjust::SetStreamActiveOPs(const std::shared_ptr &kernel_graph_ptr, - const std::unordered_set &ctrl_stream_list, - const std::unordered_set &comm_stream_list, - const std::unordered_set &momentum_stream_list) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); - ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); - std::vector index_list; - index_list.clear(); - if (GetValue(active_target) == kValueTargetSwitch) { - index_list.insert(index_list.end(), ctrl_stream_list.begin(), ctrl_stream_list.end()); - } else if (GetValue(active_target) == kValueTargetOther) { - for (uint32_t index : comm_stream_list) { - if (AnfAlgo::GetStreamId(cnode_ptr) == index) { - continue; - } - index_list.emplace_back(index); - } - index_list.insert(index_list.end(), momentum_stream_list.begin(), momentum_stream_list.end()); - } - ValuePtr index_list_value = MakeValue(index_list); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, index_list_value, cnode_ptr); - } - } -} - -void KernelAdjust::SetStreamSwitchOps(const std::shared_ptr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr switch_cnode_ptr = nullptr; - uint32_t target_stream_id = 0; - for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamSwitch) { - switch_cnode_ptr = cnode_ptr; - } - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); - ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); - if (GetValue(active_target) == kValueTargetOther) { - target_stream_id = AnfAlgo::GetStreamId(cnode_ptr); - } - } - } - if (switch_cnode_ptr != nullptr) { - // set attr:true stream - ValuePtr true_index = MakeValue(target_stream_id); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, true_index, switch_cnode_ptr); - MS_LOG(INFO) << "switch to true_index:" << target_stream_id; - } -} - bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &context, const std::shared_ptr &kernel_graph_ptr) { if (!NeedInsertSwitch()) { diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index ca01d51e54..3dced257c1 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -28,10 +28,22 @@ #include "session/session_context.h" #include "ir/meta_tensor.h" #include "device/ascend/profiling/profiling_utils.h" +#include "device/kernel_info.h" using mindspore::device::ascend::ProfilingTraceInfo; using mindspore::device::ascend::ProfilingUtils; namespace mindspore { +constexpr auto kLoopCountParamName = "loop_count"; +constexpr auto kIterLoopParamName = "iter_loop"; +constexpr auto kZeroParamName = "zero"; +constexpr auto kOneParamName = "one"; +constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; + +const uint32_t kFirstStreamSwitchLabel = kInvalidDistincLabel - 1; +const uint32_t kGetNextLabel = kInvalidDistincLabel - 2; +const uint32_t kSecondStreamSwitchLabel = kInvalidDistincLabel - 3; +const uint32_t kInvalidEventId = UINT32_MAX; +const uint32_t kFirstEventId = kInvalidEventId / 2; namespace device { class KernelAdjust { public: @@ -41,26 +53,23 @@ class KernelAdjust { } void Reorder(const std::shared_ptr &kernel_graph_ptr); void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); - void SetStreamActiveOPs(const std::shared_ptr &kernel_graph_ptr, - const std::unordered_set &ctrl_stream_list, - const std::unordered_set &comm_stream_list, - const std::unordered_set &momentum_stream_list); - void SetStreamSwitchOps(const std::shared_ptr &kernel_graph_ptr); bool StepLoadCtrlInputs(const std::shared_ptr &context, const std::shared_ptr &kernel_graph_ptr); void Profiling(NotNull kernel_graph_ptr); static bool NeedInsertSwitch(); - CNodePtr CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr); + CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); private: KernelAdjust() = default; ~KernelAdjust() = default; + + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr); void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, std::map *switch_loop_input); CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); - CNodePtr CreateStreamActiveSwitchOp(const std::shared_ptr &kernel_graph_ptr); - CNodePtr CreateStreamActiveOtherOp(const std::shared_ptr &kernel_graph_ptr); CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 7f3d31d8d0..d1a068b584 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -135,10 +135,11 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { } void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, - const session::KernelGraph *graph) { + session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); // assign memory for input nodes RunOpAssignInputMemory(input_tensors, graph); + AssignStaticMemoryValueNode(graph); for (const auto &cnode : graph->execution_order()) { // assign memory for output nodes RunOpAssignOutputMemory(cnode); diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 61b43fd5c0..b15cb31e17 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -46,7 +46,7 @@ class KernelRuntime { virtual ~KernelRuntime(); virtual bool Init() = 0; virtual void AssignMemory(session::KernelGraph *graph); - void RunOpAssignMemory(const std::vector &input_tensors, const session::KernelGraph *graph); + void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); virtual bool Run(session::KernelGraph *graph); virtual bool DumpData(session::KernelGraph *graph); virtual bool RunTask(const session::KernelGraph *graph); @@ -67,6 +67,7 @@ class KernelRuntime { TypeId type_id) = 0; virtual bool SyncStream() = 0; void AssignStaticMemory(session::KernelGraph *graph); + void AssignStaticMemoryValueNode(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph); void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); @@ -81,7 +82,6 @@ class KernelRuntime { private: void AssignStaticMemoryOutput(const session::KernelGraph *graph); - void AssignStaticMemoryValueNode(session::KernelGraph *graph); void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); bool LaunchKernelMod(const session::KernelGraph &graph); diff --git a/mindspore/ccsrc/device/memory_manager.cc b/mindspore/ccsrc/device/memory_manager.cc index 2fad5fc10e..dce54495b0 100644 --- a/mindspore/ccsrc/device/memory_manager.cc +++ b/mindspore/ccsrc/device/memory_manager.cc @@ -167,5 +167,28 @@ void MemoryManager::FreeMemFromMemPool(void *device_ptr) { MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; } } + +void MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); + if (addr_list.size() != device_ptr_list.size()) { + MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; + } + for (size_t i = 0; i < addr_list.size(); i++) { + MS_EXCEPTION_IF_NULL(device_ptr_list[i]); + MS_EXCEPTION_IF_NULL(addr_list[i]); + addr_list[i]->ptr_ = device_ptr_list[i]; + addr_list[i]->from_mem_pool_ = true; + } +} + +std::vector MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + if (total_size == 0) { + MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; + } + std::vector device_ptr_list; + device_ptr_list.emplace_back(nullptr); + return device_ptr_list; +} } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/memory_manager.h b/mindspore/ccsrc/device/memory_manager.h index c90ffc380e..dae0861506 100644 --- a/mindspore/ccsrc/device/memory_manager.h +++ b/mindspore/ccsrc/device/memory_manager.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ #include +#include #include "pre_activate/mem_reuse/mem_reuse.h" #include "pre_activate/mem_reuse/mem_reuse_allocator.h" namespace mindspore { @@ -49,6 +50,9 @@ class MemoryManager { virtual void *MallocMemFromMemPool(size_t size); virtual void FreeMemFromMemPool(const DeviceAddressPtr address); virtual void FreeMemFromMemPool(void *device_ptr); + virtual void MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); + virtual std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); size_t GetCommonAlignSize(size_t input_size) const; size_t GetCommunicationAlignSize(size_t input_size) const; diff --git a/mindspore/ccsrc/gvar/typeid_manager.cc b/mindspore/ccsrc/gvar/typeid_manager.cc index 97250a6571..f40052411a 100644 --- a/mindspore/ccsrc/gvar/typeid_manager.cc +++ b/mindspore/ccsrc/gvar/typeid_manager.cc @@ -24,7 +24,7 @@ namespace mindspore { -struct TypeIdManager* TypeIdManager::Get() { +struct TypeIdManager *TypeIdManager::Get() { static TypeIdManager manager; return &manager; } diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc index 658fb578b7..dd86e46713 100644 --- a/mindspore/ccsrc/ir/anf.cc +++ b/mindspore/ccsrc/ir/anf.cc @@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } std::string AnfNode::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } -CNode::CNode(const std::vector& inputs, const FuncGraphPtr& func_graph) +CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} // Check if CNode is an apply with the specific Primitive. -bool CNode::IsApply(const PrimitivePtr& value) const { +bool CNode::IsApply(const PrimitivePtr &value) const { if (value == nullptr) { return false; } @@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const { return false; } -void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; } +void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } std::string CNode::DebugString(int recursive_level) const { std::ostringstream buffer; @@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const { buffer << ToString() << "{"; bool is_first_node = true; int idx = 0; - for (auto& node : inputs_) { + for (auto &node : inputs_) { MS_EXCEPTION_IF_NULL(node); if (is_first_node) { is_first_node = false; @@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const { return buffer.str(); } -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) { +OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { if (operator_info_ != nullptr) { MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() << ", using the new one: " << operator_info->name(); @@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() { return fullname_with_scope_; } -void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } -void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } -void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } +void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { +bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if (cnode != nullptr) { @@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { return false; } -PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) { +PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { if (node == nullptr) { return nullptr; } @@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) { return ""; } -bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { +bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { if (IsValueNode(node)) { PrimitivePtr fn_value = GetValueNode(node); MS_EXCEPTION_IF_NULL(value); @@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { } namespace id_generator { static std::unordered_map node_ids; -std::string get_id(const AnfNodePtr& node) { +std::string get_id(const AnfNodePtr &node) { auto type_name = node->type_name(); if (node_ids.find(type_name) == node_ids.end()) { node_ids[type_name] = 0; diff --git a/mindspore/ccsrc/ir/base.h b/mindspore/ccsrc/ir/base.h index 6a3537306f..7ccef13876 100644 --- a/mindspore/ccsrc/ir/base.h +++ b/mindspore/ccsrc/ir/base.h @@ -39,15 +39,15 @@ struct is_shared_ptr> : public std::true_type {}; class Base : public std::enable_shared_from_this { public: constexpr Base() = default; - Base(const Base& other) : std::enable_shared_from_this(other) {} - virtual bool operator==(const Base& rhs) { + Base(const Base &other) : std::enable_shared_from_this(other) {} + virtual bool operator==(const Base &rhs) { if (this == &rhs) { return true; } return false; } - virtual Base& operator=(const Base&) { return *this; } + virtual Base &operator=(const Base &) { return *this; } virtual ~Base() = default; virtual std::size_t hash() const { return tid(); } virtual std::string ToString() const { return type_name(); } @@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this { virtual const bool IsFromTypeId(uint32_t tid) const; virtual std::string type_name() const { return "Base"; } - static uint32_t GetTypeId(const char* const type_key); + static uint32_t GetTypeId(const char *const type_key); virtual uint32_t tid() const { static const uint32_t tid = GetTypeId(typeid(Base).name()); return tid; } template ::value && std::is_base_of::value, T>::type* = nullptr> + typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> inline bool isa() const { static const uint32_t tid = GetTypeId(typeid(T).name()); return this->IsFromTypeId(tid); @@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr; using BaseWeakPtr = std::weak_ptr; template -inline T* cast(U* source) { +inline T *cast(U *source) { if (source != nullptr && source->template isa()) { - return static_cast(source); + return static_cast(source); } else { return nullptr; } @@ -100,7 +100,7 @@ inline T* cast(U* source) { template < typename T, typename U, - typename std::enable_if::value && std::is_base_of::value, T>::type* = nullptr> + typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> inline std::shared_ptr dyn_cast(const std::shared_ptr r) { if (r != nullptr && r->template isa()) { return std::static_pointer_cast(r); @@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager { std::mutex mutex; std::atomic type_counter{0}; std::unordered_map map; - static TypeIdManager* Get(); + static TypeIdManager *Get(); TypeIdManager() : mutex(), type_counter(0), map() {} }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/ccsrc/ir/dtype.cc index 65a42bc3fa..97291a3dc0 100644 --- a/mindspore/ccsrc/ir/dtype.cc +++ b/mindspore/ccsrc/ir/dtype.cc @@ -48,11 +48,11 @@ std::string Keyword::ToString() const { return buffer.str(); } -bool Keyword::operator==(const Type& other) const { +bool Keyword::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_keyword = static_cast(other); + const auto &other_keyword = static_cast(other); return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); } @@ -87,11 +87,11 @@ std::string Slice::ToString() const { return buffer.str(); } -bool Slice::operator==(const Type& other) const { +bool Slice::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_slice = static_cast(other); + auto other_slice = static_cast(other); return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); } @@ -122,11 +122,11 @@ std::string TensorType::DumpText() const { } } -bool TensorType::operator==(const Type& other) const { +bool TensorType::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_elem_type = static_cast(other).element_type_; + auto other_elem_type = static_cast(other).element_type_; // When element_type_ = nullptr, which means any type of Array. if (element_type_ == nullptr && other_elem_type == nullptr) { return true; @@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) { retval_ = nullptr; } -Function::Function(const std::vector& args, const TypePtr retval) +Function::Function(const std::vector &args, const TypePtr retval) : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} TypePtr Function::DeepCopy() const { @@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const { TypePtrList args; TypePtr retval = nullptr; (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), - [](const TypePtr& arg) { return arg->DeepCopy(); }); + [](const TypePtr &arg) { return arg->DeepCopy(); }); if (retval_ != nullptr) { retval = retval_->DeepCopy(); } @@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const { } } -bool Function::operator==(const Type& other) const { +bool Function::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_function = static_cast(other); + const auto &other_function = static_cast(other); if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { if (*retval_ != *other_function.retval_) { return false; @@ -188,7 +188,7 @@ std::string Function::ToString() const { } else { buffer << "Func[("; bool begin = true; - for (auto& attr : args_) { + for (auto &attr : args_) { if (!begin) { buffer << ", "; } else { @@ -242,34 +242,34 @@ std::string JTagged::DumpText() const { return buffer.str(); } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr problem) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem) { MS_EXCEPTION_IF_NULL(problem); os << problem->ToString(); return os; } -std::size_t TypeHasher::operator()(TypePtr const& type) const { +std::size_t TypeHasher::operator()(TypePtr const &type) const { MS_EXCEPTION_IF_NULL(type); std::size_t hash = std::hash()(type->type_id()); return hash; } -std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const { +std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { std::size_t hash_sum = 0; - for (auto& type : type_list) { + for (auto &type : type_list) { auto type_id = static_cast(type->type_id()); hash_sum = hash_combine(hash_sum, type_id); } return hash_sum; } -bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const { +bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t2); return t1->type_id() == t2->type_id(); } -bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const { +bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { if (lhs.size() != rhs.size()) { return false; } @@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) { namespace { template -TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) { +TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { TypePtr type = nullptr; if (type_name == num_type_name) { type = std::make_shared(); @@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_ } auto bits = std::stoi(type_name.substr(num_type_name.size())); type = std::make_shared(bits); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); } } return type; } -std::vector StringToVectorOfType(const std::string& type_names) { +std::vector StringToVectorOfType(const std::string &type_names) { std::vector types; if (type_names.length() == 0) { return types; @@ -371,7 +371,7 @@ std::vector StringToVectorOfType(const std::string& type_names) { return types; } -TypePtr TensorStrToType(const std::string& type_name) { +TypePtr TensorStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Tensor") { type = std::make_shared(); @@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) { return nullptr; } type = std::make_shared(element_type); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) { return type; } -TypePtr ListStrToType(const std::string& type_name) { +TypePtr ListStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "List") { type = std::make_shared(); @@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) { std::string element_strs = type_name.substr(start, end - start); std::vector element_types = StringToVectorOfType(element_strs); bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); if (wrong) { return nullptr; } type = std::make_shared(element_types); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) { return type; } -TypePtr TupleStrToType(const std::string& type_name) { +TypePtr TupleStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Tuple") { type = std::make_shared(); @@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) { std::string element_strs = type_name.substr(start, end - start); std::vector element_types = StringToVectorOfType(element_strs); bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); if (wrong) { return nullptr; } type = std::make_shared(element_types); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } return type; } -TypePtr FunctionStrToType(const std::string& type_name) { +TypePtr FunctionStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Function") { @@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) { std::vector args_type = StringToVectorOfType(str_args); TypePtr retval = StringToType(str_retval); - bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; }); + bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); if (retval == nullptr || wrong) { return nullptr; } type = std::make_shared(args_type, retval); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -491,10 +491,12 @@ TypePtr FunctionStrToType(const std::string& type_name) { } } // namespace -TypePtr StringToType(const std::string& type_name) { +TypePtr StringToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name.compare("None") == 0) { type = std::make_shared(); + } else if (type_name.compare("Ellipsis") == 0) { + type = std::make_shared(); } else if (type_name.compare("TypeType") == 0) { type = std::make_shared(); } else if (type_name.compare("SymbolicKeyType") == 0) { @@ -542,7 +544,7 @@ TypePtr StringToType(const std::string& type_name) { return type; } -bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { if (x == nullptr || base_type == nullptr) { MS_LOG(ERROR) << "Type is nullptr."; return false; @@ -564,7 +566,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { } } -bool IsSubType(TypePtr const& t1, TypePtr const& t2) { +bool IsSubType(TypePtr const &t1, TypePtr const &t2) { MS_EXCEPTION_IF_NULL(t1); if (t1->type_id() == kTypeUnknown) { return false; @@ -576,17 +578,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) { } REGISTER_PYBIND_DEFINE( - typing, ([](py::module* const m) { + typing, ([](py::module *const m) { auto m_sub = m->def_submodule("typing", "submodule for dtype"); py::enum_(m_sub, "TypeId"); (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); (void)m_sub.def("load_type", &TypeIdToType, "load type"); (void)m_sub.def( - "dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type"); + "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); (void)py::class_>(m_sub, "Type") .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) .def("__eq__", - [](const TypePtr& t1, const TypePtr& t2) { + [](const TypePtr &t1, const TypePtr &t2) { if (t1 != nullptr && t2 != nullptr) { return *t1 == *t2; } @@ -595,7 +597,7 @@ REGISTER_PYBIND_DEFINE( .def("__hash__", &Type::hash) .def("__str__", &Type::ToString) .def("__repr__", &Type::ReprString) - .def("__deepcopy__", [](const TypePtr& t, py::dict) { + .def("__deepcopy__", [](const TypePtr &t, py::dict) { if (t == nullptr) { return static_cast(nullptr); } @@ -605,21 +607,21 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "Bool") .def(py::init()) .def(py::pickle( - [](const Bool&) { // __getstate__ + [](const Bool &) { // __getstate__ return py::make_tuple(); }, - [](const py::tuple&) { // __setstate__ + [](const py::tuple &) { // __setstate__ return std::make_shared(); })); (void)py::class_>(m_sub, "Int") .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const Int& t) { // __getstate__ + [](const Int &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -631,11 +633,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const UInt& t) { // __getstate__ + [](const UInt &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -647,11 +649,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const Float& t) { // __getstate__ + [](const Float &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -670,11 +672,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init(), py::arg("element")) .def("element_type", &TensorType::element) .def(py::pickle( - [](const TensorType& t) { // __getstate__ + [](const TensorType &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } diff --git a/mindspore/ccsrc/ir/dtype.h b/mindspore/ccsrc/ir/dtype.h index e3e2099b5e..cefdf42099 100644 --- a/mindspore/ccsrc/ir/dtype.h +++ b/mindspore/ccsrc/ir/dtype.h @@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr; class Keyword : public Object { public: Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} - Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} + Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} ~Keyword() override = default; MS_DECLARE_PARENT(Keyword, Object) @@ -70,7 +70,7 @@ class Keyword : public Object { std::string ToString() const override; std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::string GetKey() const { return key_; } TypePtr GetValue() const { return value_; } @@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr; class Slice : public Object { public: Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} - Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step) + Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step) : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} ~Slice() override = default; @@ -95,7 +95,7 @@ class Slice : public Object { std::string ToString() const override; std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr get_start() const { return start_; } TypePtr get_stop() const { return stop_; } @@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr; class TensorType : public Object { public: TensorType() : Object(kObjectTypeTensorType) {} - explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} + explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} ~TensorType() override = default; MS_DECLARE_PARENT(TensorType, Object) TypeId generic_type_id() const override { return kObjectTypeTensorType; } const TypePtr element() const { return element_type_; } - void set_element(const TypePtr& element_type) { element_type_ = element_type; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } TypePtr DeepCopy() const override; std::string ToString() const override; std::string ToReprString() const override { return "tensor"; } std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; private: TypePtr element_type_; @@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr; class Function : public Object { public: Function(); - Function(const std::vector& args, const TypePtr retval); + Function(const std::vector &args, const TypePtr retval); ~Function() override = default; MS_DECLARE_PARENT(Function, Object) @@ -141,11 +141,11 @@ class Function : public Object { // Add temporarily for return abstraction to avoid type checking. bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } - const std::vector& args() const { return args_; } - const TypePtr& retval() const { return retval_; } + const std::vector &args() const { return args_; } + const TypePtr &retval() const { return retval_; } TypePtr DeepCopy() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::string ToString() const override; std::string ToReprString() const override { return "function"; } @@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr; class JTagged : public Object { public: JTagged() : Object(kObjectTypeJTagged) {} - explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} + explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} ~JTagged() override = default; MS_DECLARE_PARENT(JTagged, Object) @@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr; class Problem : public Type { public: Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} - explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {} + explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {} ~Problem() override = default; MS_DECLARE_PARENT(Problem, Type) @@ -222,7 +222,7 @@ class Problem : public Type { std::string ToString() const override { return kind_.name(); } std::string DumpText() const override { return "ProblemType"; } - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr problem); + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem); private: Named kind_; @@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr; // helper template template -TypePtr Clone(const T& t) { +TypePtr Clone(const T &t) { return t.Clone(); } -TypePtr StringToType(const std::string& type_name); +TypePtr StringToType(const std::string &type_name); // Judge whether x is predicate or is a subclass of predicate. -bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type); +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); // Whether t1 is identity or a subclass of t2. -bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr); +bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); struct TypeHasher { - std::size_t operator()(TypePtr const& type) const; + std::size_t operator()(TypePtr const &type) const; }; struct TypeListHasher { - std::size_t operator()(const TypePtrList& type_list) const; + std::size_t operator()(const TypePtrList &type_list) const; }; struct TypeEqual { - bool operator()(TypePtr const& t1, TypePtr const& t2) const; + bool operator()(TypePtr const &t1, TypePtr const &t2) const; }; struct TypeListEqual { - bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const; + bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const; }; extern const TypePtr kTypeExternal; diff --git a/mindspore/ccsrc/ir/dtype/container.cc b/mindspore/ccsrc/ir/dtype/container.cc index 8bca29f793..3f8244c2e3 100644 --- a/mindspore/ccsrc/ir/dtype/container.cc +++ b/mindspore/ccsrc/ir/dtype/container.cc @@ -24,7 +24,7 @@ #include "pybind_api/export_flags.h" namespace mindspore { -static std::string DumpTypeVector(const std::vector& elements, bool is_dumptext) { +static std::string DumpTypeVector(const std::vector &elements, bool is_dumptext) { std::ostringstream oss; bool begin = true; int cnt = 0; @@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const { } else { TypePtrList elements; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), - [](const TypePtr& ele) { return ele->DeepCopy(); }); + [](const TypePtr &ele) { return ele->DeepCopy(); }); auto copy = std::make_shared(elements); return copy; } @@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const { return elements_[dim]; } -bool List::operator==(const Type& other) const { +bool List::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const List& other_list = static_cast(other); + const List &other_list = static_cast(other); if (elements_.size() != other_list.elements_.size()) { return false; } @@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const { return true; } -Class::Class(const Named& tag, const ClassAttrVector& attributes, - const std::unordered_map& methods) +Class::Class(const Named &tag, const ClassAttrVector &attributes, + const std::unordered_map &methods) : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} std::string List::ToString() const { @@ -122,7 +122,7 @@ std::string List::DumpText() const { return buffer.str(); } -bool Class::operator==(const Type& other) const { +bool Class::operator==(const Type &other) const { // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. return &other == this; } @@ -143,7 +143,7 @@ std::string Class::ToString() const { } else { bool begin = true; buffer << "cls." << tag_ << "["; - for (auto& attr : attributes_) { + for (auto &attr : attributes_) { if (!begin) { buffer << ", "; } else { @@ -163,7 +163,7 @@ std::string Class::DumpText() const { } else { bool begin = true; buffer << "Cls." << tag_ << "["; - for (auto& attr : attributes_) { + for (auto &attr : attributes_) { if (!begin) { buffer << ", "; } else { @@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const { } else { TypePtrList elements; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), - [](const TypePtr& ele) { return ele->DeepCopy(); }); + [](const TypePtr &ele) { return ele->DeepCopy(); }); auto copy = std::make_shared(elements); return copy; } } -bool Tuple::operator==(const Type& other) const { +bool Tuple::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_tuple = static_cast(other); + auto other_tuple = static_cast(other); if (elements_.size() != other_tuple.elements_.size()) { return false; } @@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const { std::vector> kv; (void)std::transform( key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair& item) { return std::make_pair(item.first, item.second->DeepCopy()); }); + [](const std::pair &item) { return std::make_pair(item.first, item.second->DeepCopy()); }); return std::make_shared(kv); } } @@ -259,7 +259,7 @@ std::string Dictionary::ToString() const { std::ostringstream buffer; std::vector keys; std::vector values; - for (const auto& kv : key_values_) { + for (const auto &kv : key_values_) { keys.push_back(kv.first); values.push_back(kv.second); } @@ -276,12 +276,12 @@ std::string Dictionary::ToString() const { std::string Dictionary::DumpText() const { return ToString(); } -bool Dictionary::operator==(const mindspore::Type& other) const { +bool Dictionary::operator==(const mindspore::Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_dict = static_cast(other); + const auto &other_dict = static_cast(other); if (key_values_.size() != other_dict.key_values_.size()) { return false; } diff --git a/mindspore/ccsrc/ir/dtype/container.h b/mindspore/ccsrc/ir/dtype/container.h index 04ed484cf7..0612d24c4d 100644 --- a/mindspore/ccsrc/ir/dtype/container.h +++ b/mindspore/ccsrc/ir/dtype/container.h @@ -40,10 +40,10 @@ namespace mindspore { class List : public Object { public: List() : Object(kObjectTypeList) {} - List(const std::initializer_list& objs) + List(const std::initializer_list &objs) : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} // Shadow copy; - explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {} + explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {} ~List() override {} MS_DECLARE_PARENT(List, Object) @@ -51,7 +51,7 @@ class List : public Object { TypeId generic_type_id() const override { return kObjectTypeList; } TypePtr DeepCopy() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::size_t size() const { return elements_.size(); } TypePtrList elements() const { return elements_; } std::string ToString() const override; @@ -68,22 +68,22 @@ using ClassAttrVector = std::vector>; class Class : public Object { public: Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} - Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map& methods); + Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map &methods); ~Class() override {} MS_DECLARE_PARENT(Class, Object) TypeId generic_type_id() const override { return kObjectTypeClass; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override; std::string ToString() const override; std::string DumpText() const override; - void set_value(const std::unordered_map& v) { attributes_value_ = v; } + void set_value(const std::unordered_map &v) { attributes_value_ = v; } Named tag() { return tag_; } std::unordered_map GetValue() { return attributes_value_; } std::unordered_map methods() { return methods_; } - ClassAttrVector& GetAttributes() { return attributes_; } + ClassAttrVector &GetAttributes() { return attributes_; } ClassAttrVector attributes_; @@ -99,11 +99,11 @@ class Tuple : public Object { public: Tuple() : Object(kObjectTypeTuple) {} // usage : Tuple t = {std::make_shared(), std::make_shared(32)}; - Tuple(const std::initializer_list& objs) + Tuple(const std::initializer_list &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} // Shadow copy - explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} + explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} ~Tuple() override {} MS_DECLARE_PARENT(Tuple, Object) @@ -115,7 +115,7 @@ class Tuple : public Object { std::string ToReprString() const override { return "tuple_"; } std::string DumpText() const override; const TypePtr operator[](size_t dim) const; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtrList elements() const { return elements_; } std::size_t size() const { return elements_.size(); } @@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr; class Dictionary : public Object { public: Dictionary() : Object(kObjectTypeDictionary) {} - explicit Dictionary(const std::vector>& key_values) + explicit Dictionary(const std::vector> &key_values) : Object(kObjectTypeDictionary, false), key_values_(key_values) {} ~Dictionary() override {} @@ -136,7 +136,7 @@ class Dictionary : public Object { TypeId generic_type_id() const override { return kObjectTypeDictionary; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override; std::string ToString() const override; std::string DumpText() const override; diff --git a/mindspore/ccsrc/ir/dtype/empty.cc b/mindspore/ccsrc/ir/dtype/empty.cc index 3d4f74bf31..5cb3a91806 100644 --- a/mindspore/ccsrc/ir/dtype/empty.cc +++ b/mindspore/ccsrc/ir/dtype/empty.cc @@ -18,6 +18,5 @@ namespace mindspore { const TypePtr kTypeNone = std::make_shared(); -const TypePtr kTypeAnything = std::make_shared(); const TypePtr kAnyType = std::make_shared(); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype/empty.h b/mindspore/ccsrc/ir/dtype/empty.h index a13dc084ca..76cf8ea0eb 100644 --- a/mindspore/ccsrc/ir/dtype/empty.h +++ b/mindspore/ccsrc/ir/dtype/empty.h @@ -71,8 +71,20 @@ class TypeNull : public Type { }; using TypeNullPtr = std::shared_ptr; +class Ellipsis : public Type { + public: + Ellipsis() : Type(kMetaTypeEllipsis) {} + ~Ellipsis() override {} + MS_DECLARE_PARENT(Ellipsis, Type) + + TypeId generic_type_id() const override { return kMetaTypeEllipsis; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "Ellipsis"; } + std::string DumpText() const override { return "Ellipsis"; } +}; +using EllipsisPtr = std::shared_ptr; + extern const TypePtr kTypeNone; -extern const TypePtr kTypeAnything; extern const TypePtr kAnyType; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype/number.cc b/mindspore/ccsrc/ir/dtype/number.cc index d9ef6bb3bd..44ac9e8e6a 100644 --- a/mindspore/ccsrc/ir/dtype/number.cc +++ b/mindspore/ccsrc/ir/dtype/number.cc @@ -24,11 +24,11 @@ #include "pybind_api/export_flags.h" namespace mindspore { -bool Number::operator==(const Type& other) const { +bool Number::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_number = static_cast(other); + auto other_number = static_cast(other); return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); } diff --git a/mindspore/ccsrc/ir/dtype/number.h b/mindspore/ccsrc/ir/dtype/number.h index cb3b0a607c..3930f51d73 100644 --- a/mindspore/ccsrc/ir/dtype/number.h +++ b/mindspore/ccsrc/ir/dtype/number.h @@ -49,12 +49,12 @@ class Number : public Object { TypeId type_id() const override { return number_type_; } TypeId generic_type_id() const override { return kObjectTypeNumber; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override { return std::make_shared(); } std::string ToString() const override { return "Number"; } std::string ToReprString() const override { return "number"; } std::string DumpText() const override { return "Number"; } - std::string GetTypeName(const std::string& type_name) const { + std::string GetTypeName(const std::string &type_name) const { std::ostringstream oss; oss << type_name; if (nbits() != 0) { diff --git a/mindspore/ccsrc/ir/dtype/ref.h b/mindspore/ccsrc/ir/dtype/ref.h index 7f1dc4a95f..7d8159289f 100644 --- a/mindspore/ccsrc/ir/dtype/ref.h +++ b/mindspore/ccsrc/ir/dtype/ref.h @@ -51,7 +51,7 @@ class RefKeyType : public Object { class RefType : public Object { public: RefType() : Object(kObjectTypeRef) {} - RefType(const TypePtr& subtype, const TypePtr& subtype_origin) + RefType(const TypePtr &subtype, const TypePtr &subtype_origin) : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} ~RefType() override {} MS_DECLARE_PARENT(RefType, Object) diff --git a/mindspore/ccsrc/ir/dtype/type.cc b/mindspore/ccsrc/ir/dtype/type.cc index 6fbd7f8111..30bf0c8e3f 100644 --- a/mindspore/ccsrc/ir/dtype/type.cc +++ b/mindspore/ccsrc/ir/dtype/type.cc @@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) { } } -const char* MetaIdLabel(const TypeId& v) { +const char *MetaIdLabel(const TypeId &v) { switch (v) { case kTypeUnknown: return "kTypeUnknown"; @@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) { } } -const char* ObjectIdLabel(const TypeId& v) { +const char *ObjectIdLabel(const TypeId &v) { switch (v) { case kObjectTypeNumber: return "kObjectTypeNumber"; @@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) { } } -const char* NumberIdLabel(const TypeId& v) { +const char *NumberIdLabel(const TypeId &v) { switch (v) { case kNumberTypeBool: return "kNumberTypeBool"; @@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) { } } -const char* TypeIdLabel(const TypeId& v) { +const char *TypeIdLabel(const TypeId &v) { if (v < kMetaTypeEnd) { return MetaIdLabel(v); } else { @@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) { } } -bool IsSameObjectType(const Type& lhs, const Type& rhs) { +bool IsSameObjectType(const Type &lhs, const Type &rhs) { if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { return false; } return lhs.object_type() == rhs.object_type(); } -size_t GetTypeByte(const TypePtr& type_ptr) { +size_t GetTypeByte(const TypePtr &type_ptr) { if (type_ptr && type_ptr->isa()) { auto number = dyn_cast(type_ptr); if (!number) { @@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) { } } -bool Type::operator==(const Value& other) const { +bool Type::operator==(const Value &other) const { if (other.isa()) { - auto other_type = static_cast(&other); + auto other_type = static_cast(&other); return *this == *other_type; } else { return false; @@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() { return ptr; } -std::ostream& operator<<(std::ostream& os, const Type& type) { +std::ostream &operator<<(std::ostream &os, const Type &type) { os << type.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const TypePtr type) { +std::ostream &operator<<(std::ostream &os, const TypePtr type) { os << type->ToString(); return os; } @@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const { return false; } -std::ostream& operator<<(std::ostream& os, const Object& obj) { +std::ostream &operator<<(std::ostream &os, const Object &obj) { os << obj.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr obj) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj) { os << obj->ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const TypePtrList& types) { +std::ostream &operator<<(std::ostream &os, const TypePtrList &types) { os << "["; for (size_t i = 0; i < types.size(); ++i) { if (i > 0) { diff --git a/mindspore/ccsrc/ir/dtype/type.h b/mindspore/ccsrc/ir/dtype/type.h index 9454596538..1c67b6a855 100644 --- a/mindspore/ccsrc/ir/dtype/type.h +++ b/mindspore/ccsrc/ir/dtype/type.h @@ -49,6 +49,7 @@ enum TypeId : int { kMetaTypeExternal, kMetaTypeNone, kMetaTypeNull, + kMetaTypeEllipsis, kMetaTypeEnd, // // Object types @@ -95,10 +96,10 @@ enum TypeId : int { TypeId IntBitsToTypeId(const int nbits); TypeId UIntBitsToTypeId(const int nbits); TypeId FloatBitsToTypeId(const int nbits); -const char* TypeIdLabel(const TypeId& v); +const char *TypeIdLabel(const TypeId &v); TypeId NormalizeTypeId(const TypeId type_id); -bool IsSameObjectType(const Type& lhs, const Type& rhs); -size_t GetTypeByte(const TypePtr& type_ptr); +bool IsSameObjectType(const Type &lhs, const Type &rhs); +size_t GetTypeByte(const TypePtr &type_ptr); // Base class for all types // forward declaration. @@ -110,14 +111,14 @@ class Type : public Value { ~Type() override = default; MS_DECLARE_PARENT(Type, Value) - bool operator==(const Value& other) const override; + bool operator==(const Value &other) const override; TypeId meta_type() const { return meta_type_; } virtual TypeId type_id() const { return meta_type_; } virtual TypeId generic_type_id() const { return kMetaTypeType; } - virtual bool operator!=(const Type& other) const { return !(*this == other); } - virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); } + virtual bool operator!=(const Type &other) const { return !(*this == other); } + virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } virtual bool equal(const TypePtr other) const { return *this == *other; } virtual TypeId object_type() const { return kTypeUnknown; } @@ -134,8 +135,8 @@ class Type : public Value { bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } bool IsGeneric() const { return is_generic_; } abstract::AbstractBasePtr ToAbstract() override; - friend std::ostream& operator<<(std::ostream& os, const Type& type); - friend std::ostream& operator<<(std::ostream& os, const TypePtr type); + friend std::ostream &operator<<(std::ostream &os, const Type &type); + friend std::ostream &operator<<(std::ostream &os, const TypePtr type); const bool parse_info_ = true; @@ -163,14 +164,14 @@ class Object : public Type { bool equal(const TypePtr other) const override; std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } - friend std::ostream& operator<<(std::ostream& os, const Object& obj); - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr obj); + friend std::ostream &operator<<(std::ostream &os, const Object &obj); + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj); private: const TypeId object_type_; }; -std::ostream& operator<<(std::ostream& os, const TypePtrList& types); +std::ostream &operator<<(std::ostream &os, const TypePtrList &types); } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 93fd9c0936..8a58f320f1 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -61,7 +61,7 @@ FuncGraph::FuncGraph() AbstractFunctionPtr FuncGraph::abstract() { AbstractBasePtrList args_spec_list; - for (auto& p : parameters_) { + for (auto &p : parameters_) { MS_EXCEPTION_IF_NULL(p); if (p->abstract() == nullptr) { MS_LOG(ERROR) << "Error!!"; @@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() { return std::make_shared(args_spec_list, output()->abstract()); } -abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) { +abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { AnalysisContextPtr temp_context = context; if (temp_context == nullptr) { temp_context = abstract::AnalysisContext::DummyContext(); @@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const { } } -void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) { +void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { if (force_new_ret || return_ == nullptr) { std::vector params({NewValueNode(prim::kPrimReturn), value}); FuncGraphPtr this_graph = shared_from_base(); @@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() { return p; } -void FuncGraph::add_parameter(const ParameterPtr& p) { +void FuncGraph::add_parameter(const ParameterPtr &p) { if (manager_.lock()) { std::vector new_params = parameters_; new_params.push_back(p); @@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) { } } -ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { +ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { FuncGraphPtr this_graph = shared_from_base(); ParameterPtr p = std::make_shared(this_graph); p->set_name(name); @@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { return p; } -bool FuncGraph::has_flag(const std::string& flag) { +bool FuncGraph::has_flag(const std::string &flag) { if (flags_.count(flag)) { return flags_[flag]; } return false; } -CNodePtr FuncGraph::NewCNode(const std::vector& inputs) { +CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { CNodePtr cnode = std::make_shared(inputs, shared_from_base()); if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { order_.push_back(cnode); @@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector& inputs) { return cnode; } -CNodePtr FuncGraph::NewCNodeWithScope(const std::vector& inputs, const ScopePtr& scope) { +CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { CNodePtr app = NewCNode(inputs); app->set_scope(scope); return app; @@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector& inputs, con void FuncGraph::DumpCNodeList() { MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; - for (const auto& cnode : order_) { + for (const auto &cnode : order_) { MS_LOG(INFO) << cnode->DebugString(); } } std::string FuncGraph::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } GraphDebugInfoPtr FuncGraph::debug_info() { @@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() { return this->debug_info_; } -const AnfNodeSet& FuncGraph::nodes() { +const AnfNodeSet &FuncGraph::nodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& nodes = mng->nodes(); + auto &nodes = mng->nodes(); return nodes[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::value_nodes() { +const AnfNodeCounterMap &FuncGraph::value_nodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& cts = mng->valuenodes(); + auto &cts = mng->valuenodes(); return cts[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::free_variables_direct() { +const AnfNodeCounterMap &FuncGraph::free_variables_direct() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& fv_direct = mng->free_variables_direct(); + auto &fv_direct = mng->free_variables_direct(); return fv_direct[shared_from_base()]; } -const BaseRefCounterMap& FuncGraph::free_variables_total() { +const BaseRefCounterMap &FuncGraph::free_variables_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& fv_total = mng->free_variables_total(); + auto &fv_total = mng->free_variables_total(); return fv_total[shared_from_base()]; } std::vector FuncGraph::free_variables_nodes() { std::vector nodes; - const auto& fv_total = this->free_variables_total(); - for (auto& p : fv_total) { + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { nodes.push_back(utils::cast(key)); @@ -238,8 +238,8 @@ std::vector FuncGraph::free_variables_nodes() { std::vector FuncGraph::free_variables_func_graphs() { std::vector func_graphs; - const auto& fv_total = this->free_variables_total(); - for (auto& p : fv_total) { + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { func_graphs.push_back(utils::cast(key)); @@ -249,31 +249,31 @@ std::vector FuncGraph::free_variables_func_graphs() { return func_graphs; } -const FuncGraphCounterMap& FuncGraph::func_graphs_used() { +const FuncGraphCounterMap &FuncGraph::func_graphs_used() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& used = mng->func_graphs_used(); + auto &used = mng->func_graphs_used(); return used[shared_from_base()]; } -const FuncGraphSet& FuncGraph::func_graphs_used_total() { +const FuncGraphSet &FuncGraph::func_graphs_used_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& used = mng->func_graphs_used_total(shared_from_base()); + auto &used = mng->func_graphs_used_total(shared_from_base()); return used; } -const FuncGraphCounterMap& FuncGraph::func_graph_users() { +const FuncGraphCounterMap &FuncGraph::func_graph_users() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& users = mng->func_graph_users(); + auto &users = mng->func_graph_users(); return users[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() { +const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& users = mng->func_graph_user_cnodes(); + auto &users = mng->func_graph_user_cnodes(); return users[shared_from_base()]; } @@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() { return mng->parent(shared_from_base()); } -const FuncGraphSet& FuncGraph::children() { +const FuncGraphSet &FuncGraph::children() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->children(shared_from_base()); } -const FuncGraphSet& FuncGraph::scope() { +const FuncGraphSet &FuncGraph::scope() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->scopes(shared_from_base()); @@ -312,9 +312,9 @@ std::shared_ptr> FuncGraph::recursive_graphs() { return mng->recursive_graphs(shared_from_base()); } -void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base()); } +void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } -AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { +AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { auto itr = this->parameter_default_value_.find(name); if (itr == parameter_default_value_.end()) { return nullptr; @@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { } // set the default values -void FuncGraph::SetDefaultValues(const std::vector& name_list, const std::vector& value_list) { +void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { auto all_is_null = std::all_of(value_list.begin(), value_list.end(), - [](const AnfNodePtr& node) { return IsValueNode(node); }); + [](const AnfNodePtr &node) { return IsValueNode(node); }); if (value_list.empty()) { all_is_null = true; } @@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } size_t FuncGraph::GetDefaultValueCount() { int null_count = std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), - [](const std::pair& pair) { return IsValueNode(pair.second); }); + [](const std::pair &pair) { return IsValueNode(pair.second); }); return parameter_default_value_.size() - IntToSize(null_count); } @@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const { return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); } -AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { +AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { for (size_t i = 0; i < parameters_.size(); ++i) { MS_EXCEPTION_IF_NULL(parameters_[i]); auto param_cast = parameters_[i]->cast(); @@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { return nullptr; } -void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, - std::vector* specialized_parameter_list, - std::unordered_map* repl_nodes, int variable_args_count, +void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + std::unordered_map *repl_nodes, int variable_args_count, int pos_args_input_count) { // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple if (specialized_graph->has_vararg()) { @@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, } } -void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, - std::vector* specialized_parameter_list, - const std::vector& kwarg_list, - std::unordered_map* repl_nodes) { +void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + const std::vector &kwarg_list, + std::unordered_map *repl_nodes) { std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - for (const auto& kwarg : kwarg_list) { + for (const auto &kwarg : kwarg_list) { MS_EXCEPTION_IF_NULL(kwarg); std::string kw_param_name = kwarg->get_key(); MS_EXCEPTION_IF_NULL(specialized_graph); @@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; MS_EXCEPTION_IF_NULL(specialized_parameter_list); auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), - [param_name](const AnfNodePtr& node) { + [param_name](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto param = node->cast(); return param != nullptr && param->name() == param_name; @@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); } -void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, - std::unordered_map* repl_nodes, - const std::vector& kwarg_keys_tuple_nodes, - const std::vector& kwarg_values_tuple_nodes) { +void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, + std::unordered_map *repl_nodes, + const std::vector &kwarg_keys_tuple_nodes, + const std::vector &kwarg_values_tuple_nodes) { if (has_kwarg()) { MS_EXCEPTION_IF_NULL(specialized_graph); TraceManager::DebugTrace( @@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, } } -bool FuncGraph::NeedGenerate(const std::vector& kwarg_list) { +bool FuncGraph::NeedGenerate(const std::vector &kwarg_list) { // if the function does not have any vararg/kwarg/kwonly/default value/kw args input // return the original graph if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { @@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector& return true; } -void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, - const std::vector& specialized_parameter_list, - std::unordered_map* repl_nodes) { +void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, + const std::vector &specialized_parameter_list, + std::unordered_map *repl_nodes) { MS_EXCEPTION_IF_NULL(specialized_graph); for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { auto param_node = specialized_graph->parameters()[i]; @@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, } } -FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { std::vector kwarg_list; size_t arguments_count = args_spec_list.size(); - for (const auto& arg : args_spec_list) { + for (const auto &arg : args_spec_list) { // if it is a keyword argument MS_EXCEPTION_IF_NULL(arg); if (arg->isa()) { @@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) MS_EXCEPTION_IF_NULL(specialized_graph); auto params = specialized_graph->parameters(); (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), - std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; }); + std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); std::shared_ptr manager = mindspore::Manage(specialized_graph, false); auto tr = manager->Transact(); - for (auto& node_pair : repl_nodes) { + for (auto &node_pair : repl_nodes) { MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" << node_pair.second->DebugString(); (void)tr.Replace(node_pair.first, node_pair.second); @@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) return specialized_graph; } -void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); } +void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } std::list FuncGraph::GetOrderedCnodes() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { @@ -651,7 +651,7 @@ std::list FuncGraph::GetOrderedCnodes() { std::list cnodes; auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); - for (const auto& node : nodes) { + for (const auto &node : nodes) { auto cnode = dyn_cast(node); if (cnode) { cnodes.push_back(cnode); @@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() { } } -void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) { +void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { order_.remove(n->cast()); MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; @@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { MS_LOG(DEBUG) << "Check graph " << ToString(); for (auto it = order_.begin(); it != order_.end(); (void)it++) { - for (const auto& input_node : (*it)->inputs()) { + for (const auto &input_node : (*it)->inputs()) { if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { // Need to reorder the wrong order node. auto found = std::find(order_.begin(), it, input_node); @@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() { } auto mng = manager_.lock(); if (mng != nullptr) { - const auto& nodes = mng->nodes()[shared_from_base()]; + const auto &nodes = mng->nodes()[shared_from_base()]; if (nodes.size() != (order_.size() + parameters_.size())) { DumpCNodeList(); MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " @@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() { const char kPrimHasEffect[] = "_side_effect_flag"; -bool FuncGraph::HasEffect(const CNodePtr& cnode) { +bool FuncGraph::HasEffect(const CNodePtr &cnode) { auto prim = GetCNodePrimitive(cnode); if (prim != nullptr && prim->isa()) { auto do_sig = prim->cast(); @@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) { return false; } -std::shared_ptr> FindRoots(const std::vector& segment) { +std::shared_ptr> FindRoots(const std::vector &segment) { std::shared_ptr> roots = std::make_shared>(segment); - for (const auto& node : segment) { + for (const auto &node : segment) { if (roots->size() == 1) { return roots; } @@ -757,9 +757,9 @@ std::shared_ptr> FindRoots(const std::vector& seg return roots; } -std::shared_ptr> FindLeaves(const std::vector& segment) { +std::shared_ptr> FindLeaves(const std::vector &segment) { std::shared_ptr> nodes = std::make_shared>(segment); - for (const auto& node : segment) { + for (const auto &node : segment) { if (nodes->size() == 1) { return nodes; } @@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { std::list depends_order; std::vector segment; - for (const auto& cnode : order_) { + for (const auto &cnode : order_) { if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { continue; } @@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { } } -void FuncGraph::SetEffectDepends(const std::vector& depend_inputs) { +void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { auto old_ret = output(); std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index d90cdbacf2..c086b8d7d1 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -26,29 +26,29 @@ // namespace to support intermediate representation definition namespace mindspore { -Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, - bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation) +Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, + bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) : clone_all_valuenodes_(clone_all_valuenodes), clone_all_child_graphs_(clone_all_child_graphs), clone_all_used_graphs_(clone_all_used_graphs), relation_(relation), target_relation_(target_relation == nullptr ? relation : target_relation) { - for (auto& func_graph : func_graphs) { + for (auto &func_graph : func_graphs) { AddClone(func_graph); } scope_ = kDefaultScope; type_ = kBasic; } -void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& params, CloneType type) { +void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList ¶ms, CloneType type) { if (func_graph != nullptr) { todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); type_ = type; } } -void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); if (repl_node_.find(node) != repl_node_.end() || node->isa()) { return; @@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { } } -void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) { +void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, TraceManager::EndTrace(); } -void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { TraceManager::EndTrace(); } -void Cloner::CloneValueNode(const AnfNodePtr& node) { +void Cloner::CloneValueNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TraceManager::DebugTrace(node->debug_info(), relation_); ValueNodePtr new_const = NewValueNode(GetValueNode(node)); @@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) { TraceManager::EndTrace(); } -void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) TraceManager::EndTrace(); } -void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { +void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_valuenodes_) { return; } - auto& value_nodes = manager_->valuenodes()[func_graph]; - for (auto& value_node : value_nodes) { + auto &value_nodes = manager_->valuenodes()[func_graph]; + for (auto &value_node : value_nodes) { auto old_node = value_node.first; MS_EXCEPTION_IF_NULL(old_node); if (repl_node_.count(old_node) == 0) { @@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { } } -void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) { +void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_child_graphs_) { return; } - auto& scopes = manager_->scopes(func_graph); - for (auto& graph : scopes) { + auto &scopes = manager_->scopes(func_graph); + for (auto &graph : scopes) { if (graph != func_graph) { todo_.push_back({graph, nullptr, {}}); } } } -void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) { +void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_used_graphs_) { return; } - auto& used_graphs = manager_->func_graphs_used()[func_graph]; - for (auto& used_graph : used_graphs) { + auto &used_graphs = manager_->func_graphs_used()[func_graph]; + for (auto &used_graph : used_graphs) { todo_.push_back({used_graph.first, nullptr, {}}); } } -void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); - for (auto& item : func_graph->parameter_default_value()) { + for (auto &item : func_graph->parameter_default_value()) { auto nodes = DeepLinkedGraphSearch(item.second); - for (auto& node : nodes) { + for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { CloneNode(node, target_func_graph); @@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F } } -void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); @@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func } target_func_graph->set_return(return_node); - auto& value_nodes = manager_->func_graph_valuenodes()[func_graph]; - for (auto& value_node : value_nodes) { + auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; + for (auto &value_node : value_nodes) { CloneValueNode(value_node.first, target_func_graph); } } -void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) { +void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { MS_EXCEPTION_IF_NULL(func_graph); - auto& old_params = func_graph->parameters(); + auto &old_params = func_graph->parameters(); if (old_params.size() != params.size()) { MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; return; @@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode } } -void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) { +void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); @@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons TraceManager::EndTrace(); } -void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); - auto& params = func_graph->parameters(); - for (auto& param : params) { + auto ¶ms = func_graph->parameters(); + for (auto ¶m : params) { CloneParameter(param, target_func_graph, true); } repl_func_graph_[func_graph] = target_func_graph; } -void Cloner::GenParameters(const FuncGraphPtr& func_graph) { +void Cloner::GenParameters(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - auto& free_vars = manager_->free_variables_total(); + auto &free_vars = manager_->free_variables_total(); auto iter = free_vars.find(func_graph); if (iter == free_vars.end()) { return; } - for (auto& fv_map : iter->second) { - auto& free_var = fv_map.first; + for (auto &fv_map : iter->second) { + auto &free_var = fv_map.first; if (utils::isa(free_var)) { repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast(free_var))); } } } -void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { +void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { param->set_abstract(node->abstract()); if (node->isa()) { ParameterPtr old_param = dyn_cast(node); @@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { } } -ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) { +ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { TraceManager::DebugTrace(std::make_shared(node->debug_info())); ParameterPtr param = std::make_shared(func_graph); TraceManager::EndTrace(); @@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP return param; } -void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, - AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) { +void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, + AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { AnfNodePtrList parameters; std::unordered_set old_params; - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { auto iter = repl_node_.find(param); if (iter != repl_node_.end()) { (void)old_params.insert(iter->second); @@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& } } AnfNodePtr new_param = nullptr; - for (auto& param : params) { + for (auto ¶m : params) { auto old_param = repl_node_[param]; if (old_param->isa() && old_param->func_graph() == func_graph) { repl_node_[old_param] = old_param; @@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& func_graph->set_parameters(parameters); } -void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params) { +void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { AnfNodePtr node = nullptr; - auto& repl_func_graph = repl_map_func_graph_[func_graph_user]; + auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; auto iter = repl_func_graph.find(func_graph); if (iter == repl_func_graph.end()) { node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); @@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& OrderParameters(func_graph, inputs); } -void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) { +void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { std::unordered_set old_params; - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { (void)old_params.insert(repl_node_[param]); } std::unordered_set new_params; @@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis (void)new_params.insert(new_param); } } - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { if (new_params.find(param) == new_params.end()) { parameters.push_back(param); } @@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis func_graph->set_parameters(parameters); } -void Cloner::SetEdges(const FuncGraphPtr& func_graph) { +void Cloner::SetEdges(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - for (auto& node : func_graph->nodes()) { + for (auto &node : func_graph->nodes()) { if (node == nullptr) { continue; } @@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { continue; } auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); for (size_t i = 0; i < inputs.size(); i++) { - auto& input = inputs[i]; + auto &input = inputs[i]; if (IsValueNode(input)) { auto graph = GetValueNode(input); - auto& repl_func_graph = repl_map_func_graph_[func_graph]; + auto &repl_func_graph = repl_map_func_graph_[func_graph]; if (repl_func_graph.find(graph) != repl_func_graph.end()) { transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); } } else { - auto& repl_node = repl_map_node_[func_graph]; + auto &repl_node = repl_map_node_[func_graph]; if (repl_node.find(input) != repl_node.end()) { transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); } @@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { } } -void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params) { +void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { AnfNodePtrList lift_params; AnfNodePtrList input_params; AddParameters(func_graph_user, params, &lift_params, &input_params); @@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph if (lift_params.empty()) { return; } - for (auto& user : func_graph_user->func_graph_users()) { + for (auto &user : func_graph_user->func_graph_users()) { LiftParameters(user.first, func_graph_user, lift_params); } } void Cloner::Lift() { - for (auto& func_graph_params : repl_func_graph_params_) { - auto& func_graph = func_graph_params.first; - auto& params = func_graph_params.second; - for (auto& user : func_graph->func_graph_users()) { + for (auto &func_graph_params : repl_func_graph_params_) { + auto &func_graph = func_graph_params.first; + auto ¶ms = func_graph_params.second; + for (auto &user : func_graph->func_graph_users()) { LiftParameters(user.first, func_graph, params); } } @@ -404,18 +404,18 @@ void Cloner::Lift() { void Cloner::LiftParameters() { MS_EXCEPTION_IF_NULL(manager_); transaction_ = manager_->Transact(); - const FuncGraphSet& func_graphs = manager_->func_graphs(); - for (auto& func_graph : func_graphs) { + const FuncGraphSet &func_graphs = manager_->func_graphs(); + for (auto &func_graph : func_graphs) { GenParameters(func_graph); } Lift(); - for (auto& func_graph : func_graphs) { + for (auto &func_graph : func_graphs) { SetEdges(func_graph); } transaction_.Commit(); } -bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { +bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { MS_EXCEPTION_IF_NULL(func_graph); // Make sure only inline once if (status_.count(func_graph) != 0) { @@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { return true; } -void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); - const AnfNodeSet& nodes = manager_->nodes()[func_graph]; - for (auto& node : nodes) { + const AnfNodeSet &nodes = manager_->nodes()[func_graph]; + for (auto &node : nodes) { CloneNode(node, target_func_graph); } } @@ -449,7 +449,7 @@ void Cloner::Run() { // Basic and Inline Clone FuncGraphPtrList func_graphs; (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), - [](const CloneInfo& item) -> FuncGraphPtr { return item.origin; }); + [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); manager_ = Manage(func_graphs, false); CloneNodes(); LinkEdges(); @@ -495,13 +495,13 @@ void Cloner::CloneNodes() { } void Cloner::LinkEdges() { - for (auto& node_pair : nodes_) { + for (auto &node_pair : nodes_) { CNodePtr old_node = node_pair.first; CNodePtr new_node = node_pair.second; MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(new_node); - for (auto& input : old_node->inputs()) { - auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; + for (auto &input : old_node->inputs()) { + auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; new_node->add_input(new_input); } } @@ -509,10 +509,10 @@ void Cloner::LinkEdges() { // For the graphs cloned, update its default value map to the cloned nodes void Cloner::SetDefaults() { - for (auto& item : graph_set_) { + for (auto &item : graph_set_) { MS_EXCEPTION_IF_NULL(item); if (repl_func_graph_.count(item) != 0) { - for (auto& param_def : item->parameter_default_value()) { + for (auto ¶m_def : item->parameter_default_value()) { MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); if (repl_node_.count(param_def.second) != 0) { repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); @@ -524,7 +524,7 @@ void Cloner::SetDefaults() { } } -AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { +AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { MS_EXCEPTION_IF_NULL(root); if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; @@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; } -AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { +AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { #ifdef ENABLE_PROFILE double time = GetTime(); #endif @@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); } -FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { +FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { #ifdef ENABLE_PROFILE double time = GetTime(); #endif @@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); } -FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) { +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); Cloner cloner({func_graph}, false, true, true, std::make_shared(), nullptr); return cloner[func_graph]; } -AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& func_graph_args, const ScopePtr& scope) { +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); Cloner cloner({}, false); @@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe return cloner[func_graph->output()]; } -FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) { +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); Cloner cloner({}, false); cloner.AddClone(func_graph, nullptr, {}, kLifting); return cloner[func_graph]; } -ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { MS_EXCEPTION_IF_NULL(func_graph); FuncGraphPtrList func_graphs = {func_graph}; ClonerPtr cloner = @@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r return cloner; } -FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { MS_EXCEPTION_IF_NULL(func_graph); TraceManager::DebugTrace(func_graph->debug_info(), relation); auto new_func_graph = std::make_shared(); TraceManager::EndTrace(); - auto& parameters = func_graph->parameters(); - (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void { + auto ¶meters = func_graph->parameters(); + (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { MS_EXCEPTION_IF_NULL(param); TraceManager::DebugTrace(std::make_shared(param->debug_info())); (void)new_func_graph->add_parameter(); @@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_is_generate(func_graph->is_generated()); - for (auto& item : func_graph->parameter_default_value()) { + for (auto &item : func_graph->parameter_default_value()) { new_func_graph->set_param_default_value(item.first, cloner[item.second]); } diff --git a/mindspore/ccsrc/ir/func_graph_cloner.h b/mindspore/ccsrc/ir/func_graph_cloner.h index dd228cf79f..426cf447a3 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.h +++ b/mindspore/ccsrc/ir/func_graph_cloner.h @@ -43,26 +43,26 @@ struct CloneInfo { class Cloner { public: - explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false, + explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false, bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, - const TraceInfoPtr& relation = std::make_shared(), - const TraceInfoPtr& target_relation = nullptr); + const TraceInfoPtr &relation = std::make_shared(), + const TraceInfoPtr &target_relation = nullptr); ~Cloner() = default; - void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr, - const AnfNodePtrList& params = {}, CloneType type = kBasic); + void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr, + const AnfNodePtrList ¶ms = {}, CloneType type = kBasic); void Run(); // Interfaces for specializer - AnfNodePtr CloneDisconnected(const AnfNodePtr& root); - AnfNodePtr operator[](const AnfNodePtr& node); - FuncGraphPtr operator[](const FuncGraphPtr& func_graph); + AnfNodePtr CloneDisconnected(const AnfNodePtr &root); + AnfNodePtr operator[](const AnfNodePtr &node); + FuncGraphPtr operator[](const FuncGraphPtr &func_graph); // Map of replicate nodes and graphs - std::unordered_map* cloned_node() { return &repl_node_; } + std::unordered_map *cloned_node() { return &repl_node_; } std::unordered_map cloned_func_graph() { return repl_func_graph_; } // Scope of cloned graphs - void set_scope(const ScopePtr& scope) { scope_ = scope; } + void set_scope(const ScopePtr &scope) { scope_ = scope; } const ScopePtr scope() const { return scope_; } std::unordered_map repl_node_; @@ -71,31 +71,31 @@ class Cloner { void CloneNodes(); void LinkEdges(); void SetDefaults(); - void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneValueNode(const AnfNodePtr& node); - void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false); - void CloneValueNodes(const FuncGraphPtr& func_graph); - void AddChildGraphs(const FuncGraphPtr& func_graph); - void AddTotalGraphs(const FuncGraphPtr& func_graph); - bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline); - void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params); - void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph); - void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void GenParameters(const FuncGraphPtr& func_graph); - void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node); - ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true); - void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params, - AnfNodePtrList* const input_params); - void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params); - void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs); - void SetEdges(const FuncGraphPtr& func_graph); - void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params); + void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneValueNode(const AnfNodePtr &node); + void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); + void CloneValueNodes(const FuncGraphPtr &func_graph); + void AddChildGraphs(const FuncGraphPtr &func_graph); + void AddTotalGraphs(const FuncGraphPtr &func_graph); + bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); + void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); + void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph); + void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void GenParameters(const FuncGraphPtr &func_graph); + void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node); + ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true); + void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, AnfNodePtrList *const lift_params, + AnfNodePtrList *const input_params); + void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); + void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs); + void SetEdges(const FuncGraphPtr &func_graph); + void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms); void Lift(); void LiftParameters(); @@ -118,17 +118,17 @@ class Cloner { std::unordered_map repl_func_graph_params_; }; -FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph); +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph); -AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr); +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr); -FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph); +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph); -ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation); +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation); -FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, - const TraceInfoPtr& relation = std::make_shared()); +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, + const TraceInfoPtr &relation = std::make_shared()); } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 889a091711..a53c9e95ae 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -27,17 +27,17 @@ namespace mindspore { -FuncGraphManagerPtr MakeManager(const std::vector& func_graphs, bool manage) { +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs, bool manage) { auto m = std::make_shared(func_graphs, manage); m->Init(); return m; } -FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool manage) { +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage) { FuncGraphManagerPtr m = nullptr; bool root = false; - for (auto& fg : func_graphs) { + for (auto &fg : func_graphs) { if (fg == nullptr) { continue; } @@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool ma root = true; } - for (auto& fg : func_graphs) { + for (auto &fg : func_graphs) { if (fg == nullptr) { continue; } @@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { return Manage(func_graphs, manage); } -FuncGraphManager::FuncGraphManager(const std::vector& roots, bool manage) +FuncGraphManager::FuncGraphManager(const std::vector &roots, bool manage) : roots_(roots), is_manage_(manage) { Reset(); } @@ -103,12 +103,12 @@ void FuncGraphManager::Init() { auto roots = roots_; roots_ = FuncGraphSet(); - for (auto& fg : roots) { + for (auto &fg : roots) { AddFuncGraph(fg, true); } } -FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); func_graph_parents_total_->Recompute(fg); @@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; } -FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { +FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(func_graph_parent_); MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); @@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { return func_graph_parent_->parent_analysis()[fg]; } -FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(children_); MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); @@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { return children_->children_analysis()[fg]; } -FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(scopes_); MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); @@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { return scopes_->scope_analysis()[fg]; } -FVTotalMap& FuncGraphManager::free_variables_total() const { +FVTotalMap &FuncGraphManager::free_variables_total() const { MS_EXCEPTION_IF_NULL(free_variables_total_); free_variables_total_->Recompute(); return free_variables_total_->fv_total_analysis(); } -FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(func_graphs_used_total_); func_graphs_used_total_->Recompute(fg); return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; } -bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { +bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); recursive_->Recompute(fg); if (recursive_->recursive_analysis().count(fg) == 0) { @@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { return recursive_->recursive_analysis()[fg]; } -std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const { +std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); if (recursive(fg)) { if (!recursive_->recursive_map().count(fg)) { @@ -185,7 +185,7 @@ std::shared_ptr> FuncGraphManager::recursive_graphs(cons } } -bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const { +bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(j_total_); MS_EXCEPTION_IF_NULL(fg); j_total_->Recompute(fg); @@ -225,10 +225,10 @@ void FuncGraphManager::Clear() { signals_->InvalidateComputer(); } -void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { +void FuncGraphManager::KeepRoots(const std::vector &func_graphs) { MS_LOG(DEBUG) << "Start keep roots"; bool root_exist = false; - for (auto& item : func_graphs) { + for (auto &item : func_graphs) { if (roots_.contains(item)) { root_exist = true; break; @@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { roots = roots_; } else { roots_.clear(); - for (auto& item : roots) { + for (auto &item : roots) { AddFuncGraph(item, true); } } FuncGraphSet keep; - for (auto& item : roots) { + for (auto &item : roots) { MS_LOG(DEBUG) << "roots: " << item->ToString(); keep.update(func_graphs_used_total(item)); #ifdef DEBUG - for (auto& k : keep) { + for (auto &k : keep) { MS_LOG(DEBUG) << "keep: " << k->ToString(); } #endif @@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { } else { Clear(); FuncGraphSet roots(func_graphs); - for (auto& item : roots) { + for (auto &item : roots) { AddFuncGraph(item, true); } } @@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() { MaybeDropFuncGraphs(func_graphs_, true); } -void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { +void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(fg); if (is_manage_) { if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { @@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { func_graphs_.add(fg); } -void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) { +void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { FuncGraphSet todo(func_graphs); std::set dropped; // int count = 0; @@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool continue; } MS_EXCEPTION_IF_NULL(func_graph_users_); - auto& users = func_graph_users_->count_func_graphs_map()[func_graph]; + auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; if (!users.empty() && !ignore_users) { MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); continue; @@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool todo.update(MaybeDropNodes(return_vec)); } MS_EXCEPTION_IF_NULL(signals_); - for (auto& fg : dropped) { + for (auto &fg : dropped) { MS_EXCEPTION_IF_NULL(fg); signals_->DropFuncGraph(fg); all_nodes_.difference_update(fg->parameters()); @@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E MS_EXCEPTION_IF_NULL(inp); if (direction == kDecEdge) { MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); - auto& users_node = node_users_[inp]; + auto &users_node = node_users_[inp]; if (!users_node.contains(make_pair(node, index))) { return; } @@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); AddFuncGraph(GetValueNode(inp)); } - auto& users_node = node_users_[inp]; + auto &users_node = node_users_[inp]; users_node.add(make_pair(node, index)); MS_EXCEPTION_IF_NULL(signals_); signals_->AddEdge(node, index, inp); } } -void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) { +void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); int index = 0; - for (auto& inp : cnode->inputs()) { + for (auto &inp : cnode->inputs()) { ProcessEdge(cnode, index, inp, direction); ++index; } } } -IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { +IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { if (all_nodes_.contains(node)) { return EXCLUDE; } else { @@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { } } -void FuncGraphManager::AcquireNodes(const std::vector& nodes) { +void FuncGraphManager::AcquireNodes(const std::vector &nodes) { AnfNodeSet acq; - for (auto& node : nodes) { + for (auto &node : nodes) { std::function limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit)); @@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector& nodes) { acq.update(new_nodes); } - for (auto& node : acq) { + for (auto &node : acq) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr fg = node->func_graph(); if (fg != nullptr) { @@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector& nodes) { } } -FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& nodes) { +FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { AnfNodeSet nodes_ordered(nodes); FuncGraphSetPtr func_graphs_to_check = std::make_shared(); MS_EXCEPTION_IF_NULL(signals_); @@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& if (!all_nodes_.contains(node)) { continue; } - AnfNodeIndexSet& users = node_users_[node]; + AnfNodeIndexSet &users = node_users_[node]; std::vector parameters; if (!users.empty() || @@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& return func_graphs_to_check; } -void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector& parameters) { +void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters) { auto tr = Transact(); tr.SetParameters(fg, parameters); tr.Commit(); } -bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { +bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { auto tr = Transact(); bool success = tr.Replace(old_node, new_node); if (success) { @@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new return success; } -void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) { +void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { auto tr = Transact(); tr.SetEdge(node, index, value); tr.Commit(); } -void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) { +void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { AnfNodePtr source_return = source->get_return(); AnfNodePtr source_output = source->output(); AnfNodePtr source_prim = source_return->cast()->input(0); @@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t (void)all_nodes_.erase(source_return); (void)node_users_.erase(source_return); signals_->DropNode(source_return); - for (auto& node : source->nodes()) { + for (auto &node : source->nodes()) { node->set_func_graph(target); if (node->scope() == kDefaultScope) { node->set_scope(scope); } } - for (auto& used : source->func_graphs_used()) { + for (auto &used : source->func_graphs_used()) { (void)func_graph_users_->Inc(used.first, target, used.second); (void)this->func_graph_users()[used.first].erase(source); } - for (auto& child : this->func_graph_child_direct()[source]) { + for (auto &child : this->func_graph_child_direct()[source]) { (void)func_graph_parents_direct_->Inc(child.first, target, child.second); (void)this->func_graph_parents_direct()[child.first].erase(source); } - for (auto& fv_count : this->free_variables_direct()[source]) { + for (auto &fv_count : this->free_variables_direct()[source]) { auto fv_g = fv_count.first->func_graph(); - auto& count_on_g = this->func_graph_child_direct()[fv_g]; + auto &count_on_g = this->func_graph_child_direct()[fv_g]; auto pair = count_on_g.find(source); if (fv_g != target && pair != count_on_g.end()) { (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); @@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() { return tr; } -void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupleCounter* add_edges, - EdgeTupleCounter* rm_edges, Counter* adds, Counter* rms) { - for (auto& iter : changes) { +void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, + EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms) { + for (auto &iter : changes) { auto operation = iter.op; auto args = iter.args; if (operation == Change::kTxSetEdge) { @@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupl auto param = args.cast(); MS_EXCEPTION_IF_NULL(param.func_graph); auto old_parameters = param.func_graph->parameters(); - for (auto& p : param.params) { + for (auto &p : param.params) { (*adds)[p] += 1; } - for (auto& p : old_parameters) { + for (auto &p : old_parameters) { (*rms)[p] += 1; } param.func_graph->set_parameters(param.params); @@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupl } } -void FuncGraphManager::CommitChanges(const std::vector& changes) { +void FuncGraphManager::CommitChanges(const std::vector &changes) { EdgeTupleCounter add_edges; EdgeTupleCounter rm_edges; Counter adds; @@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); auto sub_edges = add_edges - rm_edges; - for (auto& iter : sub_edges) { + for (auto &iter : sub_edges) { auto root_node = iter.first.first; int index = iter.first.second.first; auto new_node = iter.first.second.second; @@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { auto sub_nodes = adds - rms; std::vector nodes; (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), - [](const std::pair& iter) -> AnfNodePtr { return iter.first; }); + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); AcquireNodes(nodes); auto sub_edges_reverse = rm_edges - add_edges; - for (auto& iter : sub_edges_reverse) { + for (auto &iter : sub_edges_reverse) { auto root_node = iter.first.first; int index = iter.first.second.first; auto old_node = iter.first.second.second; @@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { std::vector nodes_reverse; (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), - [](const std::pair& iter) -> AnfNodePtr { return iter.first; }); + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); auto drop_func_graphs = MaybeDropNodes(nodes_reverse); MaybeDropFuncGraphs(*drop_func_graphs); } -void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector& params) { +void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); } -bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { +bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(new_node); FuncGraphPtr old_func_graph = old_node->func_graph(); @@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& return false; } auto users = manager_->node_users()[old_node]; - for (auto& node : users) { + for (auto &node : users) { SetEdge(node.first, node.second, new_node); } return true; } -void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) { +void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { if (k < 0) { MS_LOG(EXCEPTION) << "Invalid value k = " << k; } @@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() { manager_->CommitChanges(changes); } -FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) +FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) : manager_(manager), include_func_graph_none_(false) { manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); @@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); } -NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() { +NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { include_func_graph_none_ = true; nodes_analysis_[nullptr] = AnfNodeSet(); @@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) { void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // change the owner of node except for the src's return node - for (auto& it : nodes_analysis_[src]) { + for (auto &it : nodes_analysis_[src]) { nodes_analysis_[dst].add(it); } (void)nodes_analysis_.erase(src); @@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } -DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { +DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); } void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } -bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { - auto& d = count_nodes_map_[func_graph]; +bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { + auto &d = count_nodes_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; return true; @@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { +bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { MS_EXCEPTION_IF_NULL(func_graph); - auto& d = count_nodes_map_[func_graph]; + auto &d = count_nodes_map_[func_graph]; if (d.count(key) != 0) { if (d[key] == count) { (void)d.erase(key); @@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) { +bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { @@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP } } -bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { - auto& d = count_func_graphs_map_[func_graph]; +bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; return true; @@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr return false; } -bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { - auto& d = count_func_graphs_map_[func_graph]; +bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) != 0) { if (d[key] == count) { (void)d.erase(key); @@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr return false; } -bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) { +bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { @@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr } void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed } void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc } void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { FuncGraphPtr fg2 = it.first->func_graph(); if (fg2 != dst) { (void)Inc(dst, it.first, it.second); @@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { (void)count_nodes_map_.erase(src); } -static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) { +static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { FuncGraphPtr gn = std::make_shared(); (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); return gn; @@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP } void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { FuncGraphPtr fg = it.first; if (fg != dst) { (void)Inc(dst, fg, it.second); @@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr } void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { if (it.first != dst) { (void)Inc(dst, it.first, it.second); } @@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // all graph use in src need to change to dst, so meger the to dst use - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_func_graphs_map_[dst].erase(src); @@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp } void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // all graph use in src need to change to dst, so meger the to dst use - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_func_graphs_map_.erase(src); } -DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { +DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); validate_ = false; @@ -914,20 +914,20 @@ void DepComputer::Recompute() { } } -void DepComputer::Recompute(const FuncGraphPtr& fg) { +void DepComputer::Recompute(const FuncGraphPtr &fg) { if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { RealRecompute(fg); func_graphs_validate_[fg] = true; } } -FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { +FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { if (path == nullptr || path->contains(fg)) { return std::make_shared(); } FuncGraphSetPtr parents = std::make_shared(); - FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_; - for (auto& dep : deps[fg]) { + FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; + for (auto &dep : deps[fg]) { MS_EXCEPTION_IF_NULL(dep.first); auto proxy = dep.first->transforms().find("proxy"); if (proxy != dep.first->transforms().end()) { @@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); } -bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) { +bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { auto l1 = lhs.second.size(); auto l2 = rhs.second.size(); return l1 < l2; @@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { } else { // return nearest parent as parent FuncGraphSet deps_copy(deps); - for (auto& dep : deps) { + for (auto &dep : deps) { auto parent_deps = this->manager_->func_graph_parents_total(dep); - for (auto& p_d : parent_deps) { + for (auto &p_d : parent_deps) { if (deps_copy.count(p_d)) { (void)deps_copy.erase(p_d); } @@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); auto used_fg_total = manager_->func_graphs_used_total(fg); - for (auto& used_fg : used_fg_total) { + for (auto &used_fg : used_fg_total) { if (manager_->parent(used_fg) == fg) { children_analysis_[fg].add(used_fg); } @@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { void ScopeComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); - auto& children = manager_->children(fg); + auto &children = manager_->children(fg); scope_analysis_[fg] = FuncGraphSet(); scope_analysis_[fg].add(fg); - for (auto& child : children) { + for (auto &child : children) { scope_analysis_[fg].add(child); } } @@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() { auto manager = DepComputer::manager_; MS_EXCEPTION_IF_NULL(manager); - for (auto& fg : manager->func_graphs()) { + for (auto &fg : manager->func_graphs()) { fv_total_analysis_[fg] = OrderedMap(); count_nodes_map_[fg] = OrderedMap(); count_func_graphs_map_[fg] = OrderedMap(); } - for (auto& fg : manager->func_graphs()) { + for (auto &fg : manager->func_graphs()) { AnfNodeCounterMap items = manager->free_variables_direct()[fg]; - for (auto& iter : items) { + for (auto &iter : items) { auto curr = fg; while (curr) { (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); curr = manager->parent(curr); - const AnfNodeSet& nodes = manager->nodes()[curr]; + const AnfNodeSet &nodes = manager->nodes()[curr]; if (nodes.contains(iter.first)) { break; } @@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() { } auto items_fg = manager->func_graphs_used()[fg]; - for (auto& iter : items_fg) { + for (auto &iter : items_fg) { auto p = manager->parent(iter.first); if (p == nullptr) { continue; @@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() { } } } - for (auto& fg : manager->func_graphs()) { - auto& fvp = count_nodes_map_[fg]; - auto& fvg = count_func_graphs_map_[fg]; - for (auto& item : fvp) { + for (auto &fg : manager->func_graphs()) { + auto &fvp = count_nodes_map_[fg]; + auto &fvg = count_func_graphs_map_[fg]; + for (auto &item : fvp) { fv_total_analysis_[fg][item.first] = item.second; } - for (auto& item : fvg) { + for (auto &item : fvg) { fv_total_analysis_[fg][item.first] = item.second; } } @@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() { void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); - auto& used = this->manager_->func_graphs_used(); + auto &used = this->manager_->func_graphs_used(); std::vector todo; std::vector todo_new; todo.push_back(fg); while (!todo.empty()) { todo_new.clear(); - for (auto& gt : todo) { - for (auto& item : used[gt]) { + for (auto > : todo) { + for (auto &item : used[gt]) { auto used_fg = item.first; if (used_fg == fg) { func_graph_used_total_analysis_[fg].add(used_fg); @@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { } } -bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) { +bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(manager); - auto& used = manager->func_graphs_used(); + auto &used = manager->func_graphs_used(); std::vector todo; std::vector todo_new; todo.push_back(fg); FuncGraphSet used_total; while (!todo.empty()) { todo_new.clear(); - for (auto& gt : todo) { - for (auto& item : used[gt]) { + for (auto > : todo) { + for (auto &item : used[gt]) { auto used_g = item.first; if (used_g == fg) { return true; @@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); } -void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list* trace) { +void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace) { MS_EXCEPTION_IF_NULL(trace); auto res = std::find(trace->begin(), trace->end(), fg); // find recursive @@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::listpush_back(fg); - auto& used_fgs = manager_->func_graphs_used()[fg]; + auto &used_fgs = manager_->func_graphs_used()[fg]; for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { CheckRecursiveGraphs(iter->first, trace); } @@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::listcontains(fg)) { MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked"; return false; } MS_EXCEPTION_IF_NULL(manager_); - auto& func_graph_counter_map = manager_->func_graph_j_direct(); + auto &func_graph_counter_map = manager_->func_graph_j_direct(); if (!func_graph_counter_map[fg].empty()) { // check g1->J(fg)->g2->g cycle; auto contains_j = @@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt path->add(fg); // check if func graphs used contains J(func_graph); - auto& used = this->manager_->func_graphs_used(); - for (auto& item : used[fg]) { + auto &used = this->manager_->func_graphs_used(); + for (auto &item : used[fg]) { auto used_g = item.first; if (SeekJ(used_g, path)) { MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString() diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index aaf5a0aa5f..54c1e8a692 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -46,13 +46,13 @@ class FuncGraphManager; using FuncGraphManagerPtr = std::shared_ptr; struct AnfNodeIndexPairHasher { - std::size_t operator()(const std::pair& p1) const { - return std::hash{}(p1.first.get()); + std::size_t operator()(const std::pair &p1) const { + return std::hash{}(p1.first.get()); } }; struct AnfNodeIndexPairEqual { - bool operator()(const std::pair& lhs, const std::pair& rhs) const { + bool operator()(const std::pair &lhs, const std::pair &rhs) const { return lhs == rhs; } }; @@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair; using FuncGraphSetPtr = std::shared_ptr; using EdgeTuple = std::pair>; struct EdgeTupleHasher { - std::size_t operator()(const EdgeTuple& p1) const { - return hash_combine({std::hash{}(p1.first.get()), std::hash{}(p1.second.first), - std::hash{}(p1.second.second.get())}); + std::size_t operator()(const EdgeTuple &p1) const { + return hash_combine({std::hash{}(p1.first.get()), std::hash{}(p1.second.first), + std::hash{}(p1.second.second.get())}); } }; struct EdgeTupleEqual { - bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const { + bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const { return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second; } }; @@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter; // FuncGraphManagerPtr: return created manager FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true); -FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool manage = true); +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage = true); -FuncGraphManagerPtr MakeManager(const std::vector& func_graphs = {}, bool manage = true); +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs = {}, bool manage = true); struct Signals { Signal AddFuncGraph; @@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap; // graphs analysis which compute in write, read needn't recompute class DepCollector : public FuncGraphAnalysis { public: - explicit DepCollector(const FuncGraphManager* manager); + explicit DepCollector(const FuncGraphManager *manager); ~DepCollector() override = default; void Reset() { ExtraReset(); } @@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis { class NodesCollector final : public DepCollector { public: - explicit NodesCollector(const FuncGraphManager* m); + explicit NodesCollector(const FuncGraphManager *m); ~NodesCollector() override = default; - const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; } + const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } size_t size() const override { return nodes_analysis_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } @@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector { class CounterFuncGraphCollector : public DepCollector { public: - explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {} + explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} ~CounterFuncGraphCollector() override = default; - FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } // inherit from FuncGraphAnalysis size_t size() const override { return count_func_graphs_map_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap(); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } - bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); - bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); - bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); + bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); FuncGraphToFuncGraphCounterMap count_func_graphs_map_; @@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector { class CounterAnfNodeCollector : public DepCollector { public: - explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {} + explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } size_t size() const override { return count_nodes_map_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap(); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); - bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); - bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); + bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); FuncGraphToAnfNodeCounterMap count_nodes_map_; @@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector { class ValueNodesCollector final : public CounterAnfNodeCollector { public: - explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~ValueNodesCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { public: - explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~FuncGraphValueNodesCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { class FVDirectCollector final : public CounterAnfNodeCollector { public: - explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~FVDirectCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector { class FuncGraphChildDirect final : public CounterFuncGraphCollector { public: - explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphChildDirect() override = default; @@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector { // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} ~FuncGraphParentsDirectCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { // graph's all used graphs: key is g, value is g used graph class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphsUsedCollector() override = default; @@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { // graph's all user graphs: key is g, value is graphs who used g class FuncGraphUsersCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphUsersCollector() override = default; @@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector { // graph's all user cnodes: key is g, value is cnodes who used g class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { public: - explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphUserNodesCollector() override = default; @@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; ~FuncGraphJDirectCollector() override = default; @@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap; // graphs analysis which need dynamic compute by DepCollector in each read class DepComputer : public FuncGraphAnalysis { public: - explicit DepComputer(const FuncGraphManager* manager); + explicit DepComputer(const FuncGraphManager *manager); ~DepComputer() override = default; void Reset() { @@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis { void Recompute(); - void Recompute(const FuncGraphPtr& fg); + void Recompute(const FuncGraphPtr &fg); bool IsValidate() const { return validate_; } - bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; } + bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } @@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis { // graph g's all direct or proxy parents class FuncGraphParentsTotalComputer final : public DepComputer { public: - explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {} + explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } - FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } + FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } size_t size() const override { return func_graph_parents_total_analysis_.size(); } @@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer { void RealRecompute(FuncGraphPtr fg) override; private: - FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared()); + FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared()); // when SeekParents calls itself recursively, it can access these variables by class member // other than pass by formal parameters, it can save 1 parameter for SeekParents(). - FuncGraphToFuncGraphCounterMap* all_parents_direct_; + FuncGraphToFuncGraphCounterMap *all_parents_direct_; }; using FuncGraphToFuncGraphMap = OrderedMap; @@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap; // graph's nearest parent in parents total class ParentComputer final : public DepComputer { public: - explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ParentComputer() override = default; - FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; } + FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; } size_t size() const override { return parent_analysis_.size(); } @@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer { // graph's children graph except self class ChildrenComputer final : public DepComputer { public: - explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ChildrenComputer() override = default; - FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; } + FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; } size_t size() const override { return children_analysis_.size(); } @@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer { // graph's children graph include self class ScopeComputer final : public DepComputer { public: - explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ScopeComputer() override = default; - FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; } + FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; } size_t size() const override { return scope_analysis_.size(); } @@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap* trace); + void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace); size_t size() const override { return recursive_analysis_.size(); } @@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer { class FuncGraphJTotalComputer final : public DepComputer { public: - explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} ~FuncGraphJTotalComputer() override = default; - FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; } + FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; } size_t size() const override { return j_total_analysis_.size(); } @@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer { void ExtraReset() override { j_total_analysis_.clear(); } void RealRecompute(FuncGraphPtr fg) override; - bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path); + bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); }; class FuncGraphManager : public std::enable_shared_from_this { public: - explicit FuncGraphManager(const std::vector& roots, bool manage = true); + explicit FuncGraphManager(const std::vector &roots, bool manage = true); ~FuncGraphManager() { if (is_manage_) { RemoveRoots(); @@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this { void Init(); void Clear(); void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false); - void KeepRoots(const std::vector& roots = {}); + void KeepRoots(const std::vector &roots = {}); void RemoveRoots(); - void SetParameters(const FuncGraphPtr& fg, const std::vector& parameters); - void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false); - bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); - void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value); - void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope); + void SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters); + void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); + bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); + void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); + void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope); FuncGraphTransaction Transact(); - void CommitChanges(const std::vector& changes); + void CommitChanges(const std::vector &changes); bool IsManaged() const { return is_manage_; } - const FuncGraphSet& roots() const { return roots_; } + const FuncGraphSet &roots() const { return roots_; } - const FuncGraphSet& func_graphs() const { return func_graphs_; } + const FuncGraphSet &func_graphs() const { return func_graphs_; } - AnfNodeSet& all_nodes() { return all_nodes_; } + AnfNodeSet &all_nodes() { return all_nodes_; } - NodeUsersMap& node_users() { return node_users_; } + NodeUsersMap &node_users() { return node_users_; } - FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; } + FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } - FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } - FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const { + FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { return func_graph_child_direct_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const { + FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { return func_graph_parents_direct_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } - FVTotalMap& free_variables_total() const; + FVTotalMap &free_variables_total() const; - FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const; + FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; - FuncGraphSet& scopes(const FuncGraphPtr& fg) const; + FuncGraphSet &scopes(const FuncGraphPtr &fg) const; - FuncGraphPtr parent(const FuncGraphPtr& fg) const; + FuncGraphPtr parent(const FuncGraphPtr &fg) const; - FuncGraphSet& children(const FuncGraphPtr& fg) const; + FuncGraphSet &children(const FuncGraphPtr &fg) const; - FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const; + FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const; - bool recursive(const FuncGraphPtr& fg) const; - std::shared_ptr> recursive_graphs(const FuncGraphPtr& fg) const; + bool recursive(const FuncGraphPtr &fg) const; + std::shared_ptr> recursive_graphs(const FuncGraphPtr &fg) const; - bool func_graph_j_total(const FuncGraphPtr& fg) const; + bool func_graph_j_total(const FuncGraphPtr &fg) const; std::shared_ptr signals() const { return signals_; } - IncludeType Limit(const AnfNodePtr& node); + IncludeType Limit(const AnfNodePtr &node); // Static Analysis NodeUsersMap node_users_; @@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this { std::shared_ptr func_graph_parent_; private: - void AddIntoManaged(const FuncGraphPtr& fg); + void AddIntoManaged(const FuncGraphPtr &fg); void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); - void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction); - void AcquireNodes(const std::vector& nodes); - FuncGraphSetPtr MaybeDropNodes(const std::vector& nodes); - void ParseChanges(const std::vector& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges, - Counter* adds, Counter* rms); + void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction); + void AcquireNodes(const std::vector &nodes); + FuncGraphSetPtr MaybeDropNodes(const std::vector &nodes); + void ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, + Counter *adds, Counter *rms); FuncGraphSet roots_; // managed roots FuncGraphSet func_graphs_; // managed func graphs @@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this { class FuncGraphTransaction { public: - explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() { + explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() { MS_EXCEPTION_IF_NULL(manager_); if (!manager_->IsManaged()) { MS_LOG(DEBUG) << "The manager is not managed yet"; @@ -648,19 +648,19 @@ class FuncGraphTransaction { ~FuncGraphTransaction() { manager_ = nullptr; } // set parameters of a func graph - void SetParameters(FuncGraphPtr fg, const std::vector& params); + void SetParameters(FuncGraphPtr fg, const std::vector ¶ms); // replace old_node with new_node - bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); + bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); // set esge, i.e., declare setting node.inputs[key] to value. - void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v); + void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v); // commit all changes void Commit(); private: - FuncGraphManager* manager_; + FuncGraphManager *manager_; std::vector changes_; }; @@ -668,9 +668,9 @@ class FuncGraphTransaction { struct ArgsOfSetParams { FuncGraphPtr func_graph; std::vector params; - bool operator==(const ArgsOfSetParams& other) const { return &other == this; } + bool operator==(const ArgsOfSetParams &other) const { return &other == this; } - friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) { + friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) { os << "[ArgsOfSetParams]"; return os; } @@ -681,9 +681,9 @@ struct ArgsOfSetEdge { CNodePtr root_node; AnfNodePtr new_node; size_t index; - bool operator==(const ArgsOfSetEdge& other) const { return &other == this; } + bool operator==(const ArgsOfSetEdge &other) const { return &other == this; } - friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) { + friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) { os << "[ArgsOfSetEdge]"; return os; } @@ -693,7 +693,7 @@ struct Change { enum OpName { kTxSetParams, kTxSetEdge }; OpName op; Any args; - Change(OpName name, const Any& para) : op(name), args(para) {} + Change(OpName name, const Any ¶) : op(name), args(para) {} }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/meta_func_graph.h b/mindspore/ccsrc/ir/meta_func_graph.h index 69da925e3d..482b5f9025 100644 --- a/mindspore/ccsrc/ir/meta_func_graph.h +++ b/mindspore/ccsrc/ir/meta_func_graph.h @@ -42,25 +42,25 @@ namespace mindspore { // generate a graph corresponding to these types. class MetaFuncGraph : public FuncGraphBase { public: - explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); } + explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } ~MetaFuncGraph() override = default; MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); - abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node); + abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node); // Return normalized versions of the arguments. // By default, this returns args unchanged. - virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const { + virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { return args_spec_list; } - const std::vector& signatures() const { return signatures_; } - void set_signatures(const std::vector& signatures) { signatures_ = signatures; } + const std::vector &signatures() const { return signatures_; } + void set_signatures(const std::vector &signatures) { signatures_ = signatures; } // Generate a Graph for the given abstract arguments. - virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) { + virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { TypePtrList types; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), - [](const AbstractBasePtr& arg) -> TypePtr { + [](const AbstractBasePtr &arg) -> TypePtr { MS_EXCEPTION_IF_NULL(arg); return arg->BuildType(); }); @@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase { } // Generate a Graph for this type signature. - virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) { + virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; } @@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase { std::string ToString() const override { return name_; } std::size_t hash() const override { return tid(); } - virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; } - bool operator==(const Value& other) const override { + virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } + bool operator==(const Value &other) const override { if (other.isa()) { return &other == this; } else { diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/ccsrc/ir/meta_tensor.cc index e9221039a7..fe41abcef4 100644 --- a/mindspore/ccsrc/ir/meta_tensor.cc +++ b/mindspore/ccsrc/ir/meta_tensor.cc @@ -31,7 +31,7 @@ namespace mindspore { namespace tensor { -void DataBuf2Contiguous(const py::array& src, py::array* const dest) { +void DataBuf2Contiguous(const py::array &src, py::array *const dest) { if (dest == nullptr) { MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; } @@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) { // MetaTensor has default type_id_ which is TypeId::kTypeUnknown. MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} -MetaTensor::MetaTensor(const TypeId data_type, const std::vector& shape) : data_type_(data_type), shape_(shape) {} +MetaTensor::MetaTensor(const TypeId data_type, const std::vector &shape) : data_type_(data_type), shape_(shape) {} -MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { +MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { } } -MetaTensor::MetaTensor(const MetaTensor& meta_tensor) +MetaTensor::MetaTensor(const MetaTensor &meta_tensor) : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} -MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { +MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) { if (&meta_tensor == this) { return *this; } @@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { return *this; } -bool MetaTensor::operator==(const MetaTensor& meta_tensor) const { +bool MetaTensor::operator==(const MetaTensor &meta_tensor) const { return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); } @@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { return type_ptr; } -void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) { +void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) { DeviceInfo info(format, data_type); set_device_info(info); } @@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const { return oss.str(); } -Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { +Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -151,24 +151,27 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { init(data_type_, shape_, &data_); } -Tensor::Tensor(TypeId data_type, const std::vector& shape) { init(data_type, shape, &data_); } +Tensor::Tensor(TypeId data_type, const std::vector &shape) { init(data_type, shape, &data_); } -Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); } +Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); } -Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type) - : MetaTensor(tensor), device_address_(tensor.device_address()) { +Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type) + : MetaTensor(tensor), dirty_(tensor.dirty_), device_address_(tensor.device_address_) { init(tensor.data_, data_type); + if (device_address_ != nullptr) { + (void)data_sync(); + } } -Tensor& Tensor::operator=(const Tensor& tensor) { +Tensor &Tensor::operator=(const Tensor &tensor) { if (this != &tensor) { MetaTensor::operator=(tensor); dirty_ = tensor.is_dirty(); @@ -178,11 +181,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) { return *this; } -bool Tensor::operator==(const Tensor& tensor) const { +bool Tensor::operator==(const Tensor &tensor) const { return (MetaTensor::operator==(tensor) && data_ == tensor.data_); } -bool Tensor::ValueEqualPy(const py::object& other) const { +bool Tensor::ValueEqualPy(const py::object &other) const { if (!py::isinstance(other)) { MS_LOG(WARNING) << "compare other not a tensor"; return false; @@ -190,7 +193,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const { return ValueEqual(py::cast(other)); } -bool Tensor::ValueEqual(const Tensor& other) const { +bool Tensor::ValueEqual(const Tensor &other) const { auto equal = [&other, this]() -> bool { auto np = py::module::import("numpy"); auto equal = np.attr("equal")(data_, other.data_); @@ -218,7 +221,7 @@ int Tensor::data_type_c() const { return static_cast(data_type_); } std::vector Tensor::shape_c(void) const { return shape(); } -void* Tensor::data_c(bool writable) { +void *Tensor::data_c(bool writable) { // operand of bit operation should be unsigned int. unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; bool is_c_contiguous = (flags != 0) ? true : false; @@ -231,7 +234,7 @@ void* Tensor::data_c(bool writable) { return data_.request(writable).ptr; } -TypeId Tensor::GetDataType(const py::buffer_info& buf) const { +TypeId Tensor::GetDataType(const py::buffer_info &buf) const { TypeId data_type = TypeId::kTypeUnknown; if (buf.format.compare("e") == 0) { data_type = TypeId::kNumberTypeFloat16; @@ -263,7 +266,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const { return data_type; } -void Tensor::init(const py::array& input, const TypePtr& type_ptr) { +void Tensor::init(const py::array &input, const TypePtr &type_ptr) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -271,7 +274,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) { init(input, data_type); } -void Tensor::init(const py::array& input, const TypeId& data_type) { +void Tensor::init(const py::array &input, const TypeId &data_type) { py::buffer_info buf = input.request(); data_type_ = GetDataType(buf); @@ -301,7 +304,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) { } } -void Tensor::init(TypeId data_type, const std::vector& shape, py::array* const data) { +void Tensor::init(TypeId data_type, const std::vector &shape, py::array *const data) { data_type_ = data_type; shape_ = shape; switch (data_type) { @@ -368,7 +371,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) { return data_type_; } -bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out, +bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out, const TypeId out_data_type) { if (out == nullptr) { return false; @@ -458,7 +461,7 @@ py::array Tensor::data_sync() { return data_; } -REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { // dtype should define before Tensor, because Tensor init depend dtype (void)py::class_>(*m, "Tensor") .def(py::init(), py::arg("dtype"), py::arg("shape")) @@ -541,11 +544,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { .def("__repr__", &Tensor::ToStringRepr) .def("__eq__", &Tensor::ValueEqualPy) .def(py::pickle( - [](const Tensor& t) { // __getstate__ + [](const Tensor &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(t.data()); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/ccsrc/ir/meta_tensor.h index 3e28f29f37..1f6c866f11 100644 --- a/mindspore/ccsrc/ir/meta_tensor.h +++ b/mindspore/ccsrc/ir/meta_tensor.h @@ -131,16 +131,16 @@ class MetaTensor : public Value { // information of a Tensor. The following codes will create a 2x3 float // param data_type The data type of the tensor. // param shape The shape of the tensor. - MetaTensor(const TypeId data_type, const std::vector& shape); + MetaTensor(const TypeId data_type, const std::vector &shape); - MetaTensor(const TypePtr& type_ptr, const py::tuple& shape); + MetaTensor(const TypePtr &type_ptr, const py::tuple &shape); // brief Constructs a MetaTensor object from an existing MetaTensor instance. // // The constructed MetaTensor object will have the same data type and shape as the // meta_tensor. // // param meta_tensor An existing MetaTensor object. - MetaTensor(const MetaTensor& meta_tensor); + MetaTensor(const MetaTensor &meta_tensor); ~MetaTensor() override = default; MS_DECLARE_PARENT(MetaTensor, Value) @@ -149,7 +149,7 @@ class MetaTensor : public Value { // The constructed MetaTensor object has the same type and shape with meta_tensor. // // param meta_tensor An existing MetaTensor object. - virtual MetaTensor& operator=(const MetaTensor& meta_tensor); + virtual MetaTensor &operator=(const MetaTensor &meta_tensor); // brief Compares two MetaTensor objects. // @@ -157,7 +157,7 @@ class MetaTensor : public Value { // // param meta_tensor The MetaTensor object to be compared. // return true: If having same type and shape, return true, or return false. - virtual bool operator==(const MetaTensor& meta_tensor) const; + virtual bool operator==(const MetaTensor &meta_tensor) const; // brief Returns the data type of the tensor in its MetaTensor. // @@ -193,7 +193,7 @@ class MetaTensor : public Value { // // param shape The shape of the tensor. // return The shape's size. - size_t set_shape(const std::vector& shape) { + size_t set_shape(const std::vector &shape) { this->shape_ = shape; return shape_.size(); } @@ -202,9 +202,9 @@ class MetaTensor : public Value { DeviceInfo device_info() const { return device_info_; } // Set tensor's device info. - void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; } + void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } - void SetDeviceInfo(const std::string& format, const TypePtr& data_type); + void SetDeviceInfo(const std::string &format, const TypePtr &data_type); // Get the size of a given dimension by its index number. int DimensionSize(size_t index) const; @@ -222,9 +222,9 @@ class MetaTensor : public Value { } return hash_value; } - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -262,49 +262,49 @@ class Tensor : public MetaTensor { // // param type_ptr [TypePty] Data type of the tensor. // param py_shape [py::tuple] The shape represented by py::tuple of the tensor. - Tensor(const TypePtr& type_ptr, const py::tuple& shape); + Tensor(const TypePtr &type_ptr, const py::tuple &shape); // brief Constructor for C++. // // param data_type [TypeId] Data type of the tensor. // param shape The shape represented by std::vector of the tensor. - Tensor(TypeId data_type, const std::vector& shape); + Tensor(TypeId data_type, const std::vector &shape); // brief Constructor for Python. // // param input [py::array] Data value of the tensor. // param data_type [TypeId] Data type of the tensor. - explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::list] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::tuple] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::float_] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::int_] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [Tensor] the data for tensor // param data_type [TypeId] data type - Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr); + Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr); ~Tensor() override = default; @@ -315,7 +315,7 @@ class Tensor : public MetaTensor { // The constructed Tensor object has the same type and shape with tensor. // // param tensor An existing Tensor object. - Tensor& operator=(const Tensor& tensor); + Tensor &operator=(const Tensor &tensor); // brief Compares two Tensor objects. // @@ -324,17 +324,17 @@ class Tensor : public MetaTensor { // // param tensor The Tensor object to be compared. // return true: If having same type, shape and data, return true, or return false. - bool operator==(const Tensor& tensor) const; + bool operator==(const Tensor &tensor) const; // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqual(const Tensor& other) const; + bool ValueEqual(const Tensor &other) const; // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqualPy(const py::object& other) const; + bool ValueEqualPy(const py::object &other) const; - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -375,13 +375,13 @@ class Tensor : public MetaTensor { // // param writable true if writable, false if read only // return The pointer to the object - void* data_c(bool writable = false); + void *data_c(bool writable = false); // brief Get data type from tensor data. // // param buf The buffer info of the py::array data. // return The [TypeId] of the tensor data. - TypeId GetDataType(const py::buffer_info& buf) const; + TypeId GetDataType(const py::buffer_info &buf) const; // brief Sets the data type of a tensor. // @@ -401,23 +401,23 @@ class Tensor : public MetaTensor { // param input [py::array] the data for tensor // param data_type [TypeId] data type // return true if succeed, false if failed. - void init(const py::array& input, const TypeId& data_type); - void init(const py::array& input, const TypePtr& type_ptr); + void init(const py::array &input, const TypeId &data_type); + void init(const py::array &input, const TypePtr &type_ptr); // brief init tensor attribute // // param data_type [TypeId] Data type of the tensor. // param shape [py::array] The shape of the tensor. // return true if succeed, false if failed. - void init(TypeId data_type, const std::vector& shape, py::array* data); + void init(TypeId data_type, const std::vector &shape, py::array *data); - bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type); + bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type); public: bool is_dirty() const { return dirty_; } void set_dirty(const bool dirty) { dirty_ = dirty; } DeviceAddressPtr device_address() const { return device_address_; } - void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; } + void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } py::array data_sync(); private: diff --git a/mindspore/ccsrc/ir/named.cc b/mindspore/ccsrc/ir/named.cc index 3d12e8a453..0a679e6011 100644 --- a/mindspore/ccsrc/ir/named.cc +++ b/mindspore/ccsrc/ir/named.cc @@ -18,9 +18,9 @@ #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { -bool Named::operator==(const Value& other) const { +bool Named::operator==(const Value &other) const { if (other.isa()) { - auto other_named = static_cast(other); + auto other_named = static_cast(other); return *this == other_named; } else { return false; @@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared(); abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared(); } -const NamedPtr kNullObj = std::make_shared(); +const NamedPtr kNull = std::make_shared(); + +abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared(); } +const NamedPtr kEllipsis = std::make_shared(); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/named.h b/mindspore/ccsrc/ir/named.h index 0651307a91..2d679c58b1 100644 --- a/mindspore/ccsrc/ir/named.h +++ b/mindspore/ccsrc/ir/named.h @@ -27,18 +27,18 @@ namespace mindspore { class Named : public Value { public: - explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash{}(name); } - Named(const Named& other) : Value(other) { + explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash{}(name); } + Named(const Named &other) : Value(other) { this->name_ = other.name_; hash_id_ = std::hash{}(other.name_); } ~Named() override = default; MS_DECLARE_PARENT(Named, Value); - const std::string& name() const { return name_; } - virtual bool operator==(const Named& other) const { return name_ == other.name(); } - bool operator==(const Value& other) const override; - Named& operator=(const Named& other) { + const std::string &name() const { return name_; } + virtual bool operator==(const Named &other) const { return name_ == other.name(); } + bool operator==(const Value &other) const override; + Named &operator=(const Named &other) { if (&other != this) { this->type_ = other.type_; this->name_ = other.name_; @@ -50,7 +50,7 @@ class Named : public Value { std::size_t Hash() const { return hash_id_; } std::size_t hash() const override { return hash_id_; } - friend std::ostream& operator<<(std::ostream& os, const Named& nmd) { + friend std::ostream &operator<<(std::ostream &os, const Named &nmd) { os << nmd.name(); return os; } @@ -61,7 +61,6 @@ class Named : public Value { std::string name_; std::size_t hash_id_; }; - using NamedPtr = std::shared_ptr; class None : public Named { @@ -71,7 +70,6 @@ class None : public Named { MS_DECLARE_PARENT(None, Named); abstract::AbstractBasePtr ToAbstract() override; }; - extern const NamedPtr kNone; class NullObj : public Named { @@ -81,7 +79,15 @@ class NullObj : public Named { MS_DECLARE_PARENT(NullObj, Named); abstract::AbstractBasePtr ToAbstract() override; }; +extern const NamedPtr kNull; -extern const NamedPtr kNullObj; +class EllipsisObj : public Named { + public: + EllipsisObj() : Named("Ellipsis") {} + ~EllipsisObj() override = default; + MS_DECLARE_PARENT(EllipsisObj, Named); + abstract::AbstractBasePtr ToAbstract() override; +}; +extern const NamedPtr kEllipsis; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_NAMED_H_ diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index a576c1e76b..d40f8a265d 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -31,7 +31,7 @@ namespace mindspore { using mindspore::abstract::AbstractFunction; -abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) { +abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { auto prim_func = std::make_shared(shared_from_base(), anf_node); return prim_func; } @@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() { return fn; } -bool Primitive::operator==(const Value& other) const { +bool Primitive::operator==(const Value &other) const { if (other.isa()) { - auto other_prim = static_cast(other); + auto other_prim = static_cast(other); return *this == other_prim; } else { return false; } } -bool Primitive::operator==(const Primitive& other) const { +bool Primitive::operator==(const Primitive &other) const { if (name() != other.name()) { return false; } if (attrs_.size() != other.attrs_.size()) { return false; } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair& item) -> bool { + auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { if (item.second == nullptr) { return false; } @@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const { void Primitive::set_signatures( std::vector> signatures) { signatures_.clear(); - for (auto& signature : signatures) { + for (auto &signature : signatures) { std::string name; SignatureEnumRW rw; SignatureEnumKind kind; @@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const { std::ostringstream oss; oss << "["; bool is_first = true; - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { if (is_first) { is_first = false; } else { @@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const { } py::function PrimitivePy::GetBpropFunction() { - static const char* const get_bprop_func_name = "get_bprop"; + static const char *const get_bprop_func_name = "get_bprop"; if (py::hasattr(python_obj_, get_bprop_func_name)) { py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); return fn; @@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() { } py::function PrimitivePy::GetComputeFunction() { - static const char* const compute_func_name = "vm_impl"; + static const char *const compute_func_name = "vm_impl"; if (py::hasattr(python_obj_, compute_func_name)) { MS_LOG(INFO) << "" << name() << " compute_func_name"; @@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() { return vm_fn; } -void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { +void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { std::string attr_name = name; ValuePtr converted_ret = nullptr; if (py::isinstance(obj)) { @@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { py::dict PrimitivePy::GetAttrDict() { py::dict attr_dict; - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); } return attr_dict; } -REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) .value("builtin", PrimType::kPrimTypeBuiltIn) @@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { .value("user_custom", PrimType::kPrimTypeUserCustom); (void)py::class_>(*m, "Primitive_") .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) - .def(py::init()) + .def(py::init()) .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 7dd37eb15f..73941c1058 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -48,25 +48,25 @@ enum PrimType { class Primitive : public Named { public: - explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn) + explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn) : Named(name), signatures_(), prim_type_(prim_type) {} - Primitive(const Primitive& prim) + Primitive(const Primitive &prim) : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} MS_DECLARE_PARENT(Primitive, Named); - abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node); + abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); std::string ToString() const override { return name(); } virtual py::function GetBpropFunction(); virtual py::function GetComputeFunction(); - Primitive& AddAttr(const std::string& name, const ValuePtr& attr) { + Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; return *this; } - Primitive& SetAttrs(const std::unordered_map& attrs) { - for (auto& attr : attrs) { + Primitive &SetAttrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { attrs_[attr.first] = attr.second; } return *this; @@ -76,21 +76,21 @@ class Primitive : public Named { std::vector> signatures); - const std::vector& signatures() const { return signatures_; } + const std::vector &signatures() const { return signatures_; } - void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; } - void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); } + void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } + void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - ValuePtr GetAttr(const std::string& attrName) const { + ValuePtr GetAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); return iter == attrs_.cend() ? nullptr : iter->second; } - const std::unordered_map& attrs() const { return attrs_; } + const std::unordered_map &attrs() const { return attrs_; } // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. bool HasAttr() const { return !attrs_.empty(); } - bool HasAttr(const std::string& attrName) const { + bool HasAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); return !(iter == attrs_.cend()); } @@ -103,8 +103,8 @@ class Primitive : public Named { PrimType prim_type() const { return prim_type_; } std::string instance_name() const { return instance_name_; } std::string GetAttrsText() const; - bool operator==(const Value& other) const override; - bool operator==(const Primitive& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Primitive &other) const; ~Primitive() override = default; protected: @@ -118,18 +118,18 @@ class Primitive : public Named { class PrimitivePy : public Primitive { public: - PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {} + PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {} ~PrimitivePy() override = default; MS_DECLARE_PARENT(PrimitivePy, Primitive); py::function GetBpropFunction() override; py::function GetComputeFunction() override; - void AddPyAttr(const py::str& name, const py::object& obj); + void AddPyAttr(const py::str &name, const py::object &obj); py::dict GetAttrDict(); const bool parse_info_ = true; - const py::object& GetPyObj() const { return python_obj_; } + const py::object &GetPyObj() const { return python_obj_; } bool is_tuple_input_ = false; private: @@ -138,13 +138,13 @@ class PrimitivePy : public Primitive { using PrimitivePyPtr = std::shared_ptr; -inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) { +inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { os << *p; return os; } struct PrimitiveEqual { - bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t2); return t1->name() == t2->name(); @@ -152,10 +152,7 @@ struct PrimitiveEqual { }; struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const& prim) const { - std::size_t hash = std::hash()(prim->name()); - return hash; - } + std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } }; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ diff --git a/mindspore/ccsrc/ir/scalar.h b/mindspore/ccsrc/ir/scalar.h index 3e0a827b07..ab6c485540 100644 --- a/mindspore/ccsrc/ir/scalar.h +++ b/mindspore/ccsrc/ir/scalar.h @@ -55,8 +55,8 @@ class BoolImm : public Scalar { bool value() const { return v_; } bool IsZero() override { return v_ == false; } bool IsOne() override { return v_ == true; } - bool operator==(const Value& other) const override; - bool operator==(const BoolImm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const BoolImm &other) const; std::string ToString() const override { if (v_) { return "true"; @@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool) class IntergerImm : public Scalar { public: IntergerImm() = default; - explicit IntergerImm(const TypePtr& t) : Scalar(t) {} + explicit IntergerImm(const TypePtr &t) : Scalar(t) {} ~IntergerImm() override = default; MS_DECLARE_PARENT(IntergerImm, Scalar) }; @@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int8_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int8Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int16_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int16Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int32_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int64_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint8_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt8Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint16_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt16Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint32_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint64_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t); class FloatImm : public Scalar { public: FloatImm() = default; - explicit FloatImm(const TypePtr& t) : Scalar(t) {} + explicit FloatImm(const TypePtr &t) : Scalar(t) {} ~FloatImm() override = default; MS_DECLARE_PARENT(FloatImm, Scalar) }; @@ -312,8 +312,8 @@ class FP32Imm : public FloatImm { bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } float value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const FP32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const FP32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -338,8 +338,8 @@ class FP64Imm : public FloatImm { bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } double value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const FP64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const FP64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { diff --git a/mindspore/ccsrc/ir/signature.cc b/mindspore/ccsrc/ir/signature.cc index b7eec921d4..8f312d5b98 100644 --- a/mindspore/ccsrc/ir/signature.cc +++ b/mindspore/ccsrc/ir/signature.cc @@ -21,8 +21,8 @@ #include "pipeline/parse/data_converter.h" namespace mindspore { -Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, - const py::object& arg_default, const SignatureEnumDType& arg_dtype) +Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, + const py::object &arg_default, const SignatureEnumDType &arg_dtype) : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { if (py::isinstance(arg_default) && py::cast(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { @@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, } } -Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind) +Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) : name(arg_name), rw(rw_tag), kind(arg_kind), default_value(nullptr), dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} -REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { (void)py::enum_(*m, "signature_rw", py::arithmetic()) .value("RW_READ", SignatureEnumRW::kRWRead) .value("RW_WRITE", SignatureEnumRW::kRWWrite) diff --git a/mindspore/ccsrc/ir/signature.h b/mindspore/ccsrc/ir/signature.h index 8e7409ab26..48be7e0f31 100644 --- a/mindspore/ccsrc/ir/signature.h +++ b/mindspore/ccsrc/ir/signature.h @@ -61,9 +61,9 @@ struct Signature { SignatureEnumKind kind; ValuePtr default_value; // nullptr for no default value SignatureEnumDType dtype; - Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, - const py::object& arg_default, const SignatureEnumDType& arg_dtype); - Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind); + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, + const py::object &arg_default, const SignatureEnumDType &arg_dtype); + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value.cc b/mindspore/ccsrc/ir/value.cc index f9e8abaee9..e386e1ffd2 100644 --- a/mindspore/ccsrc/ir/value.cc +++ b/mindspore/ccsrc/ir/value.cc @@ -24,7 +24,7 @@ #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { -const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const { +const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; } @@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) { } } -bool BoolImm::operator==(const Value& other) const { +bool BoolImm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; } +bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; } -bool Int8Imm::operator==(const Value& other) const { +bool Int8Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; } -bool Int16Imm::operator==(const Value& other) const { +bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; } +bool Int16Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; } -bool Int32Imm::operator==(const Value& other) const { +bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; } +bool Int32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; } -bool Int64Imm::operator==(const Value& other) const { +bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; } +bool Int64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; } -bool UInt8Imm::operator==(const Value& other) const { +bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; } +bool UInt8Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; } -bool UInt16Imm::operator==(const Value& other) const { +bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; } +bool UInt16Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; } -bool UInt32Imm::operator==(const Value& other) const { +bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; } +bool UInt32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; } -bool UInt64Imm::operator==(const Value& other) const { +bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; } +bool UInt64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; } -bool FP32Imm::operator==(const Value& other) const { +bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; } +bool FP32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } -bool FP64Imm::operator==(const Value& other) const { +bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } +bool FP64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSequeue::operator==(const Value& other) const { +bool ValueSequeue::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSequeue::operator==(const ValueSequeue& other) const { +bool ValueSequeue::operator==(const ValueSequeue &other) const { if (other.elements_.size() != elements_.size()) { return false; } return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), - [](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; }); + [](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; }); } std::string ValueSequeue::ToString() const { std::ostringstream buffer; bool begin = true; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const { return oss.str(); } -bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } -bool StringImm::operator==(const Value& other) const { +bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } +bool StringImm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; } +bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } -bool RefKey::operator==(const Value& other) const { +bool RefKey::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; } +bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } -bool AnyValue::operator==(const Value& other) const { +bool AnyValue::operator==(const Value &other) const { if (other.isa()) { return true; } else { @@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_sharedToAbstract(); }); @@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() { abstract::AbstractBasePtr ValueList::ToAbstract() { abstract::AbstractBasePtrList a_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { MS_EXCEPTION_IF_NULL(ele); return ele->ToAbstract(); }); @@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const { return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); } -bool ValueSlice::operator==(const Value& other) const { +bool ValueSlice::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSlice::operator==(const ValueSlice& other) const { +bool ValueSlice::operator==(const ValueSlice &other) const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); @@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const { return hash_combine({tid(), std::hash{}(key_), value_->hash()}); } -bool KeywordArg::operator==(const Value& other) const { +bool KeywordArg::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); } +bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); } std::string KeywordArg::ToString() const { std::ostringstream buffer; @@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() { return std::make_shared(key_, argument); } -const ValuePtr ValueDictionary::operator[](const std::string& key) const { +const ValuePtr ValueDictionary::operator[](const std::string &key) const { auto it = std::find_if(key_values_.begin(), key_values_.end(), - [key](const std::pair& item) { return item.first == key; }); + [key](const std::pair &item) { return item.first == key; }); if (it == key_values_.end()) { MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; } return it->second; } -bool ValueDictionary::operator==(const Value& other) const { +bool ValueDictionary::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueDictionary::operator==(const ValueDictionary& other) const { +bool ValueDictionary::operator==(const ValueDictionary &other) const { if (key_values_.size() != other.key_values_.size()) { return false; } @@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() { std::vector> kv; (void)std::transform( key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair& item) { return std::make_pair(item.first, item.second->ToAbstract()); }); + [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); return std::make_shared(kv); } REGISTER_PYBIND_DEFINE( - RefKey, ([](const py::module* m) { + RefKey, ([](const py::module *m) { (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); })); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value.h b/mindspore/ccsrc/ir/value.h index 85f514b57b..c80e22f735 100644 --- a/mindspore/ccsrc/ir/value.h +++ b/mindspore/ccsrc/ir/value.h @@ -35,19 +35,19 @@ namespace mindspore { class ValueSequeue : public Value { public: - explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) { + explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) { TypePtrList t_list; - (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) { + (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) { MS_EXCEPTION_IF_NULL(ele); return ele->type(); }); TypePtr t = std::make_shared(t_list); type_ = t; } - ValueSequeue(const std::initializer_list& elements) : elements_(elements.begin(), elements.end()) { + ValueSequeue(const std::initializer_list &elements) : elements_(elements.begin(), elements.end()) { TypePtrList t_list; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), - [](const ValuePtr& ele) { return ele->type(); }); + [](const ValuePtr &ele) { return ele->type(); }); TypePtr t = std::make_shared(t_list); type_ = t; } @@ -56,10 +56,10 @@ class ValueSequeue : public Value { std::size_t hash() const override { return hash_combine(tid(), std::hash{}(elements_.size())); } std::size_t size() const { return elements_.size(); } bool erase(size_t idx); - const ValuePtr operator[](const std::size_t& dim) const; - const ValuePtrList& value() const { return elements_; } - bool operator==(const Value& other) const override; - bool operator==(const ValueSequeue& other) const; + const ValuePtr operator[](const std::size_t &dim) const; + const ValuePtrList &value() const { return elements_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueSequeue &other) const; std::string ToString() const override; std::string DumpText() const override; @@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr; class ValueTuple : public ValueSequeue { public: - explicit ValueTuple(const std::vector& elements) : ValueSequeue(elements) {} - ValueTuple(const std::initializer_list& elements) : ValueSequeue(elements) {} + explicit ValueTuple(const std::vector &elements) : ValueSequeue(elements) {} + ValueTuple(const std::initializer_list &elements) : ValueSequeue(elements) {} ~ValueTuple() override = default; MS_DECLARE_PARENT(ValueTuple, ValueSequeue) abstract::AbstractBasePtr ToAbstract() override; @@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr; class ValueList : public ValueSequeue { public: - explicit ValueList(const std::vector& elements) : ValueSequeue(elements) {} - ValueList(const std::initializer_list& elements) : ValueSequeue(elements) {} + explicit ValueList(const std::vector &elements) : ValueSequeue(elements) {} + ValueList(const std::initializer_list &elements) : ValueSequeue(elements) {} ~ValueList() override = default; MS_DECLARE_PARENT(ValueList, ValueSequeue) abstract::AbstractBasePtr ToAbstract() override; @@ -94,7 +94,7 @@ class ValueList : public ValueSequeue { }; using ValueListPtr = std::shared_ptr; -inline ValuePtr MakeValue(const std::vector& v) { return std::make_shared(v); } +inline ValuePtr MakeValue(const std::vector &v) { return std::make_shared(v); } inline ValuePtr MakeValue(std::initializer_list v) { return std::make_shared(v); } template @@ -103,7 +103,7 @@ template struct is_vector> : public std::true_type {}; template ::value, typename T::value_type>::type> -ValuePtr MakeValue(const T& vec) { +ValuePtr MakeValue(const T &vec) { std::vector list; (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); return std::make_shared(list); @@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) { class ValueSlice : public Value { public: - ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step) + ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step) : start_(start), stop_(stop), step_(step) {} ~ValueSlice() override = default; MS_DECLARE_PARENT(ValueSlice, Value) std::size_t hash() const override; - bool operator==(const Value& other) const override; - bool operator==(const ValueSlice& other) const; + bool operator==(const Value &other) const override; + bool operator==(const ValueSlice &other) const; std::string ToString() const override; @@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr; class KeywordArg : public Value { public: - KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {} + KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {} ~KeywordArg() override = default; MS_DECLARE_PARENT(KeywordArg, Value) std::size_t hash() const override; ValuePtr get_value() const { return value_; } - bool operator==(const Value& other) const override; - bool operator==(const KeywordArg& other) const; + bool operator==(const Value &other) const override; + bool operator==(const KeywordArg &other) const; std::string ToString() const override; @@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr; class ValueDictionary : public Value { public: - explicit ValueDictionary(const std::vector>& key_values) : key_values_(key_values) {} + explicit ValueDictionary(const std::vector> &key_values) : key_values_(key_values) {} ~ValueDictionary() override = default; MS_DECLARE_PARENT(ValueDictionary, Value) std::size_t hash() const override { return hash_combine(tid(), std::hash{}(key_values_.size())); } std::size_t size() const { return key_values_.size(); } - const ValuePtr operator[](const std::string& key) const; - const std::vector>& value() const { return key_values_; } - bool operator==(const Value& other) const override; - bool operator==(const ValueDictionary& other) const; + const ValuePtr operator[](const std::string &key) const; + const std::vector> &value() const { return key_values_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueDictionary &other) const; std::string ToString() const override { std::ostringstream buffer; std::vector keys; std::vector values; - for (const auto& kv : key_values_) { + for (const auto &kv : key_values_) { keys.push_back(kv.first); values.push_back(kv.second); } buffer << "(Dict: " << " keys:("; - for (const auto& key : keys) { + for (const auto &key : keys) { buffer << key << ", "; } buffer << ") values:("; - for (const auto& value : values) { + for (const auto &value : values) { MS_EXCEPTION_IF_NULL(value); buffer << value->DumpText() << ", "; } @@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr; class StringImm : public Value { public: - explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} + explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} ~StringImm() override = default; MS_DECLARE_PARENT(StringImm, Value) std::size_t hash() const override { return hash_; } - const std::string& value() const { return str_; } - bool operator==(const Value& other) const override; - bool operator==(const StringImm& other) const; + const std::string &value() const { return str_; } + bool operator==(const Value &other) const override; + bool operator==(const StringImm &other) const; abstract::AbstractBasePtr ToAbstract() override; std::string ToString() const override { return str_; } @@ -218,18 +218,18 @@ class StringImm : public Value { }; using StringImmPtr = std::shared_ptr; IMM_TRAITS(StringImmPtr, std::string) -IMM_TRAITS(StringImmPtr, const char*) +IMM_TRAITS(StringImmPtr, const char *) class RefKey : public Value { public: - explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} + explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} ~RefKey() override = default; MS_DECLARE_PARENT(RefKey, Value) std::size_t hash() const override { return hash_; } - const std::string& tag() const { return tag_; } - bool operator==(const Value& other) const override; - bool operator==(const RefKey& other) const; + const std::string &tag() const { return tag_; } + bool operator==(const Value &other) const override; + bool operator==(const RefKey &other) const; abstract::AbstractBasePtr ToAbstract() override; std::string ToString() const override { return "RefKey[" + tag_ + "]"; } @@ -251,13 +251,13 @@ class AnyValue : public Value { ~AnyValue() override = default; MS_DECLARE_PARENT(AnyValue, Value) std::size_t hash() const override { return tid(); } - bool operator==(const Value& other) const override; + bool operator==(const Value &other) const override; abstract::AbstractBasePtr ToAbstract() override; }; extern const ValuePtr kAnyValue; template <> -inline const char* GetValue(const ValuePtr& value) { +inline const char *GetValue(const ValuePtr &value) { if (value == nullptr) { MS_LOG(EXCEPTION) << "Value is nullptr"; } @@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) { template ::type, typename U = typename std::enable_if::value, typename S::value_type>::type> -std::vector GetValue(const ValuePtr& value) { +std::vector GetValue(const ValuePtr &value) { if (value == nullptr) { MS_LOG(EXCEPTION) << "Value is nullptr"; } @@ -280,21 +280,21 @@ std::vector GetValue(const ValuePtr& value) { << ">"; } std::vector rets; - const std::vector& vals = value->cast()->value(); + const std::vector &vals = value->cast()->value(); (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), - [](const ValuePtr& v) { return GetValue(v); }); + [](const ValuePtr &v) { return GetValue(v); }); return rets; } -inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared(t); } +inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared(t); } template ::value>::type> -inline ValueNodePtr NewValueNode(const std::shared_ptr& x) { +inline ValueNodePtr NewValueNode(const std::shared_ptr &x) { return NewValueNode(MakeValue(x)); } template ::value>::type> -inline ValueNodePtr NewValueNode(const T& x) { +inline ValueNodePtr NewValueNode(const T &x) { return NewValueNode(MakeValue(x)); } } // namespace mindspore diff --git a/mindspore/ccsrc/ir/visitor.h b/mindspore/ccsrc/ir/visitor.h index 5305d1fe85..e771f7ad28 100644 --- a/mindspore/ccsrc/ir/visitor.h +++ b/mindspore/ccsrc/ir/visitor.h @@ -22,15 +22,15 @@ #include "optimizer/opt.h" namespace mindspore { -using VisitFuncType = std::function; +using VisitFuncType = std::function; class AnfVisitor { public: - virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&); - virtual void Visit(const AnfNodePtr&); - virtual void Visit(const CNodePtr&); - virtual void Visit(const ValueNodePtr&); - virtual void Visit(const ParameterPtr&); - VisitFuncType Match(const PrimitivePtr&, const std::vector& = {}); + virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &); + virtual void Visit(const AnfNodePtr &); + virtual void Visit(const CNodePtr &); + virtual void Visit(const ValueNodePtr &); + virtual void Visit(const ParameterPtr &); + VisitFuncType Match(const PrimitivePtr &, const std::vector & = {}); virtual ~AnfVisitor() = default; }; } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index 808e87edc0..d6217ff1cc 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -162,18 +162,17 @@ void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *p ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); for (const auto &attr_ptr : attrs_ptr) { std::string attr_name = attr_ptr->name(); - std::string real_name; auto value = primitive->GetAttr(attr_name); if (value != nullptr) { if (attr_name == kQueueName || attr_name == kSharedName) { - real_name = kChannelName; + attr_name = kChannelName; } else if (attr_name == kSeed) { - real_name = "seed"; + attr_name = "seed"; } else if (attr_name == kSeed2) { - real_name = "seed2"; + attr_name = "seed2"; } std::string type = attr_ptr->type(); - ParseAttrValue(type, real_name, value, node_attr); + ParseAttrValue(type, attr_name, value, node_attr); } } MS_LOG(INFO) << "Set node attr end!"; @@ -182,7 +181,7 @@ void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *p void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); if (input_num == 0) { - MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input. "; + MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; return; } diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc index 316df63922..a617f56f8f 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc @@ -27,6 +27,7 @@ namespace kernel { static std::map MS_PROTO_DATA_TYPE_MAP = { {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, + {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, {mindspore::TypeId::kNumberTypeInt8, mindspore::DataType::MS_INT8}, {mindspore::TypeId::kNumberTypeInt16, mindspore::DataType::MS_INT16}, {mindspore::TypeId::kNumberTypeInt32, mindspore::DataType::MS_INT32}, @@ -34,8 +35,10 @@ static std::map MS_PROTO_DATA_TYPE_MAP = { {mindspore::TypeId::kNumberTypeUInt, mindspore::DataType::MS_UINT32}, {mindspore::TypeId::kNumberTypeUInt8, mindspore::DataType::MS_UINT8}, {mindspore::TypeId::kNumberTypeUInt16, mindspore::DataType::MS_UINT16}, + {mindspore::TypeId::kNumberTypeUInt32, mindspore::DataType::MS_UINT32}, {mindspore::TypeId::kNumberTypeUInt64, mindspore::DataType::MS_UINT64}, {mindspore::TypeId::kNumberTypeFloat16, mindspore::DataType::MS_FLOAT16}, + {mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32}, {mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32}, {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, }; diff --git a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h index 198e8687fc..1c9cf925ea 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h @@ -44,7 +44,7 @@ class TransposeGpuFwdKernel : public GpuKernel { "cudaMemcpyAsync input_shape failed"); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcphalfyAsync input_axis failed"); + "cudaMemcpyAsync input_axis failed"); int size = SizeToInt(input_size_ / sizeof(T)); CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, reinterpret_cast(stream_ptr)); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu deleted file mode 100644 index a3d2e3558c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "cross_entropy_cuda_impl.cuh" -#include "include/cuda_runtime.h" - -__global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits, - const float *labels, const int batch_size, const int num_classes, - float *loss, float *dx) { - extern __shared__ float loss_shared[]; - const float mean_scale = 1.0f / static_cast(batch_size); - - loss_shared[threadIdx.x] = 0; - for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) { - loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i]; - dx[i] = (softmax_logits[i] - labels[i]) * mean_scale; - } - __syncthreads(); - if (threadIdx.x == 0) { - *loss = 0; - for (int i = 0; i < batch_size; i++) { - *loss += loss_shared[i]; - } - *loss *= mean_scale; - } -} - -void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, - const int batch_size, const int num_classes, float *loss, float *dx, - cudaStream_t cuda_stream) { - CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>( - softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx); -} diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu index 4d0503ba97..11c16581d6 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu @@ -52,38 +52,12 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label } template -__global__ void CrossEntropyWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, - const size_t class_num, T *losses) { - T epsilon = 1e-6; - for (size_t i = 0; i < batch_size; ++i) { - T logit = 0.0; - for (size_t j = 0; j < class_num; j++) { - if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { - logit = logits[i * class_num + j]; - break; - } - } - if (logit <= 0) { - logit += epsilon; - } - losses[i] = -logf(logit); +__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { + losses[threadIdx.x] = 0; + for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { + losses[threadIdx.x] -= logf(logits[i]) * labels[i]; + dlogits[i] = logits[i] - labels[i]; } - return; -} - -template -__global__ void CrossEntropyGradWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, - const size_t class_num, T *grad) { - for (size_t i = 0; i < batch_size; i++) { - for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) { - if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { - grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size; - } else { - grad[i * class_num + j] = logits[i * class_num + j] / batch_size; - } - } - } - return; } template @@ -102,18 +76,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b } template -void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *losses, cudaStream_t cuda_stream) { - CrossEntropyWithoutSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, losses); - return; -} - -template -void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream) { - CrossEntropyGradWithoutSparseKernel<<>>( - logits, labels, batch_size, class_num, grad); - return; +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream) { + CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits); } template void CrossEntropyWithSparse(const float *logits, const int *labels, const size_t batch_size, @@ -126,8 +91,6 @@ template void CrossEntropyGradWithSparse(const float *logits, const template void CrossEntropyGradWithSparse(const float *logits, const int64_t *labels, const size_t batch_size, const size_t class_num, float *grad, cudaStream_t cuda_stream); -template void CrossEntropyWithoutSparse(const float *logits, const float *labels, const size_t batch_size, - const size_t class_num, float *losses, cudaStream_t cuda_stream); -template void CrossEntropyGradWithoutSparse(const float *logits, const float *labels, - const size_t batch_size, const size_t class_num, float *grad, - cudaStream_t cuda_stream); +template void CrossEntropy(const float *logits, const float *labels, const size_t batch_size, + const size_t class_num, float *losses, float *dlogits, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh index 00ec13553d..54ae072892 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh @@ -28,11 +28,6 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b T *grad, cudaStream_t cuda_stream); template -void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *losses, cudaStream_t cuda_stream); - -template -void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream); - +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu new file mode 100644 index 0000000000..c2fd5ecd70 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/cuda_runtime.h" +#include "kernel/gpu/cuda_impl/float_status_impl.cuh" + +template +__global__ void IsNan(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsNan(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsInf(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsInf(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsFinite(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) == 0 && !isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsFinite(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void FloatStatus(const size_t size, const T* input, T* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0 || isnan(input[pos])) { + out[0] = 1; + } + } + return; +} +template <> +__global__ void FloatStatus(const size_t size, const half* input, half* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { + out[0] = 1; + } + } + return; +} + +template +void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { + FloatStatus<<>>(size, input, output); + return; +} +template +void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsNan<<>>(size, input, output); + return; +} +template +void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsInf<<>>(size, input, output); + return; +} +template +void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsFinite<<>>(size, input, output); + return; +} + +template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); +template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh similarity index 51% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh rename to mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh index 25b1624a46..da488ff937 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh @@ -14,13 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ - +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ #include "device/gpu/cuda_common.h" - -void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, - const int batch_size, const int num_classes, float *loss, float *dx, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ +template +void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); +template +void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu index 3cebefec17..5e7a25b8e6 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu @@ -53,6 +53,21 @@ __global__ void ReciprocalKernel(T *input, T *output, size_t count) { return; } template +__global__ void SquareKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = input[i] * input[i]; + } + return; +} +template +__global__ void ZeroslikeKernel(T *output, size_t count) { + T zero = 0.0; + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = zero; + } + return; +} +template void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -72,12 +87,26 @@ void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ReciprocalKernel<<>>(input, output, count); return; } +template +void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + SquareKernel<<>>(input, output, count); + return; +} +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { + ZeroslikeKernel<<>>(output, count); + return; +} template void Exponential(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Negative(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Square(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(float *output, size_t count, cudaStream_t cuda_stream); template void Exponential(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Negative(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Square(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh index 2e7227eb32..8ba9cb4a52 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh @@ -26,5 +26,9 @@ template void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc index 21f5d084a9..e38cc02e23 100644 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc @@ -41,8 +41,9 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel size_t attr_index) { if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddInputAttr(kernel_info->GetInputDeviceType(0)); + (void)iter_second->at(attr_index).first.AddInputAttr(dtype); } } else { MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; @@ -50,8 +51,9 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel } if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddOutputAttr(kernel_info->GetOutputDeviceType(0)); + (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); } } else { MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; @@ -94,9 +96,13 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & bool flag = true; // data type matching check of all input parameters of kernel for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { - if (marjor_sm < MINIUM_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { - MS_LOG(EXCEPTION) << "Half precision op can be used on Devices which compute capacity is above " << MINIUM_SM - << ", but your device's compute capacity is " << marjor_sm; + if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { + if (marjor_sm < MINIUM_SM) { + MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM + << ", but the current device's computing capacity is " << marjor_sm; + } + MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM + << ", but the current device's computing capacity is " << marjor_sm; } if (kernel_info->GetInputDeviceType(input_index) != (iter->second)[attr_index].first.GetInputAttr(input_index).first) { diff --git a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc index 56a0905e4e..4fe2acb726 100644 --- a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc @@ -38,5 +38,13 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BinaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BinaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BinaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h index 522ec2b37e..b929bbee50 100644 --- a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h @@ -27,12 +27,16 @@ #include "kernel/gpu/kernel_constants.h" namespace mindspore { namespace kernel { -enum BinaryOpType { BINARY_OP_ADD = 0, BINARY_OP_SUB, BINARY_OP_MUL, BINARY_OP_DIV, BINARY_OP_INVALID_TYPE = 255 }; -const std::map kBinaryOpTypeMap = { - {"Sub", BINARY_OP_SUB}, - {"Mul", BINARY_OP_MUL}, - {"RealDiv", BINARY_OP_DIV}, +enum BinaryOpType { + BINARY_OP_ADD = 0, + BINARY_OP_SUB, + BINARY_OP_MUL, + BINARY_OP_DIV, + BINARY_OP_MAX, + BINARY_OP_INVALID_TYPE = 255 }; +static const std::map kBinaryOpTypeMap = { + {"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}, {"Maximum", BINARY_OP_MAX}}; template class BinaryOpGpuKernel : public GpuKernel { public: @@ -84,6 +88,10 @@ class BinaryOpGpuKernel : public GpuKernel { inputB_addr = workspace_addr; break; } + case BINARY_OP_MAX: { + inputB_addr = input_addr2; + break; + } default: { MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; } @@ -201,6 +209,10 @@ class BinaryOpGpuKernel : public GpuKernel { tensor_op_ = CUDNN_OP_TENSOR_ADD; break; } + case BINARY_OP_MAX: { + tensor_op_ = CUDNN_OP_TENSOR_MAX; + break; + } default: { MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; } diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc new file mode 100644 index 0000000000..374644eaf5 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/math/float_status_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h new file mode 100644 index 0000000000..bdd93d5d54 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H + +#include +#include +#include +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/float_status_impl.cuh" + +namespace mindspore { +namespace kernel { +enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; +static const std::map kOpTypeMap = { + {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; +template +class FloatStatusGpuKernel : public GpuKernel { + public: + FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} + ~FloatStatusGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + + switch (kernel_name_) { + case OP_STATUS: { + T *output = GetDeviceAddress(outputs, 0); + CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_INF: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_NAN: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_FINITE: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; + } + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kOpTypeMap.find(kernel_name); + if (iter == kOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; + } else { + kernel_name_ = iter->second; + } + if (kernel_name_ == OP_STATUS) { + output_size_ = sizeof(T); + } else { + output_size_ = input_size_ / sizeof(T) * sizeof(bool); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + Optype kernel_name_; + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc index 1b7318c511..69716e9165 100644 --- a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc @@ -26,5 +26,8 @@ MS_REG_GPU_KERNEL_ONE( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), TensorAddGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE( + TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TensorAddGpuFwdKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h index a203567aa8..4dfbf4c3d4 100644 --- a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h @@ -71,6 +71,9 @@ class TensorAddGpuFwdKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) { InitResource(); cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + cudnn_data_type_ = CUDNN_DATA_FLOAT; + } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc index d69706663e..bfdbe11422 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc @@ -38,5 +38,9 @@ MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).A UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h index af78ea4e73..d8fea7370b 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h @@ -33,13 +33,15 @@ enum UnaryOptype { UNARY_OP_NEG, UNARY_OP_RECIPROCAL, UNARY_OP_ZEROSLIKE, + UNARY_OP_SQUARE, UNARY_OP_INVALID_TYPE = 255 }; -const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, - {"Log", UNARY_OP_LOG}, - {"Neg", UNARY_OP_NEG}, - {"Reciprocal", UNARY_OP_RECIPROCAL}, - {"ZerosLike", UNARY_OP_ZEROSLIKE}}; +static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, + {"Log", UNARY_OP_LOG}, + {"Neg", UNARY_OP_NEG}, + {"Reciprocal", UNARY_OP_RECIPROCAL}, + {"ZerosLike", UNARY_OP_ZEROSLIKE}, + {"Square", UNARY_OP_SQUARE}}; template class UnaryOpGpuKernel : public GpuKernel { public: @@ -74,7 +76,12 @@ class UnaryOpGpuKernel : public GpuKernel { Reciprocal(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_SQUARE: { + Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_ZEROSLIKE: { + Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; } default: { @@ -93,12 +100,12 @@ class UnaryOpGpuKernel : public GpuKernel { } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but negative op needs 1 inputs."; + MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs."; return false; } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but negative op needs 1 output."; + MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output."; return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); diff --git a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h index fd73f378d8..5c7153a172 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h @@ -101,7 +101,7 @@ class BiasAddGradGpuKernel : public GpuKernel { cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), "cudnnSetTensorNdDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, cudnn_data_type_, CUDNN_NOT_PROPAGATE_NAN, + cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), "cudnnSetReduceTensorDescriptor failed"); diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h index 37d0aadfbc..975dbd0082 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h @@ -48,14 +48,10 @@ class FlattenGpuFwdKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); for (size_t i = 0; i < shape.size(); ++i) { - if (input_size_ == 0) { - input_size_ = 1; - } input_size_ *= shape[i]; } - input_size_ = input_size_ * sizeof(T); - InitSizeLists(); return true; } diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc index 4ddc710a4c..91747d24d8 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc @@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGpuKernel, float) MS_REG_GPU_KERNEL_ONE(BatchNorm, @@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), FusedBatchNormGpuKernel, half) } // namespace kernel diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h index 6f0c59e29a..5ca85f8e63 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -156,9 +156,6 @@ class FusedBatchNormGpuKernel : public GpuKernel { output_size_list_.push_back(para_size); // running variance output_size_list_.push_back(para_size); // save mean output_size_list_.push_back(para_size); // save variance - if (!is_train_) { - output_size_list_.push_back(para_size); // reserve - } return; } diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h index 3822a326fb..4d50d4753d 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h @@ -58,8 +58,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { } T *logits_addr = GetDeviceAddress(inputs, 0); S *labels_addr = GetDeviceAddress(inputs, 1); - T *output1_addr = GetDeviceAddress(outputs, 0); - T *output2_addr = GetDeviceAddress(outputs, 1); + T *loss_addr = GetDeviceAddress(outputs, 0); + T *dlogits_addr = GetDeviceAddress(outputs, 1); T *softmax_output_logits = GetDeviceAddress(workspace, 0); const float alpha = 1; @@ -69,10 +69,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { softmax_output_descriptor_, softmax_output_logits), "cudnnSoftmaxForward failed."); - CrossEntropyWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output1_addr, - reinterpret_cast(stream_ptr)); - CrossEntropyGradWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output2_addr, - reinterpret_cast(stream_ptr)); + CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, + reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.cc b/mindspore/ccsrc/kernel/hccl/hcom_util.cc index 8e5f9cb7e6..d1c0a30113 100644 --- a/mindspore/ccsrc/kernel/hccl/hcom_util.cc +++ b/mindspore/ccsrc/kernel/hccl/hcom_util.cc @@ -49,7 +49,7 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector>* kernel_info_list) { +void FilterInvaildKernelInfo(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_info_list); std::vector> filtered_list; (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), - [&](const std::shared_ptr& kernel_build_info) { + [&](const std::shared_ptr &kernel_build_info) { return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); }); @@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node, } } } // namespace -void KernelQuery(const CNodePtr& kernel_node, std::vector>* kernel_info_list) { +void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list); diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h index 215df21776..670830a8b1 100644 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ b/mindspore/ccsrc/kernel/oplib/opinfo.h @@ -38,11 +38,11 @@ class OpAttr { std::string value() const { return value_; } std::string default_value() const { return default_value_; } - void set_name(const std::string& name) { name_ = name; } - void set_param_type(const std::string& param_type) { param_type_ = param_type; } - void set_type(const std::string& type) { type_ = type; } - void set_value(const std::string& value) { value_ = value; } - void set_default_value(const std::string& default_value) { default_value_ = default_value; } + void set_name(const std::string &name) { name_ = name; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_type(const std::string &type) { type_ = type; } + void set_value(const std::string &value) { value_ = value; } + void set_default_value(const std::string &default_value) { default_value_ = default_value; } private: std::string name_; @@ -67,13 +67,13 @@ class OpIOInfo { std::vector formats() const { return formats_; } void set_index(const int index) { index_ = index; } - void set_name(const std::string& name) { name_ = name; } + void set_name(const std::string &name) { name_ = name; } void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } - void set_param_type(const std::string& param_type) { param_type_ = param_type; } - void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } - void set_shape(const std::string& shape) { shape_ = shape; } - void set_dtypes(const std::vector& dtype) { dtypes_ = dtype; } - void set_formats(const std::vector& formats) { formats_ = formats; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } + void set_shape(const std::string &shape) { shape_ = shape; } + void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } + void set_formats(const std::vector &formats) { formats_ = formats; } private: int index_ = 0; @@ -104,24 +104,24 @@ class OpInfo { std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } - const std::unordered_map& ref_infos() const { return ref_infos_; } + const std::unordered_map &ref_infos() const { return ref_infos_; } - void set_op_name(const std::string& op_name) { op_name_ = op_name; } + void set_op_name(const std::string &op_name) { op_name_ = op_name; } void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } - void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; } - void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; } + void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } + void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } - void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; } + void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } - void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } + void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } - void add_attrs_ptr(const std::shared_ptr& attr) { attrs_ptr_.push_back(attr); } - void add_inputs_ptr(const std::shared_ptr& input) { inputs_ptr_.push_back(input); } - void add_outputs_ptr(const std::shared_ptr& output) { outputs_ptr_.push_back(output); } - void set_inputs_ptr(const std::vector>& inputs) { inputs_ptr_ = inputs; } - void set_outputs_ptr(const std::vector>& outputs) { outputs_ptr_ = outputs; } + void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } + void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } + void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } + void set_inputs_ptr(const std::vector> &inputs) { inputs_ptr_ = inputs; } + void set_outputs_ptr(const std::vector> &outputs) { outputs_ptr_ = outputs; } bool is_ref() const { return !ref_infos_.empty(); } bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index c8cc1530ce..f5f2e1601b 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) { return "unknow"; } } -bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) { +bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { bool ret = false; try { auto op_json = nlohmann::json::parse(json_string); @@ -83,18 +83,18 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) OpImplyType imply_type = kAICPU; ret = DecodeOpInfo(op_json, imply_type, impl_path); } else { - MS_LOG(DEBUG) << "Not support imply_type"; + MS_LOG(ERROR) << "Not support imply_type"; } if (!ret) { - MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; + MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string; } - } catch (const std::exception& e) { - MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "get op json elements failed: " << e.what(); } return ret; } -void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info) { +void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { op_info->set_async_flag(obj.at(kAsyncFlag)); op_info->set_binfile_name(obj.at(kBinfileName)); op_info->set_compute_cost(obj.at(kComputeCost)); @@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p } } -bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type, - const std::string& impl_path) { +bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, + const std::string &impl_path) { std::shared_ptr op_info = std::make_shared(); MS_EXCEPTION_IF_NULL(op_info); op_info->set_op_name(obj.at(kOpName)); @@ -120,9 +120,9 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI DecodeTBESpecificInfo(obj, op_info); } auto attrs = obj.at(kAttr); - for (const auto& attr : attrs) { + for (const auto &attr : attrs) { if (!DecodeAttr(attr, imply_type, op_info)) { - MS_LOG(DEBUG) << "DecodeAttr Failed"; + MS_LOG(ERROR) << "DecodeAttr Failed"; return false; } } @@ -131,33 +131,33 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI dtype_format = obj.at(kDtypeFormat); } auto inputs = obj.at(kIputs); - for (const auto& input : inputs) { + for (const auto &input : inputs) { if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { - MS_LOG(DEBUG) << "DecodeInputOutput Failed"; + MS_LOG(ERROR) << "DecodeInputOutput Failed"; return false; } } auto outputs = obj.at(kOutputs); - for (const auto& output : outputs) { + for (const auto &output : outputs) { if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { - MS_LOG(DEBUG) << "DecodeInputOutput Failed"; + MS_LOG(ERROR) << "DecodeInputOutput Failed"; return false; } } if (!GetRefInfo(op_info)) { - MS_LOG(DEBUG) << "GetRefInfo Failed"; + MS_LOG(ERROR) << "GetRefInfo Failed"; return false; } if (!CheckRepetition(op_info)) { - MS_LOG(DEBUG) << "CheckRepetition Failed"; + MS_LOG(ERROR) << "CheckRepetition Failed"; return false; } op_info_.push_back(op_info); return true; } -bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, - const std::shared_ptr& op_info) { +bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); bool ret = true; try { @@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, op_attr->set_default_value(obj.at(kDefaultValue)); } op_info->add_attrs_ptr(op_attr); - } catch (const std::exception& e) { - MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeAttr failed:" << e.what(); ret = false; } return ret; } -bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, +bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, size_t index) { bool ret = true; try { std::vector dtype; std::vector format; - for (const auto& it : dtype_format) { + for (const auto &it : dtype_format) { dtype.emplace_back(it[index][0]); format.emplace_back(it[index][1]); } op_io->set_dtypes(dtype); op_io->set_formats(format); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); ret = false; } return ret; } -bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info, const nlohmann::json& dtype_format) { +bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { bool ret = true; try { std::shared_ptr op_io = std::make_shared(); @@ -219,8 +219,8 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply op_io->set_formats(obj.at(kFormat)); } if (op_io->dtypes().size() != op_io->formats().size()) { - MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes() - << "is not equal to format size:" << op_io->formats(); + MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes() + << " is not equal to format size: " << op_io->formats(); return false; } if (obj.find(kParamType) != obj.end()) { @@ -243,45 +243,45 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply } else if (io_type == kOutput) { op_info->add_outputs_ptr(op_io); } - } catch (const std::exception& e) { - MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what(); ret = false; } return ret; } -std::shared_ptr OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) { +std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool is_gpu = (context->device_target() == kGPUDevice); if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || (!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) { - MS_LOG(ERROR) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) - << ", current op num:" << op_info_.size(); + MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); return nullptr; } - for (const auto& op_info : op_info_) { + for (const auto &op_info : op_info_) { MS_EXCEPTION_IF_NULL(op_info); if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { return op_info; } } - MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) - << ", current op num:" << op_info_.size(); + MS_LOG(DEBUG) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); return nullptr; } -bool OpLib::GetRefInfo(const std::shared_ptr& op_info) { +bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); - const auto& output_infos = op_info->outputs_ptr(); - const auto& input_infos = op_info->inputs_ptr(); + const auto &output_infos = op_info->outputs_ptr(); + const auto &input_infos = op_info->inputs_ptr(); for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { - const auto& out_name = output_infos[out_index]->name(); + const auto &out_name = output_infos[out_index]->name(); for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { - const auto& in_name = input_infos[in_index]->name(); + const auto &in_name = input_infos[in_index]->name(); if (out_name == in_name) { if (op_info->has_ref_index(out_index)) { - MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; + MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info"; return false; } op_info->add_ref_pair(out_index, in_index); @@ -293,14 +293,14 @@ bool OpLib::GetRefInfo(const std::shared_ptr& op_info) { return true; } -bool OpLib::CheckRepetition(const std::shared_ptr& op_info) { +bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); - for (const auto& exist_op_info : op_info_) { + for (const auto &exist_op_info : op_info_) { MS_EXCEPTION_IF_NULL(exist_op_info); if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && exist_op_info->impl_path() != op_info->impl_path()) { - MS_LOG(DEBUG) << "Has already exist, drop the latter one, op name:" << op_info->op_name() - << "op type:" << ImplTypeToStr(op_info->imply_type()); + MS_LOG(ERROR) << "Op has already exist, please use other name, op name: " << op_info->op_name() + << " op type: " << ImplTypeToStr(op_info->imply_type()); return false; } } diff --git a/mindspore/ccsrc/kernel/oplib/oplib.h b/mindspore/ccsrc/kernel/oplib/oplib.h index 0e11e28d58..3d4dcad908 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.h +++ b/mindspore/ccsrc/kernel/oplib/oplib.h @@ -28,23 +28,23 @@ class OpLib { public: OpLib() = default; virtual ~OpLib() = default; - bool RegOp(const std::string& json_string, const std::string& impl_path); - static std::shared_ptr FindOp(const std::string& op_name, OpImplyType imply_type); + bool RegOp(const std::string &json_string, const std::string &impl_path); + static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); protected: static std::vector> op_info_; private: - static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); - static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, - const std::shared_ptr& op_info); - static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, + static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); + static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info); + static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, size_t index); - static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info); - static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info, const nlohmann::json& dtype_format); - static bool GetRefInfo(const std::shared_ptr& op_info); - static bool CheckRepetition(const std::shared_ptr& op_info); + static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); + static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format); + static bool GetRefInfo(const std::shared_ptr &op_info); + static bool CheckRepetition(const std::shared_ptr &op_info); }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 3fda554759..44750fab4f 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -33,6 +33,7 @@ static std::map tbe_func_adapter_map = { {"re_lu6", "relu6"}, {"re_lu6_grad", "relu6_grad"}, {"re_lu", "relu"}, + {"re_luv2", "relu_v2"}, {"tensor_add", "add"}, {"reduce_mean", "reduce_mean_d"}, {"reduce_max", "reduce_max_d"}, @@ -57,6 +58,7 @@ static std::map tbe_func_adapter_map = { {"strided_slice", "strided_slice_d"}, {"strided_slice_grad", "strided_slice_grad_d"}, {"transpose", "transpose_d"}, + {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, {"concat", "concat_d"}, {"slice", "slice_d"}, diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index 496f99df1c..5255cc6450 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -383,6 +383,10 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr &anf_no attr_obj["name"] = attr_name; attr_obj["valid"] = true; (*attrs_json).push_back(attr_obj); + } else { + if (attr_ptr->param_type() == "required" && creater_type_ == SINGLE_BUILD && op_info->impl_path() != "") { + MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name << "is required, but not set."; + } } } return true; @@ -513,36 +517,36 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector &inp return true; } -void TbeKernelBuild::GenDescJson(const shared_ptr &anf_node, size_t out_idx, - nlohmann::json *output_desc) { +void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc) { std::string output_desc_name = anf_node->fullname_with_scope(); - if (out_idx > 0) { - output_desc_name = output_desc_name + "_" + std::to_string(out_idx); + if (node_out_idx > 0) { + output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); } (*output_desc)["name"] = NormalizeFullScopeName(output_desc_name); - auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, out_idx); + auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); (*output_desc)["data_type"] = tbe::TypeIdToString(type_id); - auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, out_idx); + auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); if (ori_shape.empty()) { ori_shape.emplace_back(1); } (*output_desc)["ori_shape"] = ori_shape; - auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, out_idx); + auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx); if (shape.empty()) { shape.emplace_back(1); } (*output_desc)["shape"] = shape; - auto format = AnfAlgo::GetOutputFormat(anf_node, out_idx); + auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); if (format == kOpFormat_DEFAULT) { if (ori_shape.size() == 4) { format = kOpFormat_NCHW; } else { - format = "ND"; + format = kOpFormat_ND; } } (*output_desc)["format"] = format; (*output_desc)["ori_format"] = kOpFormat_NCHW; - (*output_desc)["output_index"] = out_idx; + (*output_desc)["output_index"] = desc_output_idx; } void TbeKernelBuild::GenReusedOutputDesc(const shared_ptr &anf_node, size_t index, @@ -605,7 +609,7 @@ bool TbeKernelBuild::GenFusionDataInputJson(const shared_ptr MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx; // "output_desc" nlohmann::json output_desc; - GenDescJson(real_node, real_idx, &output_desc); + GenDescJson(real_node, real_idx, real_idx, &output_desc); output_desc_list.push_back(output_desc); (*data_str)["name"] = NormalizeFullScopeName(real_node->fullname_with_scope()); } @@ -653,9 +657,9 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); } -bool TbeKernelBuild::GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index) { +bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(input_desc_list); bool is_dynamic_input = IsDynamicInput(cnode); @@ -666,7 +670,7 @@ bool TbeKernelBuild::GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode size_t real_idx = kernel_idx.second; MS_LOG(INFO) << "real name" << real_node->fullname_with_scope() << "index:" << real_idx; nlohmann::json input_desc; - GenDescJson(real_node, real_idx, &input_desc); + GenDescJson(real_node, real_idx, real_idx, &input_desc); if (is_dynamic_input) { MS_LOG(INFO) << "node has dynamic input."; input_desc["dyn_index"] = (i - 1); @@ -687,6 +691,67 @@ bool TbeKernelBuild::GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode return true; } +std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector &output_used_nums) { + std::vector desc_output_index = {}; + bool find_reused = false; + size_t reused_num = 0; + for (size_t idx = 0; idx < output_used_nums.size(); ++idx) { + auto output_use_num_item = output_used_nums[idx]; + MS_LOG(INFO) << "output used num[" << idx << "] = " << output_use_num_item; + if (output_use_num_item == 1 || output_use_num_item == 0) { + desc_output_index.emplace_back(idx); + } else { + if (!find_reused) { + desc_output_index.emplace_back(idx); + } else { + desc_output_index.emplace_back(desc_output_index[idx - 1]); + } + reused_num += (output_use_num_item - 1); + find_reused = true; + } + } + auto pad_value = output_used_nums.size() == 1 ? 0 : desc_output_index[desc_output_index.size() - 1] + 1; + for (size_t i = 0; i < reused_num; ++i) { + desc_output_index.emplace_back(pad_value); + } + return desc_output_index; +} + +bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list) { + auto output_size = AnfAlgo::GetOutputTensorNum(cnode); + if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { + // wait anther pr: auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); + auto output_used_nums = {SizeToInt(AnfAlgo::GetNodeAttr(cnode, kAttrOutputUsedNum))}; + MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope(); + if (output_used_nums.size() != output_size) { + MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")" + << " is not match output used num(" << output_used_nums.size() << ")"; + return false; + } + auto desc_output_index = GetDescOutputIndex(output_used_nums); + for (size_t i = 0; i < output_size; ++i) { + MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; + nlohmann::json output_desc; + GenDescJson(cnode, i, desc_output_index[i], &output_desc); + output_desc_list->emplace_back(output_desc); + } + for (size_t j = output_size; j < desc_output_index.size(); ++j) { + MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j]; + nlohmann::json output_desc; + GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc); + output_desc_list->emplace_back(output_desc); + } + } else { + for (size_t i = 0; i < output_size; ++i) { + nlohmann::json output_desc; + GenDescJson(cnode, i, i, &output_desc); + output_desc_list->push_back(output_desc); + } + } + return true; +} + bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, std::vector>::iterator *layer_iter, nlohmann::json *compute_op_str, std::string *fusion_kernel_name, @@ -696,28 +761,14 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n MS_EXCEPTION_IF_NULL(cnode); // gen input desc std::vector input_desc_list; - (void)GenFusionComputeInputeJson(cnode, layer_iter, &input_desc_list, index); + (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index); (*compute_op_str)["input_desc"] = input_desc_list; // gen output desc std::vector output_desc_list; - auto output_size = AnfAlgo::GetOutputTensorNum(cnode); - for (size_t i = 0; i < output_size; ++i) { - nlohmann::json output_desc; - GenDescJson(cnode, i, &output_desc); - output_desc_list.push_back(output_desc); - } - - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimConv2D->name()) { - if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, compute_node)) { - auto output_used_num = AnfAlgo::GetNodeAttr(compute_node, kAttrOutputUsedNum); - for (size_t i = output_size; i < output_used_num; ++i) { - nlohmann::json output_desc; - GenReusedOutputDesc(cnode, i, 0, &output_desc); - output_desc_list.push_back(output_desc); - } - } + if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { + MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope(); + return false; } - (*compute_op_str)["output_desc"] = output_desc_list; // gen others auto type = AnfAlgo::GetCNodeName(cnode); diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h index de5ed84e41..1a3eee7fd9 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h @@ -53,11 +53,14 @@ class TbeKernelBuild { static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, std::vector>::iterator *layer_iter, nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); - static bool GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index); - static void GenDescJson(const std::shared_ptr &anf_node, size_t out_idx, - nlohmann::json *output_desc); + static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index); + static std::vector GetDescOutputIndex(const std::vector &output_used_nums); + static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list); + static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc); static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, size_t output_index, nlohmann::json *output_desc); static size_t GetIOSizeImpl(const nlohmann::json &desc); diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 338a17ac2d..8718e9b871 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -75,12 +75,9 @@ void BindShardWriter(py::module *m) { .def("set_header_size", &ShardWriter::set_header_size) .def("set_page_size", &ShardWriter::set_page_size) .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", - (MSRStatus(ShardWriter::*)(std::map> &, vector> &, bool)) & - ShardWriter::WriteRawData) - .def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map> &, - std::map> &, bool)) & - ShardWriter::WriteRawData) + .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, + vector> &, bool, bool)) & + ShardWriter::WriteRawData) .def("commit", &ShardWriter::Commit); } diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h index d31037c8ad..3af4d7f891 100644 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h @@ -72,6 +72,8 @@ enum ShardType { enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; +enum ShuffleType { kShuffleCategory, kShuffleSample }; + const double kEpsilon = 1e-7; const int kThreadNumber = 14; diff --git a/mindspore/ccsrc/mindrecord/include/shard_category.h b/mindspore/ccsrc/mindrecord/include/shard_category.h index b8a7611540..b2fe18fbac 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/mindrecord/include/shard_category.h @@ -17,6 +17,8 @@ #ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ #define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#include +#include #include #include #include @@ -26,16 +28,34 @@ namespace mindspore { namespace mindrecord { class ShardCategory : public ShardOperator { public: - explicit ShardCategory(const std::vector> &categories); + explicit ShardCategory(const std::vector> &categories, + int64_t num_elements = std::numeric_limits::max(), bool replacement = false); + + ShardCategory(const std::string &category_field, int64_t num_elements, + int64_t num_categories = std::numeric_limits::max(), bool replacement = false); ~ShardCategory() override{}; - const std::vector> &get_categories() const; + const std::vector> &get_categories() const { return categories_; } + + const std::string GetCategoryField() const { return category_field_; } + + int64_t GetNumElements() const { return num_elements_; } + + int64_t GetNumCategories() const { return num_categories_; } + + bool GetReplacement() const { return replacement_; } MSRStatus execute(ShardTask &tasks) override; + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + private: std::vector> categories_; + std::string category_field_; + int64_t num_elements_; + int64_t num_categories_; + bool replacement_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index ca4d3bd66f..70cfcdb6b7 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -121,6 +121,10 @@ class ShardHeader { std::vector SerializeHeader(); + MSRStatus PagesToFile(const std::string dump_file_name); + + MSRStatus FileToPages(const std::string dump_file_name); + private: MSRStatus InitializeHeader(const std::vector &headers); diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h index 9f302e5321..7476660a70 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_operator.h @@ -43,6 +43,8 @@ class ShardOperator { virtual MSRStatus execute(ShardTask &tasks) = 0; virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } + + virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h new file mode 100644 index 0000000000..df3888dad4 --- /dev/null +++ b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ + +#include +#include +#include +#include +#include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_shuffle.h" +#include "mindrecord/include/shard_category.h" + +namespace mindspore { +namespace mindrecord { +class ShardPkSample : public ShardCategory { + public: + ShardPkSample(const std::string &category_field, int64_t num_elements); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); + + ~ShardPkSample() override{}; + + MSRStatus suf_execute(ShardTask &tasks) override; + + private: + bool shuffle_; + std::shared_ptr shuffle_op_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 5548473cd7..3263b2006d 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -115,9 +115,10 @@ class ShardReader { /// \brief get the number of rows in database /// \param[in] file_path the path of ONE file, any file in dataset is fine + /// \param[in] op smart pointer refer to ShardCategory or ShardSample object /// \param[out] count # of rows /// \return MSRStatus the status of MSRStatus - MSRStatus CountTotalRows(const std::string &file_path, int64_t *count); + MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr &op, int64_t *count); /// \brief shuffle task with incremental seed /// \return void @@ -197,6 +198,9 @@ class ShardReader { /// \brief get NLP flag bool get_nlp_flag(); + /// \brief get all classes + MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); + protected: /// \brief sqlite call back function static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); @@ -249,8 +253,8 @@ class ShardReader { const std::vector> &operators); /// \brief create category-applied task list - int CreateTasksByCategory(const std::vector> &row_group_summary, - const std::vector> &operators); + MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op); /// \brief create task list in row-reader mode MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, @@ -284,6 +288,12 @@ class ShardReader { MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); + /// \brief get classes in one shard + void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); + + /// \brief get number of classes + int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); + protected: uint64_t header_size_; // header size uint64_t page_size_; // page size diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index 15353fd0ff..b16fc5cc4f 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -41,8 +41,11 @@ class ShardSample : public ShardOperator { const std::pair get_partitions() const; MSRStatus execute(ShardTask &tasks) override; + MSRStatus suf_execute(ShardTask &tasks) override; + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + private: int numerator_; int denominator_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h index 464881aa7a..027a5ad527 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h @@ -24,7 +24,7 @@ namespace mindspore { namespace mindrecord { class ShardShuffle : public ShardOperator { public: - explicit ShardShuffle(uint32_t seed = 0); + explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); ~ShardShuffle() override{}; @@ -32,6 +32,7 @@ class ShardShuffle : public ShardOperator { private: uint32_t shuffle_seed_; + ShuffleType shuffle_type_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_task.h b/mindspore/ccsrc/mindrecord/include/shard_task.h index 30ea352ef3..b276b5150f 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/mindrecord/include/shard_task.h @@ -41,7 +41,9 @@ class ShardTask { std::tuple, std::vector, json> &get_task_by_id(size_t id); - static ShardTask Combine(std::vector &category_tasks); + std::tuple, std::vector, json> &get_random_task(); + + static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); uint32_t categories = 1; diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 6a22f07700..78a434fc97 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -18,6 +18,7 @@ #define MINDRECORD_INCLUDE_SHARD_WRITER_H_ #include +#include #include #include #include @@ -87,7 +88,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -95,7 +96,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -103,7 +104,8 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true); + std::map> &blob_data, bool sign = true, + bool parallel_writer = false); private: /// \brief write shard header data to disk @@ -201,7 +203,34 @@ class ShardWriter { MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, std::map &err_raw_data); + /// \brief Lock writer and save pages info + int LockWriter(bool parallel_writer = false); + + /// \brief Unlock writer and save pages info + MSRStatus UnlockWriter(int fd, bool parallel_writer = false); + + /// \brief Check raw data before writing + MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, + bool sign, int *schema_count, int *row_count); + + /// \brief Get full path from file name + MSRStatus GetFullPathFromFileName(const std::vector &paths); + + /// \brief Open files + MSRStatus OpenDataFiles(bool append); + + /// \brief Remove lock file + MSRStatus RemoveLockFile(); + + /// \brief Remove lock file + MSRStatus InitLockFile(); + private: + const std::string kLockFileSuffix = "_Locker"; + const std::string kPageFileSuffix = "_Pages"; + std::string lock_file_; // lock file for parallel run + std::string pages_file_; // temporary file of pages info for parallel run + int shard_count_; // number of files uint64_t header_size_; // header size uint64_t page_size_; // page size @@ -211,7 +240,7 @@ class ShardWriter { std::vector raw_data_size_; // Raw data size std::vector blob_data_size_; // Blob data size - std::vector file_paths_; // file paths + std::vector file_paths_; // file paths std::vector> file_streams_; // file handles std::shared_ptr shard_header_; // shard headers diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index 5a5cd7cbf3..dc2743cdc7 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -520,13 +520,16 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std for (int raw_page_id : raw_page_ids) { auto sql = GenerateRawSQL(fields_); if (sql.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw SQL failed"; return FAILED; } auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); if (data.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw data failed"; return FAILED; } if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { + MS_LOG(ERROR) << "Execute SQL failed"; return FAILED; } MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index fd3fede5a2..9cd02d9120 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -315,6 +315,43 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); } +MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); + if (SUCCESS != ret.first) { + return FAILED; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count_); + for (int x = 0; x < shard_count_; x++) { + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count_; x++) { + threads[x].join(); + } + return SUCCESS; +} + +void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, + std::set &categories) { + if (nullptr == db) { + return; + } + std::vector> columns; + char *errmsg = nullptr; + int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); + if (ret != SQLITE_OK) { + sqlite3_free(errmsg); + sqlite3_close(db); + MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; + return; + } + MS_LOG(INFO) << "Get" << static_cast(columns.size()) << " records from shard " << shard_id << " index."; + for (int i = 0; i < static_cast(columns.size()); ++i) { + categories.emplace(columns[i][0]); + } +} + ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; std::vector>> offsets(shard_count_, std::vector>{}); @@ -667,11 +704,64 @@ MSRStatus ShardReader::Finish() { return SUCCESS; } -MSRStatus ShardReader::CountTotalRows(const std::string &file_path, int64_t *count) { +int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) { + ShardHeader sh = ShardHeader(); + if (sh.Build(file_path) == FAILED) { + return -1; + } + auto header = std::make_shared(sh); + auto file_paths = header->get_shard_addresses(); + auto shard_count = file_paths.size(); + auto index_fields = header->get_fields(); + + std::map map_schema_id_fields; + for (auto &field : index_fields) { + map_schema_id_fields[field.second] = field.first; + } + auto ret = + ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); + if (SUCCESS != ret.first) { + return -1; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count); + std::set categories; + for (int x = 0; x < shard_count; x++) { + sqlite3 *db = nullptr; + int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + if (SQLITE_OK != rc) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return -1; + } + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count; x++) { + threads[x].join(); + } + return categories.size(); +} + +MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr &op, + int64_t *count) { if (Init(file_path) == FAILED) { return FAILED; } - *count = num_rows_; + int64_t num_samples = num_rows_; + if (std::dynamic_pointer_cast(op)) { + auto category_op = std::dynamic_pointer_cast(op); + std::string category_field = category_op->GetCategoryField(); + auto num_classes = GetNumClasses(file_path, category_field); + num_samples = category_op->GetNumSamples(num_rows_, num_classes); + } else if (std::dynamic_pointer_cast(op)) { + num_samples = op->GetNumSamples(num_rows_, 0); + } else { + } + if (-1 == num_samples) { + MS_LOG(ERROR) << "Failed to get dataset size."; + return FAILED; + } + *count = num_samples; return SUCCESS; } @@ -793,6 +883,8 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); } } + + MS_LOG(INFO) << "Launch read thread successfully."; return SUCCESS; } @@ -828,44 +920,67 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators) { +MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op) { vector columns = GetAllColumns(); CheckIfColumnInIndex(columns); - int category_operator = -1; - for (uint32_t i = 0; i < operators.size(); ++i) { - const auto &op = operators[i]; - if (std::dynamic_pointer_cast(op)) category_operator = static_cast(i); + auto category_op = std::dynamic_pointer_cast(op); + auto categories = category_op->get_categories(); + int64_t num_elements = category_op->GetNumElements(); + if (num_elements <= 0) { + MS_LOG(ERROR) << "Parameter num_element is not positive"; + return FAILED; + } + if (categories.empty() == true) { + std::string category_field = category_op->GetCategoryField(); + int64_t num_categories = category_op->GetNumCategories(); + if (num_categories <= 0) { + MS_LOG(ERROR) << "Parameter num_categories is not positive"; + return FAILED; + } + std::set categories_set; + auto ret = GetAllClasses(category_field, categories_set); + if (SUCCESS != ret) { + return FAILED; + } + int i = 0; + for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { + categories.emplace_back(category_field, *it); + i++; + } } - - if (category_operator == -1) return category_operator; - - auto categories = std::dynamic_pointer_cast(operators[category_operator])->get_categories(); - // Generate task list, a task will create a batch std::vector categoryTasks(categories.size()); for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { + int category_index = 0; for (const auto &rg : row_group_summary) { + if (category_index >= num_elements) break; auto shard_id = std::get<0>(rg); auto group_id = std::get<1>(rg); auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns); if (SUCCESS != std::get<0>(details)) { - return -2; + return FAILED; } auto offsets = std::get<4>(details); auto number_of_rows = offsets.size(); for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { - categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], - std::get<5>(details)[iStart]); + if (category_index < num_elements) { + categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], + std::get<5>(details)[iStart]); + category_index++; + } } } MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; } - tasks_ = ShardTask::Combine(categoryTasks); - return category_operator; + tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); + if (SUCCESS != (*category_op)(tasks_)) { + return FAILED; + } + return SUCCESS; } MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, @@ -896,14 +1011,26 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, const std::vector> &operators) { if (block_reader_) { - CreateTasksByBlock(row_group_summary, operators); + if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { + return FAILED; + } } else { - int category_operator = CreateTasksByCategory(row_group_summary, operators); - if (category_operator == -1) { - CreateTasksByRow(row_group_summary, operators); + int category_operator = -1; + for (uint32_t i = 0; i < operators.size(); ++i) { + const auto &op = operators[i]; + if (std::dynamic_pointer_cast(op)) { + category_operator = static_cast(i); + break; + } } - if (category_operator == -2) { - return FAILED; + if (-1 == category_operator) { + if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { + return FAILED; + } + } else { + if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { + return FAILED; + } } } diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 864e6697d0..2fb5db5503 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -40,17 +40,7 @@ ShardWriter::~ShardWriter() { } } -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - +MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { // Get full path from file name for (const auto &path : paths) { if (!CheckIsValidUtf8(path)) { @@ -60,7 +50,7 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) char resolved_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func failed"; + MS_LOG(ERROR) << "Secure func failed"; return FAILED; } #if defined(_WIN32) || defined(_WIN64) @@ -82,7 +72,10 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) #endif file_paths_.emplace_back(string(resolved_path)); } + return SUCCESS; +} +MSRStatus ShardWriter::OpenDataFiles(bool append) { // Open files for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); @@ -116,6 +109,67 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) return SUCCESS; } +MSRStatus ShardWriter::RemoveLockFile() { + // Remove temporary file + int ret = std::remove(pages_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove page file."; + } + + ret = std::remove(lock_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove lock file."; + } + return SUCCESS; +} + +MSRStatus ShardWriter::InitLockFile() { + if (file_paths_.size() == 0) { + MS_LOG(ERROR) << "File path not initialized."; + return FAILED; + } + + lock_file_ = file_paths_[0] + kLockFileSuffix; + pages_file_ = file_paths_[0] + kPageFileSuffix; + + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + + // Get full path from file name + if (GetFullPathFromFileName(paths) == FAILED) { + MS_LOG(ERROR) << "Get full path from file name failed."; + return FAILED; + } + + // Open files + if (OpenDataFiles(append) == FAILED) { + MS_LOG(ERROR) << "Open data files failed."; + return FAILED; + } + + // Init lock file + if (InitLockFile() == FAILED) { + MS_LOG(ERROR) << "Init lock file failed."; + return FAILED; + } + return SUCCESS; +} + MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (!IsLegalFile(path)) { return FAILED; @@ -143,11 +197,28 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { } MSRStatus ShardWriter::Commit() { + // Read pages file + std::ifstream page_file(pages_file_.c_str()); + if (page_file.good()) { + page_file.close(); + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return FAILED; + } + } + if (WriteShardHeader() == FAILED) { MS_LOG(ERROR) << "Write metadata failed"; return FAILED; } MS_LOG(INFO) << "Write metadata successfully."; + + // Remove lock file + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove lock file failed."; + return FAILED; + } + return SUCCESS; } @@ -455,15 +526,75 @@ void ShardWriter::FillArray(int start, int end, std::map> } } -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign) { +int ShardWriter::LockWriter(bool parallel_writer) { + if (!parallel_writer) { + return 0; + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Lock file done by python."; + const int fd = 0; +#else + const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); + if (fd >= 0) { + flock(fd, LOCK_EX); + } else { + MS_LOG(ERROR) << "Shard writer failed when locking file"; + return -1; + } +#endif + + // Open files + file_streams_.clear(); + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); + if (fs->fail()) { + MS_LOG(ERROR) << "File could not opened"; + return -1; + } + file_streams_.push_back(fs); + } + + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return -1; + } + return fd; +} + +MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { + if (!parallel_writer) { + return SUCCESS; + } + + if (shard_header_->PagesToFile(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Write pages to file failed"; + return FAILED; + } + + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Unlock file done by python."; +#else + flock(fd, LOCK_UN); + close(fd); +#endif + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, + std::vector> &blob_data, bool sign, int *schema_count, + int *row_count) { // check the free disk size auto st_space = GetDiskSize(file_paths_[0], kFreeSize); if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { MS_LOG(ERROR) << "IO error / there is no free disk to be used"; return FAILED; } - // Add 4-bytes dummy blob data if no any blob fields if (blob_data.size() == 0 && raw_data.size() > 0) { blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); @@ -479,10 +610,29 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d MS_LOG(ERROR) << "Validate raw data failed"; return FAILED; } + *schema_count = std::get<1>(v); + *row_count = std::get<2>(v); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, bool parallel_writer) { + // Lock Writer if loading data parallel + int fd = LockWriter(parallel_writer); + if (fd < 0) { + MS_LOG(ERROR) << "Lock writer failed"; + return FAILED; + } // Get the count of schemas and rows - int schema_count = std::get<1>(v); - int row_count = std::get<2>(v); + int schema_count = 0; + int row_count = 0; + + // Serialize raw data + if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { + MS_LOG(ERROR) << "Check raw data failed"; + return FAILED; + } if (row_count == kInt0) { MS_LOG(INFO) << "Raw data size is 0."; @@ -516,11 +666,17 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d } MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; + if (UnlockWriter(fd, parallel_writer) == FAILED) { + MS_LOG(ERROR) << "Unlock writer failed"; + return FAILED; + } + return SUCCESS; } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign) { + std::map> &blob_data, bool sign, + bool parallel_writer) { std::map> raw_data_json; std::map> blob_data_json; @@ -554,11 +710,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; return FAILED; } - return WriteRawData(raw_data_json, bin_blob_data, sign); + return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign) { + vector> &blob_data, bool sign, bool parallel_writer) { std::map> raw_data_json; (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), [](const std::pair> &pair) { @@ -568,7 +724,7 @@ MSRStatus ShardWriter::WriteRawData(std::map> [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); return std::make_pair(pair.first, std::move(json_raw_data)); }); - return WriteRawData(raw_data_json, blob_data, sign); + return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); } MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc index 859a3b343f..80816e7a79 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_category.cc @@ -18,11 +18,30 @@ namespace mindspore { namespace mindrecord { -ShardCategory::ShardCategory(const std::vector> &categories) - : categories_(categories) {} +ShardCategory::ShardCategory(const std::vector> &categories, int64_t num_elements, + bool replacement) + : categories_(categories), + category_field_(""), + num_elements_(num_elements), + num_categories_(0), + replacement_(replacement) {} -const std::vector> &ShardCategory::get_categories() const { return categories_; } +ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, + bool replacement) + : categories_({}), + category_field_(category_field), + num_elements_(num_elements), + num_categories_(num_categories), + replacement_(replacement) {} MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } + +int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (dataset_size == 0) return dataset_size; + if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) { + return std::min(num_categories_, num_classes) * num_elements_; + } + return -1; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 57b2e5fa9e..26008e3ca9 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -677,5 +677,43 @@ std::pair, MSRStatus> ShardHeader::GetStatisticByID( } return std::make_pair(statistics_.at(statistic_id), SUCCESS); } + +MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { + // write header content to file, dump whatever is in the file before + std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); + if (page_out_handle.fail()) { + MS_LOG(ERROR) << "Failed in opening page file"; + return FAILED; + } + + auto pages = SerializePage(); + for (const auto &shard_pages : pages) { + page_out_handle << shard_pages << "\n"; + } + + page_out_handle.close(); + return SUCCESS; +} + +MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { + for (auto &v : pages_) { // clean pages + v.clear(); + } + // attempt to open the file contains the page in json + std::ifstream page_in_handle(dump_file_name.c_str()); + + if (!page_in_handle.good()) { + MS_LOG(INFO) << "No page file exists."; + return SUCCESS; + } + + std::string line; + while (std::getline(page_in_handle, line)) { + ParsePage(json::parse(line)); + } + + page_in_handle.close(); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc new file mode 100644 index 0000000000..8e2e892e63 --- /dev/null +++ b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindrecord/include/shard_pk_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) + : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, + uint32_t seed) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement +} + +MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) { + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc index ef627b0c09..a9cfce0d01 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc @@ -56,6 +56,24 @@ ShardSample::ShardSample(const std::vector &indices, uint32_t seed) shuffle_op_ = std::make_shared(seed); } +int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (sampler_type_ == kCustomTopNSampler) { + return no_of_samples_; + } + + if (sampler_type_ == kCustomTopPercentSampler) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } + if (sampler_type_ == kSubsetRandomSampler) { + return indices_.size(); + } + return -1; +} + const std::pair ShardSample::get_partitions() const { if (numerator_ == 1 && denominator_ > 1) { return std::pair(denominator_, partition_id_); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc index f8ad2c341d..757dcb7b74 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc @@ -20,25 +20,33 @@ namespace mindspore { namespace mindrecord { -ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} +ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) + : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} MSRStatus ShardShuffle::execute(ShardTask &tasks) { if (tasks.categories < 1) { return FAILED; } - uint32_t individual_size = tasks.Size() / tasks.categories; - std::vector> new_permutations(tasks.categories, std::vector(individual_size)); - for (uint32_t i = 0; i < tasks.categories; i++) { - for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); - std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); - } - shuffle_seed_++; - tasks.permutation_.clear(); - for (uint32_t j = 0; j < individual_size; j++) { + if (shuffle_type_ == kShuffleSample) { + if (tasks.permutation_.empty() == true) { + tasks.MakePerm(); + } + std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) + uint32_t individual_size = tasks.Size() / tasks.categories; + std::vector> new_permutations(tasks.categories, std::vector(individual_size)); for (uint32_t i = 0; i < tasks.categories; i++) { - tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); + std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); + } + tasks.permutation_.clear(); + for (uint32_t j = 0; j < individual_size; j++) { + for (uint32_t i = 0; i < tasks.categories; i++) { + tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + } } } + shuffle_seed_++; return SUCCESS; } } // namespace mindrecord diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc index 3744d881a4..be566d1601 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_task.cc @@ -35,8 +35,6 @@ void ShardTask::InsertTask(int shard_id, int group_id, const std::vector, std::vector, json> task) { @@ -44,9 +42,6 @@ void ShardTask::InsertTask(std::tuple, std::vector(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() << ", size of task_list_: " << task_list_.size() << "."; task_list_.push_back(std::move(task)); - MS_LOG(DEBUG) << "Out of insert task, shard_id: " << std::get<0>(std::get<0>(task)) - << ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() - << ", size of task_list_: " << task_list_.size() << "."; } void ShardTask::PopBack() { task_list_.pop_back(); } @@ -69,18 +64,39 @@ std::tuple, std::vector, json> &ShardTask::get_ta return task_list_[id]; } -ShardTask ShardTask::Combine(std::vector &category_tasks) { +std::tuple, std::vector, json> &ShardTask::get_random_task() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + return task_list_[dis(gen)]; +} +ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { ShardTask res; if (category_tasks.empty()) return res; auto total_categories = category_tasks.size(); res.categories = static_cast(total_categories); - auto minTasks = category_tasks[0].Size(); - for (uint32_t i = 1; i < total_categories; i++) { - minTasks = std::min(minTasks, category_tasks[i].Size()); - } - for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + if (replacement == false) { + auto minTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + minTasks = std::min(minTasks, category_tasks[i].Size()); + } + for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + for (uint32_t i = 0; i < total_categories; i++) { + res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast(task_no)))); + } + } + } else { + auto maxTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + maxTasks = std::max(maxTasks, category_tasks[i].Size()); + } + if (num_elements != std::numeric_limits::max()) { + maxTasks = static_cast(num_elements); + } for (uint32_t i = 0; i < total_categories; i++) { - res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast(task_no)))); + for (uint32_t j = 0; j < maxTasks; j++) { + res.InsertTask(category_tasks[i].get_random_task()); + } } } return res; diff --git a/mindspore/ccsrc/mindspore.cc b/mindspore/ccsrc/mindspore.cc index 542814016f..c98f67b51e 100644 --- a/mindspore/ccsrc/mindspore.cc +++ b/mindspore/ccsrc/mindspore.cc @@ -19,6 +19,6 @@ namespace mindspore { // cppcheck-suppress unusedFunction -std::string set_version(const std::string& version) { return version; } +std::string set_version(const std::string &version) { return version; } } // namespace mindspore diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc index 80661a4539..772986d714 100644 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/onnx/onnx_exporter.cc @@ -42,11 +42,11 @@ struct OpMergedInfo { }; using GenAttrFuncType = - std::function; + std::function; template -void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { auto casted_value = dyn_cast(value); if (casted_value == nullptr) { MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; @@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy } template -void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { auto tuple_ptr = dyn_cast(value); if (tuple_ptr == nullptr) { MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; @@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib attr_proto->set_type(attr_type); } -void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); auto attr_value = GetValue(value); if (attr_value == "VALID") { @@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType class OpAttrInfo { public: - OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) + OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) : attr_name_(attr_name), onnx_attr_name_(onnx_attr_name), onnx_attr_type_(onnx_attr_type), fn_gen_attr_(fn_gen_attr) {} ~OpAttrInfo() {} - const std::string& attr_name() const { return attr_name_; } - const std::string& onnx_attr_name() const { return onnx_attr_name_; } + const std::string &attr_name() const { return attr_name_; } + const std::string &onnx_attr_name() const { return onnx_attr_name_; } onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } @@ -134,27 +134,27 @@ class OpAttrInfo { class OpNameInfo { public: - OpNameInfo& set_op_type(const std::string& op_type) { + OpNameInfo &set_op_type(const std::string &op_type) { op_type_ = op_type; return *this; } - const std::string& op_type() const { return op_type_; } + const std::string &op_type() const { return op_type_; } - OpNameInfo& set_onnx_type(const std::string& onnx_type) { + OpNameInfo &set_onnx_type(const std::string &onnx_type) { onnx_type_ = onnx_type; return *this; } - const std::string& onnx_type() const { return onnx_type_; } + const std::string &onnx_type() const { return onnx_type_; } - OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) { + OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); return *this; } - const std::vector& op_attrs() const { return op_attrs_; } + const std::vector &op_attrs() const { return op_attrs_; } private: std::string op_type_; // operator type of MindSpore @@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE( .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto) .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, - [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto, - const PrimitivePtr& prim) { + [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, + const PrimitivePtr &prim) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); auto attr_value = GetValue(value); if (attr_value == "valid") { @@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, SetAttrValueToProto) .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, [](ValuePtr, onnx::AttributeProto_AttributeType, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); attr_proto->set_i(0); })) @@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE( #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name -void RegisterOpConverters(const std::function& fn) { +void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); fn(OP_CONVERT_FUNCTION_NAME(Mul)()); @@ -265,16 +265,16 @@ class OpConvertRegistry { public: ~OpConvertRegistry() { Clear(); } - static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } + static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } - static OpConvertRegistry& GetSingleton() { + static OpConvertRegistry &GetSingleton() { static OpConvertRegistry registry = OpConvertRegistry(); return registry; } - static const std::unordered_map& GetOpConvertMap() { return GetSingleton().op_map_; } + static const std::unordered_map &GetOpConvertMap() { return GetSingleton().op_map_; } void Clear() noexcept { op_map_.clear(); } @@ -289,59 +289,59 @@ class OnnxExporter { OnnxExporter() {} ~OnnxExporter() {} - std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); + std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); private: void InitModelInfo(); - void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); - void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); + void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); - size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - const PrimitivePtr& prim, const std::vector& inputs, - onnx::GraphProto* graph_proto); + size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *graph_proto); static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false); - void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto); - - void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector& nodes, - std::unordered_map* op_merged_infos_ptr); - void ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - - void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - - void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - - void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - - void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - std::string GetNodeInputName(const AnfNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto); - - void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto); - void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto); + void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); + void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); + + void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr); + void ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + + void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + std::string GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto); + + void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); + void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); size_t AllocateNodeIndex() { return ++onnx_node_index_; } void ResetNodeIndex() { onnx_node_index_ = 0; } - static int GetInt32Value(const AnfNodePtr& node) { + static int GetInt32Value(const AnfNodePtr &node) { auto value_node_ptr = dyn_cast(node); MS_EXCEPTION_IF_NULL(value_node_ptr); return GetValue(value_node_ptr->value()); @@ -352,7 +352,7 @@ class OnnxExporter { size_t onnx_node_index_ = 0; }; -std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { +std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ""; } @@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { OpConvertRegistry::GetSingleton().Clear(); OpConvertRegistry::RegisterAllOpConverters(); InitModelInfo(); - onnx::GraphProto* graph_proto = model_.mutable_graph(); + onnx::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_.SerializeAsString(); } @@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() { model_.set_ir_version(onnx::IR_VERSION_2019_1_22); model_.set_producer_name("MindSpore"); model_.set_producer_version("1.0"); - onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import(); + onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); opset_proto->set_version(9); } -void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { std::map node_map; onnx_node_index_ = func_graph->parameters().size(); @@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr ExportNodes(func_graph, &node_map, graph_proto); } -void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { - for (auto& param : func_graph->parameters()) { +void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto ¶m : func_graph->parameters()) { const ParameterPtr param_ptr = dyn_cast(param); if (param_ptr == nullptr) { MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; } - onnx::ValueInfoProto* input_proto = graph_proto->add_input(); + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_ptr->ToString()); SetValueInfoType(param_ptr, input_proto); @@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP continue; } // parameter with default value is an ONNX initializer - onnx::TensorProto* initializer_proto = graph_proto->add_initializer(); + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); initializer_proto->set_name(param_ptr->ToString()); SetTensorProtoInfo(param_ptr, initializer_proto); // set value for initializer @@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { return iter->second; } -void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) { +void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { auto dtype = node->Type(); auto shape = node->Shape(); - onnx::TypeProto* type_proto = value_proto->mutable_type(); + onnx::TypeProto *type_proto = value_proto->mutable_type(); if (dtype->isa() && shape->isa()) { auto tensor = dyn_cast(dtype); auto elem_type = tensor->element(); - const auto& dims = dyn_cast(shape)->shape(); + const auto &dims = dyn_cast(shape)->shape(); // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); type_proto->mutable_tensor_type()->set_elem_type(type); - for (const auto& dim : dims) { + for (const auto &dim : dims) { type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); } } } -void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) { +void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { auto dtype = param->Type(); auto shape = param->Shape(); if (!dtype->isa() || !shape->isa()) { @@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro auto tensor = dyn_cast(dtype); auto elem_type = tensor->element(); - const auto& dims = dyn_cast(shape)->shape(); + const auto &dims = dyn_cast(shape)->shape(); tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); - for (const auto& dim : dims) { + for (const auto &dim : dims) { tensor_proto->add_dims(dim); } } -void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector& nodes, - std::unordered_map* op_merged_infos_ptr) { - std::unordered_map& op_merged_infos = *op_merged_infos_ptr; +void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr) { + std::unordered_map &op_merged_infos = *op_merged_infos_ptr; - for (auto& node : nodes) { + for (auto &node : nodes) { if (!node->isa()) { continue; } @@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto // if the key `input` does not exist, just create a new one op_merged_infos[cnode].referred_count += 1; } - for (auto& input : cnode->inputs()) { + for (auto &input : cnode->inputs()) { if (!input->isa()) { continue; } @@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto * | +-- Parameter * | `-- ValueNode */ -void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::unordered_map op_merged_infos; MatchAndMark(func_graph, nodes, &op_merged_infos); - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { if (!node->isa()) { continue; } @@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_shape = node->input(2); std::string name_shape; if (input_shape->isa()) { auto const_node_idx = AllocateNodeIndex(); (*node_map_ptr)[input_shape] = const_node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); name_shape = std::to_string(const_node_idx); node_proto->add_output(name_shape); node_proto->set_op_type("Constant"); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); @@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimReshape->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(name_x); node_proto->add_input(name_shape); } -void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_axis = node->input(2); auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimReduceMean->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_data); if (input_axis->isa()) { - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("axes"); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); auto axis_value = dyn_cast(input_axis)->value(); @@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons } } -void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_type = node->input(2); auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimCast->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_data); if (input_type->isa()) { - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("to"); attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); auto type_value = dyn_cast(input_type)->value(); @@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod } } -void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); @@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { auto node_idx = AllocateNodeIndex(); - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type("Unsqueeze"); node_proto->add_output(std::to_string(node_idx)); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); attr_proto->set_name("axes"); attr_proto->add_ints(1); @@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type("PRelu"); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_x); node_proto->add_input(input_slope); } -void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert if (node->IsApply(prim::kPrimReshape)) { return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); @@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); } -size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map* node_map_ptr, - const PrimitivePtr& prim, const std::vector& inputs, - onnx::GraphProto* const graph_proto) { +size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *const graph_proto) { auto op_map = OpConvertRegistry::GetOpConvertMap(); auto op_iter = op_map.find(prim->name()); if (op_iter == op_map.end()) { MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; } - const OpNameInfo& op_convert_info = op_iter->second; + const OpNameInfo &op_convert_info = op_iter->second; auto node_idx = AllocateNodeIndex(); - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->add_output(std::to_string(node_idx)); node_proto->set_op_type(op_convert_info.onnx_type()); // Set inputs - for (const auto& input : inputs) { + for (const auto &input : inputs) { auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); node_proto->add_input(input_name); } // Set node attribute - for (const OpAttrInfo& attr : op_convert_info.op_attrs()) { - const std::string& attr_name = attr.attr_name(); + for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { + const std::string &attr_name = attr.attr_name(); ValuePtr attr_value = nullptr; if (!attr_name.empty()) { attr_value = prim->GetAttr(attr_name); @@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; } } - onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); onnx_attr_proto->set_name(attr.onnx_attr_name()); attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); } return node_idx; } -void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto conv_node = dyn_cast(node->input(1)); auto input_x = conv_node->input(1); // conv input x auto input_w = conv_node->input(2); // conv weight(filter) @@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); } -void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto matmul_node = dyn_cast(node->input(1)); auto input_x = matmul_node->input(1); // matmul input x auto input_y = matmul_node->input(2); // matmul input y @@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); } -void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { auto batch_norm_node = dyn_cast(node->input(1)); PrimitivePtr prim_batch_norm = dyn_cast((dyn_cast(batch_norm_node->input(0)))->value()); @@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); } -void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { if (node->inputs().size() != 2) { MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; } AnfNodePtr arg = node->input(1); std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); - onnx::ValueInfoProto* output_proto = graph_proto->add_output(); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); output_proto->set_name(name); SetValueInfoType(arg, output_proto, false); } -std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { if (node->isa()) { auto iter = node_map_ptr->find(node); if (iter == node_map_ptr->end()) { @@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::mapadd_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->add_output(node_name); SetNodeAttribute(node->cast()->value(), node_proto); @@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::maptype_name(); } -void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) { +void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { auto tuple_ptr = dyn_cast(value); MS_EXCEPTION_IF_NULL(tuple_ptr); if (tuple_ptr->size() == 0) { @@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto } } -void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) { +void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { node_proto->set_op_type("Constant"); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; } -std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { OnnxExporter exporter; return exporter.GetOnnxProtoString(func_graph); } diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc index 49dc3ab791..52b71f410f 100644 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ b/mindspore/ccsrc/operator/cc_implementations.cc @@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown }; // Whether has a T type data in AnyPtrList. template -bool HasType(const AnyPtrList& list) { - bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is(); }); +bool HasType(const AnyPtrList &list) { + bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); return ret; } -DataType InferType(const AnyPtrList& list) { +DataType InferType(const AnyPtrList &list) { if (HasType(list)) { return DataType::kDouble; } else if (HasType(list)) { @@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) { if (std::is_integral::value) { return static_cast(x) % static_cast(y); } - float x_int = std::floor(x); - float y_int = std::ceil(y); - float max = x_int / y_int; + int x_int = std::floor(x); + int y_int = std::ceil(y); + int max = x_int / y_int; float ret = x - y * max; return ret; } @@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) { } #define SCALAR_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList& list) { \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ do { \ if (list.size() < 2) { \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ @@ -223,7 +223,7 @@ SCALAR_OP(Pow) SCALAR_OP(Floordiv) #define LOGIC_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList& list) { \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ if (list.size() < 2) { \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ } \ @@ -274,7 +274,7 @@ LOGIC_OP(Ne) LOGIC_OP(Le) LOGIC_OP(Ge) -ValuePtr ScalarUAdd(const ValuePtrList& list) { +ValuePtr ScalarUAdd(const ValuePtrList &list) { if (list.size() != 1) { MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); } @@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) { return x; } -ValuePtr ScalarUSub(const ValuePtrList& list) { +ValuePtr ScalarUSub(const ValuePtrList &list) { if (list.size() != 1) { MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); } @@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; } -ValuePtr ScalarLog(const ValuePtrList& list) { +ValuePtr ScalarLog(const ValuePtrList &list) { if (list.empty()) { MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; } @@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); } -ValuePtr BoolNot(const ValuePtrList& list) { +ValuePtr BoolNot(const ValuePtrList &list) { if (list.empty()) { MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; } @@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); } -ValuePtr BoolAnd(const ValuePtrList& list) { +ValuePtr BoolAnd(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; } @@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; } -ValuePtr BoolOr(const ValuePtrList& list) { +ValuePtr BoolOr(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; } @@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; } -ValuePtr BoolEq(const ValuePtrList& list) { +ValuePtr BoolEq(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; } diff --git a/mindspore/ccsrc/operator/cc_implementations.h b/mindspore/ccsrc/operator/cc_implementations.h index 69981cea7d..cef34da4f4 100644 --- a/mindspore/ccsrc/operator/cc_implementations.h +++ b/mindspore/ccsrc/operator/cc_implementations.h @@ -29,29 +29,29 @@ namespace prim { using Any = mindspore::Any; using AnyPtrList = std::vector>; using ValuePtrList = std::vector; -using OpsFunction = std::function; -using AnfNodeOpsFunction = std::function&)>; +using OpsFunction = std::function; +using AnfNodeOpsFunction = std::function &)>; -ValuePtr ScalarAdd(const ValuePtrList& list); -ValuePtr ScalarSub(const ValuePtrList& list); -ValuePtr ScalarMul(const ValuePtrList& list); -ValuePtr ScalarDiv(const ValuePtrList& list); -ValuePtr ScalarMod(const ValuePtrList& list); -ValuePtr ScalarPow(const ValuePtrList& list); -ValuePtr ScalarFloordiv(const ValuePtrList& list); -ValuePtr ScalarUAdd(const ValuePtrList& list); -ValuePtr ScalarUSub(const ValuePtrList& list); -ValuePtr ScalarLog(const ValuePtrList& list); -ValuePtr ScalarEq(const ValuePtrList& list); -ValuePtr ScalarLt(const ValuePtrList& list); -ValuePtr ScalarGt(const ValuePtrList& list); -ValuePtr ScalarNe(const ValuePtrList& list); -ValuePtr ScalarLe(const ValuePtrList& list); -ValuePtr ScalarGe(const ValuePtrList& list); -ValuePtr BoolNot(const ValuePtrList& list); -ValuePtr BoolAnd(const ValuePtrList& list); -ValuePtr BoolOr(const ValuePtrList& list); -ValuePtr BoolEq(const ValuePtrList& list); +ValuePtr ScalarAdd(const ValuePtrList &list); +ValuePtr ScalarSub(const ValuePtrList &list); +ValuePtr ScalarMul(const ValuePtrList &list); +ValuePtr ScalarDiv(const ValuePtrList &list); +ValuePtr ScalarMod(const ValuePtrList &list); +ValuePtr ScalarPow(const ValuePtrList &list); +ValuePtr ScalarFloordiv(const ValuePtrList &list); +ValuePtr ScalarUAdd(const ValuePtrList &list); +ValuePtr ScalarUSub(const ValuePtrList &list); +ValuePtr ScalarLog(const ValuePtrList &list); +ValuePtr ScalarEq(const ValuePtrList &list); +ValuePtr ScalarLt(const ValuePtrList &list); +ValuePtr ScalarGt(const ValuePtrList &list); +ValuePtr ScalarNe(const ValuePtrList &list); +ValuePtr ScalarLe(const ValuePtrList &list); +ValuePtr ScalarGe(const ValuePtrList &list); +ValuePtr BoolNot(const ValuePtrList &list); +ValuePtr BoolAnd(const ValuePtrList &list); +ValuePtr BoolOr(const ValuePtrList &list); +ValuePtr BoolEq(const ValuePtrList &list); std::vector BroadcastShape_(std::vector s1, std::vector s2); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index 9a665e8a30..11ab31a292 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractEllipsis; +using mindspore::abstract::AbstractEllipsisPtr; using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunctionPtr; using mindspore::abstract::AbstractList; @@ -66,7 +68,7 @@ const MetaFuncGraphPtr kTail = std::make_shared("tail"); // Apply a function of two arguments cumulatively to the items of a sequence, // from left to right, so as to reduce the sequence to a single value.For example, // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). -AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { +AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { std::shared_ptr ret; size_t size = list.size(); if (size < 2) { @@ -88,7 +90,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { return ret; } -AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector& list) { +AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector &list) { size_t size = list.size(); if (size < 2) { MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; @@ -121,7 +123,7 @@ void HyperMap::Init() { {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); } -HyperMap::HyperMap(const std::shared_ptr& fn_leaf) +HyperMap::HyperMap(const std::shared_ptr &fn_leaf) : MetaFuncGraph("hyper_map"), fn_leaf_(fn_leaf), broadcast_(false), @@ -129,13 +131,13 @@ HyperMap::HyperMap(const std::shared_ptr& fn_leaf) Init(); } -HyperMap::HyperMap(const HyperMap& h) +HyperMap::HyperMap(const HyperMap &h) : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { Init(); } -AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); std::vector inputs; if (fn_arg != nullptr) { @@ -145,17 +147,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf } (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), - [](const std::pair& item) { return item.first; }); + [](const std::pair &item) { return item.first; }); return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { auto lhs = std::static_pointer_cast(item.second); MS_EXCEPTION_IF_NULL(lhs); return lhs->elements().size() != size; @@ -179,7 +181,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraph (void)std::transform( arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair& item) { + [&func_graph, i](const std::pair &item) { return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); }); @@ -188,13 +190,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraph return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { auto lhs = std::static_pointer_cast(item.second); MS_EXCEPTION_IF_NULL(lhs); return lhs->elements().size() != size; @@ -226,8 +228,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGrap return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(func_graph); @@ -257,11 +259,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGrap return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { bool found = false; TypeId id = kObjectTypeEnd; std::pair pair; - for (auto& item : arg_map) { + for (auto &item : arg_map) { pair = item; id = item.second->type_id(); if (nonleaf_.count(id)) { @@ -272,7 +274,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a if (found) { // In a nonleaf situation, all arguments must have the same generic. - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair &item) { if (item.first != pair.first) { return item.second->type_id() != pair.second->type_id(); } @@ -283,7 +285,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; int idx = 0; - for (auto& item : arg_map) { + for (auto &item : arg_map) { oss << ++idx << ": " << item.second->ToString() << "\n"; } MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); @@ -308,14 +310,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a } } -ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) { +ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { TypePtr type_tensor = std::make_shared(); bool flag = std::any_of( args_spec_list.begin(), args_spec_list.end(), - [type_tensor](const std::pair& item) { return IsSubType(item.second, type_tensor); }); + [type_tensor](const std::pair &item) { return IsSubType(item.second, type_tensor); }); if (flag && broadcast_) { ArgsPairList ret; - for (auto& item : args_spec_list) { + for (auto &item : args_spec_list) { if (!IsSubType(item.second, type_tensor)) { TypePtr type_tensor_ele = std::make_shared(item.second); ret.push_back( @@ -329,7 +331,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL return args_spec_list; } -FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { +FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr ptrGraph = std::make_shared(); ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->debug_info()->set_name("hyper_map"); @@ -353,7 +355,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { return ptrGraph; } -abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const { +abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { if (fn_leaf_ == nullptr) { MS_EXCEPTION_IF_NULL(args_spec_list[0]); // Assert that hypermap's function param does not contain free variables @@ -368,20 +370,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& AbstractBasePtrList broadened; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), - [](const AbstractBasePtr& arg) -> AbstractBasePtr { + [](const AbstractBasePtr &arg) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(arg); return arg->Broaden(); }); return broadened; } -REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { (void)py::class_>(*m, "HyperMap_") .def(py::init>(), py::arg("leaf")) .def(py::init<>()); })); -FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) { +FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { MS_EXCEPTION_IF_NULL(a_tuple); FuncGraphPtr ret = std::make_shared(); @@ -401,7 +403,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu return ret; } -FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) { +FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { MS_EXCEPTION_IF_NULL(a_list); FuncGraphPtr ret = std::make_shared(); @@ -421,7 +423,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list return ret; } -FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() != 1) { MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; } @@ -441,11 +443,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) } REGISTER_PYBIND_DEFINE( - Tail_, ([](const py::module* m) { - (void)py::class_>(*m, "Tail_").def(py::init()); + Tail_, ([](const py::module *m) { + (void)py::class_>(*m, "Tail_").def(py::init()); })); -FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { int tuple_size = SizeToInt(args_spec_list.size()); std::ostringstream ss; @@ -486,7 +488,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg return fg; } -GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param) +GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { if (get_by_list) { signatures_ = @@ -496,8 +498,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_ } } -FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, - const std::vector& params_list, bool applyJ) { +FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, + const std::vector ¶ms_list, bool applyJ) { FuncGraphPtr ret = std::make_shared(); ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); @@ -537,7 +539,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, return ret; } -void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, +void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, ValueNodePtr opsTupleItem) { MS_EXCEPTION_IF_NULL(func_graph); @@ -590,7 +592,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An } // Generate the graph. -FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() < 1) { MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " << args_spec_list.size() << "."; @@ -637,21 +639,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp return dfBuilder; } -REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { (void)py::class_>( *m, "GradOperation_") - .def(py::init(), py::arg("fn")) - .def(py::init(), py::arg("fn"), py::arg("get_all"), + .def(py::init(), py::arg("fn")) + .def(py::init(), py::arg("fn"), py::arg("get_all"), py::arg("get_by_list"), py::arg("sens_param")); })); -MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) { +MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { fn_cache_.clear(); signatures_ = std::vector({// def multitype(*args:ref): {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); } -void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) { +void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; auto fn = fn_cache_.find(types); if (fn != fn_cache_.end()) { @@ -660,7 +662,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) fn_cache_[types] = s_fn; } -void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) { +void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; auto fn = fn_cache_.find(types); if (fn != fn_cache_.end()) { @@ -669,9 +671,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& fn_cache_py_[types] = py_fn; } -void MultitypeFuncGraph::Register(const std::vector& types_name, const py::function& py_fn) { +void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { TypePtrList types; - for (auto& type_name : types_name) { + for (auto &type_name : types_name) { auto type_ptr = StringToType(type_name); if (type_ptr == nullptr) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error "; @@ -681,7 +683,7 @@ void MultitypeFuncGraph::Register(const std::vector& types_name, co Register(types, py_fn); } -void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) { +void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { std::vector types_name; for (size_t it = 0; it < tuple.size(); ++it) { py::object name_py = tuple[it]; @@ -693,16 +695,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& } Register(types_name, py_fn); } -static TypePtr UnwrapRef(const TypePtr& type) { +static TypePtr UnwrapRef(const TypePtr &type) { if (type->isa()) { return type->cast()->subtype(); } return type; } -FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { +FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { bool find_fn = false; py::function py_fn; - for (auto& item : fn_cache_py_) { + for (auto &item : fn_cache_py_) { TypePtrList sign = item.first; if (sign.size() != types.size()) { continue; @@ -735,7 +737,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ << "`, corresponding location info:\n"; int idx = 0; - for (auto& item : fn_cache_py_) { + for (auto &item : fn_cache_py_) { FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); if (func_graph == nullptr) { MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; @@ -747,15 +749,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { << oss.str(); } -REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { (void)py::class_>( *m, "MultitypeFuncGraph_") - .def(py::init()) + .def(py::init()) .def("register_fn", &MultitypeFuncGraph::PyRegister); })); // Generate the ListMap func graph. -FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { size_t args_num = args_spec_list.size(); // args: fn, list1, list2, ... if (args_num < 2) { @@ -821,8 +823,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis return fg_ptr; } -void ListMap::MakeCond(const std::vector& lists, const FuncGraphPtr& fgnext_ptr, - const FuncGraphPtr& fg_ptr) { +void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr &fgnext_ptr, + const FuncGraphPtr &fg_ptr) { MS_EXCEPTION_IF_NULL(fg_ptr); AnfNodePtr fn = fg_ptr->add_parameter(); @@ -858,8 +860,8 @@ void ListMap::MakeCond(const std::vector& lists, const FuncGraphPtr& fgtrue_ptr->set_output(output_cnode); } -void ListMap::MakeNext(const std::vector& lists, const FuncGraphPtr& fgcond_ptr, - const FuncGraphPtr& fg_ptr) { +void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr &fgcond_ptr, + const FuncGraphPtr &fg_ptr) { MS_EXCEPTION_IF_NULL(fg_ptr); AnfNodePtr fn = fg_ptr->add_parameter(); @@ -893,7 +895,7 @@ void ListMap::MakeNext(const std::vector& lists, const FuncGraphPtr& fg_ptr->set_output(output_cnode); } -FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // args: tuple1, tuple2 abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); AbstractBasePtr abs_a = args_spec_list[0]; @@ -928,7 +930,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li return ret; } -int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) { +int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { MS_EXCEPTION_IF_NULL(scalar); return GetValue(scalar->BuildValue()); } @@ -942,7 +944,7 @@ int GetPositiveIndex(int index, int length) { return index; } -int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) { +int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { MS_EXCEPTION_IF_NULL(member); if (member->isa()) { @@ -957,8 +959,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std << member->ToString(); } -void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index, - int* stop_index, int* step_value) { +void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, + int *stop_index, int *step_value) { MS_EXCEPTION_IF_NULL(tuple); MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(start_index); @@ -998,7 +1000,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl } } -FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tuple // args: tuple, start index, end index, step const std::string op_name("TupleSlice"); @@ -1032,7 +1034,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ return ret; } -int ConvertBinaryToDecimal(const std::vector& number_bin) { +int ConvertBinaryToDecimal(const std::vector &number_bin) { unsigned int number_dec = 0; for (size_t index = 0; index < number_bin.size(); index++) { number_dec |= number_bin[index] << index; @@ -1040,8 +1042,8 @@ int ConvertBinaryToDecimal(const std::vector& number_bin) { return static_cast(number_dec); } -void ParseSlice(const AbstractSlicePtr& slice, std::vector* begin, std::vector* end, - std::vector* strides, int length) { +void ParseSlice(const AbstractSlicePtr &slice, std::vector *begin, std::vector *end, + std::vector *strides, int length) { MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); @@ -1064,8 +1066,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector* begin, std::vec strides->push_back(step_value); } -int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector& shape, - std::vector* begin, std::vector* end, std::vector* strides) { +int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector &shape, + std::vector *begin, std::vector *end, std::vector *strides) { MS_EXCEPTION_IF_NULL(slice_tuple); MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); @@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, std::vector shrink; auto slice_tuple_eles = slice_tuple->elements(); + size_t ellipsis_num = 0; for (size_t index = 0; index < slice_tuple_size; index++) { if (slice_tuple_eles[index]->isa()) { AbstractSlicePtr slice = dyn_cast(slice_tuple_eles[index]); @@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, continue; } - MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got " + if (slice_tuple_eles[index]->isa()) { + ellipsis_num++; + if (ellipsis_num > 1) { + MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis"; + } + size_t ellipsis_len = shape_size - (slice_tuple_size - 1); + begin->insert(begin->end(), ellipsis_len, 0); + end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len); + strides->insert(strides->end(), ellipsis_len, 1); + shrink.insert(shrink.end(), ellipsis_len, 0); + continue; + } + + MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got " << slice_tuple_eles[index]->ToString(); } @@ -1111,8 +1127,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, return ConvertBinaryToDecimal(shrink); } -int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector& shape, - std::vector* begin, std::vector* end, std::vector* strides) { +int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector &shape, + std::vector *begin, std::vector *end, std::vector *strides) { MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(strides); @@ -1132,9 +1148,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const return 0; } -int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector& shape, - std::vector* begin, std::vector* end, - std::vector* strides) { +int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector &shape, + std::vector *begin, std::vector *end, + std::vector *strides) { MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(strides); @@ -1153,13 +1169,18 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co return 1; } -FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tensor // args: tensor, slice or slice tuple const std::string op_name = std::string("TensorSlice"); abstract::CheckArgsSize(op_name, args_spec_list, 2); AbstractTensorPtr tensorPtr = abstract::CheckArg(op_name, args_spec_list, 0); + FuncGraphPtr ret_graph = std::make_shared(); + ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr tensor_node = ret_graph->add_parameter(); + (void)ret_graph->add_parameter(); + auto shape = tensorPtr->shape()->shape(); std::vector begin; std::vector end; @@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides); } else if (args_spec_list[1]->isa()) { AbstractScalarPtr scalar_ptr = dyn_cast(args_spec_list[1]); + if (scalar_ptr->BuildValue()->isa()) { + if (scalar_ptr->BuildValue()->cast()->value()) { + return ExpandADim(ret_graph, tensor_node); + } + } shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); + } else if (args_spec_list[1]->isa()) { + ret_graph->set_output(tensor_node); + return ret_graph; + } else if (args_spec_list[1]->isa()) { + return ExpandADim(ret_graph, tensor_node); } else { std::ostringstream args_info; - for (const auto& arg : args_spec_list) { + for (const auto &arg : args_spec_list) { MS_EXCEPTION_IF_NULL(arg); args_info << arg->ToString() << "\n"; } - MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got " - << args_info.str(); + MS_LOG(EXCEPTION) + << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got " + << args_info.str(); } - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); - - AnfNodePtr tensor_node = ret_graph->add_parameter(); - (void)ret_graph->add_parameter(); - auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations"); auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0), NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)}); @@ -1199,19 +1225,25 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec return ret_graph; } -REGISTER_PYBIND_DEFINE( - TupleAdd_, ([](const py::module* m) { - (void)py::class_>(*m, "TupleAdd_").def(py::init()); - })); +FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const { + auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); + ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); + return ret_graph; +} + +REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleAdd_") + .def(py::init()); + })); -REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { (void)py::class_>(*m, "TupleSlice_") - .def(py::init()); + .def(py::init()); })); -REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { (void)py::class_>(*m, "TensorSlice_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index dc8627ba61..429cf5341a 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -47,20 +47,20 @@ using ArgsPairList = std::vector>; class MultitypeFuncGraph : public MetaFuncGraph { public: - explicit MultitypeFuncGraph(const std::string& name); + explicit MultitypeFuncGraph(const std::string &name); ~MultitypeFuncGraph() override = default; MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) - using specialize_fn = FuncGraph* (*)(TypePtrList); + using specialize_fn = FuncGraph *(*)(TypePtrList); // Register a method which specialize based on types vectors; - virtual void Register(const TypePtrList& types, specialize_fn s_fn); - virtual void Register(const TypePtrList& types, const py::function& py_fn); - virtual void Register(const std::vector& types_name, const py::function& py_fn); - virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn); + virtual void Register(const TypePtrList &types, specialize_fn s_fn); + virtual void Register(const TypePtrList &types, const py::function &py_fn); + virtual void Register(const std::vector &types_name, const py::function &py_fn); + virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); - FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } - const std::unordered_map& GetPyFunctions() const { + const std::unordered_map &GetPyFunctions() const { return fn_cache_py_; } @@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr; class HyperMap : public MetaFuncGraph { public: - explicit HyperMap(const std::shared_ptr& fn_leaf = nullptr); - HyperMap(const HyperMap& h); + explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); + HyperMap(const HyperMap &h); void Init(); - HyperMap& operator=(const HyperMap& h) { + HyperMap &operator=(const HyperMap &h) { if (this != &h) { fn_leaf_ = h.fn_leaf_; broadcast_ = h.broadcast_; @@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph { ~HyperMap() override = default; MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) - abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override; - FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override; + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } private: - AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map); - ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list); + AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); + ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); MultitypeFuncGraphPtr fn_leaf_; bool broadcast_; @@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr; class HyperMapPy : public HyperMap { public: - explicit HyperMapPy(const std::shared_ptr& fn_leaf = nullptr) : HyperMap(fn_leaf) {} + explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} ~HyperMapPy() override = default; MS_DECLARE_PARENT(HyperMapPy, HyperMap) }; @@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap; class Tail : public MetaFuncGraph { public: - explicit Tail(const std::string& name) : MetaFuncGraph(name) {} + explicit Tail(const std::string &name) : MetaFuncGraph(name) {} ~Tail() override = default; MS_DECLARE_PARENT(Tail, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple); - FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); + FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); - friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } }; using TailPtr = std::shared_ptr; class MakeTupleGradient : public MetaFuncGraph { public: - explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {} + explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} ~MakeTupleGradient() override = default; MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } }; using MakeTupleGradientPtr = std::shared_ptr; class GradOperation : public MetaFuncGraph { public: - explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false, + explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, bool sens_param = false); ~GradOperation() override = default; MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) - FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector& ptrParams, + FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, bool applyJ = false); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; bool sens_param() const { return sens_param_; } bool get_all_; bool get_by_list_; bool sens_param_; private: - void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, + void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, ValueNodePtr opsTupleItem); }; using GradOperationPtr = std::shared_ptr; class ListMap { public: - explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); } + explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } ~ListMap() = default; - void MakeCond(const std::vector& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr); - void MakeNext(const std::vector& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list); + void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); + void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); private: std::string name_; @@ -181,31 +181,33 @@ class ListMap { class TupleAdd : public MetaFuncGraph { public: - explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {} + explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} ~TupleAdd() override = default; MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } }; using TupleAddPtr = std::shared_ptr; class TupleSlice : public MetaFuncGraph { public: - explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {} + explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} ~TupleSlice() override = default; MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } }; using TupleSlicePtr = std::shared_ptr; class TensorSlice : public MetaFuncGraph { public: - explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {} + explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} ~TensorSlice() override = default; MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } + + FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; }; using TensorSlicePtr = std::shared_ptr; diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index a4a26377f5..c3fe45a48a 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -34,7 +34,7 @@ namespace prim { namespace { using PatternListType = std::initializer_list; -const std::vector& GetSignature(const ValuePtr& function) { +const std::vector &GetSignature(const ValuePtr &function) { static const auto empty = std::vector(); if (function->isa()) { return function->cast()->signatures(); @@ -44,8 +44,8 @@ const std::vector& GetSignature(const ValuePtr& function) { return empty; } -void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list, - const std::vector& signature, bool has_var, std::vector* op_inputs) { +void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, + const std::vector &signature, bool has_var, std::vector *op_inputs) { std::size_t sig_size = signature.size(); auto positional_size = sig_size; if (has_var) { @@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg } // Get the largest type of index in the same SignatureEnumDType of arguments. -std::map GetMaxDtypeIndex(const std::vector& dtypes, - const abstract::AbstractBasePtrList& args_spec_list) { +std::map GetMaxDtypeIndex(const std::vector &dtypes, + const abstract::AbstractBasePtrList &args_spec_list) { // record index for signature.dtypes of the same type // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} std::map> type_indexs; @@ -89,7 +89,7 @@ std::map GetMaxDtypeIndex(const std::vectorisa()) { arg_value = arg_value->cast()->ref(); @@ -104,7 +104,7 @@ std::map GetMaxDtypeIndex(const std::vector& signature, const abstract::AbstractBasePtrList& args_spec_list, - const FuncGraphPtr& graph, std::vector* op_inputs) { +void DoAutoCast(const std::vector &signature, const abstract::AbstractBasePtrList &args_spec_list, + const FuncGraphPtr &graph, std::vector *op_inputs) { std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature& sig) { return sig.dtype; }); + [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { return; @@ -137,16 +137,29 @@ void DoAutoCast(const std::vector& signature, const abstract::Abstrac if (it == dst_type.end() || it->second == i || !arg_value->isa()) { continue; } + // When scalar is of bool type, the type of tensor must also be of bool type, + // otherwise the cast operator will not be added. + auto scalar = arg_value->cast(); + auto scalar_type = scalar->BuildType(); + MS_EXCEPTION_IF_NULL(scalar_type); + if (scalar_type->type_id() == kNumberTypeBool) { + auto tensor = args_spec_list[it->second]->cast(); + auto tensor_type = tensor->element()->BuildType(); + MS_EXCEPTION_IF_NULL(tensor_type); + if (tensor_type->type_id() != kNumberTypeBool) { + continue; + } + } // get source node for cast AnfNodePtr source_node = (*op_inputs)[it->second + 1]; (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], source_node, graph); } } -AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const std::vector& params_list) { +AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const std::vector ¶ms_list) { // args: original inputs - auto& signature = GetSignature(function); + auto &signature = GetSignature(function); std::size_t sig_size = signature.size(); auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); if (sig_size > 0) { @@ -196,13 +209,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func } } // namespace -AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) { +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); return new_cnode; } -FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { FuncGraphPtr func_graph = std::make_shared(); for (size_t i = 0; i < args_spec_list.size(); ++i) { diff --git a/mindspore/ccsrc/operator/composite/do_signature.h b/mindspore/ccsrc/operator/composite/do_signature.h index b88053e224..3e1596d63f 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.h +++ b/mindspore/ccsrc/operator/composite/do_signature.h @@ -37,17 +37,17 @@ namespace mindspore { namespace prim { class DoSignatureMetaFuncGraph : public MetaFuncGraph { public: - explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function) + explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) : MetaFuncGraph("S-" + name), function_(function) {} ~DoSignatureMetaFuncGraph() override = default; MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override; + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; const ValuePtr function() const { return function_; } - friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) { + friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { return &lhs == &rhs; } @@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph { }; using RWSignaturePtr = std::shared_ptr; -AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs); +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.cc b/mindspore/ccsrc/operator/composite/list_append_operation.cc index 8621a8a8ba..b5a4fc626e 100644 --- a/mindspore/ccsrc/operator/composite/list_append_operation.cc +++ b/mindspore/ccsrc/operator/composite/list_append_operation.cc @@ -27,7 +27,7 @@ namespace mindspore { // namespace to support composite operators definition namespace prim { -FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) { +FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { abstract::CheckArgsSize("ListAppend", args_list, 2); AbstractBasePtr arg0 = args_list[0]; @@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& return ret; } -REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { (void)py::class_>(*m, "ListAppend_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.h b/mindspore/ccsrc/operator/composite/list_append_operation.h index f34b6b864e..1da3f9a009 100644 --- a/mindspore/ccsrc/operator/composite/list_append_operation.h +++ b/mindspore/ccsrc/operator/composite/list_append_operation.h @@ -28,15 +28,15 @@ namespace mindspore { namespace prim { class ListAppend : public MetaFuncGraph { public: - explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {} + explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {} ~ListAppend() override = default; MS_DECLARE_PARENT(ListAppend, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override; - friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) { + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override; + friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) { os << list_append.name_; return os; } - friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; } }; using ListAppendPtr = std::shared_ptr; } // namespace prim diff --git a/mindspore/ccsrc/operator/composite/unpack_call.cc b/mindspore/ccsrc/operator/composite/unpack_call.cc index 64d6b3433b..122f276657 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.cc +++ b/mindspore/ccsrc/operator/composite/unpack_call.cc @@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuplePtr; -FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tensor // args: tensor, slice or slice tuple const std::string op_name = std::string("UnpackCall"); @@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ AnfNodePtr para_dict = ret_graph->add_parameter(); auto dict_elems = arg_dict->elements(); (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), - [ret_graph, para_dict](const AbstractAttribute& item) { + [ret_graph, para_dict](const AbstractAttribute &item) { auto dict_get_item = ret_graph->NewCNode( {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); return ret_graph->NewCNode( @@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ return ret_graph; } -REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { (void)py::class_>(*m, "UnpackCall_") - .def(py::init()); + .def(py::init()); })); } // namespace prim diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h index 7ec5f9ad33..2f39615c1a 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.h +++ b/mindspore/ccsrc/operator/composite/unpack_call.h @@ -40,11 +40,11 @@ namespace prim { // and generate positional parameters and key-value pairs for function. class UnpackCall : public MetaFuncGraph { public: - explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} + explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} ~UnpackCall() override = default; MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } }; using UnpackCallPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/operator/composite/zip_operation.cc b/mindspore/ccsrc/operator/composite/zip_operation.cc index b87e19b009..4d34163f28 100644 --- a/mindspore/ccsrc/operator/composite/zip_operation.cc +++ b/mindspore/ccsrc/operator/composite/zip_operation.cc @@ -36,7 +36,7 @@ namespace prim { using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractTuple; -FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // zip operation: // input: tuple arguments // output: tuple of items of input iterated on every input @@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; } - auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool { + auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { MS_EXCEPTION_IF_NULL(abs); return abs->isa(); }); @@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe } auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), - [](const AbstractBasePtr& x, const AbstractBasePtr& y) { + [](const AbstractBasePtr &x, const AbstractBasePtr &y) { return (x->cast()->size() < y->cast()->size()); }); FuncGraphPtr ret_graph = std::make_shared(); @@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe return ret_graph; } -REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { (void)py::class_>(*m, "ZipOperation_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/zip_operation.h b/mindspore/ccsrc/operator/composite/zip_operation.h index e1fb8d60cf..1a3fa1f5fe 100644 --- a/mindspore/ccsrc/operator/composite/zip_operation.h +++ b/mindspore/ccsrc/operator/composite/zip_operation.h @@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr; class ZipOperation : public MetaFuncGraph { public: - explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {} + explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} ~ZipOperation() override = default; MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) { + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { os << op.name_; return os; } - friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } }; using ZipOperationPtr = std::shared_ptr; } // namespace prim diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index b1a8a9b782..91a54e1fdb 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -237,7 +237,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); -ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) { +ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) { py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); ValuePtr node = nullptr; bool succ = parse::ConvertData(obj, &node); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 26c13993e0..d84b2e4738 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -26,8 +26,8 @@ namespace mindspore { // namespace to support primitive operators namespace prim { -ValuePtr GetPythonOps(const std::string& op_name, - const std::string& module_name = "mindspore._extends.parse.standard_method"); +ValuePtr GetPythonOps(const std::string &op_name, + const std::string &module_name = "mindspore._extends.parse.standard_method"); // Arithmetic extern const PrimitivePtr kPrimScalarAdd; @@ -240,7 +240,7 @@ extern const PrimitivePtr kPrimVirtualDataset; class DoSignaturePrimitive : public Primitive { public: - explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function) + explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) : Primitive("S-Prim-" + name), function_(function) {} ~DoSignaturePrimitive() override = default; @@ -256,7 +256,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr; class UnpackGraphPrimitive : public Primitive { public: - explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) + explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} ~UnpackGraphPrimitive() override = default; MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) diff --git a/mindspore/ccsrc/operator/prim_to_function.cc b/mindspore/ccsrc/operator/prim_to_function.cc index bdfe48157c..733cdbdb73 100644 --- a/mindspore/ccsrc/operator/prim_to_function.cc +++ b/mindspore/ccsrc/operator/prim_to_function.cc @@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction() {"scalar_sub", kPrimTypeTwoArgs}, {"scalar_floordiv", kPrimTypeTwoArgs}}) {} -bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { +bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { bool result = false; if (func != nullptr) { @@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu return result; } -int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const { +int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { MS_EXCEPTION_IF_NULL(prim); int prim_type = static_cast(kPrimTypeUnknown); diff --git a/mindspore/ccsrc/operator/prim_to_function.h b/mindspore/ccsrc/operator/prim_to_function.h index 71518e4057..285ab8d3ab 100644 --- a/mindspore/ccsrc/operator/prim_to_function.h +++ b/mindspore/ccsrc/operator/prim_to_function.h @@ -41,21 +41,21 @@ class PrimToFunction; class PrimToFunction { public: // Return a thread-safe singleton instance - static PrimToFunction& GetInstance() { + static PrimToFunction &GetInstance() { static PrimToFunction instance; return instance; } - PrimToFunction(const PrimToFunction&) = delete; - PrimToFunction& operator=(const PrimToFunction&) = delete; + PrimToFunction(const PrimToFunction &) = delete; + PrimToFunction &operator=(const PrimToFunction &) = delete; ~PrimToFunction() = default; // Get the args and return value for a primitive instance. - bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const; + bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const; private: PrimToFunction(); // Get the number of primitive arguments - int GetPrimType(const PrimitivePtr& prim) const; + int GetPrimType(const PrimitivePtr &prim) const; const std::unordered_map prim_func_type_map_; }; } // namespace prim diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.cc b/mindspore/ccsrc/optimizer/ad/adjoint.cc index 46746b3f44..ed89aba20e 100644 --- a/mindspore/ccsrc/optimizer/ad/adjoint.cc +++ b/mindspore/ccsrc/optimizer/ad/adjoint.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace ad { -Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller) +Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) : primal_(primal), caller_(caller), dout_(nullptr) { if (k != nullptr) { k_ = k; @@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP AnfNodePtr Adjoint::k() { return k_; } -void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } +void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } -void Adjoint::UpdateK(const AnfNodePtr& new_k) { +void Adjoint::UpdateK(const AnfNodePtr &new_k) { MS_EXCEPTION_IF_NULL(new_k); MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); // In recursive case, it needs update. - for (auto& user : k_user_) { + for (auto &user : k_user_) { MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" << new_k->ToString(); if (user.first->input(user.second) != k_) { @@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; } AnfNodePtr Adjoint::dout() { return dout_hole_; } -void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) { +void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { dout_user_.emplace_back(std::make_pair(user, index)); } -void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { +void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { if (dout_ != nullptr) { MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); auto add = prim::GetPythonOps("hyper_add"); @@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { void Adjoint::CallDoutHole() { if (dout_ != nullptr) { - for (auto& user : dout_user_) { + for (auto &user : dout_user_) { MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " << dout_->ToString(); if (user.first->input(user.second) != dout_hole_) { diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.h b/mindspore/ccsrc/optimizer/ad/adjoint.h index 673928129b..b2dae8e66f 100644 --- a/mindspore/ccsrc/optimizer/ad/adjoint.h +++ b/mindspore/ccsrc/optimizer/ad/adjoint.h @@ -28,15 +28,15 @@ namespace mindspore { namespace ad { class Adjoint { public: - Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller); + Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); ~Adjoint() = default; AnfNodePtr primal(); AnfNodePtr k(); - void UpdateK(const AnfNodePtr& k); - void RegisterKUser(const CNodePtr& user, size_t index); + void UpdateK(const AnfNodePtr &k); + void RegisterKUser(const CNodePtr &user, size_t index); AnfNodePtr dout(); - void AccumulateDout(const AnfNodePtr& dout_factor); - void RegisterDoutUser(const CNodePtr& user, size_t index); + void AccumulateDout(const AnfNodePtr &dout_factor); + void RegisterDoutUser(const CNodePtr &user, size_t index); void CallDoutHole(); private: diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc index 9e713d3425..fe11191546 100644 --- a/mindspore/ccsrc/optimizer/clean.cc +++ b/mindspore/ccsrc/optimizer/clean.cc @@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractTuple; -static AbstractBasePtr Reabs(const AbstractBasePtr& t) { +static AbstractBasePtr Reabs(const AbstractBasePtr &t) { if (t == nullptr) { return nullptr; } @@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { AbstractBasePtrList baselist; auto attributes = abs_class->attributes(); (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), - [](const AbstractAttribute& item) { return item.second; }); + [](const AbstractAttribute &item) { return item.second; }); res = std::make_shared(baselist); } else if (t->isa()) { auto abs_dict = dyn_cast(t); AbstractBasePtrList baselist; auto elements = abs_dict->elements(); (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), - [](const AbstractAttribute& item) { return item.second; }); + [](const AbstractAttribute &item) { return item.second; }); res = std::make_shared(baselist); } else if (t->isa()) { auto abs_dict = dyn_cast(t); @@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { return res; } -AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [getattr, data, attribute] MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); @@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); - const auto& cmap = ct->attributes(); + const auto &cmap = ct->attributes(); int count = 0; - for (auto& item : cmap) { + for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } @@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } -AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); // Inputs should be [dict_getitem, dict, item] - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); AnfNodePtr data = inputs[1]; @@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); - const auto& cmap = ct->elements(); + const auto &cmap = ct->elements(); int count = 0; - for (auto& item : cmap) { + for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } @@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } -AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { +AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); @@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { return node->func_graph()->NewCNode(inputs); } -AnfNodePtr ErasePartialNode(const CNodePtr& node) { +AnfNodePtr ErasePartialNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); @@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) { return nullptr; } -AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { +AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); @@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { return node->func_graph()->NewCNode(inputs); } -AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [list_getitem, list, item] if (inputs.size() < 3) { MS_LOG(EXCEPTION) << "Node's input number < 3."; @@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); } -AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { +AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [list_setitem, list, index, item] if (inputs.size() < 4) { MS_LOG(EXCEPTION) << "Node's input number < 4."; @@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); } -AnfNodePtr EraseMakeDictNode(const CNodePtr& node) { +AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); return inputs[2]; } -AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) { +AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [make_keyword_arg, key, value] MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); return inputs[2]; } -AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) { +AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [extract_keyword_arg, arg, key] MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); return inputs[2]; } -ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) { +ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { const int DEPTH_MAX = 5; if (depth > DEPTH_MAX) { MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; } std::vector elements; - for (const auto& it : value_list->value()) { + for (const auto &it : value_list->value()) { ValuePtr value = nullptr; if (it->isa()) { value = ConvertValueListToValueTuple(it->cast(), depth + 1); @@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d return std::make_shared(elements); } -AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { +AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(node); ValuePtr value = node->value(); auto value_list = value->cast(); @@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { // Convert class to Tuple // Convert getattr to getitem // Convert make_record to make_tuple -void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { +void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); - for (auto& node : all_node) { + for (auto &node : all_node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); AnfNodePtr new_node = nullptr; @@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& } } - for (auto& node : manager->all_nodes()) { + for (auto &node : manager->all_nodes()) { auto ret = Reabs(node->abstract()); node->set_abstract(ret); } } // expand tuples in graph parameters -static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph, - const std::vector& params) { +static std::vector ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, + const std::vector ¶ms) { MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(func_graph); std::vector new_params; - for (const auto& param : params) { + for (const auto ¶m : params) { MS_EXCEPTION_IF_NULL(param); auto param_abs = param->abstract(); MS_EXCEPTION_IF_NULL(param_abs); @@ -350,7 +350,7 @@ static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, con std::vector new_param; std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; auto abs_tuple = dyn_cast(param_abs); - for (auto& elem : abs_tuple->elements()) { + for (auto &elem : abs_tuple->elements()) { auto np = std::make_shared(func_graph); np->set_abstract(elem); new_param.emplace_back(np); @@ -366,11 +366,11 @@ static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, con } // expand tuples in graph applies -static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const std::vector& inputs) { +static std::vector ExpandTuplesC(const FuncGraphPtr &graph, const std::vector &inputs) { MS_EXCEPTION_IF_NULL(graph); std::vector new_inputs; - for (const auto& input : inputs) { + for (const auto &input : inputs) { MS_EXCEPTION_IF_NULL(input); auto input_abs = input->abstract(); @@ -391,7 +391,7 @@ static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const st int idx = 0; std::vector new_input; auto abs_tuple = dyn_cast(input_abs); - for (auto& elem : abs_tuple->elements()) { + for (auto &elem : abs_tuple->elements()) { auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); AbstractBasePtr aptr = std::make_shared(std::make_shared(idx)); c_node->input(2)->set_abstract(aptr); @@ -416,19 +416,19 @@ static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const st // tuples in Graph's parameters: AbstractTuple (a, b, c) --> // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) // cppcheck-suppress unusedFunction -void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { +void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); - for (auto& node : all_node) { + for (auto &node : all_node) { auto cnode = node->cast(); if (cnode == nullptr) { continue; } - const auto& inputs = cnode->inputs(); + const auto &inputs = cnode->inputs(); // Bypass the first input in inputs as it's fn. if (!IsValueNode(inputs[0])) { @@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { } FuncGraphSet all_graph = manager->func_graphs(); - for (auto& func_graph : all_graph) { + for (auto &func_graph : all_graph) { MS_EXCEPTION_IF_NULL(func_graph); auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); manager->SetParameters(func_graph, expand_p); diff --git a/mindspore/ccsrc/optimizer/control_depend.h b/mindspore/ccsrc/optimizer/control_depend.h index 2a51a24718..076e2c0229 100644 --- a/mindspore/ccsrc/optimizer/control_depend.h +++ b/mindspore/ccsrc/optimizer/control_depend.h @@ -22,7 +22,7 @@ namespace mindspore { namespace opt { // Automatically adding control depend based on effect order and side effect analysis. -void AddControlDepend(const FuncGraphPtr& graph); +void AddControlDepend(const FuncGraphPtr &graph); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc index 5daeced3a5..32a42bc16b 100644 --- a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc +++ b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc @@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func nodes.push_back(func_node); // {unpackcall, {GradOperation, ...}, args...} std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr& node) { return node; }); + [](const AnfNodePtr &node) { return node; }); unpack_graph_node = func_graph->NewCNode(nodes); } else { auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); @@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func nodes.push_back(func_node); // {{GradOperation, ...}, args...} std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr& node) { return node; }); + [](const AnfNodePtr &node) { return node; }); unpack_graph_node = func_graph->NewCNode(nodes); } return unpack_graph_node; } // get metagraph of value node -MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { +MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { ValuePtr value; if (IsValueNode(node)) { value = GetValueNode(node)->cast()->function(); @@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { } // check if node is a specific metafuncgraph op -bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { +bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { if (node != nullptr) { auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); if (meta_func_graph_ptr == nullptr) { @@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr // {{GradOperation, g, w}, Ys} // {UnPackCall, {GradOperation, g, w}, Ys} -AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { +AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 24339ddb84..0dbaf1107f 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -31,20 +31,20 @@ namespace mindspore { /* namespace to support opt */ namespace opt { -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, - const RenormAction& renorm_action) { - auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &renorm_action) { + auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const std::vector& prims, const RenormAction& renorm_action) { - auto fn = [prims](const AnfNodePtr& node) -> bool { +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, + const std::vector &prims, const RenormAction &renorm_action) { + auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { return false; } - for (auto& prim : prims) { + for (auto &prim : prims) { if (IsPrimitiveCNode(node, prim)) { return true; } @@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const PredicateFuncType& predicate, const RenormAction& renorm_action) { +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, + const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } -AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { #ifdef ENABLE_PROFILE double t = GetTime(); #endif @@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode return result; } -bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node, - const SubstitutionPtr& transform) const { +bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, + const SubstitutionPtr &transform) const { FuncGraphManagerPtr manager = optimizer->manager(); std::unordered_set seen_node; std::deque todo{root_node}; @@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo } if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); } - auto& node_users = manager->node_users(); + auto &node_users = manager->node_users(); if (change && node_users.find(node) != node_users.end()) { - for (auto& use : node_users[node]) { + for (auto &use : node_users[node]) { auto use_node = use.first; todo.push_back(use_node); if (seen_node.find(use_node) != seen_node.end()) { @@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo return changes; } -bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const { +bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = optimizer->manager(); @@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize do { loop = false; - for (auto const& transform : list_) { + for (auto const &transform : list_) { auto change = ApplyTransform(optimizer, func_graph->output(), transform); changes = changes || change; loop = loop || change; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc index 03f7d054e0..30173e533c 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) { +std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -39,7 +39,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t MS_EXCEPTION_IF_NULL(manager); auto node_set = manager->node_users()[para]; std::unordered_set cnode_set; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { auto cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -54,7 +54,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t (void)cnode_set.emplace(cnode); } else { auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); - for (auto& cnode_sub : cnode_set_sub) { + for (auto &cnode_sub : cnode_set_sub) { (void)cnode_set.emplace(cnode_sub); } } @@ -63,8 +63,8 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t } Status AllreduceFusion::AddNodeToGraph() { - const auto& parameters = root_graph_->parameters(); - for (auto& parameter : parameters) { + const auto ¶meters = root_graph_->parameters(); + for (auto ¶meter : parameters) { if (!ParameterRequireGrad(parameter)) { continue; } @@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() { if (cnode_set.empty()) { continue; } - for (auto& cnode : cnode_set) { + for (auto &cnode : cnode_set) { MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); @@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() { return SUCCESS; } -CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const { +CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi return cnode_dist; } else { auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); - for (auto& ele : cnode_dist_next) { + for (auto &ele : cnode_dist_next) { cnode_dist[ele.first] = cost + ele.second; } } } else { auto cnode_dist_next = FindNextCNodes(cnode); - for (auto& ele : cnode_dist_next) { + for (auto &ele : cnode_dist_next) { cnode_dist[ele.first] = ele.second; } } return cnode_dist; } -CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const { +CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; } - const auto& from_inputs = from->inputs(); + const auto &from_inputs = from->inputs(); std::unordered_map dist_map; MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; - for (auto& input_node : from_inputs) { + for (auto &input_node : from_inputs) { auto cnode_dist = FindCNode(input_node, recursive_times + 1); - for (auto& ele : cnode_dist) { + for (auto &ele : cnode_dist) { (void)dist_map.emplace(ele); } } @@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu Status AllreduceFusion::AddEdgeToGraph() { std::unordered_map cnode_state_map; - const auto& cnodes = allreduce_graph_.cnode_set(); - for (auto& cnode : cnodes) { + const auto &cnodes = allreduce_graph_.cnode_set(); + for (auto &cnode : cnodes) { cnode_state_map[cnode] = 0; } - const auto& head_cnode = allreduce_graph_.head_cnode(); + const auto &head_cnode = allreduce_graph_.head_cnode(); std::queue cnode_queue; cnode_queue.emplace(head_cnode); cnode_state_map[head_cnode] = 1; @@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() { cnode_queue.pop(); cnode_state_map[cur_cnode] = 2; auto next = FindNextCNodes(cur_cnode); - for (auto& ele : next) { - auto& cnode = ele.first; - auto& dist = ele.second; + for (auto &ele : next) { + auto &cnode = ele.first; + auto &dist = ele.second; if (cnode_state_map[cnode] == 0) { cnode_queue.emplace(cnode); cnode_state_map[cnode] = 1; @@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() { return SUCCESS; } -std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) { +std::vector FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -184,7 +184,7 @@ std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_time MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[para]; std::vector cnode_list; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { auto cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -210,7 +210,7 @@ std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_time return cnode_list; } -void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) { +void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { MS_EXCEPTION_IF_NULL(mirror_cnode); MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; auto node_prim = GetValueNode(mirror_cnode->input(0)); @@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); } -Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { +Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { auto mirror_cnodes = FindMirror(para); if (mirror_cnodes.empty()) { MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; return SUCCESS; } if (mirror_cnodes.size() > 2) { - for (auto& mirror_cnode : mirror_cnodes) { + for (auto &mirror_cnode : mirror_cnodes) { MS_EXCEPTION_IF_NULL(mirror_cnode); MS_LOG(INFO) << mirror_cnode->DebugString(); } @@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { << "Mirror CNode found."; return FAILED; } - for (auto& mirror_cnode : mirror_cnodes) { + for (auto &mirror_cnode : mirror_cnodes) { auto parameter_name = ParameterName(para); SetMirrorFusion(mirror_cnode, fusion, parameter_name); } return SUCCESS; } -Status FindMirrorAndSetFusion(const std::vector& paras, int32_t fusion) { - for (auto& param_node : paras) { +Status FindMirrorAndSetFusion(const std::vector ¶s, int32_t fusion) { + for (auto ¶m_node : paras) { if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; return FAILED; @@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector& paras, int32_t fusi return SUCCESS; } -Status AllreduceFusion::SetFusion(const std::vector& cost_map) { +Status AllreduceFusion::SetFusion(const std::vector &cost_map) { if (cost_map.size() < 2) { MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); return FAILED; @@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { return SetFusionByBackwardCompAndAllreduceTime(); } -Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) { +Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { if (ret == nullptr) { MS_LOG(ERROR) << "ret is nullptr."; return FAILED; @@ -399,7 +399,12 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) { ret_ = ret; root_graph_ = ret_->func_graph(); MS_EXCEPTION_IF_NULL(root_graph_); - auto forward_graph = ForwardGraph(root_graph_); + auto graph_set = ForwardGraph(root_graph_); + if (graph_set.size() > 1) { + MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now."; + return SUCCESS; + } + auto forward_graph = *(graph_set.begin()); MS_EXCEPTION_IF_NULL(forward_graph); forward_ret_ = forward_graph->get_return(); MS_EXCEPTION_IF_NULL(forward_ret_); diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h index 67dc55836a..43a9935095 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h @@ -50,15 +50,15 @@ class AllreduceFusion { allreduce_bandwidth_(0), computation_time_parameter_(0) {} virtual ~AllreduceFusion() = default; - Status ProcessAllreduceFusion(const CNodePtr& ret); + Status ProcessAllreduceFusion(const CNodePtr &ret); private: Status AddNodeToGraph(); - CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const; - CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const; + CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; + CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; Status AddEdgeToGraph(); std::vector GenerateCostMap(int32_t fusion_times, double tail_percent) const; - Status SetFusion(const std::vector& cost_map); + Status SetFusion(const std::vector &cost_map); Status SetFusionByAlgorithm(int32_t algorithm); Status SetFusionByBackwardCompTime(); Status SetFusionByBackwardCompAndAllreduceTime(); diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc index 9e04593c83..2a98a38add 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace parallel { -Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { +Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { AllreduceNodePtr arnode; auto cnode_emplace_return = cnode_set_.emplace(node); if (!cnode_emplace_return.second) { @@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { return SUCCESS; } -Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) { +Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { auto from_arnode_iter = cnode_arnode_map_.find(from); if (from_arnode_iter == cnode_arnode_map_.end()) { MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; @@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double return SUCCESS; } -bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const { +bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { auto cnode_iter = cnode_set_.find(node); return !(cnode_iter == cnode_set_.end()); } std::vector AllreduceGraph::GetParaByCost(double from, double to) { std::vector nodes; - for (auto& cnode_arnode : cnode_arnode_map_) { + for (auto &cnode_arnode : cnode_arnode_map_) { MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() << " curr_para_size: " << cnode_arnode.second->curr_para_size(); @@ -117,7 +117,7 @@ std::pair, double> AllreduceGraph::GetParaByParaSize(dou std::vector nodes; double cur_para_size = 0; double from = to; - for (auto& arnode : arnode_vec_) { + for (auto &arnode : arnode_vec_) { if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { continue; } @@ -135,14 +135,14 @@ std::pair, double> AllreduceGraph::GetParaByParaSize(dou void AllreduceGraph::PrintCNodeSet() const { MS_LOG(INFO) << "CNodeSet:"; - for (auto& cnode : cnode_set_) { + for (auto &cnode : cnode_set_) { MS_LOG(INFO) << cnode->DebugString(); } } void AllreduceGraph::PrintAllredueGraphInfo() const { MS_LOG(INFO) << "max: " << max_; - for (auto& cnode_arnode : cnode_arnode_map_) { + for (auto &cnode_arnode : cnode_arnode_map_) { MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); MS_LOG(INFO) << "arnode info: "; cnode_arnode.second->ToString(); @@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const { void AllreduceGraph::PrintArnodeVec() const { MS_LOG(INFO) << "ArnodeVec:"; - for (auto& arnode : arnode_vec_) { + for (auto &arnode : arnode_vec_) { arnode.ToString(); } } void AllreduceGraph::PrintArnodeSet() const { MS_LOG(INFO) << "ArnodeSet:"; - for (auto& arnode : arnode_set_) { + for (auto &arnode : arnode_set_) { arnode->ToString(); } } void AllreduceGraph::SortArnode() { arnode_vec_.clear(); - for (auto& node : arnode_set_) { + for (auto &node : arnode_set_) { arnode_vec_.emplace_back(*node); } std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); @@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() { Status AllreduceGraph::RemoveExtraParas() { std::unordered_set para_map; - for (auto& node : arnode_vec_) { - for (auto& para : node.paras()) { + for (auto &node : arnode_vec_) { + for (auto ¶ : node.paras()) { auto emplac_result = para_map.emplace(para); if (!emplac_result.second) { MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; @@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() { return SUCCESS; } -Status AllreduceGraph::set_head_cnode(const CNodePtr& node) { +Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { auto arnode = std::make_shared(AllreduceNode()); if (arnode->Init(node) != SUCCESS) { MS_LOG(ERROR) << "AllreduceNode Init failed"; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h index f0db78a130..b2084b735c 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h @@ -42,9 +42,9 @@ class AllreduceGraph { cnode_arnode_map_(), max_(0) {} virtual ~AllreduceGraph() = default; - Status AddNode(const CNodePtr& node, const AnfNodePtr& para); - Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist); - bool NodeInGraph(const CNodePtr& node) const; + Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); + Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); + bool NodeInGraph(const CNodePtr &node) const; std::vector GetParaByCost(double from, double to); // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is // over para_size. @@ -60,9 +60,9 @@ class AllreduceGraph { void PrintAllredueGraphInfo() const; void PrintArnodeVec() const; void PrintArnodeSet() const; - const std::unordered_set& cnode_set() const { return cnode_set_; } + const std::unordered_set &cnode_set() const { return cnode_set_; } CNodePtr head_cnode() const { return head_cnode_; } - Status set_head_cnode(const CNodePtr& node); + Status set_head_cnode(const CNodePtr &node); double max() const { return max_; } private: diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc index 6be588928a..113d4ec59b 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -Status AllreduceNode::AddNext(const AllreduceNodePtr& next_node) { +Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) { if (next_node == nullptr) { MS_LOG(ERROR) << "next_node is nullptr!"; return FAILED; @@ -30,7 +30,7 @@ Status AllreduceNode::AddNext(const AllreduceNodePtr& next_node) { return SUCCESS; } -Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max) { +Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) { if (prev_node == nullptr) { MS_LOG(ERROR) << "next_node is nullptr!"; return FAILED; @@ -46,7 +46,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do *max = depend_feat_size_; } std::queue next_queue; - for (auto& next : next_) { + for (auto &next : next_) { next_queue.push(next); } while (!next_queue.empty()) { @@ -55,7 +55,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do if (ele->depend_feat_size() > *max) { *max = ele->depend_feat_size(); } - for (auto& next : ele->next()) { + for (auto &next : ele->next()) { next_queue.push(next); } next_queue.pop(); @@ -63,7 +63,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do return SUCCESS; } -Status AllreduceNode::Init(const CNodePtr& cnode_ptr) { +Status AllreduceNode::Init(const CNodePtr &cnode_ptr) { if (cnode_ptr == nullptr) { MS_LOG(ERROR) << "cnode_ptr is nullptr!"; return FAILED; @@ -72,7 +72,7 @@ Status AllreduceNode::Init(const CNodePtr& cnode_ptr) { return SUCCESS; } -Status AllreduceNode::AddPara(const AnfNodePtr& node_ptr) { +Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { if (node_ptr == nullptr) { MS_LOG(ERROR) << "node_ptr is nullptr!"; return FAILED; @@ -99,7 +99,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr& node_ptr) { return SUCCESS; } -Status AllreduceNode::RemovePara(const AnfNodePtr& node_ptr) { +Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) { if (node_ptr == nullptr) { MS_LOG(ERROR) << "node_ptr is nullptr!"; return FAILED; @@ -115,7 +115,7 @@ Status AllreduceNode::RemovePara(const AnfNodePtr& node_ptr) { void AllreduceNode::ToString() const { MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size(); - for (auto& para : paras_) { + for (auto ¶ : paras_) { MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para); } MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h index d9ba98c3a2..db1c4e3f2e 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h @@ -33,23 +33,23 @@ class AllreduceNode { public: AllreduceNode() : cnode_ptr_(nullptr), prev_(), next_(), paras_(), para_size_map_(), curr_para_size_(0), depend_feat_size_(0) {} - Status Init(const CNodePtr& cnode_ptr); - Status AddPara(const AnfNodePtr& node_ptr); - Status RemovePara(const AnfNodePtr& node_ptr); - const std::unordered_set& paras() const { return paras_; } + Status Init(const CNodePtr &cnode_ptr); + Status AddPara(const AnfNodePtr &node_ptr); + Status RemovePara(const AnfNodePtr &node_ptr); + const std::unordered_set ¶s() const { return paras_; } double curr_para_size() const { return curr_para_size_; } virtual ~AllreduceNode() = default; // Add previous node // prev_node is the previous to be added // max is the current max depend_feat_size of the AllreduceGraph - Status AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max); - Status AddNext(const AllreduceNodePtr& next_node); + Status AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max); + Status AddNext(const AllreduceNodePtr &next_node); double depend_feat_size() const { return depend_feat_size_; } void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; } - const std::vector& next() const { return next_; } + const std::vector &next() const { return next_; } void ToString() const; - bool operator<(const AllreduceNode& node) const { return depend_feat_size_ < node.depend_feat_size(); } - bool operator>(const AllreduceNode& node) const { return depend_feat_size_ > node.depend_feat_size(); } + bool operator<(const AllreduceNode &node) const { return depend_feat_size_ < node.depend_feat_size(); } + bool operator>(const AllreduceNode &node) const { return depend_feat_size_ > node.depend_feat_size(); } private: CNodePtr cnode_ptr_; diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc index 190f589bb5..ad3a3a1298 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -void Simplify(CostPtrList* clist_ptrs) { +void Simplify(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method // excludes the cost with greater computation_cost_ and greater communication_cost. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} @@ -44,7 +44,7 @@ void Simplify(CostPtrList* clist_ptrs) { *clist_ptrs = std::move(ret); } -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { +void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. if (!COST_MODEL_SIMPLIFY_CALCULATION) { @@ -66,7 +66,7 @@ void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { *clist_ptrs = std::move(ret); } -void RefineForPracticalCost(const CostPtr& origin_cost, bool is_redistribution) { +void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { MS_EXCEPTION_IF_NULL(origin_cost); if (is_redistribution) { // Redistribution cost diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index 9e9003848b..2cb24dd7f3 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -44,7 +44,7 @@ using RedistributionOpListPtr = std::shared_ptr& decision_ = nullptr) + Cost(double computation, double commuication, const std::shared_ptr &decision_ = nullptr) : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { memory_with_reuse_ = 0.0; communication_without_parameter_ = 0.0; @@ -76,8 +76,8 @@ class StrategyWithCost { StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} - StrategyWithCost(const StrategyWithCost& swc) = delete; - StrategyWithCost(StrategyWithCost&& swc) + StrategyWithCost(const StrategyWithCost &swc) = delete; + StrategyWithCost(StrategyWithCost &&swc) : strategy_ptr(swc.strategy_ptr), inputs_ptr(swc.inputs_ptr), outputs_ptr(swc.outputs_ptr), @@ -295,9 +295,9 @@ using StarEliminationDecisionPtr = std::shared_ptr; using FinalDecisionPtr = std::shared_ptr; using FinalSingleDecisionPtr = std::shared_ptr; -void Simplify(CostPtrList* clist); -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist); -void RefineForPracticalCost(const CostPtr&, bool is_redistribution); +void Simplify(CostPtrList *clist); +void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist); +void RefineForPracticalCost(const CostPtr &, bool is_redistribution); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc index dd21096fcc..8d439f1522 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -Status GetStrategy(const CostGraphPtr& graph) { +Status GetStrategy(const CostGraphPtr &graph) { MS_LOG(INFO) << "Searching strategies begins."; MS_EXCEPTION_IF_NULL(graph); std::vector eliminations; @@ -141,7 +141,7 @@ Status RecoverStrategy(std::vector eliminations) { auto elimination = (*rit)->cast(); auto new_edge = elimination->new_edge_; MS_EXCEPTION_IF_NULL(new_edge); - auto& edges = elimination->edges_; + auto &edges = elimination->edges_; auto decision = new_edge->selected_cost()->decision_ptr_->cast(); for (size_t j = 0; j < edges.size(); ++j) { MS_EXCEPTION_IF_NULL(edges[j]); diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h index 6d43218e19..efedba7d10 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h @@ -65,7 +65,7 @@ struct OpElimination : public Elimination { // Edge Elimination struct EdgeElimination : public Elimination { - EdgeElimination(const EdgePtr& n_edge, std::vector eds) + EdgeElimination(const EdgePtr &n_edge, std::vector eds) : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} std::vector edges_; @@ -139,7 +139,7 @@ using TriangleEliminationPtr = std::shared_ptr; using StarEliminationPtr = std::shared_ptr; // Phase 1 and Phase 2 -Status GetStrategy(const CostGraphPtr& graph); +Status GetStrategy(const CostGraphPtr &graph); // Phase 3 Status RecoverStrategy(std::vector eliminations); diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc index 21e67f9f7b..6973830779 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc @@ -28,19 +28,19 @@ namespace mindspore { namespace parallel { Status Edge::InitEdgeCost() { bool has_available_cost = false; - for (auto& swc : prev_op_->GetStrategyCost()) { + for (auto &swc : prev_op_->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(swc); pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); } - for (auto& swc : next_op_->GetStrategyCost()) { + for (auto &swc : next_op_->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(swc); next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr)); } if (is_identity_edge) { - for (auto& target_output : pre_op_output_) { + for (auto &target_output : pre_op_output_) { auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); auto target_output_str = target_output.first; - for (auto& target_input : next_op_input_) { + for (auto &target_input : next_op_input_) { auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); auto target_input_str = target_input.first; if (target_output_lyt == target_input_lyt) { @@ -57,12 +57,12 @@ Status Edge::InitEdgeCost() { } } } else { - for (auto& target_output : pre_op_output_) { + for (auto &target_output : pre_op_output_) { auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); auto target_output_str = target_output.first; auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_]; auto type = prev_op_->outputs_type()[prev_op_output_index_]; - for (auto& target_input : next_op_input_) { + for (auto &target_input : next_op_input_) { auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); auto target_input_str = target_input.first; CostPtr cost; @@ -99,8 +99,8 @@ Status Edge::InitEdgeCost() { return Status::SUCCESS; } -Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, const TensorLayout& next_op_input_layout, - size_t type_length, TypePtr type, CostPtr* cost) { +Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t type_length, TypePtr type, CostPtr *cost) { MS_EXCEPTION_IF_NULL(prev_op_); MS_EXCEPTION_IF_NULL(cost); RankList dev_list = prev_op_->global_device_list(); @@ -148,9 +148,9 @@ CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) { return result; } -CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr, const std::vector& edges, - const StrategyPtr& input_st_ptr) { - std::function LocalGetCostList = [&](const EdgePtr& edge) { +CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector &edges, + const StrategyPtr &input_st_ptr) { + std::function LocalGetCostList = [&](const EdgePtr &edge) { MS_EXCEPTION_IF_NULL(edge); return edge->GetCostList(output_st_ptr, input_st_ptr); }; @@ -174,7 +174,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr result.push_back(new_cost); return; } - for (auto& c : all_cost_list[k]) { + for (auto &c : all_cost_list[k]) { MS_EXCEPTION_IF_NULL(c); selected_cost_list[k] = c; recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, @@ -187,11 +187,11 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr return result; } -void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector& edges, OperatorInfoPtr) { +void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector &edges, OperatorInfoPtr) { bool valid = false; - for (const auto& output_pair : pre_op_output_) { + for (const auto &output_pair : pre_op_output_) { StrategyPtr output_st_ptr = output_pair.first; - for (const auto& input_pair : next_op_input_) { + for (const auto &input_pair : next_op_input_) { StrategyPtr input_st_ptr = input_pair.first; CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr); CostPtrKey key = {output_st_ptr, input_st_ptr}; @@ -206,14 +206,14 @@ void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector } } -void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& left_cost_list, - const CostPtrList& middle_cost_list, const CostPtrList& right_cost_list, - CostPtrList* ret_cost_list) { - for (auto& left_cost : left_cost_list) { +void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list) { + for (auto &left_cost : left_cost_list) { MS_EXCEPTION_IF_NULL(left_cost); - for (auto& middle_cost : middle_cost_list) { + for (auto &middle_cost : middle_cost_list) { MS_EXCEPTION_IF_NULL(middle_cost); - for (auto& right_cost : right_cost_list) { + for (auto &right_cost : right_cost_list) { MS_EXCEPTION_IF_NULL(right_cost); double computation = left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; @@ -238,14 +238,14 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr } } -CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr& e1, const StrategyPtr& output_st_ptr, - const OperatorInfoPtr& op, const EdgePtr& e2, - const StrategyPtr& input_st_ptr) { +CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr, + const OperatorInfoPtr &op, const EdgePtr &e2, + const StrategyPtr &input_st_ptr) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(e1); MS_EXCEPTION_IF_NULL(e2); CostPtrList result; - for (const auto& op_strategy : op->GetStrategyCost()) { + for (const auto &op_strategy : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_strategy); auto middle_strategy = op_strategy->strategy_ptr; CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), @@ -255,11 +255,11 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr& e1, const StrategyP return result; } -void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op, const EdgePtr& e2) { +void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) { bool valid = false; - for (const auto& output_pair : pre_op_output_) { + for (const auto &output_pair : pre_op_output_) { StrategyPtr output_st_ptr = output_pair.first; - for (const auto& input_pair : next_op_input_) { + for (const auto &input_pair : next_op_input_) { StrategyPtr input_st_ptr = input_pair.first; CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr); @@ -283,8 +283,8 @@ Status Edge::CalculateMemoryCost() { if (is_output_parameter_involve_ == 0) { // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is // unnecessary to keep them in memory. - for (auto& cost_kv : cost_map_) { - auto& cost_v = cost_kv.second; + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; if (!cost_v.empty()) { cost_v[0]->memory_with_reuse_ = 0; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h index f974125749..e760c24c34 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h @@ -37,9 +37,9 @@ using EdgePtr = std::shared_ptr; class Edge { // An 'Edge' connects two Operators in the CostGraph. public: - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const size_t& output_index_, const size_t& input_index_, - const bool& is_com) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -49,9 +49,9 @@ class Edge { is_identity_edge = false; } - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const size_t& output_index_, const size_t& input_index_, - const bool& is_com, const bool& is_iden) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com, const bool &is_iden) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -60,9 +60,9 @@ class Edge { is_combined_(is_com), is_identity_edge(is_iden) {} - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const std::vector& output_indexs_, - const std::vector& input_indexs_, const bool& is_com) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const std::vector &output_indexs_, + const std::vector &input_indexs_, const bool &is_com) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -83,13 +83,13 @@ class Edge { // For two operators u--->v, given the output tensor layout of u, // and the input tensor layout of v, return the redistribution cost, // and the op_list to carry out the redistribution. - Status GetRedistributionCost(const TensorLayout& prev_op_output_layout, const TensorLayout& next_op_input_layout, - size_t, TypePtr type, CostPtr* cost); + Status GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t, TypePtr type, CostPtr *cost); - void set_pre_op_output(const std::vector, std::vector>>& output_set) { + void set_pre_op_output(const std::vector, std::vector>> &output_set) { pre_op_output_ = output_set; } - void set_next_op_input(const std::vector, std::vector>>& input_set) { + void set_next_op_input(const std::vector, std::vector>> &input_set) { next_op_input_ = input_set; } @@ -109,27 +109,27 @@ class Edge { std::vector prev_op_output_indexs() const { return pre_op_output_indexs_; } std::vector next_op_input_indexs() const { return next_op_input_indexs_; } - CostPtrList CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr, - const std::vector>& edges, - const StrategyPtr& input_st_ptr); + CostPtrList CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, + const std::vector> &edges, + const StrategyPtr &input_st_ptr); // In the Edge Elimination operation in DP algorithm, 'edges' is replaced by a new edge. This method is used to // set cost for this new edge - void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector>& edges, + void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector> &edges, std::shared_ptr v); - void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& left_cost_list, - const CostPtrList& middle_cost_list, const CostPtrList& right_cost_list, - CostPtrList* ret_cost_list); + void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list); - CostPtrList CreateOpEliminationCostList(const std::shared_ptr& e1, const StrategyPtr& output_st_ptr, - const std::shared_ptr& op, const std::shared_ptr& e2, - const StrategyPtr& input_st_ptr); + CostPtrList CreateOpEliminationCostList(const std::shared_ptr &e1, const StrategyPtr &output_st_ptr, + const std::shared_ptr &op, const std::shared_ptr &e2, + const StrategyPtr &input_st_ptr); // In the Operation Elimination operation in DP algorithm, 'op', 'e1' and 'e2' are replaced by a new edge. // This method is used to set cost for this new edge - void OpEliminationSetNewCost(const std::shared_ptr& e1, const std::shared_ptr& op, - const std::shared_ptr& e2); + void OpEliminationSetNewCost(const std::shared_ptr &e1, const std::shared_ptr &op, + const std::shared_ptr &e2); - void set_selected_cost(const CostPtr& cost) { selected_cost_ = cost; } - const CostPtr& selected_cost() const { return selected_cost_; } + void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } + const CostPtr &selected_cost() const { return selected_cost_; } void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index c56d3a6fbd..501a983a95 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -144,7 +144,7 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { } } -void CostGraph::RemoveOperator(const OperatorInfoPtr& op) { +void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { for (auto it = ops_.begin(); it != ops_.end();) { if ((*it) == op) { it = ops_.erase(it); @@ -154,19 +154,19 @@ void CostGraph::RemoveOperator(const OperatorInfoPtr& op) { } } -bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr& op_test) { +bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { struct IsInGraph { const OperatorInfoPtr test_; - explicit IsInGraph(const OperatorInfoPtr& n) : test_(n) {} - bool operator()(const OperatorInfoPtr& in) const { return (test_ == in); } + explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {} + bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); } }; return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); } -bool CostGraph::IsEdgeInCostGraph(const std::string& test_edge_name, size_t output_index, size_t input_index) { - for (auto& edge_pair : edges_) { +bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { + for (auto &edge_pair : edges_) { auto edges = edge_pair.second; - for (auto& edge : edges) { + for (auto &edge : edges) { MS_EXCEPTION_IF_NULL(edge); bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) && (edge->next_op_input_index() == input_index); @@ -182,12 +182,12 @@ std::vector> CostGraph::ConstructConnectedComponents( std::vector alive_ops) { std::map visited; - for (auto& op : alive_ops) { + for (auto &op : alive_ops) { visited[op] = false; } MS_LOG(INFO) << "visited: " << visited.size() << "."; - for (auto& op : alive_ops) { + for (auto &op : alive_ops) { if ((!visited[op]) && op->is_alive()) { std::shared_ptr new_component = std::make_shared(); MS_EXCEPTION_IF_NULL(new_component); @@ -199,14 +199,14 @@ std::vector> CostGraph::ConstructConnectedComponents( return connected_compoents_; } -void CostGraph::DFS(const OperatorInfoPtr& current_op, std::map* visited, - const std::shared_ptr& component) { +void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component) { MS_EXCEPTION_IF_NULL(visited); MS_EXCEPTION_IF_NULL(component); visited->at(current_op) = true; component->AddOperator(current_op); - for (auto& edge : current_op->succ_edges()) { + for (auto &edge : current_op->succ_edges()) { bool bool_test = (visited->find(edge->next_operator()) != visited->end()) && (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive(); if (bool_test) { @@ -215,7 +215,7 @@ void CostGraph::DFS(const OperatorInfoPtr& current_op, std::mapprev_edges()) { + for (auto &edge : current_op->prev_edges()) { bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) && (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive(); if (bool_test) { @@ -226,14 +226,14 @@ void CostGraph::DFS(const OperatorInfoPtr& current_op, std::map v -CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::shared_ptr& e, - const OperatorInfoPtr& v) { +CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr &e, + const OperatorInfoPtr &v) { MS_EXCEPTION_IF_NULL(u); MS_EXCEPTION_IF_NULL(v); MS_EXCEPTION_IF_NULL(e); CostPtrList ret; - for (const auto& u_strategy : u->GetStrategyCost()) { - for (const auto& v_strategy : v->GetStrategyCost()) { + for (const auto &u_strategy : u->GetStrategyCost()) { + for (const auto &v_strategy : v->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(u_strategy); MS_EXCEPTION_IF_NULL(v_strategy); auto u_strategy_ptr = u_strategy->strategy_ptr; @@ -241,9 +241,9 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: CostPtrList clist1 = u_strategy->cost_list; CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr); CostPtrList clist3 = v_strategy->cost_list; - for (const auto& cost1 : clist1) { - for (const auto& cost2 : clist2) { - for (const auto& cost3 : clist3) { + for (const auto &cost1 : clist1) { + for (const auto &cost2 : clist2) { + for (const auto &cost3 : clist3) { MS_EXCEPTION_IF_NULL(cost1); MS_EXCEPTION_IF_NULL(cost2); MS_EXCEPTION_IF_NULL(cost3); @@ -274,14 +274,14 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: } // Create final cost list for the graph containing a signle node: u -CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { +CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { MS_EXCEPTION_IF_NULL(u); CostPtrList ret; - for (const auto& u_strategy : u->GetStrategyCost()) { + for (const auto &u_strategy : u->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(u_strategy); auto u_strategy_ptr = u_strategy->strategy_ptr; CostPtrList clist1 = u_strategy->cost_list; - for (const auto& cost1 : clist1) { + for (const auto &cost1 : clist1) { MS_EXCEPTION_IF_NULL(cost1); auto decision = std::make_shared(u_strategy_ptr, cost1); auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); @@ -299,16 +299,16 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { return ret; } -CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { +CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) { CostPtrList after_mem_filter; // Filter out the valid costs - for (auto& a_cost : cost_list) { + for (auto &a_cost : cost_list) { if (a_cost->memory_with_reuse_ <= memory) { after_mem_filter.emplace_back(std::move(a_cost)); } } - std::function LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { + std::function LocalCompare = [&](CostPtr init, const CostPtr &cost_x) { MS_EXCEPTION_IF_NULL(cost_x); if (init == nullptr || cost_x->computation_cost_ < memory) { init = cost_x; @@ -319,7 +319,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); } -CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { +CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { // Select the cost with minimum training time. Currently, the training time is modeled as = // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ if (cost_list.empty()) { @@ -329,7 +329,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d CostPtrList after_mem_filter; double minimum_memory = DBL_MAX; // Filter out the valid costs. - for (auto& a_cost : cost_list) { + for (auto &a_cost : cost_list) { if (a_cost->memory_with_reuse_ <= memory) { after_mem_filter.emplace_back(std::move(a_cost)); } else if (a_cost->memory_with_reuse_ < minimum_memory) { @@ -371,7 +371,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d return ret; } -CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector& all_cost_list, +CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_cost_list, double available_memory) { CostPtrList selected_cost_list(all_cost_list.size(), nullptr); double minimum = DBL_MAX, total_memory = 0.0; @@ -418,7 +418,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect } MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; - for (auto& c : all_cost_list[k]) { + for (auto &c : all_cost_list[k]) { selected_cost_list[k] = c; recursive(k + 1); } @@ -427,7 +427,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect return ret; } -Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector& alive_ops) { +Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph."; auto connected_components = ConstructConnectedComponents(alive_ops); MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; @@ -516,7 +516,7 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector alive_ops; - (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr& op) { + (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); if (op->is_alive()) { alive_ops.push_back(op); @@ -620,7 +620,7 @@ Status CostGraph::SearchStrategy() { // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated // return the v and the edge u --> v OperatorInfoPtr CostGraph::CheckOpElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1; if (bool_test) { if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) { @@ -633,21 +633,21 @@ OperatorInfoPtr CostGraph::CheckOpElimination() const { // Check the graph whether an EdgeElimination can be performed std::vector> CostGraph::CheckEdgeElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); if (!op->is_alive()) continue; - std::map count; - for (auto& edge : op->GetAliveSuccEdges()) { + std::map count; + for (auto &edge : op->GetAliveSuccEdges()) { MS_EXCEPTION_IF_NULL(edge); auto v = edge->next_operator(); count[v.get()]++; } - for (auto& pair : count) { - auto* op_ptr = pair.first; + for (auto &pair : count) { + auto *op_ptr = pair.first; int op_count = pair.second; if (op_count > 1) { std::vector> ret; - for (auto& edge : op->GetAliveSuccEdges()) { + for (auto &edge : op->GetAliveSuccEdges()) { MS_EXCEPTION_IF_NULL(edge); if (edge->next_operator().get() == op_ptr) { ret.push_back(edge); @@ -662,7 +662,7 @@ std::vector> CostGraph::CheckEdgeElimination() const { // Check the graph whether a MergeElimination can be performed OperatorInfoPtr CostGraph::CheckMergeElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1; if (bool_test) { @@ -678,7 +678,7 @@ OperatorInfoPtr CostGraph::CheckMergeElimination() const { // Check the graph whether a ContractElimination can be performed OperatorInfoPtr CostGraph::CheckContractElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty(); if (bool_test) { @@ -696,7 +696,7 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const { // Check the graph whether a TriangleElimination can be performed std::pair> CostGraph::CheckTriangleElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2); if (bool_test) { @@ -707,13 +707,13 @@ std::pair> CostGraph::CheckTriangleElimin auto first_op = edge1->next_operator(); auto second_op = edge2->next_operator(); MS_EXCEPTION_IF_NULL(first_op); - for (auto& first_op_succ_edge : first_op->GetAliveSuccEdges()) { + for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) { if (first_op_succ_edge->next_operator() == second_op) { return {op, first_op_succ_edge}; } } MS_EXCEPTION_IF_NULL(second_op); - for (auto& second_op_succ_edge : second_op->GetAliveSuccEdges()) { + for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) { if (second_op_succ_edge->next_operator() == first_op) { return {op, second_op_succ_edge}; } @@ -726,7 +726,7 @@ std::pair> CostGraph::CheckTriangleElimin // Check the graph whether a StarElimination can be performed. // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. OperatorInfoPtr CostGraph::CheckStarElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1); if (bool_test) { @@ -738,7 +738,7 @@ OperatorInfoPtr CostGraph::CheckStarElimination() const { // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge. -std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr& op) { +std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr &op) { // in this case, the operators are organised in the form of u-->op-->v, and the goal // is to eliminate 'op'. MS_EXCEPTION_IF_NULL(op); @@ -786,7 +786,7 @@ std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr& op) { // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges', // and sets new costlist for the new edge. -std::shared_ptr CostGraph::EliminationEdges(const std::vector>& edges) { +std::shared_ptr CostGraph::EliminationEdges(const std::vector> &edges) { MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges."; MS_EXCEPTION_IF_NULL(edges[0]); auto u = edges[0]->prev_operator(); @@ -796,7 +796,7 @@ std::shared_ptr CostGraph::EliminationEdges(const std::vectorname() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); std::vector output_indexs, input_indexs; - for (auto& edge : edges) { + for (auto &edge : edges) { MS_EXCEPTION_IF_NULL(edge); if (edge->is_combined()) { auto from_output_indexs = edge->prev_op_output_indexs(); @@ -824,18 +824,18 @@ std::shared_ptr CostGraph::EliminationEdges(const std::vectorcomputation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; @@ -862,7 +862,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the // target_op -OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { +OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); auto target_op = op->GetAliveSuccEdges()[0]->next_operator(); auto edge_ptr = op->GetAliveSuccEdges()[0]; @@ -871,13 +871,13 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << "."; bool valid = false; - for (auto& tar_stra_cost : target_op->GetStrategyCost()) { + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(tar_stra_cost); auto tar_stra = tar_stra_cost->strategy_ptr; auto tar_clist_origin = tar_stra_cost->cost_list; CostPtrList tar_clist_new; - for (auto& op_stra_cost : op->GetStrategyCost()) { + for (auto &op_stra_cost : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_stra_cost); auto op_stra = op_stra_cost->strategy_ptr; auto op_clist = op_stra_cost->cost_list; @@ -904,17 +904,17 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' // for this contract under the strategy 'contract_op_stra' void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra, - const CostPtrList& contract_op_cost_list, - const CostPtrList& edge_cost_list, StrategyPtr target_op_stra, - const CostPtrList& tar_cost_list, CostPtrList* tar_cost_list_new) { + const CostPtrList &contract_op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr target_op_stra, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) { for (size_t i = 0; i < contract_op_cost_list.size(); ++i) { - auto& contract_op_cost = contract_op_cost_list[i]; + auto &contract_op_cost = contract_op_cost_list[i]; MS_EXCEPTION_IF_NULL(contract_op_cost); for (size_t j = 0; j < edge_cost_list.size(); ++j) { - auto& edge_cost = edge_cost_list[j]; + auto &edge_cost = edge_cost_list[j]; MS_EXCEPTION_IF_NULL(edge_cost); for (size_t k = 0; k < tar_cost_list.size(); ++k) { - auto& tar_cost = tar_cost_list[k]; + auto &tar_cost = tar_cost_list[k]; MS_EXCEPTION_IF_NULL(tar_cost); double computation = contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; @@ -941,20 +941,20 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the // target_op -OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { +OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); auto target_op = op->GetAlivePrevEdges()[0]->prev_operator(); auto edge_ptr = op->GetAlivePrevEdges()[0]; MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << "."; bool valid = false; - for (auto& tar_stra_cost : target_op->GetStrategyCost()) { + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(tar_stra_cost); auto tar_stra = tar_stra_cost->strategy_ptr; auto tar_clist_origin = tar_stra_cost->cost_list; CostPtrList tar_clist_new; - for (auto& op_stra_cost : op->GetStrategyCost()) { + for (auto &op_stra_cost : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_stra_cost); auto op_stra = op_stra_cost->strategy_ptr; auto op_clist = op_stra_cost->cost_list; @@ -978,19 +978,19 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { } void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, - StrategyPtr right_op_stra, const CostPtr& right_op_cost, - const CostPtrList& elimi_op_clist, - const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, - const CostPtrList& left_node_clist_origin, - CostPtrList* left_node_clist_new) { + StrategyPtr right_op_stra, const CostPtr &right_op_cost, + const CostPtrList &elimi_op_clist, + const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { MS_EXCEPTION_IF_NULL(right_edge_cost); MS_EXCEPTION_IF_NULL(right_op_cost); MS_EXCEPTION_IF_NULL(left_node_clist_new); - for (auto& elimi_op_cost : elimi_op_clist) { + for (auto &elimi_op_cost : elimi_op_clist) { MS_EXCEPTION_IF_NULL(elimi_op_cost); - for (auto& left_edge_cost : left_edge_clist) { + for (auto &left_edge_cost : left_edge_clist) { MS_EXCEPTION_IF_NULL(left_edge_cost); - for (auto& left_node_cost : left_node_clist_origin) { + for (auto &left_node_cost : left_node_clist_origin) { MS_EXCEPTION_IF_NULL(left_node_cost); double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; @@ -1015,16 +1015,16 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, } } -void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist, - const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra, - const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra, - const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, - const CostPtrList& left_node_clist_origin, - CostPtrList* left_node_clist_new) { +void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist, + const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra, + const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra, + const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { MS_EXCEPTION_IF_NULL(elimi_op); - for (auto& right_node_cost : right_node_clist) { + for (auto &right_node_cost : right_node_clist) { MS_EXCEPTION_IF_NULL(right_node_cost); - for (auto& right_edge_cost : right_edge_clist) { + for (auto &right_edge_cost : right_edge_clist) { MS_EXCEPTION_IF_NULL(right_edge_cost); CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, @@ -1033,8 +1033,8 @@ void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_o } } -OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, - const std::shared_ptr& edge_left_right) { +OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, + const std::shared_ptr &edge_left_right) { MS_EXCEPTION_IF_NULL(edge_left_right); MS_EXCEPTION_IF_NULL(elimi_op); MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << "."; @@ -1056,19 +1056,19 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, } bool valid = false; - for (auto& left_node_stra_cost : left_node->GetStrategyCost()) { + for (auto &left_node_stra_cost : left_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(left_node_stra_cost); auto left_node_stra = left_node_stra_cost->strategy_ptr; auto left_node_clist_origin = left_node_stra_cost->cost_list; CostPtrList left_node_clist_new; - for (auto& elimi_op_stra_cost : elimi_op->GetStrategyCost()) { + for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(elimi_op_stra_cost); auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr; auto elimi_op_clist = elimi_op_stra_cost->cost_list; auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra); - for (auto& right_node_stra_cost : right_node->GetStrategyCost()) { + for (auto &right_node_stra_cost : right_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(right_node_stra_cost); auto right_node_stra = right_node_stra_cost->strategy_ptr; auto right_node_clist = right_node_stra_cost->cost_list; @@ -1095,16 +1095,16 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, return left_node; } -void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_node_stra, - const CostPtrList& first_succ_node_clist, - const CostPtrList& first_succ_edge_clist, - const StrategyPtr& merged_op_stra, const CostPtrList& merged_op_clist, +void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, std::vector succ_nodes_stras, - CostPtrList& succ_edges_costs, CostPtrList& succ_nodes_costs, - CostPtrList* first_succ_node_clist_new) { - for (auto& first_succ_node_cost : first_succ_node_clist) { - for (auto& first_succ_edge_cost : first_succ_edge_clist) { - for (auto& merged_node_cost : merged_op_clist) { + CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs, + CostPtrList *first_succ_node_clist_new) { + for (auto &first_succ_node_cost : first_succ_node_clist) { + for (auto &first_succ_edge_cost : first_succ_edge_clist) { + for (auto &merged_node_cost : merged_op_clist) { MS_EXCEPTION_IF_NULL(merged_node_cost); succ_nodes_stras[0] = first_succ_node_stra; succ_edges_costs[0] = first_succ_edge_cost; @@ -1141,12 +1141,12 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n } } -void CostGraph::CreateStarEliminationCostList(std::vector>& succ_edges, - const StrategyPtr& first_succ_node_stra, - const CostPtrList& first_succ_node_clist, - const CostPtrList& first_succ_edge_clist, - const StrategyPtr& merged_op_stra, const CostPtrList& merged_op_clist, - CostPtrList* first_succ_node_clist_new) { +void CostGraph::CreateStarEliminationCostList(std::vector> &succ_edges, + const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, + CostPtrList *first_succ_node_clist_new) { std::vector succ_nodes_stras(succ_edges.size(), nullptr); CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr); std::function recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist, @@ -1167,15 +1167,15 @@ void CostGraph::CreateStarEliminationCostList(std::vector> MS_EXCEPTION_IF_NULL(succ_edge); auto succ_node = succ_edge->next_operator(); MS_EXCEPTION_IF_NULL(succ_node); - for (auto& succ_node_stra_cost : succ_node->GetStrategyCost()) { + for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(succ_node_stra_cost); auto succ_node_stra = succ_node_stra_cost->strategy_ptr; auto succ_node_clist = succ_node_stra_cost->cost_list; auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra); - for (auto& succ_node_cost : succ_node_clist) { + for (auto &succ_node_cost : succ_node_clist) { MS_EXCEPTION_IF_NULL(succ_node_cost); - for (auto& succ_edge_cost : succ_edge_clist) { + for (auto &succ_edge_cost : succ_edge_clist) { MS_EXCEPTION_IF_NULL(succ_edge_cost); succ_nodes_stras[k] = succ_node_stra; succ_edges_costs[k] = succ_edge_cost; @@ -1189,11 +1189,11 @@ void CostGraph::CreateStarEliminationCostList(std::vector> recursive(1); } -std::vector> CostGraph::EliminationStar(const OperatorInfoPtr& merged_op) { +std::vector> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) { MS_EXCEPTION_IF_NULL(merged_op); auto succ_edges = merged_op->GetAliveSuccEdges(); MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << "."; - for (auto& succ_edge : succ_edges) { + for (auto &succ_edge : succ_edges) { MS_EXCEPTION_IF_NULL(succ_edge->next_operator()); MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << "."; } @@ -1205,13 +1205,13 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo // 'merged_op' is merged into first_node MS_EXCEPTION_IF_NULL(first_succ_node); - for (auto& first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { + for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost); auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr; auto first_succ_node_clist = first_succ_node_stra_cost->cost_list; CostPtrList first_succ_node_clist_new; - for (auto& merged_op_stra_cost : merged_op->GetStrategyCost()) { + for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(merged_op_stra_cost); auto merged_op_stra = merged_op_stra_cost->strategy_ptr; auto merged_op_clist = merged_op_stra_cost->cost_list; @@ -1238,7 +1238,7 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo } Status CostGraph::InitSelectedStrategy() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); auto result = op->InitSelectedStrategy(op->selected_strategy()); if (result != SUCCESS) { @@ -1249,9 +1249,9 @@ Status CostGraph::InitSelectedStrategy() { } Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); - const auto& output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); + const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); if ((output_parameter != 0) && (output_parameter != 1)) { MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed."; return FAILED; @@ -1261,7 +1261,7 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { } Status CostGraph::CalculateOpsMemoryCost() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); if (op->CalculateMemoryCost() != SUCCESS) { MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; @@ -1272,9 +1272,9 @@ Status CostGraph::CalculateOpsMemoryCost() { } Status CostGraph::CalculateEdgesMemoryCost() { - for (auto& edge_pair : edges_) { - const auto& edges = edge_pair.second; - for (auto& one_edge : edges) { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { if (one_edge->CalculateMemoryCost() != SUCCESS) { MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; return FAILED; @@ -1284,7 +1284,7 @@ Status CostGraph::CalculateEdgesMemoryCost() { return SUCCESS; } -OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) const { +OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { for (auto one_op : ops_) { if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { if (one_op->refkey_parameter_name() == p_name) { @@ -1295,7 +1295,7 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c return nullptr; } Status CostGraph::CorrectOpsMemoryCost() { - for (auto& one_op : ops_) { + for (auto &one_op : ops_) { if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { if (one_op->GetAliveSuccEdges().size() > 1) { // Filter out the case when the TmpIdentity being used by multiple operators diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index b6591c0741..530f67ba45 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -34,7 +34,7 @@ namespace parallel { #define OPERATOR_TO_OPERATOR_CONNECTOR "-" #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) #define DEFAULT_COST_MODEL_ALPHA 1.0 -#define DEFAULT_COST_MODEL_BETA 260.0 +#define DEFAULT_COST_MODEL_BETA 400.0 #define DEFAULT_COST_MODEL_GAMMA 0.001 #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 @@ -70,7 +70,7 @@ class CostGraph { costmodel_beta_ = DEFAULT_COST_MODEL_BETA; } ~CostGraph() = default; - void AddOperator(const OperatorInfoPtr& op) { ops_.push_back(op); } + void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } OperatorInfoPtr FindOperatorByIndex(size_t index) { if (index >= ops_.size()) { MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; @@ -78,29 +78,29 @@ class CostGraph { } return ops_[index]; } - void RemoveOperator(const OperatorInfoPtr& op); - bool IsOperatorInCostGraph(const OperatorInfoPtr& op); + void RemoveOperator(const OperatorInfoPtr &op); + bool IsOperatorInCostGraph(const OperatorInfoPtr &op); // the edge is in the form: u --> v - void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr& edge) { + void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { std::vector curr_edges(edges_[{u_node, v_node}]); curr_edges.push_back(edge); edges_[{u_node, v_node}] = curr_edges; } // An edge is uniquely identified by its name, and its output index and input index. - bool IsEdgeInCostGraph(const std::string&, size_t, size_t); + bool IsEdgeInCostGraph(const std::string &, size_t, size_t); void SetDeviceMemoryAndCostParameter(); std::vector> ConstructConnectedComponents(std::vector); - void DFS(const OperatorInfoPtr& current_op, std::map* visited, - const std::shared_ptr& component); + void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component); - CostPtrList CreateFinalCostList(const OperatorInfoPtr& u, const EdgePtr& e, const OperatorInfoPtr& v); - CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr& u); - CostPtr SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory); - CostPtr SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory); - CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector& all_costlist, double memory); - Status SearchStrategyForMultiNodeFinalGraph(const std::vector&); + CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); + CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); + CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory); + CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); + CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); + Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { return edges_[{u_node, v_node}]; } @@ -145,36 +145,36 @@ class CostGraph { */ OperatorInfoPtr CheckStarElimination() const; // Applying Operator Elimination in DP algorithm - EdgePtr EliminationOp(const OperatorInfoPtr& op); + EdgePtr EliminationOp(const OperatorInfoPtr &op); // Applying Edge Elimination in DP algorithm - EdgePtr EliminationEdges(const std::vector& edges); + EdgePtr EliminationEdges(const std::vector &edges); // Applying Merge Elimination in DP algorithm - OperatorInfoPtr EliminationMerge(const OperatorInfoPtr& op); - void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& op_cost_list, - const CostPtrList& edge_cost_list, StrategyPtr tar_op_strategy, - const CostPtrList& tar_cost_list, CostPtrList* tar_cost_list_new); + OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); + void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); // Applying Contract Elimination in DP algorithm - OperatorInfoPtr EliminationContract(const OperatorInfoPtr& op); - void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList&, const CostPtrList&, StrategyPtr, - const CostPtrList&, CostPtrList*); + OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); + void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, + const CostPtrList &, CostPtrList *); // Applying Triangle Elimination in DP algorithm. return the left_node - OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr& elimi_op, const EdgePtr& edge_left_right); - void CreateTriangleEliminationCostList(const OperatorInfoPtr&, const CostPtrList&, const CostPtrList&, - const StrategyPtr&, const StrategyPtr&, const StrategyPtr&, const CostPtrList&, - const CostPtrList&, const CostPtrList&, CostPtrList*); + OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); + void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, + const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); // Given the relevant costlist, create the TriangleElimination cost - void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, const CostPtrList&, - const CostPtrList&, const CostPtr&, const CostPtrList&, CostPtrList*); + void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, + const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - std::vector EliminationStar(const OperatorInfoPtr& op); - void CreateStarEliminationCostList(std::vector&, const StrategyPtr&, const CostPtrList&, const CostPtrList&, - const StrategyPtr&, const CostPtrList&, CostPtrList*); - void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, - const CostPtrList&, std::vector, CostPtrList&, CostPtrList&, - CostPtrList*); + std::vector EliminationStar(const OperatorInfoPtr &op); + void CreateStarEliminationCostList(std::vector &, const StrategyPtr &, const CostPtrList &, + const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); + void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const CostPtrList &, std::vector, + CostPtrList &, CostPtrList &, CostPtrList *); // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then // the memory cost can be resused. Status CalculateOpsMemoryCost(); @@ -186,16 +186,16 @@ class CostGraph { std::vector GetOperators() const { return ops_; } size_t GetNumPairs() const { return edges_.size(); } Status InitSelectedStrategy(); - OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const; + OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only // once (instead of multiple times), this method is used to correct this. Status CorrectOpsMemoryCost(); // Needed by rec_parser - void add_inputs_tensor_name(const std::vector& inputs_tensor_name) { + void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { inputs_tensor_name_list_.push_back(inputs_tensor_name); } const std::vector> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } - void add_tuple_getitem(const std::pair& tuple_getitem) { + void add_tuple_getitem(const std::pair &tuple_getitem) { auto ret = tuple_getitem_list_.insert(tuple_getitem); if (ret.second == false) { MS_LOG(EXCEPTION) << "The insert item is already exist."; diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 0192dce8b8..8ad8b46f32 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -23,22 +23,22 @@ namespace mindspore { namespace parallel { -void OperatorCost::set_is_parameter(const std::vector& is_parameter) { is_parameter_ = is_parameter; } +void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_parameter_ = is_parameter; } -void OperatorCost::set_is_parameter_involve(const std::vector& is_parameter_inv) { +void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { is_parameter_involve_ = is_parameter_inv; } void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } -void OperatorCost::SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths) { +void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { inputs_type_lengths_ = input_lengths; outputs_type_lengths_ = output_lengths; } -double OperatorCost::GetMemoryCost(const std::vector& inputs, - const std::vector& outputs) const { +double OperatorCost::GetMemoryCost(const std::vector &inputs, + const std::vector &outputs) const { double result = 0.0; if (output_parameter_involve_ == 1) { // When this operator has multiple outputs, they all contributes to the memory. @@ -64,7 +64,7 @@ double OperatorCost::GetMemoryCost(const std::vector& inputs, } // return the per device communication cost in the forward phase. -double MatMulCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, +double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t) const { TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -80,7 +80,7 @@ double MatMulCost::GetForwardCommCost(const std::vector& inputs, con } // return the per device communication cost in the forward phase. -double MatMulCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double MatMulCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not // fully utilize all devices @@ -107,8 +107,8 @@ double MatMulCost::GetBackwardCommCost(const std::vector& inputs, co // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t) const { +double MatMulCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) double result = 0.0; TensorInfo output0 = outputs[0]; @@ -126,7 +126,7 @@ double MatMulCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, +double MatMulCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; @@ -151,14 +151,14 @@ double MatMulCost::GetBackwardComputationCost(const std::vector& inp } // Return the per device communication cost in the forward phase. -double ActivationCost::GetForwardCommCost(const std::vector&, const std::vector&, +double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // ReLU is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // Return the per device communication cost in the backward phase. -double ActivationCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -180,7 +180,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector& inputs // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { TensorInfo input0_info = inputs[0]; Shape input0_slice_shape = input0_info.slice_shape(); @@ -189,19 +189,20 @@ double ActivationCost::GetForwardComputationCost(const std::vector& // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } // Return the per device communication cost in the forward phase. -double SoftmaxCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // In the forward phase, the communication cost = 0 return 0.0; } // Return the per device communication cost in the backward phase. -double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -223,7 +224,7 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In the forward phase, the computation cost = slice(A) TensorInfo input0 = inputs[0]; @@ -233,46 +234,47 @@ double SoftmaxCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double TmpIdentityCost::GetForwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double TmpIdentityCost::GetBackwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double TmpIdentityCost::GetForwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetForwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double TmpIdentityCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, + int32_t) const { return 0.0; } // Return the per device PEAK memory cost contributed by this operator in a training iteration. -double TmpIdentityCost::GetMemoryCost(const std::vector&, const std::vector&) const { +double TmpIdentityCost::GetMemoryCost(const std::vector &, const std::vector &) const { return 0.0; } -double BatchParallelCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector&, +double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { double cost = 0.0; for (size_t i = 0; i < inputs.size(); ++i) { @@ -281,13 +283,13 @@ double BatchParallelCost::GetForwardComputationCost(const std::vector&, - const std::vector&, +double BatchParallelCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } -double BatchParallelCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double BatchParallelCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -313,13 +315,13 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector& inp return result; } // return the per device communication cost in the forward phase. -double PReLUCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // prelu does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double PReLUCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double PReLUCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[1]) { @@ -342,7 +344,7 @@ double PReLUCost::GetBackwardCommCost(const std::vector& inputs, con // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double PReLUCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -354,8 +356,8 @@ double PReLUCost::GetForwardComputationCost(const std::vector& input // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetBackwardComputationCost(const std::vector& inputs, - const std::vector&, +double PReLUCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; @@ -380,20 +382,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector&, const std::vector&, int32_t) const { +double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // onehot does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double OneHotCost::GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double OneHotCost::GetBackwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // onehot does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double OneHotCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In onehot's forward phase, the computation cost = slice(A) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -402,29 +405,29 @@ double OneHotCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double OneHotCost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -435,13 +438,13 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double ReshapeCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, +double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -457,7 +460,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector& inputs, co } // return the per device communication cost in the backward phase. -double ReshapeCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ReshapeCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -479,8 +482,8 @@ double ReshapeCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReshapeCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); @@ -496,12 +499,12 @@ double ReshapeCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double ReshapeCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } -double ArithmeticCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double ArithmeticCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { double result; result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + @@ -509,8 +512,8 @@ double ArithmeticCost::GetForwardComputationCost(const std::vector& return result; } -double ArithmeticCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, - int32_t stage_id) const { +double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -544,7 +547,7 @@ double ArithmeticCost::GetBackwardComputationCost(const std::vector& return result; } -double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -580,7 +583,7 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs return result; } -bool IsDataParallel(const Shape& shape, const Shape& slice_shape, int32_t stage_id) { +bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int32_t stage_id) { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); @@ -589,8 +592,8 @@ bool IsDataParallel(const Shape& shape, const Shape& slice_shape, int32_t stage_ return (total_device_num == IntToSize(strategy0)); } -double ReduceMethodCost::GetForwardCommCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -611,7 +614,7 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector& input return result; } -double ReduceMethodCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -634,8 +637,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector& inpu return result; } -double ReduceMethodCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -656,8 +659,8 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector return result; } -double ReduceMeanCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -678,7 +681,7 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector& return result; } -double DropOutCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { if (inputs.empty()) { return 0.0; @@ -689,13 +692,14 @@ double DropOutCost::GetForwardComputationCost(const std::vector& inp } // return the per device communication cost in the forward phase. -double GatherV2Cost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double GatherV2Cost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // GatherV2Cost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double GatherV2Cost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double GatherV2Cost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -721,7 +725,7 @@ double GatherV2Cost::GetBackwardCommCost(const std::vector& inputs, return result; } -double GatherV2Cost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double GatherV2Cost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -731,12 +735,12 @@ double GatherV2Cost::GetForwardComputationCost(const std::vector& in return result; } -double GatherV2Cost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double GatherV2Cost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } -double LayerNormCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_.size() != inputs.size()) { @@ -769,7 +773,7 @@ double LayerNormCost::GetBackwardCommCost(const std::vector& inputs, return result; } -double LayerNormCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double LayerNormCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { double result = 0.0; if (inputs_type_lengths_.size() != inputs.size()) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 37b054aa98..a243f8adfa 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -63,31 +63,31 @@ class OperatorCost { } virtual ~OperatorCost() = default; - void set_is_parameter(const std::vector& is_parameter); - void set_is_parameter_involve(const std::vector&); + void set_is_parameter(const std::vector &is_parameter); + void set_is_parameter_involve(const std::vector &); void set_output_parameter_involve(int); - void SetInputAndOutputTypeLength(const std::vector& input_lengths, const std::vector& output_lengths); + void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); std::vector inputs_type_lengths() const { return inputs_type_lengths_; } std::vector outputs_type_lengths() const { return outputs_type_lengths_; } // per device communication cost - virtual double GetCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; // per device computation cost - virtual double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const = 0; - virtual double GetBackwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const = 0; + virtual double GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; + virtual double GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; // per device PEAK memory cost in a training iteration // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), // plus necessary inputs. - virtual double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const; + virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; protected: // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of @@ -113,23 +113,23 @@ class MatMulCost : public OperatorCost { ~MatMulCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using MatMulCostPtr = std::shared_ptr; @@ -140,21 +140,21 @@ class ActivationCost : public OperatorCost { ActivationCost() : OperatorCost(false) {} ~ActivationCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ActivationCostPtr = std::shared_ptr; @@ -167,21 +167,21 @@ class SoftmaxCost : public OperatorCost { SoftmaxCost() : OperatorCost(false) {} ~SoftmaxCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t) const override; }; using SoftmaxCostPtr = std::shared_ptr; @@ -192,24 +192,24 @@ class TmpIdentityCost : public OperatorCost { TmpIdentityCost() : OperatorCost(false) {} ~TmpIdentityCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override; + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; }; using TmpIdentityCostPtr = std::shared_ptr; @@ -219,21 +219,21 @@ class BatchParallelCost : public OperatorCost { BatchParallelCost() : OperatorCost(false) {} ~BatchParallelCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using BatchParallelCostPtr = std::shared_ptr; @@ -244,30 +244,30 @@ class VirtualDatasetCost : public OperatorCost { VirtualDatasetCost() : OperatorCost(false) {} ~VirtualDatasetCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override { + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { return 0.0; } }; @@ -279,27 +279,27 @@ class GeneratorBaseCost : public OperatorCost { GeneratorBaseCost() : OperatorCost(false) {} ~GeneratorBaseCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -313,23 +313,23 @@ class PReLUCost : public OperatorCost { ~PReLUCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using PReLUCostPtr = std::shared_ptr; @@ -341,23 +341,23 @@ class OneHotCost : public OperatorCost { ~OneHotCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using OneHotCostPtr = std::shared_ptr; @@ -369,23 +369,23 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { ~SoftmaxCrossEntropyWithLogitsCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; @@ -398,27 +398,27 @@ class ReshapeCost : public OperatorCost { ~ReshapeCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ReshapeCostPtr = std::shared_ptr; @@ -429,22 +429,22 @@ class ArithmeticCost : public OperatorCost { ArithmeticCost() : OperatorCost(false) {} ~ArithmeticCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ArithmeticCostPtr = std::shared_ptr; @@ -457,21 +457,21 @@ class ReduceMethodCost : public OperatorCost { ReduceMethodCost() : OperatorCost(true) {} ~ReduceMethodCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -488,7 +488,7 @@ class ReduceMeanCost : public ReduceMethodCost { ReduceMeanCost() : ReduceMethodCost(true) {} ~ReduceMeanCost() override = default; - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ReduceMeanCostPtr = std::shared_ptr; @@ -499,27 +499,27 @@ class GetNextCost : public OperatorCost { GetNextCost() : OperatorCost(false) {} ~GetNextCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -532,23 +532,23 @@ class DropOutCost : public OperatorCost { DropOutCost() : OperatorCost(true) {} ~DropOutCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -562,21 +562,21 @@ class LayerNormCost : public OperatorCost { LayerNormCost() : OperatorCost(true) {} ~LayerNormCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -590,21 +590,21 @@ class GatherV2Cost : public OperatorCost { GatherV2Cost() : OperatorCost(true) {} ~GatherV2Cost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t) const override; }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 60f3003a42..b2c34127a1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -27,44 +27,27 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(const std::shared_ptr graph, std::vector> ops, - const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list) { - MaskNoSupportedOps(graph); +void GenerateStrategy(std::shared_ptr graph, bool mask_special_ops, + const std::vector> &ops) { + MS_EXCEPTION_IF_NULL(graph); + if (mask_special_ops) { + MaskSpecialOps(graph); + } for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { - auto type = ops[iter_ops]->type(); - size_t iter_nodes = index_list->at(ops_nodes_list->at(iter_ops)); std::vector> stra; - iter_nodes = IterNodes(ops_nodes_list, index_list, eli_list, iter_ops, iter_nodes); for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - std::vector s = PrepareStrategy(graph, ops, type, iter_ops, iter_nodes, iter_op_inputs); - stra.push_back(s); + stra.push_back(PrepareStrategy(graph, ops, iter_ops, iter_op_inputs)); } StrategyPtr sp = std::make_shared(0, stra); ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); } } -size_t IterNodes(const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list, const size_t iter_ops, - size_t iter_nodes) { - if (iter_nodes > SIZE_MAX / 2) { - for (size_t iter_eli = 0; iter_eli < eli_list->size(); iter_eli++) { - if (eli_list->at(iter_eli)[0] == ops_nodes_list->at(iter_ops)) { - iter_nodes = index_list->at(eli_list->at(iter_eli)[1]); - break; - } - } - } - return iter_nodes; -} - -void PrepareMatMul(const std::shared_ptr graph, const std::vector> &ops, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s) { - auto attrs = ops[iter_ops]->attrs(); +std::vector PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_nodes, + const size_t iter_op_inputs) { + std::vector s; + auto attrs = ops[iter_nodes]->attrs(); bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); if (transpose_a && (iter_op_inputs == 0)) { @@ -77,10 +60,12 @@ void PrepareMatMul(const std::shared_ptr graph, const std::vector(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); } + return s; } -void PrepareConv2D(const std::shared_ptr graph, const size_t iter_nodes, size_t iter_op_inputs, - std::vector s) { +std::vector PrepareConv2D(const std::shared_ptr &graph, const size_t iter_nodes, + size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c)); @@ -92,20 +77,24 @@ void PrepareConv2D(const std::shared_ptr graph, const size_t iter_nodes, s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_h)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w)); } + return s; } -void PrepareBiasAdd(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s) { +std::vector PrepareBiasAdd(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); } else { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); } + return s; } -void PrepareBN(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s) { +std::vector PrepareBN(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c)); @@ -114,97 +103,133 @@ void PrepareBN(const std::shared_ptr graph, const size_t iter_nodes, cons } else { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w)); } + return s; } -void PrepareSparse(const size_t iter_op_inputs, std::vector s) { +std::vector PrepareSparse(const size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(g_device_manager->DeviceNum()); s.push_back(1); } else { s.push_back(g_device_manager->DeviceNum()); } + return s; +} + +std::vector MakeOriginalStrategy(const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs) { + std::vector s; + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size()) + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size(); + for (size_t dim = 0; dim < input_size; dim++) { + s.push_back(1); + } + return s; } -void RefillOrigin(const std::vector> &ops, const size_t iter_ops, - const size_t iter_op_inputs, std::vector s) { +std::vector MakeRecSearchStrategy(const std::shared_ptr &graph, const size_t iter_ops, + const size_t iter_op_inputs) { + std::vector s; + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n)); + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c)); + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); + return s; +} + +std::vector MakeDataParallelStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t iter_op_inputs) { + std::vector s; + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } StrategyPtr origin_strategy = ops[iter_ops]->strategy(); - if (iter_op_inputs == 0) { - for (size_t j = 0; j < origin_strategy->GetInputDim()[0].size(); j++) { - s.push_back(1); - } - } else { - for (size_t k = 0; k < origin_strategy->GetInputDim()[iter_op_inputs].size(); k++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + for (size_t dim = 0; dim < input_size; dim++) { + if (dim == 0 && input_size == 4) { + size_t max_device_num = g_device_manager->DeviceNum(); + size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; + s.push_back(std::min(max_device_num, target_tensor_batch)); + } else { s.push_back(1); } } + return s; } -std::vector PrepareStrategy(const std::shared_ptr graph, - const std::vector> &ops, const std::string &type, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs) { - std::vector s; +std::vector PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + auto type = ops[iter_ops]->type(); if (type == MATMUL) { - PrepareMatMul(graph, ops, iter_ops, iter_nodes, iter_op_inputs, s); + return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs); } else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) { - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_n)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_c)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); + return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs); } else if (type == CONV2D) { - PrepareConv2D(graph, iter_nodes, iter_op_inputs, s); + return PrepareConv2D(graph, iter_ops, iter_op_inputs); } else if (type == BIAS_ADD) { - PrepareBiasAdd(graph, iter_nodes, iter_op_inputs, s); + return PrepareBiasAdd(graph, iter_ops, iter_op_inputs); } else if (type == RESHAPE) { - s.push_back(1); - s.push_back(1); - s.push_back(1); - s.push_back(1); + return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs); } else if (type == RELU) { - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_n)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_c)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_h)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_w)); + return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs); } else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) { - PrepareBN(graph, iter_nodes, iter_op_inputs, s); + return PrepareBN(graph, iter_ops, iter_op_inputs); } else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { - PrepareSparse(iter_op_inputs, s); + return PrepareSparse(iter_op_inputs); } else { - RefillOrigin(ops, iter_ops, iter_op_inputs, s); + return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); } - return s; } -void MaskNoSupportedOps(const std::shared_ptr graph) { +void MaskSpecialOps(std::shared_ptr graph) { size_t iter_nodes = graph->nodes.size(); for (size_t i = 0; i < iter_nodes; i++) { - if (0 == graph->nodes[i].info) { - Graph::NodeType &node = graph->nodes[i]; + Graph::NodeType &node = graph->nodes[i]; - if (node.apply.op_type == 1) { // For Convolution - // cover input tensor strategy - node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); - node.apply.arguments[0].tensor_str.str_c = 1; - node.apply.arguments[0].tensor_str.str_h = 1; - node.apply.arguments[0].tensor_str.str_w = 1; - // cover filter tensor strategy - node.apply.arguments[1].tensor_str.str_n = 1; - node.apply.arguments[1].tensor_str.str_c = 1; - node.apply.arguments[1].tensor_str.str_h = 1; - node.apply.arguments[1].tensor_str.str_w = 1; - } else if (node.apply.op_type == 8) { // For BN - node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); - node.apply.arguments[0].tensor_str.str_c = 1; - node.apply.arguments[0].tensor_str.str_h = 1; - node.apply.arguments[0].tensor_str.str_w = 1; - // cover 1-d argument blobs - node.apply.arguments[1].tensor_str.str_w = 1; - node.apply.arguments[2].tensor_str.str_w = 1; - node.apply.arguments[3].tensor_str.str_w = 1; - node.apply.arguments[4].tensor_str.str_w = 1; - } else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits - node.tensor_parm.tensor_str.str_h = 1.0 / static_cast(g_device_manager->DeviceNum()); - node.tensor_parm.tensor_str.str_w = 1; - } + if (node.apply.op_type == 1) { // For Convolution + // cover input tensor strategy + node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); + node.apply.arguments[0].tensor_str.str_c = 1; + node.apply.arguments[0].tensor_str.str_h = 1; + node.apply.arguments[0].tensor_str.str_w = 1; + // cover filter tensor strategy + node.apply.arguments[1].tensor_str.str_n = 1; + node.apply.arguments[1].tensor_str.str_c = 1; + node.apply.arguments[1].tensor_str.str_h = 1; + node.apply.arguments[1].tensor_str.str_w = 1; + } else if (node.apply.op_type == 8) { // For BN + node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); + node.apply.arguments[0].tensor_str.str_c = 1; + node.apply.arguments[0].tensor_str.str_h = 1; + node.apply.arguments[0].tensor_str.str_w = 1; + // cover 1-d argument blobs + node.apply.arguments[1].tensor_str.str_n = 1; + node.apply.arguments[2].tensor_str.str_c = 1; + node.apply.arguments[3].tensor_str.str_h = 1; + node.apply.arguments[4].tensor_str.str_w = 1; + } else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits + node.tensor_parm.tensor_str.str_h = 1.0 / static_cast(g_device_manager->DeviceNum()); + node.tensor_parm.tensor_str.str_w = 1; } } } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 4abef843a8..f3274e1440 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -27,29 +27,28 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(const std::shared_ptr graph, std::vector> ops, - const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list); -void PrepareMatMul(const std::shared_ptr graph, const std::vector> &ops, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs, std::vector s); -void PrepareConv2D(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s); -void PrepareBiasAdd(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s); -void PrepareBN(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s); -void PrepareSparse(const size_t iter_op_inputs, std::vector s); -void RefillOrigin(const std::vector> &ops, const size_t iter_ops, - const size_t iter_op_inputs, std::vector s); -std::vector PrepareStrategy(const std::shared_ptr graph, - const std::vector> &ops, const std::string &type, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs); -size_t IterNodes(const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list, const size_t iter_ops, - size_t iter_nodes); -void MaskNoSupportedOps(const std::shared_ptr graph); +void GenerateStrategy(std::shared_ptr graph, bool mask_special_ops, + const std::vector> &ops); +std::vector PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareConv2D(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareBiasAdd(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareBN(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareSparse(const size_t iter_op_inputs); +std::vector MakeOriginalStrategy(const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs); +std::vector MakeRecSearchStrategy(const std::shared_ptr &graph, const size_t iter_ops, + const size_t iter_op_inputs); +std::vector MakeDataParallelStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t iter_op_inputs); +std::vector PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs); +void MaskSpecialOps(std::shared_ptr graph); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 3ff3473298..6b438cb670 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -35,308 +35,153 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { new_tensor.tensor_shape.shape_c = c; new_tensor.tensor_shape.shape_h = h; new_tensor.tensor_shape.shape_w = w; - const TensorParam& tensor = new_tensor; + const TensorParam &tensor = new_tensor; return tensor; } -bool IsInList(const std::string& name, const std::vector& list) { - return std::find(list.begin(), list.end(), name) != list.end(); -} - Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops) { Graph::NodeType NewOp; - NewOp.name = ops[iter_ops]->cnode_name(); + NewOp.name = ops[iter_ops]->name(); NewOp.info = InfoType::kApplication; auto op_type = ops[iter_ops]->type(); auto idx = DictOpType.find(op_type); if (idx == DictOpType.end()) { NewOp.apply.op_type = OperatorType::kRecUnkownType; - MS_LOG(INFO) << "Unknown type in rec_parse_graph::MakeNewOperator"; + MS_LOG(INFO) << "Unknown operator type."; } else { NewOp.apply.op_type = DictOpType.at(op_type); } - if ((NewOp.apply.op_type == OperatorType::kRecMatMul) || (NewOp.apply.op_type == OperatorType::kRecBiasAdd) || - (NewOp.apply.op_type == OperatorType::kRecReshape)) { - NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } else if ((NewOp.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) || - (NewOp.apply.op_type == OperatorType::kRecUnkownType)) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); - } else { + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { NewOp.tensor_parm = MakeTensor( ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { + NewOp.tensor_parm = Fill2DTensor(ops, iter_ops, NewOp); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; } + NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); return NewOp; } -Graph::NodeType MakeNewTensor(std::vector> ops, const size_t iter_ops, - const std::string& input, const size_t iter_input_tensors, std::shared_ptr graph, - size_t current_op_index) { - Graph::NodeType NewTensor; - NewTensor.name = input; - NewTensor.info = InfoType::kConstant; - - if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { - NewTensor.tensor_parm = MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { - Fill2DTensor(ops, iter_ops, graph, iter_input_tensors, current_op_index, NewTensor); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { - NewTensor.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { - NewTensor.tensor_parm = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape unknown in rec_parse_graph::MakeNewTensor"; - } - return NewTensor; -} - -void Fill2DTensor(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, const size_t current_op_index, - Graph::NodeType NewTensor) { - if (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecMatMul) { +TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor) { + if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - if (transpose_a && (iter_input_tensors == 0)) { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (transpose_b && (iter_input_tensors == 1)) { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + if (transpose_a) { + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[1], + ops[iter_ops]->inputs_tensor_info()[0].shape()[0]); + } else if (transpose_b) { + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[1], + ops[iter_ops]->inputs_tensor_info()[0].shape()[0]); } else { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[0], + ops[iter_ops]->inputs_tensor_info()[0].shape()[1]); } } else { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[0], + ops[iter_ops]->inputs_tensor_info()[0].shape()[1]); } + return NewTensor.tensor_parm; } -void CompleteOperatorInputs(std::vector> ops, size_t iter_ops, size_t iter_input_tensors, - size_t current_op_index, std::shared_ptr graph) { - if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { - Complete2DInputs(ops, iter_ops, graph, iter_input_tensors, current_op_index); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape unknown in rec_parse_graph::MakeNewTensor"; +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor) { + for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); + iter_input_tensors++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { + NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; + } } + return NewTensor.apply; } -void Complete2DInputs(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, - const size_t current_op_index) { - if (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecMatMul) { +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensors, Graph::NodeType NewTensor) { + if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); if (transpose_a && (iter_input_tensors == 0)) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); } else if (transpose_b && (iter_input_tensors == 1)) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); } else { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); } } else { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); } + return NewTensor.apply.arguments[iter_input_tensors]; } -void MakeEdge(std::shared_ptr graph, const size_t input_index, const size_t current_op_index) { - graph->nodes[input_index].node_out.push_back(current_op_index); - graph->nodes[current_op_index].node_in.push_back(input_index); -} - -void ModifyTensorToOperator(std::shared_ptr graph, const size_t current_op_index, const size_t iter_ops, - std::vector> ops) { - graph->nodes[current_op_index].info = InfoType::kApplication; - std::string op_type = ops[iter_ops]->type(); - auto idx = DictOpType.find(op_type); - if (idx == DictOpType.end()) { - graph->nodes[current_op_index].apply.op_type = OperatorType::kRecUnkownType; - MS_LOG(INFO) << "Unknown type in rec_parse_graph::ModifyTensorToOperator"; - } else { - graph->nodes[current_op_index].apply.op_type = DictOpType.at(op_type); - } - - if ((graph->nodes[current_op_index].apply.op_type == OperatorType::kRecMatMul) || - (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecBiasAdd) || - (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecReshape)) { - graph->nodes[current_op_index].tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } else if ((graph->nodes[current_op_index].apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) || - (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecUnkownType)) { - graph->nodes[current_op_index].tensor_parm = MakeTensor(1, 1, 1, 1); - } else { - graph->nodes[current_op_index].tensor_parm = MakeTensor( - ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], - ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); - } -} - -std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names, - const std::shared_ptr>& ops_nodes_list) { - std::vector current_graph; +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names) { std::shared_ptr graph(new Graph); if (ops.size() > SIZE_MAX / 2) { MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; } - for (size_t iter_ops = ops.size(); iter_ops > 0; iter_ops--) { - if (IsInList(ops[iter_ops - 1]->cnode_name(), current_graph)) { - size_t current_op_index = static_cast(std::distance( - current_graph.begin(), std::find(current_graph.begin(), current_graph.end(), ops[iter_ops]->cnode_name()))); - std::vector::iterator itr = ops_nodes_list->insert(ops_nodes_list->begin(), current_op_index); - if (itr != ops_nodes_list->begin()) { - MS_LOG(EXCEPTION) << "Iterator error."; - } - ModifyTensorToOperator(graph, current_op_index, iter_ops - 1, ops); - LinkOps(graph, ops, input_tensor_names, current_graph, iter_ops - 1, current_op_index); - } else { - Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops - 1); - current_graph.push_back(NewOp.name); - graph->nodes.push_back(NewOp); - size_t current_op_index = graph->nodes.size() - 1; - std::vector::iterator itr = ops_nodes_list->insert(ops_nodes_list->begin(), current_op_index); - if (itr != ops_nodes_list->begin()) { - MS_LOG(EXCEPTION) << "Iterator error."; - } - LinkOps(graph, ops, input_tensor_names, current_graph, iter_ops - 1, current_op_index); - } + for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { + Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); + graph->nodes.push_back(NewOp); } - return graph; -} - -void LinkOps(std::shared_ptr graph, std::vector> ops, - const std::vector>& input_tensor_names, std::vector current_graph, - const size_t iter_ops, const size_t current_op_index) { - for (size_t iter_input_tensors = 0; - iter_input_tensors < std::min(input_tensor_names[iter_ops].size(), ops[iter_ops]->inputs_tensor_info().size()); - iter_input_tensors++) { - std::string input = input_tensor_names[iter_ops][iter_input_tensors]; - if (IsInList(input, current_graph)) { - size_t input_index = static_cast( - std::distance(current_graph.begin(), std::find(current_graph.begin(), current_graph.end(), input))); - MakeEdge(graph, input_index, current_op_index); - CompleteOperatorInputs(ops, iter_ops, iter_input_tensors, current_op_index, graph); - } else { - Graph::NodeType NewTensor = MakeNewTensor(ops, iter_ops, input, iter_input_tensors, graph, current_op_index); - current_graph.push_back(NewTensor.name); - graph->nodes.push_back(NewTensor); - size_t input_index = graph->nodes.size() - 1; - CompleteOperatorInputs(ops, iter_ops, iter_input_tensors, current_op_index, graph); - MakeEdge(graph, input_index, current_op_index); - } + MakeEdge(input_tensor_names, graph); - if (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecBatchNorm) { - break; - } - } + return graph; } -void Eliminate_Aux(const size_t node_index, std::shared_ptr graph, - const std::shared_ptr>> eli_list) { - if ((graph->nodes[node_index].apply.op_type == OperatorType::kRecUnkownType) || - (graph->nodes[node_index].apply.op_type == OperatorType::kRecReLU)) { - size_t input_index = (graph->nodes[node_index].node_in)[0]; - std::vector outputs = graph->nodes[node_index].node_out; - - std::vector eli; - eli.push_back(node_index); - eli.push_back(input_index); - for (size_t i = 0; i < outputs.size(); i++) { - eli.push_back(i); - } - eli_list->push_back(eli); - - for (size_t i = 1; i < (size_t)graph->nodes[node_index].node_in.size(); i++) { - std::vector tmp; - tmp.push_back(node_index); - tmp.push_back((graph->nodes[node_index].node_in)[i]); - eli_list->push_back(tmp); - } - - auto it = find(graph->nodes[input_index].node_out.begin(), graph->nodes[input_index].node_out.end(), node_index); - std::vector::iterator itr = graph->nodes[input_index].node_out.erase(it); - if (itr != it) { - MS_LOG(EXCEPTION) << "Iterator error."; - } - for (auto output : outputs) { - graph->nodes[input_index].node_out.push_back(output); - } - for (auto& output_index : outputs) { - auto itt = find(graph->nodes[output_index].node_in.begin(), graph->nodes[output_index].node_in.end(), node_index); - graph->nodes[output_index] - .node_in[static_cast(std::distance(graph->nodes[output_index].node_in.begin(), itt))] = input_index; +void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph) { + for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { + for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { + size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); + if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { + graph->nodes[iter_i].node_in.push_back(head_node_index); + graph->nodes[head_node_index].node_out.push_back(iter_i); + } } } } -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - std::shared_ptr>> eli_list, - std::shared_ptr> index_list) { - for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { - if (graph->nodes[node_index].info == InfoType::kApplication) { - Eliminate_Aux(node_index, graph, eli_list); - } - } - - index_list->reserve(graph->nodes.size()); - for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { - index_list->push_back(i); - } - - for (size_t i = 0; i < (size_t)eli_list->size(); i++) { - index_list->at((eli_list->at(i)[0])) = SIZE_MAX; - for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) { - index_list->at(j)--; +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, + const std::string &input_name) { + for (size_t index = 0; index < input_tensor_name.size(); index++) { + if (input_tensor_name[index][0] == input_name) { + return index; } } - - std::shared_ptr new_graph(new Graph); - for (size_t i = 0; i < (size_t)(graph->nodes.size() - eli_list->size()); i++) { - Graph::NodeType NewOp; - new_graph->nodes.push_back(NewOp); - } - - for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { - if (index_list->at(i) > SIZE_MAX / 2) continue; - new_graph->nodes[index_list->at(i)] = graph->nodes[i]; - for (size_t j = 0; j < (size_t)new_graph->nodes[index_list->at(i)].node_in.size(); j++) { - new_graph->nodes[index_list->at(i)].node_in[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_in[j]); - } - for (size_t j = 0; j < (size_t)new_graph->nodes[index_list->at(i)].node_out.size(); j++) { - new_graph->nodes[index_list->at(i)].node_out[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_out[j]); - } - } - - return new_graph; + MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; + return SIZE_MAX; } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 7dfca86a21..ae50ced418 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -43,39 +43,24 @@ const std::map DictOpType{ const TensorParam MakeTensor(int n, int c, int h, int w); -bool IsInList(const std::string& name, const std::vector& list); - Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops); -Graph::NodeType MakeNewTensor(std::vector> ops, const size_t iter_ops, - const std::string& input, const size_t iter_input_tensors, std::shared_ptr graph, - size_t current_op_index); -void Fill2DTensor(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, const size_t current_op_index, - Graph::NodeType NewTensor); -void CompleteOperatorInputs(std::vector> ops, size_t iter_ops, size_t iter_input_tensors, - size_t current_op_index, std::shared_ptr graph); -void Complete2DInputs(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, - const size_t current_op_index); -void MakeEdge(std::shared_ptr graph, const size_t input_index, const size_t current_op_index); +TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor); + +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor); -void ModifyTensorToOperator(std::shared_ptr graph, const size_t current_op_index, const size_t iter_ops, - std::vector> ops); +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensor, Graph::NodeType NewTensor); -std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names, - const std::shared_ptr>& ops_nodes_list); +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names); -void LinkOps(std::shared_ptr graph, std::vector> ops, - const std::vector>& input_tensor_names, std::vector current_graph, - const size_t iter_ops, const size_t current_op_index); +void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph); -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - std::shared_ptr>> eli_list, - std::shared_ptr> index_list); -void Eliminate_Aux(const size_t node_index, std::shared_ptr graph, - const std::shared_ptr>> eli_list); +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, + const std::string &input_name); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index ab216cb22c..bc4aca896b 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -73,11 +73,11 @@ void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_bef void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } -void ParallelContext::set_communication_backend(const std::string& communication_backend) { +void ParallelContext::set_communication_backend(const std::string &communication_backend) { communication_backend_ = communication_backend; } -bool ParallelContext::set_parallel_mode(const std::string& parallel_mode) { +bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); if (iter == PARALLEL_MODE_LIST.end()) { MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode; @@ -87,7 +87,7 @@ bool ParallelContext::set_parallel_mode(const std::string& parallel_mode) { return true; } -bool ParallelContext::set_strategy_search_mode(const std::string& strategy_search_mode) { +bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) { auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode); if (iter == STRATEGY_SEARCH_MODE_LIST.end()) { MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode; diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 265f5bac71..64261cb964 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -40,8 +40,8 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; class ParallelContext { public: ~ParallelContext() = default; - ParallelContext(const ParallelContext&) = delete; - ParallelContext& operator=(const ParallelContext&) = delete; + ParallelContext(const ParallelContext &) = delete; + ParallelContext &operator=(const ParallelContext &) = delete; static std::shared_ptr GetInstance(); @@ -60,13 +60,13 @@ class ParallelContext { void set_global_rank(int32_t global_rank); int32_t global_rank() const { return global_rank_; } - void set_communication_backend(const std::string& communication_backend); + void set_communication_backend(const std::string &communication_backend); std::string communication_backend() const { return communication_backend_; } - bool set_parallel_mode(const std::string& parallel_mode); + bool set_parallel_mode(const std::string ¶llel_mode); std::string parallel_mode() const { return parallel_mode_; } - bool set_strategy_search_mode(const std::string& strategy_search_mode); + bool set_strategy_search_mode(const std::string &strategy_search_mode); std::string strategy_search_mode() const { return strategy_search_mode_; } void set_parameter_broadcast(bool parameter_broadcast); diff --git a/mindspore/ccsrc/parallel/costmodel_context.h b/mindspore/ccsrc/parallel/costmodel_context.h index 23c9f7cc8d..9937483051 100644 --- a/mindspore/ccsrc/parallel/costmodel_context.h +++ b/mindspore/ccsrc/parallel/costmodel_context.h @@ -28,8 +28,8 @@ namespace parallel { class CostModelContext { public: ~CostModelContext() = default; - CostModelContext(const CostModelContext&) = delete; - CostModelContext& operator=(const CostModelContext&) = delete; + CostModelContext(const CostModelContext &) = delete; + CostModelContext &operator=(const CostModelContext &) = delete; void ResetCostModel(); void ResetAlgoParameters(); diff --git a/mindspore/ccsrc/parallel/device_manager.cc b/mindspore/ccsrc/parallel/device_manager.cc index 0b34cedc00..45628bec65 100644 --- a/mindspore/ccsrc/parallel/device_manager.cc +++ b/mindspore/ccsrc/parallel/device_manager.cc @@ -30,15 +30,15 @@ namespace mindspore { namespace parallel { DeviceManagerPtr g_device_manager = nullptr; -Stage::Stage(const std::vector& devices, int num, int rank) +Stage::Stage(const std::vector &devices, int num, int rank) : devices_(devices), number_(num), rank_(rank) { gm_ = GroupManager(); } // NOTE: '-1' indicates ERROR -int Stage::global_rank(Group* g) const { return ((g == nullptr) ? rank_ : -1); } +int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string& backend) { +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { if (device_num <= 0) { MS_LOG(ERROR) << "'device_num' must be positive."; return false; @@ -87,7 +87,7 @@ void CheckGlobalDeviceManager() { } } -int32_t GetListMemberByIndex(size_t index, const RankList& devices) { +int32_t GetListMemberByIndex(size_t index, const RankList &devices) { size_t i = 0; int32_t result = 0; if ((devices.empty()) || (index >= devices.size())) { @@ -104,7 +104,7 @@ int32_t GetListMemberByIndex(size_t index, const RankList& devices) { return result; } -std::shared_ptr GetListMemberByIndex(size_t index, const std::vector>& device_list) { +std::shared_ptr GetListMemberByIndex(size_t index, const std::vector> &device_list) { size_t i = 0; std::shared_ptr result; if ((device_list.empty()) || (index >= device_list.size())) { @@ -123,8 +123,8 @@ std::shared_ptr GetListMemberByIndex(size_t index, const std::vector DeviceManager::GetStageById(int32_t stage_id) { return res; } int32_t index = 0; - for (auto& stage : stages_) { + for (auto &stage : stages_) { if (index == stage_id) return stage; index++; } @@ -224,7 +224,7 @@ RankList DeviceManager::GetDeviceListByStageId(int32_t stage_id) const { << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); RankList res; int32_t index = 0; - for (auto& stage : stage_devices_) { + for (auto &stage : stage_devices_) { if (index == stage_id) { return stage; } @@ -280,19 +280,19 @@ Device DeviceManager::CreateNewDeviceByRank(int32_t rank) const { return Device( std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { std::vector dev_list; - for (auto& rank : ranks) { + for (auto &rank : ranks) { Device one = CreateNewDeviceByRank(rank); dev_list.push_back(one); } return dev_list; } -DeviceManager& DeviceManager::GetInstance() { +DeviceManager &DeviceManager::GetInstance() { static DeviceManager instance = DeviceManager(); return instance; } -std::string DeviceManager::FindRankListNameByHashName(const std::string& hash_name) { +std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) { std::string tmp = "WORLD_GROUP"; if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) { return tmp; @@ -305,7 +305,7 @@ std::string DeviceManager::FindRankListNameByHashName(const std::string& hash_na return iter->second; } -std::string HashName(const std::string& origin_name) { return std::to_string(std::hash{}(origin_name)); } +std::string HashName(const std::string &origin_name) { return std::to_string(std::hash{}(origin_name)); } // Group name is generated using the increasing ranks of the devices. // E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name @@ -343,8 +343,8 @@ std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { // Create the group with the given devices and the given name. The GroupManager // gm_ will create a new group only if there does not exit a group with the same // name. Otherwise, let the pointer g point to that group. -Group DeviceManager::CreateGroup(const std::string& group_name, - const std::vector& devices) { +Group DeviceManager::CreateGroup(const std::string &group_name, + const std::vector &devices) { if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; } @@ -354,7 +354,7 @@ Group DeviceManager::CreateGroup(const std::string& group_name, } // Create the group with only the given devices' ranks. -Group DeviceManager::CreateGroup(const RankList& dev_ranks) { +Group DeviceManager::CreateGroup(const RankList &dev_ranks) { std::unordered_set rank_set(dev_ranks.begin(), dev_ranks.end()); if (dev_ranks.size() != rank_set.size()) { MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list"; diff --git a/mindspore/ccsrc/parallel/device_manager.h b/mindspore/ccsrc/parallel/device_manager.h index e87c1d740f..3afafe6a9c 100644 --- a/mindspore/ccsrc/parallel/device_manager.h +++ b/mindspore/ccsrc/parallel/device_manager.h @@ -53,13 +53,13 @@ class Stage { explicit Stage(std::vector devices) : devices_(std::move(devices)), number_(0), rank_(0) { gm_ = GroupManager(); } - Stage(const std::vector& devices, int num, int rank); + Stage(const std::vector &devices, int num, int rank); ~Stage() = default; int GetStageNum() const { return number_; } size_t GetDevicesNum() const { return devices_.size(); } std::vector GetDevicesList() { return devices_; } - int global_rank(Group* g) const; + int global_rank(Group *g) const; private: std::vector devices_; @@ -70,11 +70,11 @@ class Stage { // This method is used for initializing the global DeviceManager 'g_device_manager', // arguments including 'device_num' and 'global_rank' -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string& backend); +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); void CheckGlobalDeviceManager(); -std::string HashName(const std::string& rank_list_name); +std::string HashName(const std::string &rank_list_name); class DeviceManager { // This class is used to manage the abstract devices, including group-related and stage-related management. @@ -82,9 +82,9 @@ class DeviceManager { DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); } ~DeviceManager() = default; - Status Init(const RankList& devices, int32_t local_device, const RankList& stage_map, const std::string& backend); + Status Init(const RankList &devices, int32_t local_device, const RankList &stage_map, const std::string &backend); - static DeviceManager& GetInstance(); + static DeviceManager &GetInstance(); RankList GetDeviceListByStageId(int32_t stage_id) const; RankList global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const; @@ -92,8 +92,8 @@ class DeviceManager { std::vector CreateDeviceListByRankList(RankList ranks); std::string GenerateGroupNameByRanks(RankList dev_ranks); - Group CreateGroup(const std::string& group_name, const std::vector& devices); - Group CreateGroup(const RankList& dev_ranks); + Group CreateGroup(const std::string &group_name, const std::vector &devices); + Group CreateGroup(const RankList &dev_ranks); std::shared_ptr GetStageById(int32_t stage_id); size_t DeviceNum() const { return devices_.size(); } @@ -105,7 +105,7 @@ class DeviceManager { void set_global_rank(int32_t global_rank) { global_rank_ = global_rank; } void Clear(); std::string world_group() const { return gm_.world_group(); } - std::string FindRankListNameByHashName(const std::string& hash_name); + std::string FindRankListNameByHashName(const std::string &hash_name); private: std::vector> devices_; diff --git a/mindspore/ccsrc/parallel/device_matrix.cc b/mindspore/ccsrc/parallel/device_matrix.cc index 3fdc3dd15a..3c9467a223 100644 --- a/mindspore/ccsrc/parallel/device_matrix.cc +++ b/mindspore/ccsrc/parallel/device_matrix.cc @@ -53,7 +53,7 @@ Status DeviceMatrix::CreateGroupList() { return Status::SUCCESS; } -Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) { +Status DeviceMatrix::GetDevicesAlongDim(const uint32_t &dim, RankList *devices) { if (dim >= dev_shape_.size()) { MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!"; } @@ -78,7 +78,7 @@ Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) for (int32_t i = 0; i < step; i++) { local_group_list.push_back(group); - (void)std::for_each(group.begin(), group.end(), [](int32_t& a) { a++; }); + (void)std::for_each(group.begin(), group.end(), [](int32_t &a) { a++; }); } // higher than dim @@ -88,19 +88,19 @@ Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) // search rank int32_t target = rank_; for (int32_t i = 0; i < len; i++) { - for (RankList& temp : local_group_list) { + for (RankList &temp : local_group_list) { if (std::any_of(temp.begin(), temp.end(), [target](int32_t a) { return a == target; })) { *devices = temp; return Status::SUCCESS; } - (void)std::for_each(temp.begin(), temp.end(), [step](int32_t& a) { a = a + step; }); + (void)std::for_each(temp.begin(), temp.end(), [step](int32_t &a) { a = a + step; }); } } MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!"; return Status::FAILED; } -Shape ConvertRankToCoordinate(int32_t rank, const Shape& dev_shape) { +Shape ConvertRankToCoordinate(int32_t rank, const Shape &dev_shape) { Shape dev_coordinate; for (size_t i = 0; i < dev_shape.size(); ++i) { int32_t size = dev_shape[dev_shape.size() - i - 1]; @@ -115,8 +115,8 @@ Shape ConvertRankToCoordinate(int32_t rank, const Shape& dev_shape) { return dev_coordinate; } -Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* rank_list) { - for (auto& element : tensor_map) { +Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) { + for (auto &element : tensor_map) { // -1 means the corresponding dimension is not split. if (element == MAP_NONE) { continue; @@ -127,10 +127,10 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* ra } Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); - for (auto& tmp_rank : dev_list_) { + for (auto &tmp_rank : dev_list_) { Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); bool matched = true; - for (auto& map : tensor_map) { + for (auto &map : tensor_map) { if (map == MAP_NONE) { continue; } @@ -148,7 +148,7 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* ra return SUCCESS; } -std::string ShapeToString(const Shape& shape) { +std::string ShapeToString(const Shape &shape) { std::string str = "["; for (size_t i = 0; i < shape.size(); ++i) { str += std::to_string(shape[i]); @@ -159,9 +159,9 @@ std::string ShapeToString(const Shape& shape) { return str + "]"; } -std::string ListToString(const std::vector& list) { +std::string ListToString(const std::vector &list) { std::string str = "["; - for (auto& element : list) { + for (auto &element : list) { str += std::to_string(element) + ", "; } return str + "]"; diff --git a/mindspore/ccsrc/parallel/device_matrix.h b/mindspore/ccsrc/parallel/device_matrix.h index a912000604..236a7fad08 100644 --- a/mindspore/ccsrc/parallel/device_matrix.h +++ b/mindspore/ccsrc/parallel/device_matrix.h @@ -37,8 +37,8 @@ class DeviceMatrix { ~DeviceMatrix() = default; std::vector group_list() const { return group_list_; } Status CreateGroupList(); - Status GetDevicesByTensorMap(const Shape& tensor_map, RankList* rank_list); - Status GetDevicesAlongDim(const uint32_t& dim, RankList* devices); + Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); + Status GetDevicesAlongDim(const uint32_t &dim, RankList *devices); private: int32_t rank_ = -1; @@ -48,8 +48,8 @@ class DeviceMatrix { std::vector group_list_; }; -std::string ShapeToString(const Shape& shape); -std::string ListToString(const std::vector& list); +std::string ShapeToString(const Shape &shape); +std::string ListToString(const std::vector &list); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index bad947687d..42ba42cf8a 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -28,28 +28,28 @@ namespace mindspore { namespace parallel { #define REGISTER(className) \ - OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs& attrs) { \ + OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ return std::make_shared(name, in, out, attrs); \ } \ RegisterAction className##Register(#className, (CreatFn)objectCreator##className); -typedef OperatorInfoPtr (*CreatFn)(const std::string& name, const Shapes& shape_in, const Shapes shape_out, - const PrimitiveAttrs& attrs); +typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, + const PrimitiveAttrs &attrs); class DynCreator { public: ~DynCreator() = default; // creat static singleton dyn_creator instance - static DynCreator& Instance() { + static DynCreator &Instance() { static DynCreator fac = DynCreator(); return fac; } // register void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } // creator - OperatorInfoPtr Creat(const std::string& name, const Shapes& shape_in, const Shapes& shape_out, - const PrimitiveAttrs& attrs, size_t count) { + OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, + const PrimitiveAttrs &attrs, size_t count) { std::string op_name = name + std::to_string(count); auto iter = Function_map_.find(name); if (iter == Function_map_.end()) { @@ -66,7 +66,7 @@ class DynCreator { class RegisterAction { public: - RegisterAction(const std::string& name, CreatFn creatfn) : name_(name) { + RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { DynCreator::Instance().Regist(name, creatfn); } ~RegisterAction() = default; diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc index 43df9fe802..f5f0fe85cb 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc @@ -25,7 +25,7 @@ using mindspore::tensor::Tensor; namespace mindspore { namespace parallel { -std::string GetOpPythonPath(const OperatorName& op_name) { +std::string GetOpPythonPath(const OperatorName &op_name) { // almost all ops are defined in two main paths const std::string ops_module = OP_PATH; py::module mod = py::module::import(common::SafeCStr(ops_module)); @@ -35,7 +35,7 @@ std::string GetOpPythonPath(const OperatorName& op_name) { return ops_module; } -ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name, const std::string& instance_name) { +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { std::string op_path = GetOpPythonPath(op_name); py::module mod = py::module::import(common::SafeCStr(op_path)); if (!py::hasattr(mod, common::SafeCStr(op_name))) { @@ -44,7 +44,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name } std::vector arg_list; (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), - [](const Attr& attr) { return ValuePtrToPyData(attr.second); }); + [](const Attr &attr) { return ValuePtrToPyData(attr.second); }); py::object obj = parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list); ValuePtr op_instance = nullptr; @@ -56,7 +56,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name return op_instance; } -AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr& value_ptr) { +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) { auto value_node = NewValueNode(value_ptr); MS_EXCEPTION_IF_NULL(value_node); return value_node->cast(); @@ -85,7 +85,7 @@ AnfNodePtr CreatInt32Imm(int32_t value) { return ValuePtrToAnfNodePtr(value_ptr); } -std::string GetInstanceNameByCNode(const CNodePtr& cnode) { +std::string GetInstanceNameByCNode(const CNodePtr &cnode) { PrimitivePtr prim = GetValueNode(cnode->input(0)); if (!prim) { MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr."; @@ -94,7 +94,7 @@ std::string GetInstanceNameByCNode(const CNodePtr& cnode) { return HashInstanceName(instance_name); } -std::string HashInstanceName(const std::string& name) { +std::string HashInstanceName(const std::string &name) { auto using_hash_name = common::GetEnv(USING_HASH_NAME); std::string instance_name; if ((using_hash_name.empty()) || (using_hash_name == "on")) { @@ -105,7 +105,7 @@ std::string HashInstanceName(const std::string& name) { return instance_name; } -Status GenerateGraph::Init(const CNodePtr& cnode) { +Status GenerateGraph::Init(const CNodePtr &cnode) { if (!cnode) { MS_LOG(ERROR) << "Init:cnode is nullptr"; return FAILED; @@ -133,7 +133,7 @@ Status GenerateGraph::Init(const CNodePtr& cnode) { return SUCCESS; } -AnfNodePtr GenerateGraph::PushBack(const std::vector& inputs) { +AnfNodePtr GenerateGraph::PushBack(const std::vector &inputs) { CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode MS_EXCEPTION_IF_NULL(cnode); cnode->set_scope(scope_); @@ -146,7 +146,7 @@ AnfNodePtr GenerateGraph::PushBack(const std::vector& inputs) { return new_anf_node_ptr; } -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name, const OperatorAttrs& attrs) { +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) { name_idx_++; ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_)); if (pyop_instance == nullptr) { @@ -156,7 +156,7 @@ AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name, const OperatorA return value_node->cast(); } -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name) { +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) { name_idx_++; OperatorAttrs attrs; ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_)); diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/parallel/graph_util/generate_graph.h index c829e67b6a..d5535c7dc2 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.h @@ -33,25 +33,25 @@ namespace mindspore { namespace parallel { #define USING_HASH_NAME "USING_HASH_NAME" // Get the operator's path where the operator has be defined -std::string GetOpPythonPath(const OperatorName& op_name); +std::string GetOpPythonPath(const OperatorName &op_name); // Init python operator Instance -ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name, const std::string& instance_name); +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); AnfNodePtr CreatTypeInt(int32_t value); AnfNodePtr CreatInt32Imm(int32_t value); AnfNodePtr CreateInt32Tensor(int32_t value); -std::string HashInstanceName(const std::string& name); +std::string HashInstanceName(const std::string &name); class GenerateGraph { public: GenerateGraph() : name_idx_(0) {} - Status Init(const CNodePtr& cnode); + Status Init(const CNodePtr &cnode); ~GenerateGraph() = default; AnfNodePtr virtual_input_node() { return virtual_input_node_; } - AnfNodePtr NewOpInst(const OperatorName& op_name, const OperatorAttrs& attrs); - AnfNodePtr NewOpInst(const OperatorName& op_name); - AnfNodePtr PushBack(const std::vector& inputs); + AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs); + AnfNodePtr NewOpInst(const OperatorName &op_name); + AnfNodePtr PushBack(const std::vector &inputs); private: CNodePtr cnode_; diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc index 3006cb7680..cbffc10e70 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc @@ -29,7 +29,7 @@ namespace mindspore { namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr& graph) { +py::dict GetParameterLayout(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; std::vector graph_params = graph->parameters(); @@ -50,7 +50,7 @@ py::dict GetParameterLayout(const FuncGraphPtr& graph) { return dict; } -py::dict GetCNodeStrategy(const FuncGraphPtr& graph) { +py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; auto ret = graph->get_return(); @@ -75,7 +75,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr& graph) { return dict; } -py::dict GetAllreduceFusion(const FuncGraphPtr& graph) { +py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE); diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h index 78f597b213..e21b81a557 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h @@ -23,9 +23,9 @@ namespace mindspore { namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr& graph); -py::dict GetCNodeStrategy(const FuncGraphPtr& graph); -py::dict GetAllreduceFusion(const FuncGraphPtr& graph); +py::dict GetParameterLayout(const FuncGraphPtr &graph); +py::dict GetCNodeStrategy(const FuncGraphPtr &graph); +py::dict GetAllreduceFusion(const FuncGraphPtr &graph); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/parallel/graph_util/graph_info.cc index 46c9a37960..175413c0fd 100644 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/graph_info.cc @@ -24,12 +24,12 @@ namespace mindspore { namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr& graph, const std::string& name) { +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name) { AnfNodePtr ret = graph->get_return(); MS_EXCEPTION_IF_NULL(ret); std::vector all_nodes = DeepScopedGraphSearch(ret); std::vector prim_list; - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { if (!IsValueNode(node)) { continue; } @@ -44,7 +44,7 @@ std::vector FindPrimtive(const FuncGraphPtr& graph, const std::str return prim_list; } -void DumpGraph(const FuncGraphPtr& root, const std::string& name) { +void DumpGraph(const FuncGraphPtr &root, const std::string &name) { if (MsContext::GetInstance()->save_graphs_flag()) { draw::Draw(name + ".dot", root); DumpIR(name + ".ir", root); diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.h b/mindspore/ccsrc/parallel/graph_util/graph_info.h index 96deab2906..de800f0981 100644 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.h +++ b/mindspore/ccsrc/parallel/graph_util/graph_info.h @@ -24,8 +24,8 @@ namespace mindspore { namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr& graph, const std::string& name); -void DumpGraph(const FuncGraphPtr& root, const std::string& name); +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name); +void DumpGraph(const FuncGraphPtr &root, const std::string &name); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.cc b/mindspore/ccsrc/parallel/graph_util/node_info.cc index b2ce8ba432..c085d71240 100644 --- a/mindspore/ccsrc/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/node_info.cc @@ -23,13 +23,13 @@ namespace mindspore { namespace parallel { -std::string ParameterName(const AnfNodePtr& node_ptr) { +std::string ParameterName(const AnfNodePtr &node_ptr) { auto para_ptr = node_ptr->cast(); MS_EXCEPTION_IF_NULL(para_ptr); return para_ptr->name(); } -bool ParameterRequireGrad(const AnfNodePtr& node_ptr) { +bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { auto para_ptr = node_ptr->cast(); if (para_ptr == nullptr) { return false; diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.h b/mindspore/ccsrc/parallel/graph_util/node_info.h index f4f46d2149..bda268e582 100644 --- a/mindspore/ccsrc/parallel/graph_util/node_info.h +++ b/mindspore/ccsrc/parallel/graph_util/node_info.h @@ -22,9 +22,9 @@ namespace mindspore { namespace parallel { -std::string ParameterName(const AnfNodePtr& node_ptr); +std::string ParameterName(const AnfNodePtr &node_ptr); -bool ParameterRequireGrad(const AnfNodePtr& node_ptr); +bool ParameterRequireGrad(const AnfNodePtr &node_ptr); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/group_manager.h b/mindspore/ccsrc/parallel/group_manager.h index 430d2f64ed..f763d483cc 100644 --- a/mindspore/ccsrc/parallel/group_manager.h +++ b/mindspore/ccsrc/parallel/group_manager.h @@ -37,11 +37,11 @@ class Group { public: Group(); ~Group() = default; - Status Init(const std::string& name, const std::vector& devices); + Status Init(const std::string &name, const std::vector &devices); std::vector GetDevicesList() const; std::string name() const { return name_; } bool IsInThisGroup(int32_t device_rank); - Status GetIndex(size_t* index); + Status GetIndex(size_t *index); size_t GetDevNum() const { return devices_.size(); } private: @@ -54,14 +54,14 @@ class GroupManager { GroupManager(); ~GroupManager() = default; - Status CreateGroup(const std::string& name, const std::vector& devices, Group* group); - Status DestroyGroup(Group* group); + Status CreateGroup(const std::string &name, const std::vector &devices, Group *group); + Status DestroyGroup(Group *group); Status DestroyAllGroups(); - Status GetRankID(const std::string& name, unsigned int* rank_id); - Status GetRankSize(const std::string& name, unsigned int* rank_size); - Status FindGroup(const std::string& name, Group** group); + Status GetRankID(const std::string &name, unsigned int *rank_id); + Status GetRankSize(const std::string &name, unsigned int *rank_size); + Status FindGroup(const std::string &name, Group **group); std::string world_group() const { return world_group_; } - void set_world_group(const std::string& name) { world_group_ = name; } + void set_world_group(const std::string &name) { world_group_ = name; } void Clear(); private: diff --git a/mindspore/ccsrc/parallel/node_check.cc b/mindspore/ccsrc/parallel/node_check.cc index e43d03c29c..7fecd307c7 100644 --- a/mindspore/ccsrc/parallel/node_check.cc +++ b/mindspore/ccsrc/parallel/node_check.cc @@ -80,7 +80,7 @@ const std::set BLACK_LIST = {TUPLE_GETITEM, REF_TO_EMBED, STOP_GRADIENT}; -bool IsInBlackList(const PrimitivePtr& prim) { +bool IsInBlackList(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); } diff --git a/mindspore/ccsrc/parallel/node_check.h b/mindspore/ccsrc/parallel/node_check.h index 6e5db37069..8b628f31b1 100644 --- a/mindspore/ccsrc/parallel/node_check.h +++ b/mindspore/ccsrc/parallel/node_check.h @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -bool IsInBlackList(const PrimitivePtr& prim); +bool IsInBlackList(const PrimitivePtr &prim); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc index e659759de2..6bc33677a6 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -Status Activation::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -41,7 +41,7 @@ Status Activation::SetCostUnderStrategy(const StrategyPtr& strategy) { return SUCCESS; } -Status Activation::CheckStrategy(const StrategyPtr& strategy) { +Status Activation::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -110,7 +110,7 @@ Status Activation::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; @@ -120,7 +120,7 @@ Status Activation::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status Softmax::CheckStrategy(const StrategyPtr& strategy) { +Status Softmax::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -133,7 +133,7 @@ Status Softmax::CheckStrategy(const StrategyPtr& strategy) { std::vector stra = strategy->GetInputDim(); Dimensions input_strategy = stra.at(0); - for (auto& element : axis_) { + for (auto &element : axis_) { int32_t axis_index = element; if (element < 0) { size_t input_dim = inputs_shape_.at(0).size(); @@ -176,7 +176,7 @@ Status Softmax::GetAttrs() { } std::vector value_vector = value_tuple->value(); (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_), - [](const ValuePtr& value) { return static_cast(GetValue(value)); }); + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); if (axis_.empty()) { MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; return FAILED; @@ -205,7 +205,7 @@ Status Softmax::GetAttrs() { return SUCCESS; } -Status Softmax::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -231,7 +231,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { is_auto_parallel_ = true; Shape input0_split; (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); - for (auto& element : axis_) { + for (auto &element : axis_) { int32_t axis_index = element; if (element < 0) { size_t input_dim = inputs_shape_.at(0).size(); @@ -247,7 +247,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; @@ -334,7 +334,7 @@ Status ActivationBase::InferTensorInfo() { return SUCCESS; } -Status ActivationBase::Init(const StrategyPtr& strategy) { +Status ActivationBase::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -344,7 +344,7 @@ Status ActivationBase::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status ActivationBase::InitForCostModel(const StrategyPtr& strategy) { +Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -547,7 +547,7 @@ Status ExpandDimsInfo::InferMirrorOps() { return SUCCESS; } -Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { +Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) { std::vector axis; auto axis_list = value_tuple->value(); if (inputs_shape_.empty()) { @@ -568,7 +568,7 @@ Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { } // convert negative axis to positive. - for (auto& dim : axis_list) { + for (auto &dim : axis_list) { if (!dim->isa()) { MS_LOG(ERROR) << name_ << ": The type of axis is not int"; return FAILED; @@ -595,7 +595,7 @@ Status SqueezeInfo::GetAttrs() { return SUCCESS; } -Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) { +Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { Attr attr = std::make_pair(AXIS, axis_); OperatorAttrs attrs = {attr}; OperatorParams params; @@ -689,7 +689,7 @@ Status SqueezeInfo::InferTensorInfo() { return SUCCESS; } -Status SqueezeInfo::Init(const StrategyPtr& strategy) { +Status SqueezeInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; } diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index 887be5ea33..a71c6b6df7 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -31,13 +31,13 @@ namespace mindspore { namespace parallel { class ActivationBase : public OperatorInfo { public: - ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ActivationBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; protected: Status InferMirrorOps() override; @@ -49,21 +49,21 @@ class ActivationBase : public OperatorInfo { class Activation : public ActivationBase { public: - Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Activation() override = default; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; }; class ActivationInfo : public Activation { public: - ActivationInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~ActivationInfo() override = default; @@ -73,8 +73,8 @@ class ActivationInfo : public Activation { class ActivationOther : public Activation { public: - ActivationOther(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~ActivationOther() override = default; @@ -84,31 +84,31 @@ class ActivationOther : public Activation { class GeluInfo : public ActivationOther { public: - GeluInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~GeluInfo() override = default; }; class TanhInfo : public ActivationOther { public: - TanhInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~TanhInfo() override = default; }; class Softmax : public ActivationBase { public: - explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Softmax() override = default; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; private: @@ -117,32 +117,32 @@ class Softmax : public ActivationBase { class SoftmaxInfo : public Softmax { public: - SoftmaxInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Softmax(name, inputs_shape, outputs_shape, attrs) {} ~SoftmaxInfo() override = default; }; class LogSoftmaxInfo : public Softmax { public: - LogSoftmaxInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Softmax(name, inputs_shape, outputs_shape, attrs) {} ~LogSoftmaxInfo() override = default; }; class ReLUInfo : public ActivationOther { public: - ReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ReLUInfo() override = default; }; class CastInfo : public ActivationOther { public: - CastInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~CastInfo() override = default; @@ -152,23 +152,23 @@ class CastInfo : public ActivationOther { class SqrtInfo : public ActivationOther { public: - SqrtInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SqrtInfo() override = default; }; class NegInfo : public ActivationOther { public: - NegInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~NegInfo() override = default; }; class ExpandDimsInfo : public ActivationOther { public: - ExpandDimsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ExpandDimsInfo() override = default; @@ -187,18 +187,18 @@ class ExpandDimsInfo : public ActivationOther { class SqueezeInfo : public ActivationOther { public: - SqueezeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SqueezeInfo() override = default; protected: - Status InferAxis(const ValueTuplePtr& value_tuple); + Status InferAxis(const ValueTuplePtr &value_tuple); Status GetAttrs() override; - Status InferReplaceOps(const StrategyPtr& strategy); + Status InferReplaceOps(const StrategyPtr &strategy); Status InferTensorMap() override; Status InferTensorInfo() override; - Status Init(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; private: ValueTuplePtr axis_; @@ -206,8 +206,8 @@ class SqueezeInfo : public ActivationOther { class SquareInfo : public ActivationOther { public: - SquareInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SquareInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index 78dfc23803..27caacc30c 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -31,92 +31,92 @@ namespace mindspore { namespace parallel { class ArithmeticBase : public OperatorInfo { public: - ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ArithmeticBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; void ReComputeBatchSplitFlagList() override; protected: Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); Shapes InferExpendShape(); }; class SubInfo : public ArithmeticBase { public: - SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SubInfo() override = default; }; class TensorAddInfo : public ArithmeticBase { public: - TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TensorAddInfo() override = default; }; class MulInfo : public ArithmeticBase { public: - MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MulInfo() override = default; }; class DivInfo : public ArithmeticBase { public: - DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DivInfo() override = default; }; class RealDivInfo : public ArithmeticBase { public: - RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~RealDivInfo() override = default; }; class FloorDivInfo : public ArithmeticBase { public: - FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~FloorDivInfo() override = default; }; class PowInfo : public ArithmeticBase { public: - PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PowInfo() override = default; }; class GreaterInfo : public ArithmeticBase { public: - GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~GreaterInfo() override = default; }; class AssignSubInfo : public ArithmeticBase { public: - AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~AssignSubInfo() override = default; }; @@ -124,8 +124,8 @@ class AssignSubInfo : public ArithmeticBase { // All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { public: - SigmoidCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SigmoidCrossEntropyWithLogitsInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc index 9d356cd573..dac3b0a675 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status BatchParallelInfo::CheckStrategy(const StrategyPtr& strategy) { +Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -161,7 +161,7 @@ Status BatchParallelInfo::InferTensorInfo() { Status BatchParallelInfo::GetAttrs() { return SUCCESS; } -Status BatchParallelInfo::Init(const StrategyPtr& strategy) { +Status BatchParallelInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -170,7 +170,7 @@ Status BatchParallelInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status BatchParallelInfo::InitForCostModel(const StrategyPtr& strategy) { +Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -184,7 +184,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h index 4cedb9b7b8..db6cb206d5 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h @@ -29,22 +29,22 @@ namespace mindspore { namespace parallel { class BatchParallelInfo : public OperatorInfo { public: - BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} - BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), dev_num_(1) {} ~BatchParallelInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; @@ -60,8 +60,8 @@ class BatchParallelInfo : public OperatorInfo { class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { public: - SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, - const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, + const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; void ReComputeBatchSplitFlagList() override; diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h index e792858338..37f555a258 100644 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h +++ b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h @@ -32,26 +32,26 @@ namespace mindspore { namespace parallel { class BiasAddInfo : public OperatorInfo { public: - BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~BiasAddInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; void ReComputeBatchSplitFlagList() override; protected: Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h index 9ea496e0b0..8dd2976b04 100644 --- a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h @@ -30,32 +30,32 @@ namespace mindspore { namespace parallel { class EqualInfo : public ArithmeticBase { public: - EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~EqualInfo() override = default; }; class NotEqualInfo : public ArithmeticBase { public: - NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~NotEqualInfo() override = default; }; class MaximumInfo : public ArithmeticBase { public: - MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MaximumInfo() override = default; }; class MinimumInfo : public ArithmeticBase { public: - MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MinimumInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc index c755cc785d..87b8d15cca 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc @@ -32,7 +32,7 @@ namespace mindspore { namespace parallel { static int32_t SEED_NUM = 1; -Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null"; return FAILED; @@ -129,7 +129,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() { return SUCCESS; } -Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -159,7 +159,7 @@ Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; @@ -178,7 +178,7 @@ std::shared_ptr>> DropoutDoMaskInfo::GenerateBa return std::make_shared>>(strategy_v); } -Status DropoutDoMaskInfo::Init(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -188,7 +188,7 @@ Status DropoutDoMaskInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -202,7 +202,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr& cnode) { +PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; @@ -237,7 +237,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr& cnode) { // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. -Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr& cnode) { +Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); MS_EXCEPTION_IF_NULL(prim); diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h index 3b154bd6db..c0d112f52d 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h @@ -31,20 +31,20 @@ namespace mindspore { namespace parallel { class DropoutDoMaskInfo : public OperatorInfo { public: - DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DropoutDoMaskInfo() override = default; - Status Init(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; - Operator GetDropoutGenMaskReplaceOp(const CNodePtr& cnode); + Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorMap() override; diff --git a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h index 84b8030f37..2172c5cd89 100644 --- a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h @@ -29,37 +29,37 @@ namespace mindspore { namespace parallel { class ExpInfo : public ActivationOther { public: - ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ExpInfo() override = default; }; class LogInfo : public ActivationOther { public: - LogInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~LogInfo() override = default; }; class CosInfo : public ActivationOther { public: - CosInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~CosInfo() override = default; }; class ACosInfo : public ActivationOther { public: - ACosInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ACosInfo() override = default; }; class LogicalNotInfo : public ActivationOther { public: - LogicalNotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~LogicalNotInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc index c315991849..c9e8835f35 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc @@ -70,7 +70,7 @@ Status GatherV2Info::GetAttrs() { return SUCCESS; } -Status GatherV2Info::CheckStrategy(const StrategyPtr& strategy) { +Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " << inputs_shape_.size(); @@ -256,7 +256,7 @@ Status GatherV2Info::InferTensorSubOps() { return SUCCESS; } -Status GatherV2Info::Init(const StrategyPtr& strategy) { +Status GatherV2Info::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -270,7 +270,7 @@ Status GatherV2Info::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status GatherV2Info::InitForCostModel(const StrategyPtr& strategy) { +Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -301,7 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; @@ -311,7 +311,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h index 773d46f429..f7aeb6a0d9 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h @@ -38,22 +38,22 @@ constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; // If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. class GatherV2Info : public OperatorInfo { public: - GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), axis_(-1), index_size_(0), axis_strategy_(1) {} ~GatherV2Info() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc index ac9acff41b..29d519fda8 100644 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc @@ -39,7 +39,7 @@ Status GetNextInfo::InferTensorMap() { return SUCCESS; } -Status GetNextInfo::InferTensorLayout(TensorLayouts* outputs_layout) { +Status GetNextInfo::InferTensorLayout(TensorLayouts *outputs_layout) { if (outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << " : The layout is null."; return FAILED; @@ -96,7 +96,7 @@ Status GetNextInfo::InferDevMatrixShape() { return SUCCESS; } -Status GetNextInfo::Init(const StrategyPtr& strategy) { +Status GetNextInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed"; return FAILED; @@ -109,7 +109,7 @@ Status GetNextInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status GetNextInfo::CheckStrategy(const StrategyPtr& strategy) { +Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { std::vector stras = strategy->GetInputDim(); for (Dimensions stra : stras) { if (stra.size() != 0) { @@ -135,7 +135,7 @@ Status GetNextInfo::GetAttrTypes() { auto iter_cast = iter->second->cast(); MS_EXCEPTION_IF_NULL(iter_cast); auto types = iter_cast->value(); - for (auto& type : types) { + for (auto &type : types) { MS_EXCEPTION_IF_NULL(type); types_.push_back(type->ToString()); } @@ -143,7 +143,7 @@ Status GetNextInfo::GetAttrTypes() { auto iter_cast = iter->second->cast(); MS_EXCEPTION_IF_NULL(iter_cast); auto types = iter_cast->value(); - for (auto& type : types) { + for (auto &type : types) { MS_EXCEPTION_IF_NULL(type); types_.push_back(type->ToString()); } @@ -189,7 +189,7 @@ Status GetNextInfo::GetAttrs() { return SUCCESS; } -Status GetNextInfo::InferReplaceOps(const StrategyPtr&) { +Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { Shapes out_shapes = outputs_shape_; for (size_t i = 0; i < out_shapes.size(); ++i) { if (dev_num_ <= 0) { @@ -214,7 +214,7 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr&) { return SUCCESS; } -Status GetNextInfo::InitForCostModel(const StrategyPtr& strategy) { +Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -227,7 +227,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc index 2955f76506..8716997d9f 100644 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status L2NormalizeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -111,7 +111,7 @@ Status L2NormalizeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h index 22ed5a965b..ca063d01d8 100644 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h +++ b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h @@ -31,8 +31,8 @@ namespace mindspore { namespace parallel { class L2NormalizeInfo : public Activation { public: - L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~L2NormalizeInfo() override = default; Status GenerateStrategies(int32_t stage_id) override; @@ -40,7 +40,7 @@ class L2NormalizeInfo : public Activation { protected: Status GetAttrs() override; Status InferMirrorOps() override; - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; private: int32_t axis_ = 0; // Default value = 0 diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h index c52645ade2..50117b8185 100644 --- a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h +++ b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h @@ -38,20 +38,20 @@ constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; // arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. class LayerNormInfo : public OperatorInfo { public: - LayerNormInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), begin_norm_axis_(0) {} ~LayerNormInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; protected: Status GetAttrs() override; - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; @@ -61,7 +61,7 @@ class LayerNormInfo : public OperatorInfo { Status CreateTensorMap(size_t input_index); Status CreateTensorInfo(size_t input_index); Status CreateMirrorOp(size_t input_index); - Status GenerateGammaAndBetaStrategies(const std::vector& sp_vector); + Status GenerateGammaAndBetaStrategies(const std::vector &sp_vector); Status InitShapes(); private: diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/parallel/ops_info/loss_info.cc index 28ea19f120..0ba325c0cd 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -152,7 +152,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -162,7 +162,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -205,7 +205,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; @@ -216,7 +216,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { PrintStrategy(strategy); if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.h b/mindspore/ccsrc/parallel/ops_info/loss_info.h index 44fe22ce90..2679c2d62b 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.h @@ -34,20 +34,20 @@ namespace parallel { // output_0 : [a], output_1: [a, b] class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { public: - SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SoftmaxCrossEntropyWithLogitsInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; void ReComputeBatchSplitFlagList() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index 8d1264482b..3f55efb66c 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -31,8 +31,8 @@ namespace mindspore { namespace parallel { -void SetDevMatrixShape(const Dimensions& mat_a_strategy, const Dimensions& mat_b_strategy, bool transpose_b, - Shape* dev_matrix_shape) { +void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b, + Shape *dev_matrix_shape) { MS_EXCEPTION_IF_NULL(dev_matrix_shape); size_t mat_a_size = mat_a_strategy.size(); size_t mat_b_size = mat_b_strategy.size(); @@ -105,7 +105,7 @@ Status MatMulBase::GetAttrs() { return SUCCESS; } -Status CheckRelevantDimension(const Dimensions& long_strategy, const Dimensions& short_strategy) { +Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) { size_t long_size = long_strategy.size(); size_t short_size = short_strategy.size(); if (long_size < short_size) { @@ -126,7 +126,7 @@ Status CheckRelevantDimension(const Dimensions& long_strategy, const Dimensions& return SUCCESS; } -Status MatMul::CheckStrategy(const StrategyPtr& strategy) { +Status MatMul::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -239,7 +239,7 @@ Status MatMulBase::InferForwardCommunication() { } // dev_matrix_shape: [a, b, c, d, e], then output strategy: [a, b, c, e]; -Dimensions GetOutputStrategy(const Shape& dev_matrix_shape, int32_t repeated_calculation_num) { +Dimensions GetOutputStrategy(const Shape &dev_matrix_shape, int32_t repeated_calculation_num) { Dimensions output_strategy = dev_matrix_shape; if (repeated_calculation_num > 1) { // move the first dimension(repeated_calc_num_) @@ -301,7 +301,7 @@ Status MatMulBase::InferTensorMap() { return SUCCESS; } -Status MatMulBase::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { TensorLayout mat_a_layout, mat_b_layout, output_layout; if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || @@ -353,7 +353,7 @@ Status MatMulBase::InferTensorInfo() { return SUCCESS; } -Status MatMulBase::Init(const StrategyPtr& strategy) { +Status MatMulBase::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -363,7 +363,7 @@ Status MatMulBase::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status MatMulBase::InitForCostModel(const StrategyPtr& strategy) { +Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -377,7 +377,7 @@ Status MatMulBase::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape* const input) { +Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { if (input->size() < 2) { MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; return FAILED; @@ -463,7 +463,7 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) { Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, mindspore::parallel::StrategyPtr* const sp) { + size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); if (!FULLY_USE_DEVICES) { if (IntToSize(product) > dev_num) { @@ -519,7 +519,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, return SUCCESS; } -void MatMulBase::InitTensorInfoForCost(std::vector* relica_inputs_tensor_vector) { +void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_tensor_vector) { TensorLayout tly; if (transpose_a_) { Shape replica_input0_shape(inputs_tensor_info_[0].shape()); @@ -560,7 +560,7 @@ Status MatMulBase::CheckForTensorSliceValid() const { if (inputs_tensor_info_.empty()) { return FAILED; } - for (auto& one_input_tensor : inputs_tensor_info_) { + for (auto &one_input_tensor : inputs_tensor_info_) { auto slice_shape = one_input_tensor.slice_shape(); if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { @@ -570,7 +570,7 @@ Status MatMulBase::CheckForTensorSliceValid() const { return SUCCESS; } -Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (InitForCostModel(strategy) == FAILED) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h index 8a64fb7206..86a74f78f2 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.h @@ -32,21 +32,21 @@ namespace mindspore { namespace parallel { class MatMulBase : public OperatorInfo { public: - MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MatMulBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; // Generate all strategies and the corresponding cost for this MatMul operator Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, StrategyPtr* sp); + size_t input1_shape_size, StrategyPtr *sp); - Status SwapLastTwoElements(Shape* shape); + Status SwapLastTwoElements(Shape *shape); protected: Status InferMirrorOps() override; @@ -54,8 +54,8 @@ class MatMulBase : public OperatorInfo { Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); - void InitTensorInfoForCost(std::vector*); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + void InitTensorInfoForCost(std::vector *); Status CheckForTensorSliceValid() const; Status GetAttrs() override; @@ -67,26 +67,26 @@ class MatMulBase : public OperatorInfo { class MatMul : public MatMulBase { public: - MatMul(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + MatMul(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : MatMulBase(name, inputs_shape, outputs_shape, attrs) {} ~MatMul() override = default; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; }; class MatMulInfo : public MatMul { public: - MatMulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : MatMul(name, inputs_shape, outputs_shape, attrs) {} ~MatMulInfo() override = default; }; class BatchMatMulInfo : public MatMul { public: - BatchMatMulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BatchMatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : MatMul(name, inputs_shape, outputs_shape, attrs) {} ~BatchMatMulInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc index e07609d3c4..2c06a1ace9 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc @@ -54,7 +54,7 @@ Status OneHotInfo::GetAttrs() { return SUCCESS; } -Status OneHotInfo::CheckStrategy(const StrategyPtr& strategy) { +Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { if (inputs_shape_.size() != 3) { MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); return FAILED; @@ -185,7 +185,7 @@ Status OneHotInfo::ExtractInputInfo() { return SUCCESS; } -Status OneHotInfo::ComputeReplaceGraph(const CNodePtr& cnode) { +Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { if (dev_matrix_shape_.back() == 1) { replace_graph_ = nullptr; return SUCCESS; @@ -222,7 +222,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr& cnode) { return SUCCESS; } -ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr& cnode) { +ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { if (ComputeReplaceGraph(cnode) != SUCCESS) { MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; return nullptr; @@ -230,7 +230,7 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr& cnode) { return replace_graph_; } -Status OneHotInfo::Init(const StrategyPtr& strategy) { +Status OneHotInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -244,7 +244,7 @@ Status OneHotInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status OneHotInfo::InitForCostModel(const StrategyPtr& strategy) { +Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -276,7 +276,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; @@ -287,7 +287,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h index a4f00ea093..3c8a64f954 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.h @@ -31,20 +31,20 @@ namespace mindspore { namespace parallel { class OneHotInfo : public OperatorInfo { public: - OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~OneHotInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; std::shared_ptr>> GenerateBatchStrategies() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } @@ -54,7 +54,7 @@ class OneHotInfo : public OperatorInfo { Status ExtractInputInfo(); private: - Status ComputeReplaceGraph(const CNodePtr& cnode); + Status ComputeReplaceGraph(const CNodePtr &cnode); int axis_ = -1; int32_t rank_ = 0; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index c6115a9fa6..8074f2a32e 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -35,7 +35,7 @@ namespace mindspore { namespace parallel { -Status CheckStrategyValue(const StrategyPtr& strategy, const Shapes& inputs_shape, bool is_auto_parallel) { +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { if (strategy == nullptr) { MS_LOG(ERROR) << "The strategy is null."; return FAILED; @@ -190,7 +190,7 @@ Operator CreateVirtualDivOp(int32_t div_num) { } // use for forward all reduce -Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& group) { +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) { OperatorName operator_name = ALL_REDUCE; ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM ValuePtr attr1_value = MakeValue(group); // group @@ -209,7 +209,7 @@ Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& grou } // use for get tensor slice -Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout) { +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { Shape tensor_map = tensor_layout.tensor_map().array(); Shape dev_matrix_shape = tensor_layout.device_arrangement().array(); OperatorName operator_name = GET_TENSOR_SLICE; @@ -228,7 +228,7 @@ Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout) { return op; } -OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num) { +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { if ((dev_num == 0) || (dev_num == 1)) { MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; } @@ -260,7 +260,7 @@ OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num) { return op_for_weight; } -Status OperatorInfo::CreateGroupByTensorMap(const Shape& tensor_map, std::vector* group) { +Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group) { if (group == nullptr) { MS_LOG(ERROR) << "The group is null."; return FAILED; @@ -283,7 +283,7 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape& tensor_map, std::vector return SUCCESS; } -Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector* group) { +Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { if (group == nullptr) { MS_LOG(ERROR) << "The group is null."; return FAILED; @@ -306,7 +306,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector* group) { return SUCCESS; } -Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy) { +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { Shape slice_shape; if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; @@ -318,7 +318,7 @@ Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy) { return slice_shape; } -Status InferSliceShapeByStrategy(const Strategys& strategys, const Shapes& shapes, Shapes* slice_shapes) { +Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { if (slice_shapes == nullptr) { MS_LOG(ERROR) << "The slice_shapes is null."; return FAILED; @@ -357,8 +357,8 @@ Status InferSliceShapeByStrategy(const Strategys& strategys, const Shapes& shape return SUCCESS; } -Status OperatorInfo::InferSliceShape(const Strategys& inputs_strategy, const Strategys& outputs_strategy, - Shapes* inputs_slice_shape, Shapes* outputs_slice_shape) { +Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { MS_LOG(ERROR) << "The slice_shape is null."; return FAILED; @@ -379,7 +379,7 @@ Status OperatorInfo::InferSliceShape(const Strategys& inputs_strategy, const Str } // method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 -Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -437,7 +437,7 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strat } // method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape -Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -485,7 +485,7 @@ Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr& str return SUCCESS; } -Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -513,7 +513,7 @@ Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr& strategy) { return SUCCESS; } -Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -543,12 +543,12 @@ Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr& strategy) { std::vector> OperatorInfo::GetAliveSuccEdges() { std::vector> ret; - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { ret.push_back(edge); } } - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) { ret.push_back(edge); } @@ -558,7 +558,7 @@ std::vector> OperatorInfo::GetAliveSuccEdges() { std::vector> OperatorInfo::GetAlivePrevEdges() { std::vector> ret; - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator()->is_alive()) { ret.push_back(edge); } @@ -566,12 +566,12 @@ std::vector> OperatorInfo::GetAlivePrevEdges() { return ret; } -void OperatorInfo::ReplacePreEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null."; return; } - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator() == op) { edge = new_edge; return; @@ -580,12 +580,12 @@ void OperatorInfo::ReplacePreEdge(const std::shared_ptr& op, const MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; } -void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null."; return; } - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if (edge->next_operator() == op) { edge = new_edge; return; @@ -594,13 +594,13 @@ void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr& op, cons MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; } -void OperatorInfo::ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null."; return; } std::vector> new_pre_edges; - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator() != op) { new_pre_edges.push_back(edge); } @@ -609,13 +609,13 @@ void OperatorInfo::ReplacePreEdges(const std::shared_ptr& op, cons prev_edges_ = new_pre_edges; } -void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null"; return; } std::vector> new_succ_edges; - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if (edge->next_operator() != op) { new_succ_edges.push_back(edge); } @@ -625,7 +625,7 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr& op, con } std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes& shapes, const std::vector& split_flag_list) { + const Shapes &shapes, const std::vector &split_flag_list) { if (shapes.size() != split_flag_list.size()) { MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " << shapes.size(); @@ -665,14 +665,14 @@ void OperatorInfo::ComputeBatchSplitFlagList() { } // This is a common method for checking whether the generated stragegy has the correct number of devuces. -Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes& inputs_partitions, StrategyPtr* const sp) { +Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { if (sp == nullptr) { MS_LOG(ERROR) << "The strategy is null."; return FAILED; } int32_t product = 1; - for (auto& input_partition : inputs_partitions) { + for (auto &input_partition : inputs_partitions) { product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); } if (!FULLY_USE_DEVICES) { @@ -694,7 +694,7 @@ std::shared_ptr>> OperatorInfo::GenerateBatchSt return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); } -void PrintStrategy(const StrategyPtr& strategy) { +void PrintStrategy(const StrategyPtr &strategy) { if (strategy == nullptr) { return; } @@ -716,8 +716,8 @@ void PrintStrategy(const StrategyPtr& strategy) { } // generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d]) -Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* const sp_vector) { +Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -740,7 +740,7 @@ Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& input return FAILED; } - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { sp->ExpandInputDimFromOneToTwo(); } @@ -749,8 +749,8 @@ Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& input // generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast // such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -770,7 +770,7 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs } // second, get the correct strategy for input0 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { std::vector tmp_strategy; Dimensions input0_strategy = sp->GetInputDim()[0]; size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); @@ -798,8 +798,8 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs // generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast // such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -819,7 +819,7 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& input } // second, get the correct strategy for input1 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { std::vector tmp_strategy; tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 @@ -848,8 +848,8 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& input // generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast // such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -881,7 +881,7 @@ Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs } // step3: reset the strategy if the dimension is 1 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { Dimensions input0_strategy = sp->GetInputDim()[0]; Dimensions input1_strategy = sp->GetInputDim()[1]; for (size_t i = 0; i < inputs_shape[0].size(); ++i) { @@ -904,9 +904,9 @@ Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs // dimension is splittable. 'inputs_partitions' is the result of partitions. // NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring // specific dimensions in inputs have the identical partition should have individual implementation. -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -932,7 +932,7 @@ Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& in MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size(); Shapes inputs_partitions; size_t global_index = 0; - for (auto& shape : inputs_shape) { + for (auto &shape : inputs_shape) { Shape tmp_partition; for (size_t j = 0; j < shape.size(); ++j) { tmp_partition.push_back(combined_partitions[global_index]); @@ -974,8 +974,8 @@ Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& in // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -1025,7 +1025,7 @@ Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_sh return SUCCESS; } -Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { +Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { if (InitForCostModel(strategy) == FAILED) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed."; @@ -1063,8 +1063,8 @@ int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { return is_output_parameter_involve_; } is_parameter_involve_ = is_parameter_; - const auto& prev_edges = this->GetAlivePrevEdges(); - for (auto& p_edge : prev_edges) { + const auto &prev_edges = this->GetAlivePrevEdges(); + for (auto &p_edge : prev_edges) { auto input_index = p_edge->next_op_input_index(); auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); if (input_index >= is_parameter_involve_.size()) { @@ -1090,7 +1090,7 @@ int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { return is_output_parameter_involve_; } -Status OperatorInfo::set_is_parameter(const std::vector& is_parameter) { +Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { if (is_parameter.size() != inputs_shape_.size()) { MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size() << " do not have the same number of inputs_shape_: " << inputs_shape_.size(); @@ -1111,7 +1111,7 @@ Status OperatorInfo::CalculateMemoryCost() { operator_cost()->set_is_parameter_involve(is_parameter_involve_); operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); // Set the memory cost in the 'strategy_cost_' - for (auto& swc : strategy_cost_) { + for (auto &swc : strategy_cost_) { auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); swc->cost_list[0]->memory_with_reuse_ = mem_cost; } @@ -1119,7 +1119,7 @@ Status OperatorInfo::CalculateMemoryCost() { } Status OperatorInfo::CorrectMemoryCost(size_t input_index) { - for (auto& swc : strategy_cost_) { + for (auto &swc : strategy_cost_) { double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * static_cast(operator_cost()->inputs_type_lengths()[input_index]); swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; @@ -1132,13 +1132,13 @@ Status OperatorInfo::CorrectMemoryCost(size_t input_index) { return SUCCESS; } -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape& dev_matrix_shape, const Shape& tensor_map) { +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) { int32_t ret = -1; // The number of repetitions is equal to the number of all devices divided by the number of devices use for // tensor map. int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); - for (auto& element : tensor_map) { + for (auto &element : tensor_map) { // -1 means the corresponding dimension is not split. if (element == MAP_NONE) { continue; @@ -1211,8 +1211,8 @@ Status OperatorInfo::InferVirtualDivOps() { return SUCCESS; } -Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths) { +Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { if (input_lengths.size() != inputs_shape_.size()) { MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size() << " do not have the same number of inputs shape: " << inputs_shape_.size(); @@ -1229,7 +1229,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& inpu return SUCCESS; } -Status OperatorInfo::set_outputs_type(const std::vector& outputs_type) { +Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { if (outputs_type.size() != outputs_shape_.size()) { MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() << " do not have the same number of outputs shape: " << outputs_shape_.size(); @@ -1239,7 +1239,7 @@ Status OperatorInfo::set_outputs_type(const std::vector& outputs_type) return SUCCESS; } -void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra, const CostPtr& cost) { +void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { CheckGlobalDeviceManager(); auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 19e0eeeda1..347da7e573 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -69,23 +69,23 @@ class OperatorInfo { virtual ~OperatorInfo() = default; - Status set_is_parameter(const std::vector& is_parameter); - Status SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths); + Status set_is_parameter(const std::vector &is_parameter); + Status SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths); // Set outputs dtype. // If only one output, outputs_type.size() is 1. // If output is tuple, outputs_type.size() is greater than 1. - Status set_outputs_type(const std::vector& outputs_type); - const std::vector& outputs_type() const { return outputs_type_; } - virtual Status Init(const StrategyPtr& strategy) = 0; - virtual Status InitForCostModel(const StrategyPtr& strategy) = 0; // only init the necessary parts + Status set_outputs_type(const std::vector &outputs_type); + const std::vector &outputs_type() const { return outputs_type_; } + virtual Status Init(const StrategyPtr &strategy) = 0; + virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts // Given the stage_id (which indicates the number of devices), // generate all strategies for this operator virtual Status GenerateStrategies(int32_t stage_id) = 0; - const OperatorCostPtr& operator_cost() const { return operator_cost_; } - void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; } - virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; + const OperatorCostPtr &operator_cost() const { return operator_cost_; } + void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } + virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; virtual std::shared_ptr>> GenerateBatchStrategies(); virtual void ReComputeBatchSplitFlagList(); @@ -94,7 +94,7 @@ class OperatorInfo { double GetForwardMemoryCostFromCNode(); // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy // is checked - Status SetCostUnderStrategyBase(const StrategyPtr& strategy); + Status SetCostUnderStrategyBase(const StrategyPtr &strategy); std::vector> GetStrategyCost() { return strategy_cost_; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. @@ -104,61 +104,61 @@ class OperatorInfo { ForwardOp forward_op() const { return forward_op_; } ForwardOp replace_op() const { return replace_op_; } OutPutInfoVector replace_op_info() const { return replace_op_info_; } - virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; } + virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; } MirrorOps mirror_ops() const { return mirror_ops_; } Ops sub_ops() const { return sub_ops_; } VirtualDivOp virtual_div_op() const { return virtual_div_op_; } Shape dev_matrix_shape() const { return dev_matrix_shape_; } std::vector inputs_tensor_info() const { return inputs_tensor_info_; } std::vector outputs_tensor_info() const { return outputs_tensor_info_; } - const std::string& name() const { return name_; } - void set_name(const std::string& name) { name_ = name; } + const std::string &name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } RankList global_device_list() const { return global_device_list_; } - void AddSuccEdge(const std::shared_ptr& e) { succ_edges_.push_back(e); } - void AddPrevEdge(const std::shared_ptr& e) { prev_edges_.push_back(e); } + void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } + void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } std::vector> succ_edges() const { return succ_edges_; } std::vector> prev_edges() const { return prev_edges_; } std::vector> GetAliveSuccEdges(); std::vector> GetAlivePrevEdges(); - void ReplacePreEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); + void ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } - void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { + void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) { selected_strategy_ = s_strategy; selected_cost_ = cost; } StrategyPtr selected_strategy() const { return selected_strategy_; } CostPtr selected_cost() const { return selected_cost_; } - Status InitSelectedStrategy(const StrategyPtr& s_strategy) { return Init(s_strategy); } - void set_input_value(const std::vector& input_value) { input_value_ = input_value; } - void set_outputs_dtype(const TypePtr& dtype) { outputs_dtype_ = dtype; } - void set_cnode(const CNodePtr& cnode) { cnode_ = cnode; } + Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } + void set_input_value(const std::vector &input_value) { input_value_ = input_value; } + void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } + void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } bool is_alive() const { return is_alive_; } void SetNotAlive() { is_alive_ = false; } StrategyPtr strategy() const { return strategy_; } - void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; } + void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; } void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } - const std::string& refkey_parameter_name() const { return refkey_parameter_name_; } + const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated // multiple times. This method is to correct this, and makes the cost is calulated only once. Status CorrectMemoryCost(size_t input_index); int is_output_parameter_involve() const { return is_output_parameter_involve_; } int used_devices() const { return used_devices_; } // needed by rec_parser - void set_type(const std::string& type) { type_ = type; } - const std::string& type() const { return type_; } - void set_cnode_name(const std::string& cnode_name) { cnode_name_ = cnode_name; } - const std::string& cnode_name() const { return cnode_name_; } - const std::unordered_map& attrs() const { return attrs_; } + void set_type(const std::string &type) { type_ = type; } + const std::string &type() const { return type_; } + void set_cnode_name(const std::string &cnode_name) { cnode_name_ = cnode_name; } + const std::string &cnode_name() const { return cnode_name_; } + const std::unordered_map &attrs() const { return attrs_; } protected: // needed by rec_parser std::string type_; std::string cnode_name_; - virtual Status CheckStrategy(const StrategyPtr& strategy) = 0; + virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status InferTensorMap() = 0; virtual Status InferForwardCommunication() = 0; virtual Status InferMirrorOps() = 0; @@ -167,14 +167,14 @@ class OperatorInfo { virtual Status InferDevMatrixShape() = 0; void SetDeviceListByStrategy(); void SetRepeatedCalcDevMatrix(); - Status CreateGroupByTensorMap(const Shape& tensor_map, std::vector* group); - Status CreateGroupByDim(size_t axis, std::vector* group); + Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); + Status CreateGroupByDim(size_t axis, std::vector *group); Status InferAttrs(); void ResetQueueMember(); - Status InitWithAutoRepeatCalc(const StrategyPtr& strategy); - Status InitWithManualRepeatCalc(const StrategyPtr& strategy); - Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strategy); - Status InitForCostModelWithManualRepeatCalc(const StrategyPtr& strategy); + Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitWithManualRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); Status InferRepeatedCalcInfo(); Status InferVirtualDivOps(); @@ -182,9 +182,9 @@ class OperatorInfo { // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output // is used for grad and overload the function. If the output is a scalar, need to override the function too. virtual Status InferAsLossDivisor(); - Status InferSliceShape(const Strategys& inputs_strategy, const Strategys& outputs_strategy, - Shapes* inputs_slice_shape, Shapes* outputs_slice_shape); - void BreakingTiesForPerferringDataParallel(const StrategyPtr&, const CostPtr&); + Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); + void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); std::string name_; Shapes inputs_shape_; @@ -242,29 +242,29 @@ class OperatorInfo { std::vector outputs_type_; }; -Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy); -Status CheckStrategyValue(const StrategyPtr& strategy, const Shapes& inputs_shape, bool); +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); Operator CreateVirtualDivOp(int32_t div_num); -Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& group); -Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout); -OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num); -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape& dev_matrix_shape, const Shape& tensor_map); +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes& shapes, const std::vector& split_flag_list); + const Shapes &shapes, const std::vector &split_flag_list); -void PrintStrategy(const StrategyPtr& strategy); +void PrintStrategy(const StrategyPtr &strategy); // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* sp_vector); +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *sp_vector); // generate strategies for that have two inputs, and input0 or input1 maybe broadcast, // and the corresponding dimensions that are not broadcast are all relevant dimensions // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* sp_vector); +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *sp_vector); -Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc index a4d601dbe9..fed361616b 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc @@ -34,7 +34,7 @@ namespace parallel { * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. * the strategy of w should equal to the channel dimension of strategy of A */ -Status PReLUInfo::CheckStrategy(const StrategyPtr& strategy) { +Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -119,7 +119,7 @@ Dimensions PReLUInfo::GetOutputStrategy() { return output_strategy; } -Status PReLUInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if (inputs_layout == nullptr || outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -181,7 +181,7 @@ Status PReLUInfo::GetAttrs() { return SUCCESS; } -Status PReLUInfo::Init(const StrategyPtr& strategy) { +Status PReLUInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -190,7 +190,7 @@ Status PReLUInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status PReLUInfo::InitForCostModel(const StrategyPtr& strategy) { +Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -224,7 +224,7 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; @@ -234,7 +234,7 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/parallel/ops_info/prelu_info.h index 396407c1ee..28e149fad7 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.h +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.h @@ -33,24 +33,24 @@ namespace parallel { */ class PReLUInfo : public OperatorInfo { public: - PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PReLUInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Dimensions GetOutputStrategy(); diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc index 4cb81ee769..d6e1c277ef 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status ReshapeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -137,7 +137,7 @@ Status ReshapeInfo::GetParameterInput() { return FAILED; } - for (auto& element : elements) { + for (auto &element : elements) { MS_EXCEPTION_IF_NULL(element); if (element->isa()) { int32_t axis = element->cast()->value(); @@ -216,7 +216,7 @@ Strategys ReshapeInfo::GetOutputsStrategy() { return outputs_strategy; } -Status ReshapeInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if (inputs_layout == nullptr || outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -302,7 +302,7 @@ void ReshapeInfo::InferTensorInfoByLayout() { */ Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } -void ReshapeInfo::device_number(const StrategyPtr& strategy) { +void ReshapeInfo::device_number(const StrategyPtr &strategy) { int32_t stage = 0; if (strategy != nullptr) { stage = strategy->GetInputStage(); @@ -313,7 +313,7 @@ void ReshapeInfo::device_number(const StrategyPtr& strategy) { MS_ASSERT(dev_num_ > 0); } -Status ReshapeInfo::InferDefaultLayout(const Shape& shape, TensorLayout* const layout) { +Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { std::vector tensor_map_index; for (size_t i = 0; i < shape.size(); i++) { tensor_map_index.push_back(MAP_NONE); @@ -326,7 +326,7 @@ Status ReshapeInfo::InferDefaultLayout(const Shape& shape, TensorLayout* const l return Status::SUCCESS; } -Status ReshapeInfo::Init(const StrategyPtr& strategy) { +Status ReshapeInfo::Init(const StrategyPtr &strategy) { ResetQueueMember(); device_number(strategy); if (strategy) { @@ -375,7 +375,7 @@ Status ReshapeInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status ReshapeInfo::InitForCostModel(const StrategyPtr& strategy) { +Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -389,7 +389,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; @@ -423,7 +423,7 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h index 3864d2b93d..99ee014175 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.h @@ -34,34 +34,34 @@ namespace parallel { */ class ReshapeInfo : public OperatorInfo { public: - ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), dev_num_(0), input_layout_set_flag_(false), output_layout_set_flag_(false) {} ~ReshapeInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - void SetInputLayout(const TensorLayout& input_layout) { + Status Init(const StrategyPtr &strategy) override; + void SetInputLayout(const TensorLayout &input_layout) { input_layout_ = input_layout; input_layout_set_flag_ = true; } - void SetOutputLayout(const TensorLayout& output_layout) { + void SetOutputLayout(const TensorLayout &output_layout) { output_layout_ = output_layout; output_layout_set_flag_ = true; } - Status InitForCostModel(const StrategyPtr& strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorMap() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Strategys GetOutputsStrategy(); @@ -69,8 +69,8 @@ class ReshapeInfo : public OperatorInfo { Status GetParameterInput(); Status ComputeReplaceOp(); void InferTensorInfoByLayout(); - void device_number(const StrategyPtr& strategy); - Status InferDefaultLayout(const Shape& shape, TensorLayout* const layout); + void device_number(const StrategyPtr &strategy); + Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); int32_t dev_num_; std::vector parameter_input_v_; diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h index 3682fe334f..f7895d0511 100644 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h +++ b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h @@ -32,19 +32,19 @@ class TmpIdentityInfo : public OperatorInfo { // consider this parameter tensor as TmpIdentityInfo operator. TmpIdentityInfo operator tasks as input a tensor, // and outputs the same tensor. After the transformation, subsequent operators can share the output tensor. public: - TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, - const std::string& name = IDENTITY_INFO) + TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, + const std::string &name = IDENTITY_INFO) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TmpIdentityInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override { return SUCCESS; } Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc index 84333a1337..49bbae0cb4 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status TransposeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -43,7 +43,7 @@ Status TransposeInfo::CheckStrategy(const StrategyPtr& strategy) { Status TransposeInfo::InferDevMatrixShape() { std::vector stra = strategy_->GetInputDim(); input_strategy_ = stra.at(0); - for (auto& iter : input_strategy_) { + for (auto &iter : input_strategy_) { dev_matrix_shape_.push_back(iter); } return SUCCESS; @@ -77,7 +77,7 @@ Status TransposeInfo::ComputeAxis() { return FAILED; } axis_v_.clear(); - for (auto& element : elements) { + for (auto &element : elements) { MS_EXCEPTION_IF_NULL(element); if (element->isa()) { int32_t axis = element->cast()->value(); @@ -130,7 +130,7 @@ Strategys TransposeInfo::GetOutputsStrategy() { return outputs_strategy; } -Status TransposeInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -179,7 +179,7 @@ Status TransposeInfo::InferTensorInfo() { // compute axis_v_ during this method Status TransposeInfo::GetAttrs() { return ComputeAxis(); } -Status TransposeInfo::Init(const StrategyPtr& strategy) { +Status TransposeInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -188,7 +188,7 @@ Status TransposeInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status TransposeInfo::InitForCostModel(const StrategyPtr& strategy) { +Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -202,7 +202,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; @@ -234,7 +234,7 @@ Status TransposeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << "strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/parallel/ops_info/transpose_info.h index e4e2b90b7b..50b76bde65 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.h +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.h @@ -33,23 +33,23 @@ namespace parallel { */ class TransposeInfo : public OperatorInfo { public: - TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TransposeInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Strategys GetOutputsStrategy(); diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc index cd3b40315c..4b695ba62d 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -171,7 +171,7 @@ Status VirtualDatasetInfo::InferTensorInfo() { Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } -Status VirtualDatasetInfo::Init(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { if (InitWithManualRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -179,7 +179,7 @@ Status VirtualDatasetInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -199,7 +199,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { } } -Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; @@ -223,7 +223,7 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); StrategyPtr sp; std::vector strategy; - for (auto& shape : inputs_shape_) { + for (auto &shape : inputs_shape_) { Shape temp; temp.emplace_back(SizeToInt(total_dev_num)); (void)temp.insert(temp.end(), shape.size() - 1, 1); diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h index 398bae3585..312ac7a6a4 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h @@ -30,19 +30,19 @@ namespace mindspore { namespace parallel { class VirtualDatasetInfo : public OperatorInfo { public: - VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~VirtualDatasetInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; void ReComputeBatchSplitFlagList() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 81aae04c73..8a95232aa4 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -462,7 +462,6 @@ Status ConstructCostGraphNodes(const std::vector &all_nodes, const F // Needed by rec_parser operator_info->set_type(prim->name()); std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); - operator_info->set_cnode_name(cnode->ToString()); entire_costgraph->AddOperator(operator_info); (void)cnode->set_operator_info(operator_info); @@ -932,12 +931,9 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const } std::shared_ptr> ops_nodes_list(new std::vector); - std::shared_ptr> index_list(new std::vector); - std::shared_ptr>> eli_list(new std::vector>); - std::shared_ptr graph = ParseGraph(ops, input_tensor_names, ops_nodes_list); + std::shared_ptr graph = ParseGraph(ops, input_tensor_names); - graph = EliminateGraph(graph, eli_list, index_list); size_t num_device = g_device_manager->DeviceNum(); if (PartitionForAllDevices(num_device, graph) == SUCCESS) { MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; @@ -946,7 +942,8 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const return FAILED; } - GenerateStrategy(graph, ops, ops_nodes_list, index_list, eli_list); + bool mask_special_ops = true; + GenerateStrategy(graph, mask_special_ops, ops); if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { MS_LOG(INFO) << "Init selected strategy succeeded."; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index bcd4dc3763..c24c14abf6 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -76,7 +76,7 @@ void SetCommunicationOpGroupLabel(std::vector new_node_input) { } } -std::vector CreateInput(const Operator& op, const AnfNodePtr& node, const std::string& instance_name) { +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { MS_EXCEPTION_IF_NULL(node); OperatorArgs arg_forward = op.second; ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name); @@ -85,7 +85,7 @@ std::vector CreateInput(const Operator& op, const AnfNodePtr& node, std::vector new_node_input = {NewValueNode(pyop_instance), node}; if (!params.empty()) { - for (auto& param : params) { + for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); MS_EXCEPTION_IF_NULL(val); int32_t position = param.second; @@ -98,8 +98,8 @@ std::vector CreateInput(const Operator& op, const AnfNodePtr& node, return new_node_input; } -void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const AnfNodePtr& pre_node, - const FuncGraphPtr& func_graph, const std::string& instance_name) { +void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node, + const FuncGraphPtr &func_graph, const std::string &instance_name) { // insert new node before the node FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -121,7 +121,7 @@ void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const An manager->SetEdge(node, SizeToInt(index), new_node); } -std::string CreateInstanceName(const CNodePtr& node, size_t index) { +std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; @@ -132,7 +132,7 @@ std::string CreateInstanceName(const CNodePtr& node, size_t index) { return instance_name; } -void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); // step1:get graph manager distribute_operator FuncGraphPtr func_graph = node->func_graph(); @@ -141,7 +141,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { MS_EXCEPTION_IF_NULL(manager); auto uses_set = manager->node_users()[node]; CNodePtr node_to_insert = node; - for (auto& uses_pair : uses_set) { + for (auto &uses_pair : uses_set) { auto uses_cnode = uses_pair.first->cast(); MS_EXCEPTION_IF_NULL(uses_cnode); if (!IsValueNode(uses_cnode->input(0))) { @@ -175,7 +175,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { } } -CNodePtr InsertMakeTuple(const AnfNodePtr& prev, uint32_t num, const FuncGraphPtr& func_graph) { +CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint32_t num, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(prev); MS_EXCEPTION_IF_NULL(func_graph); std::vector make_tuple_inputs; @@ -195,8 +195,8 @@ CNodePtr InsertMakeTuple(const AnfNodePtr& prev, uint32_t num, const FuncGraphPt return make_tuple; } -void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_ptr, const CNodePtr& node, - const FuncGraphPtr& func_graph, int pos, const CNodePtr& pre_node) { +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(pre_node); MS_EXCEPTION_IF_NULL(func_graph); @@ -226,8 +226,8 @@ void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_p } } -void InsertGetTensorSliceOp(const Operator& op, const CNodePtr& node, const FuncGraphPtr& func_graph, int pos, - const std::string& instance_name) { +void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int pos, + const std::string &instance_name) { if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name; } @@ -244,8 +244,8 @@ void InsertGetTensorSliceOp(const Operator& op, const CNodePtr& node, const Func InsertNode(op, node, IntToSize(pos), pre_node, func_graph, instance_name); } -TensorLayout GetTensorInLayout(const CNodePtr& middle_node, const PrimitivePtr& middle_prim, - const OperatorInfoPtr& distribute_operator) { +TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, + const OperatorInfoPtr &distribute_operator) { TensorInfo tensorinfo_in; if (middle_prim->name() == TUPLE_GETITEM) { auto value_node = middle_node->input(2)->cast(); @@ -265,7 +265,7 @@ TensorLayout GetTensorInLayout(const CNodePtr& middle_node, const PrimitivePtr& return tensorinfo_in.tensor_layout(); } -OperatorInfoPtr GetDistributeOperator(const CNodePtr& node) { +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!IsParallelCareNode(node)) { return nullptr; @@ -277,9 +277,9 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr& node) { return distribute_operator; } -void Redistribution(const std::pair& node_pair, const OperatorInfoPtr& distribute_operator, - const CNodePtr& middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr& pre_node) { +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node) { FuncGraphPtr func_graph = middle_node->func_graph(); if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "Redistribution:get graph failed"; @@ -333,13 +333,13 @@ bool StrategyFound(std::unordered_map attrs) { return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); } -bool IsCommunicationOp(const PrimitivePtr& prim) { +bool IsCommunicationOp(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); } -bool FindCommunicationOp(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +bool FindCommunicationOp(const std::vector &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -364,7 +364,7 @@ bool FindCommunicationOp(const std::vector& all_nodes) { return false; } -bool IsParallelCareNode(const CNodePtr& cnode) { +bool IsParallelCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); ValueNodePtr prim_node = cnode->input(0)->cast(); if (prim_node == nullptr) { @@ -389,8 +389,8 @@ bool IsParallelCareNode(const CNodePtr& cnode) { return cnode->in_forward_flag(); } -void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_operator, const CNodePtr& insert_node, - const TensorRedistribution& tensor_redistribution, const CNodePtr& pre_node) { +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) { MS_EXCEPTION_IF_NULL(node->func_graph()); FuncGraphManagerPtr manager = node->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -406,7 +406,7 @@ void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_ insert_node_new = insert_node; } MS_EXCEPTION_IF_NULL(insert_node_new); - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(use_cnode); if (!IsValueNode(use_cnode->input(0))) { @@ -429,7 +429,7 @@ void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_ } } -void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { +void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(next_node); OperatorInfoPtr op_info = next_node->operator_info(); @@ -474,11 +474,11 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { } } -void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) { +void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_cnode = node_pair.first->cast(); if (use_cnode == nullptr || !IsValueNode(use_cnode->input(0))) { continue; @@ -496,8 +496,8 @@ void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) } } -std::vector ReplaceOpInput(const Operator& replace_op, const std::string& instance_name, - const CNodePtr& node) { +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node) { OperatorArgs arg_replace_op = replace_op.second; ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); if (pyop_instance == nullptr) { @@ -518,7 +518,7 @@ std::vector ReplaceOpInput(const Operator& replace_op, const std::st if (first_position == 1) { replace_input.pop_back(); } - for (auto& param : params) { + for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); if (val == nullptr) { MS_LOG(EXCEPTION) << "Failure:val is nullptr"; @@ -531,7 +531,7 @@ std::vector ReplaceOpInput(const Operator& replace_op, const std::st return replace_input; } -void ReplaceOneOp(const Operator& replace_op, const CNodePtr& node) { +void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = func_graph->manager(); @@ -551,7 +551,7 @@ void ReplaceOneOp(const Operator& replace_op, const CNodePtr& node) { (void)manager->Replace(node, replace_node); } -void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node) { +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { // step1:get graph manager distribute_operator OperatorInfoPtr distribute_operator = node->operator_info(); if (distribute_operator == nullptr) { @@ -599,15 +599,15 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node) { MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name(); } -bool IsSomePrimitive(const CNodePtr& cnode, const std::string& name) { +bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { ValueNodePtr anf_node = cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(anf_node); PrimitivePtr prim = anf_node->value()->cast(); return (prim->name() == name); } -void StepReplaceGraph(const std::shared_ptr, AnfNodePtr>>& replace_graph, - const CNodePtr& node) { +void StepReplaceGraph(const std::shared_ptr, AnfNodePtr>> &replace_graph, + const CNodePtr &node) { MS_EXCEPTION_IF_NULL(replace_graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(replace_graph->second); @@ -627,7 +627,7 @@ void StepReplaceGraph(const std::shared_ptr, A if (replace_graph->first.size() != 2) { MS_LOG(EXCEPTION) << "Failure:replace_graph->first.size() must be 2 for OneHot Primitive!"; } - for (auto& replace_input : replace_graph->first) { + for (auto &replace_input : replace_graph->first) { MS_EXCEPTION_IF_NULL(replace_input); manager->SetEdge(replace_input, 1, pre_node); CNodePtr replace_input_cnode = replace_input->cast(); @@ -645,7 +645,7 @@ void StepReplaceGraph(const std::shared_ptr, A replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node } -int32_t GetTupleGetItemIndex(const CNodePtr& cnode) { +int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() != 3) { MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; @@ -666,7 +666,7 @@ int32_t GetTupleGetItemIndex(const CNodePtr& cnode) { // Judge whether the node is a loss, and if there are multiple outputs, // get which output is a grad according to the tuple getitem. // Currently, it is not supported that the sens is a tuple. -LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { +LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { MS_EXCEPTION_IF_NULL(loss_node); FuncGraphPtr sub_graph = loss_node->func_graph(); MS_EXCEPTION_IF_NULL(sub_graph); @@ -718,7 +718,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { MS_LOG(EXCEPTION) << "Invalid loss"; } -void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node) { +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); FuncGraphPtr func_graph = node->func_graph(); @@ -742,7 +742,7 @@ void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node } } -std::pair FindParameter(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { if (!node->isa() && !node->isa() && !node->isa()) { return std::make_pair(nullptr, false); } else if (node->isa()) { @@ -790,7 +790,7 @@ std::pair FindParameter(const AnfNodePtr& node, const FuncGrap return std::make_pair(nullptr, false); } -std::pair FindCNode(const AnfNodePtr& anode, const std::string& name, const FuncGraphPtr& func_graph) { +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(anode); MS_EXCEPTION_IF_NULL(anode->func_graph()); FuncGraphManagerPtr manager = anode->func_graph()->manager(); @@ -798,7 +798,7 @@ std::pair FindCNode(const AnfNodePtr& anode, const std::string& AnfNodeIndexSet node_set = manager->node_users()[anode]; bool result = false; CNodePtr cnode_return = nullptr; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; @@ -820,7 +820,7 @@ std::pair FindCNode(const AnfNodePtr& anode, const std::string& return std::make_pair(result, cnode_return); } -bool IsCastBeforMirror(const CNodePtr& node, size_t index) { +bool IsCastBeforMirror(const CNodePtr &node, size_t index) { // only if cast_before_mirror is true, pre node is cast and type is not float32 return true if (!ParallelContext::GetInstance()->cast_before_mirror()) { return false; @@ -850,7 +850,7 @@ bool IsCastBeforMirror(const CNodePtr& node, size_t index) { return (type_id != kNumberTypeFloat32); } -void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); FuncGraphPtr func_graph = node->func_graph(); @@ -887,7 +887,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { } std::string instance_name = MIRROR_OP; if (IsCastBeforMirror(node, index)) { - for (auto& op : backward_op) { + for (auto &op : backward_op) { // insert new node before the node CNodePtr cnode = node->input(index)->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -895,7 +895,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); } } else { - for (auto& op : backward_op) { + for (auto &op : backward_op) { AnfNodePtr pre_node = node->input(index); InsertNode(op, node, index, pre_node, func_graph, instance_name); } @@ -903,7 +903,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { } } -void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNodePtr& node, bool is_loss_node) { +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(node); MirrorOps mirror_ops = distribute_operator->mirror_ops(); @@ -920,7 +920,7 @@ void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNo } } -std::string GetDisOpName(const std::string& prim_name) { +std::string GetDisOpName(const std::string &prim_name) { std::string op_name = prim_name; if (!prim_name.empty() && (prim_name[0] == '_')) { op_name = prim_name.substr(1); @@ -928,8 +928,8 @@ std::string GetDisOpName(const std::string& prim_name) { return op_name + "Info"; } -OperatorInfoPtr OperatorInstanceByName(const std::string& name, const PrimitiveAttrs& attrs, - const std::vector& shape_list) { +OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { if (shape_list.size() != 2) { MS_LOG(ERROR) << "The size of shape list is not 2"; return nullptr; @@ -951,8 +951,8 @@ OperatorInfoPtr OperatorInstanceByName(const std::string& name, const PrimitiveA return operator_; } -OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, - const std::vector& shape_list) { +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { MS_EXCEPTION_IF_NULL(prim); OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list); if (operator_ == nullptr) { @@ -963,7 +963,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& return operator_; } -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, std::vector shape_list) { OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); for (size_t i = 0; i < shape_list[0].size(); ++i) { @@ -992,7 +992,7 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { std::vector value_vector = value_tuple->value(); (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), - [](const ValuePtr& value) { return static_cast(GetValue(value)); }); + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); strategy.push_back(dim); } else { MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; @@ -1007,7 +1007,7 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { return strategyPtr; } -Shapes GetNodeShape(const AnfNodePtr& node) { +Shapes GetNodeShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shapes; BaseShapePtr base_shape_ptr = node->Shape(); @@ -1039,7 +1039,7 @@ Shapes GetNodeShape(const AnfNodePtr& node) { auto tuple_shape_ptr = dyn_cast(base_shape_ptr); if (tuple_shape_ptr != nullptr) { auto tuple_shape = tuple_shape_ptr->shape(); - for (auto& shape : tuple_shape) { + for (auto &shape : tuple_shape) { auto each_shape = dyn_cast(shape); MS_EXCEPTION_IF_NULL(each_shape); shapes.push_back(each_shape->shape()); @@ -1052,7 +1052,7 @@ Shapes GetNodeShape(const AnfNodePtr& node) { return shapes; } -std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(func_graph); std::vector parameters; @@ -1075,7 +1075,7 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr root_g = roots.back(); MS_EXCEPTION_IF_NULL(root_g); - for (auto& param_node : root_g->parameters()) { + for (auto ¶m_node : root_g->parameters()) { auto param = param_node->cast(); if (param && (name == param->name())) { parameters.push_back(param_node); @@ -1088,7 +1088,7 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const return parameters; } -Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(func_graph); @@ -1107,7 +1107,7 @@ Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph return input_shapes; } -std::vector ExtractShape(const CNodePtr& node) { +std::vector ExtractShape(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shape_inputs, shape_outputs; std::vector shape_all; @@ -1145,14 +1145,14 @@ std::vector ExtractShape(const CNodePtr& node) { return shape_all; } -std::pair FindParallelCareNode(const AnfNodePtr& node) { +std::pair FindParallelCareNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -1174,7 +1174,7 @@ std::pair FindParallelCareNode(const AnfNodePtr& node) { return std::make_pair(nullptr, 0); } -std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNodePtr& parameter) { +std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(parameter); FuncGraphManagerPtr manager = graph->manager(); @@ -1184,7 +1184,7 @@ std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNode return prim_anf_node_pair; } else { AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; - for (auto& param_pair : param_sub_set) { + for (auto ¶m_pair : param_sub_set) { CNodePtr graph_cnode = param_pair.first->cast(); if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { continue; @@ -1208,7 +1208,7 @@ std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNode return std::make_pair(nullptr, 0); } -void SetParallelShape(const AnfNodePtr& parameter, const std::pair& res) { +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { MS_EXCEPTION_IF_NULL(parameter); AbstractBasePtr abstract = parameter->abstract(); MS_EXCEPTION_IF_NULL(abstract); @@ -1237,10 +1237,10 @@ void SetParallelShape(const AnfNodePtr& parameter, const std::pairset_tensor_layout(std::make_shared(tensor_layout)); } -void CoverSliceShape(const FuncGraphPtr& root) { +void CoverSliceShape(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); auto parameters = root->parameters(); - for (auto& parameter : parameters) { + for (auto ¶meter : parameters) { MS_EXCEPTION_IF_NULL(parameter->Shape()); auto iter = g_RefMap.find(parameter); if (iter != g_RefMap.end()) { @@ -1258,7 +1258,7 @@ void CoverSliceShape(const FuncGraphPtr& root) { g_RefMap.clear(); } -bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) { +bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(parameter_node); FuncGraphManagerPtr manager = root->manager(); @@ -1281,9 +1281,9 @@ bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_nod return true; } -void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { +void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); - for (auto& cloned_parameter_node : root->parameters()) { + for (auto &cloned_parameter_node : root->parameters()) { MS_EXCEPTION_IF_NULL(cloned_parameter_node); auto cloned_parameter = cloned_parameter_node->cast(); MS_EXCEPTION_IF_NULL(cloned_parameter); @@ -1300,7 +1300,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { bool found_be_cloned_parameter = false; ParameterPtr cloned_from_parameter = nullptr; AnfNodePtr cloned_from_node = nullptr; - for (auto& be_cloned_parameter_node : root->parameters()) { + for (auto &be_cloned_parameter_node : root->parameters()) { MS_EXCEPTION_IF_NULL(be_cloned_parameter_node); auto be_cloned_parameter = be_cloned_parameter_node->cast(); MS_EXCEPTION_IF_NULL(be_cloned_parameter); @@ -1315,7 +1315,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { // get the be cloned index py::list be_cloned_index = parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED_INDEX); - for (auto& index : be_cloned_index) { + for (auto &index : be_cloned_index) { if (cloned_index == py::cast(index)) { found_be_cloned_parameter = true; cloned_from_parameter = be_cloned_parameter; @@ -1341,7 +1341,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { } } -void SetVirtualDatasetStrategy(const CNodePtr& node) { +void SetVirtualDatasetStrategy(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); PrimitivePtr prim = GetValueNode(node->input(0)); MS_EXCEPTION_IF_NULL(prim); @@ -1370,8 +1370,8 @@ void SetVirtualDatasetStrategy(const CNodePtr& node) { } } -void ExtractInformation(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void ExtractInformation(const std::vector &all_nodes) { + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; @@ -1390,7 +1390,7 @@ void ExtractInformation(const std::vector& all_nodes) { if (operator_ == nullptr) { MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); std::vector input_value; for (size_t index = 1; index < inputs.size(); ++index) { if (inputs[index]->isa()) { @@ -1440,7 +1440,7 @@ void ExtractInformation(const std::vector& all_nodes) { } } -TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair) { +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair) { CNodePtr cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); @@ -1456,13 +1456,13 @@ TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair } // if reshape's output connect to several primitive, return the first layout found -std::shared_ptr FindNextLayout(const CNodePtr& cnode) { +std::shared_ptr FindNextLayout(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode->func_graph()); FuncGraphManagerPtr manager = cnode->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; @@ -1492,7 +1492,7 @@ std::shared_ptr FindNextLayout(const CNodePtr& cnode) { return nullptr; } -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, size_t output_index) { +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) { MS_EXCEPTION_IF_NULL(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); MS_EXCEPTION_IF_NULL(distribute_operator); @@ -1505,7 +1505,7 @@ std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, si return std::make_shared(tensorlayout_out); } -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& node, size_t output_index) { +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) { if (!node->isa()) { return nullptr; } @@ -1523,7 +1523,7 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& n return nullptr; } -std::shared_ptr FindPrevLayout(const AnfNodePtr& node) { +std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (node->isa()) { MS_LOG(EXCEPTION) << "Failure: parameter before reshape is not supported temporary"; } @@ -1567,8 +1567,8 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr& node) { return nullptr; } -void ReshapeInit(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void ReshapeInit(const std::vector &all_nodes) { + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; @@ -1607,72 +1607,79 @@ void ReshapeInit(const std::vector& all_nodes) { } } -// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -bool IsGradSensNode(const AnfNodePtr& node) { - if (!node->isa()) { - return false; +CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + CNodePtr return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() < 2) { + MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; } + AnfNodePtr pre_node = return_node->input(1); + MS_EXCEPTION_IF_NULL(pre_node); - // cnode(sens)-->cnode(tuple_getitem) - auto cnode = node->cast(); - AnfNodePtr expect_tuple_getitem = cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem); - if (!expect_tuple_getitem->isa()) { - return false; - } - auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); - if (!IsValueNode(expect_tuple_getitem_cnode->input(0))) { - return false; + auto pre_cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + auto current_prim = GetValueNode(pre_cnode->input(0)); + + // return -> cast + if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_cnode = pre_cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + current_prim = GetValueNode(pre_cnode->input(0)); } - ValueNodePtr expect_tuple_getitem_value_node = expect_tuple_getitem_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_value_node); - PrimitivePtr expect_tuple_getitem_prim = expect_tuple_getitem_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_prim); - if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { - return false; + + // notice: the GetNext op has not input + if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { + MS_LOG(INFO) << "The loss is: " << current_prim->name(); + return pre_cnode; } - // cnode(sens)-->cnode(tuple_getitem)-->cnode - AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); - MS_EXCEPTION_IF_NULL(expect_anonymous); - if (!expect_anonymous->isa()) { - return false; + // size of common cnode is larger than 1 + if (pre_cnode->size() < 2) { + MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; } - // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) - auto expect_anonymous_cnode = expect_anonymous->cast(); - MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); - AnfNodePtr expect_j = expect_anonymous_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_j); - if (!expect_j->isa()) { - return false; + // return -> tuple_getitem -> loss + if (current_prim->name() == TUPLE_GETITEM) { + AnfNodePtr pre_pre_node = pre_cnode->input(1); + MS_EXCEPTION_IF_NULL(pre_pre_node); + + auto pre_pre_cnode = pre_pre_node->cast(); + auto value = pre_pre_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(value); + PrimitivePtr prim = value->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(DEBUG) << "The loss name is " << prim->name(); + return pre_pre_cnode; } - auto expect_j_cnode = expect_j->cast(); - MS_EXCEPTION_IF_NULL(expect_j_cnode); - if (!IsValueNode(expect_j_cnode->input(0))) { - return false; + + // return -> make_tuple + if (current_prim->name() == MAKE_TUPLE) { + MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; } - ValueNodePtr expect_j_value_node = expect_j_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_j_value_node); - PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_j_prim); - return (expect_j_prim->name() == J); + + // return -> loss + MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); + return pre_cnode; } -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr& loss_cnode) { +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + TensorLayouts ret; + if (!IsValueNode(cnode->input(1))) { + MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; + } + auto func_graph = GetValueNode(cnode->input(1)); + auto loss_cnode = FindLossCNode(func_graph); MS_EXCEPTION_IF_NULL(loss_cnode); AnfNodePtr node = loss_cnode->cast(); MS_EXCEPTION_IF_NULL(node); LossNodeInfo node_info = GetLossNodeInfo(node); - ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(prim_anf_node); PrimitivePtr prim = prim_anf_node->value()->cast(); MS_EXCEPTION_IF_NULL(prim); - - TensorLayouts ret; if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; return ret; @@ -1680,7 +1687,6 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr& loss_cnode) { OperatorInfoPtr operator_info = loss_cnode->operator_info(); MS_EXCEPTION_IF_NULL(operator_info); - TensorInfo loss_grad_tensor_info; size_t op_output_size = operator_info->outputs_tensor_info().size(); MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " @@ -1700,7 +1706,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr& loss_cnode) { return ret; } -void SplitSens(const AnfNodePtr& grad_sens_node, const TensorLayout& loss_grad_layout) { +void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { MS_EXCEPTION_IF_NULL(grad_sens_node); auto cnode = grad_sens_node->cast(); @@ -1752,7 +1758,7 @@ void SplitSens(const AnfNodePtr& grad_sens_node, const TensorLayout& loss_grad_l InsertGetTensorSliceOp(op, cnode, func_graph, 1, SPLIT_SENS); } -void InsertForwardOps(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); OperatorVector forward_op = distribute_operator->forward_op(); @@ -1762,7 +1768,7 @@ void InsertForwardOps(const OperatorInfoPtr& distribute_operator, const CNodePtr } } -void StepReplace(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); // StepReplaceOp @@ -1783,7 +1789,7 @@ void StepReplace(const OperatorInfoPtr& distribute_operator, const CNodePtr& cno } } -void HandleDropoutNode(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); @@ -1801,29 +1807,120 @@ void HandleDropoutNode(const OperatorInfoPtr& distribute_operator, const CNodePt ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); } -void HandleSpecialNode(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { HandleDropoutNode(distribute_operator, cnode); } -void ParallelCommunication(const FuncGraphPtr& root, const std::vector& all_nodes, - const FuncGraphManagerPtr& manager) { +std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { + // J->CNode->Graph + std::set graph_set; + for (auto &node : root_all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { + continue; + } + auto expect_j_prim = GetValueNode(cnode->input(0)); + if (expect_j_prim->name() != J) { + continue; + } + if (IsValueNode(cnode->input(1))) { + auto graph = GetValueNode(cnode->input(1)); + MS_LOG(DEBUG) << "Find the forward graph success"; + graph_set.insert(graph); + } + } + return graph_set; +} + +// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) +void StepSplitSens(const AnfNodePtr &node) { + if (!node->isa()) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem) + auto cnode = node->cast(); + AnfNodePtr expect_tuple_getitem = cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem); + if (!expect_tuple_getitem->isa()) { + return; + } + auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); + if (!IsValueNode(expect_tuple_getitem_cnode->input(0))) { + return; + } + auto expect_tuple_getitem_prim = GetValueNode(expect_tuple_getitem_cnode->input(0)); + if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode + AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); + MS_EXCEPTION_IF_NULL(expect_anonymous); + if (!expect_anonymous->isa()) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) + auto expect_anonymous_cnode = expect_anonymous->cast(); + MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); + AnfNodePtr expect_j = expect_anonymous_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_j); + if (!expect_j->isa()) { + return; + } + auto expect_j_cnode = expect_j->cast(); + MS_EXCEPTION_IF_NULL(expect_j_cnode); + if (!IsValueNode(expect_j_cnode->input(0))) { + return; + } + auto expect_j_prim = GetValueNode(expect_j_cnode->input(0)); + if (expect_j_prim->name() == J) { + auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode); + if (!loss_grad_layout.empty()) { + SplitSens(node, loss_grad_layout[0]); + } + } +} + +std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + AnfNodePtr root_return_node = root->get_return(); + MS_EXCEPTION_IF_NULL(root_return_node); + std::vector loss_node; + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + if (graph_set.empty()) { + loss_node.push_back(FindLossCNode(root)); + } + (void)std::transform(graph_set.begin(), graph_set.end(), std::back_inserter(loss_node), + [](const FuncGraphPtr &graph) { return FindLossCNode(graph); }); + return loss_node; +} + +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(manager); TensorRedistribution tensor_redistribution; AnfNodePtr grad_sens_node = nullptr; - CNodePtr loss_cnode = FindLossCNodeFromRoot(root); - MS_EXCEPTION_IF_NULL(loss_cnode); - // get output layout of loss must before inserting the operators below - TensorLayouts loss_layout = GetLossNodeGradOutputLayout(loss_cnode); - - for (auto& node : all_nodes) { - // find sens node - if ((grad_sens_node == nullptr) && IsGradSensNode(node)) { - grad_sens_node = node; - MS_LOG(INFO) << "Find the sens node success"; - } + std::vector loss_cnode = FindLossCNodeFromRoot(root); + // split sens must before inserting the operators. + for (auto &node : all_nodes) { + // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. + // If the type of sens node is not Tensor, it is unsupported now, do nothing default. + StepSplitSens(node); + } + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); @@ -1837,7 +1934,8 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector(node); MS_EXCEPTION_IF_NULL(symbolic_key); auto all_upstream_node = root->manager()->node_users()[node]; - for (auto& upstream_node : all_upstream_node) { + for (auto &upstream_node : all_upstream_node) { FuncGraphPtr fg = upstream_node.first->func_graph(); if (symbolic_key->node()->isa()) { - for (auto& param : root->parameters()) { + for (auto ¶m : root->parameters()) { if (*param == *symbolic_key->node()) { AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param}); MS_EXCEPTION_IF_NULL(reverted_node); @@ -1889,9 +1981,9 @@ void RevertSymbolicKeyInstance(const FuncGraphPtr& root, const AnfNodePtr& node) } } // namespace -void HandleSymbolicKeyInstance(const FuncGraphPtr& root, const std::vector& all_nodes) { +void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector &all_nodes) { MS_EXCEPTION_IF_NULL(root); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { // revert back SymbolicKeyInstance to embed() primitive if (IsValueNode(node)) { RevertSymbolicKeyInstance(root, node); @@ -1900,13 +1992,13 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr& root, const std::vectorget_return(); auto all_nodes = DeepScopedGraphSearch(ret); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1931,7 +2023,7 @@ void CheckpointStrategy(const FuncGraphPtr& func_graph) { } } -void RestoreStrategy(const FuncGraphPtr& func_graph) { +void RestoreStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_LOG(INFO) << "Extract strategy from checkpoint begin"; StrategyMap straMap; @@ -1943,7 +2035,7 @@ void RestoreStrategy(const FuncGraphPtr& func_graph) { } auto ret = func_graph->get_return(); auto all_nodes = DeepScopedGraphSearch(ret); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1968,8 +2060,8 @@ void RestoreStrategy(const FuncGraphPtr& func_graph) { } } -void SetForwardFlag(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void SetForwardFlag(const std::vector &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -1986,8 +2078,8 @@ void SetForwardFlag(const std::vector& all_nodes) { } } -void SetForwardFlag(const AnfNodeSet& all_nodes) { - for (auto& node : all_nodes) { +void SetForwardFlag(const AnfNodeSet &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -2003,134 +2095,57 @@ void SetForwardFlag(const AnfNodeSet& all_nodes) { } } -CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - CNodePtr return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - auto current_value = pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - - // return -> cast - if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_cnode = pre_cnode->input(1)->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - current_prim = GetValueNode(pre_cnode->input(0)); - } - - // notice: the GetNext op has not input - if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { - MS_LOG(INFO) << "The loss is: " << current_prim->name(); - return pre_cnode; - } - - // size of common cnode is larger than 1 - if (pre_cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; - } - - // return -> tuple_getitem -> loss - if (current_prim->name() == TUPLE_GETITEM) { - AnfNodePtr pre_pre_node = pre_cnode->input(1); - MS_EXCEPTION_IF_NULL(pre_pre_node); - - auto pre_pre_cnode = pre_pre_node->cast(); - auto value = pre_pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value); - PrimitivePtr prim = value->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - MS_LOG(INFO) << "The loss name is " << prim->name(); - return pre_pre_cnode; - } else if (current_prim->name() == MAKE_TUPLE) { - MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; - } - - // return -> loss - MS_LOG(INFO) << "The loss name is " << current_prim->name(); - return pre_cnode; +std::set ForwardGraph(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + return graph_set; } -FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet& root_all_nodes) { - for (auto& node : root_all_nodes) { +std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { + MS_EXCEPTION_IF_NULL(graph); + auto loss_cnode = FindLossCNode(graph); + MS_EXCEPTION_IF_NULL(loss_cnode); + auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); + std::vector root_forward_nodes; + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; } - auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if ((cnode->inputs().size() < 2) || !IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr expect_j_value_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_j_value_node); - PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_j_prim); - if (expect_j_prim->name() != J) { - continue; - } - MS_LOG(DEBUG) << "Find J prim: " << expect_j_value_node->DebugString() << "."; - if (IsValueNode(cnode->input(1))) { - auto graph = GetValueNode(cnode->input(1)); - MS_LOG(INFO) << "Find the forward graph success"; - return graph; + auto root_node_id = node->UniqueIdThroughCopy(); + if (loss_cnode_id == root_node_id) { + root_forward_nodes = DeepLinkedGraphSearch(cnode); + break; } } - return nullptr; + return root_forward_nodes; } -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root) { +void MarkForwardCNode(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - const auto& all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); - if (func_graph == nullptr) { - return FindLossCNode(root); - } else { - return FindLossCNode(func_graph); - } -} + auto all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); -FuncGraphPtr ForwardGraph(const FuncGraphPtr& root) { - FuncGraphPtr forward_graph = root; - MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - const auto& all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); - if (func_graph != nullptr) { - forward_graph = func_graph; - } - return forward_graph; -} - -void MarkForwardCNode(const FuncGraphPtr& root) { - MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - auto& all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); - - if (func_graph == nullptr) { - // Can not find the forward graph, so the ops in root graph are forward. + if (graph_set.empty()) { MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; SetForwardFlag(all_nodes); } else { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); - AnfNodePtr return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - std::vector all_dfs_nodes = DeepLinkedGraphSearch(return_node); - SetForwardFlag(all_dfs_nodes); + for (auto &func_graph : graph_set) { + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); + auto return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); + SetForwardFlag(all_dfs_nodes); + auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); + if (root_forward_nodes.empty()) { + continue; + } + // Mark forward flag for the nodes in root graph. + SetForwardFlag(root_forward_nodes); + } } } @@ -2178,7 +2193,7 @@ Status ParallelInit() { return SUCCESS; } -bool StepParallel(const FuncGraphPtr& root, const opt::OptimizerPtr& optimizer) { +bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); @@ -2258,12 +2273,12 @@ bool StepParallel(const FuncGraphPtr& root, const opt::OptimizerPtr& optimizer) } // Needed by rec_parser -std::vector ExtractInputsTensorName(const CNodePtr& node) { +std::vector ExtractInputsTensorName(const CNodePtr &node) { std::vector name_inputs; std::vector all_inputs = node->inputs(); std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - for (auto& input : node_inputs) { + for (auto &input : node_inputs) { std::string name; if (IsValueNode(input) || input->isa() || input->isa()) { name = input->ToString(); diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index fd47a59bf5..b0d128f515 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "./common.h" #include "optimizer/opt.h" @@ -41,114 +42,114 @@ struct LossNodeInfo { int dout_index = 0; // now don't support the sens is a tuple }; -std::vector CreateInput(const Operator& op, const AnfNodePtr& node, const std::string& instance_name); -std::string CreateInstanceName(const CNodePtr& node, size_t index); -void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node); +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); +std::string CreateInstanceName(const CNodePtr &node, size_t index); +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); -void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_ptr, const CNodePtr& node, - const FuncGraphPtr& func_graph, int pos, const CNodePtr& pre_node); +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node); -TensorLayout GetTensorInLayout(const CNodePtr& pre_node, const PrimitivePtr& pre_prim, - const OperatorInfoPtr& distribute_operator_pre); +TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim, + const OperatorInfoPtr &distribute_operator_pre); -OperatorInfoPtr GetDistributeOperator(const CNodePtr& node); +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node); -void Redistribution(const std::pair& node_pair, const OperatorInfoPtr& distribute_operator, - const CNodePtr& middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr& pre_node); +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node); bool StrategyFound(std::unordered_map attrs); -bool IsParallelCareNode(const CNodePtr& cnode); +bool IsParallelCareNode(const CNodePtr &cnode); -void MarkForwardCNode(const FuncGraphPtr& root); +void MarkForwardCNode(const FuncGraphPtr &root); -bool FindCommunicationOp(const std::vector& all_nodes); +bool FindCommunicationOp(const std::vector &all_nodes); -void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_operator, const CNodePtr& insert_node, - const TensorRedistribution& tensor_redistribution, const CNodePtr& pre_node); +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); -std::vector ReplaceOpInput(const Operator& replace_op, const std::string& instance_name, - const CNodePtr& node); +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node); -void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node); +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); -void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node); +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); -std::pair FindParameter(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph); -std::pair FindCNode(const AnfNodePtr& anode, const std::string& name, const FuncGraphPtr& func_graph); +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); -void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node); +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); -void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNodePtr& node, bool is_loss_node); +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node); // Generate and init parallel operator -OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, - const std::vector& shape_list); +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list); // Generate without initing parallel operator -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, std::vector shape_list); // Extract strategy from attr StrategyPtr ExtractStrategy(std::unordered_map attrs); -Shapes GetNodeShape(const AnfNodePtr& node); +Shapes GetNodeShape(const AnfNodePtr &node); -std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); // Extract shape from anfnode -std::vector ExtractShape(const CNodePtr& node); +std::vector ExtractShape(const CNodePtr &node); -std::pair FindParallelCareNode(const AnfNodePtr& node); +std::pair FindParallelCareNode(const AnfNodePtr &node); // Find finally sub graph -std::pair FindSubGraph(const FuncGraphPtr& func_graph, const AnfNodePtr& parameter); +std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); // Set distribute shape for parameters abstract -void SetParallelShape(const AnfNodePtr& parameter, const std::pair& res); +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); // change parameters'shape in resource -void CoverSliceShape(const FuncGraphPtr& root); +void CoverSliceShape(const FuncGraphPtr &root); -void SetVirtualDatasetStrategy(const CNodePtr& node); +void SetVirtualDatasetStrategy(const CNodePtr &node); // Creat parallel operator for primitive node(has strategy) -void ExtractInformation(const std::vector& all_nodes); +void ExtractInformation(const std::vector &all_nodes); -TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair); +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); -std::shared_ptr FindNextLayout(const CNodePtr& node); +std::shared_ptr FindNextLayout(const CNodePtr &node); -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, size_t output_index); +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index); -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& node, size_t output_index); +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index); -std::shared_ptr FindPrevLayout(const AnfNodePtr& node); +std::shared_ptr FindPrevLayout(const AnfNodePtr &node); -void ReshapeInit(const std::vector& all_nodes); +void ReshapeInit(const std::vector &all_nodes); // Add node for whole graph -void ParallelCommunication(const FuncGraphPtr& root, const std::vector& all_nodes, - const FuncGraphManagerPtr& manager); +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager); -void RestoreStrategy(const FuncGraphPtr& func_graph); +void RestoreStrategy(const FuncGraphPtr &func_graph); -void CheckpointStrategy(const FuncGraphPtr& func_graph); +void CheckpointStrategy(const FuncGraphPtr &func_graph); // main step of Parallel -bool StepParallel(const FuncGraphPtr& func_graph, const opt::OptimizerPtr& optimizer); +bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); -int32_t GetTupleGetItemIndex(const CNodePtr& cnode); +int32_t GetTupleGetItemIndex(const CNodePtr &cnode); -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root); +std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root); Status ParallelInit(); -std::vector ExtractInputsTensorName(const CNodePtr& node); +std::vector ExtractInputsTensorName(const CNodePtr &node); -FuncGraphPtr ForwardGraph(const FuncGraphPtr& root); +std::set ForwardGraph(const FuncGraphPtr &root); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/strategy.h b/mindspore/ccsrc/parallel/strategy.h index 93d4d4dff1..fce99305a5 100644 --- a/mindspore/ccsrc/parallel/strategy.h +++ b/mindspore/ccsrc/parallel/strategy.h @@ -46,7 +46,7 @@ class Strategy { inputs_.push_back(inputs_[0]); } } - void ResetInputs(const std::vector& input) { inputs_ = input; } + void ResetInputs(const std::vector &input) { inputs_ = input; } private: const int32_t stage_; @@ -55,7 +55,7 @@ class Strategy { std::vector inputs_; }; -inline StrategyPtr NewStrategy(const int32_t stage, const std::vector& inputs) { +inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { return std::make_shared(stage, inputs); } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 9e3573eee2..dd518dc76c 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -StrategyCheckpoint& StrategyCheckpoint::GetInstance() { +StrategyCheckpoint &StrategyCheckpoint::GetInstance() { static StrategyCheckpoint instance = StrategyCheckpoint(); return instance; } @@ -47,7 +47,7 @@ Status StrategyCheckpoint::RemoveCheckPoint() const { return FAILED; } -Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { +Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { if (strategy_map == nullptr) { MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; } @@ -82,18 +82,18 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } -Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { straspb::ParallelStrategyMap parallel_strategy_map; parallel_strategy_map.set_train_time(IntToUint(++current_train_time_)); - for (auto& node_stra : strategy_map) { - straspb::ParallelStrategyItem* parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); + for (auto &node_stra : strategy_map) { + straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); MS_EXCEPTION_IF_NULL(parallel_strategy_item); parallel_strategy_item->set_node_name(node_stra.first); - straspb::ParallelStrategys* parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); + straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); MS_EXCEPTION_IF_NULL(parallel_strategys); parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); - for (auto& dims : node_stra.second->GetInputDim()) { - straspb::ParallelStrategy* parallel_strategy = parallel_strategys->add_parallel_strategy(); + for (auto &dims : node_stra.second->GetInputDim()) { + straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); MS_EXCEPTION_IF_NULL(parallel_strategy); for (auto dim : dims) { parallel_strategy->add_dim(IntToUint(dim)); diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index b5d3626f53..c871ea6eef 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -32,11 +32,11 @@ class StrategyCheckpoint { StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) { train_times_ = 1; checkpoint_on_ = false; - const char* train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); + const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); if (train_times_str != nullptr && std::stoi(train_times_str) > 0) { train_times_ = std::stoi(train_times_str); } - const char* checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); + const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); if (checkpoint_on_str != nullptr) { checkpoint_on_ = (std::string(checkpoint_on_str) == "on"); } @@ -44,10 +44,10 @@ class StrategyCheckpoint { ~StrategyCheckpoint() = default; bool CheckPointExit() const; Status RemoveCheckPoint() const; - Status Load(StrategyMap* strategy_map); - Status Save(const StrategyMap& strategy_map); + Status Load(StrategyMap *strategy_map); + Status Save(const StrategyMap &strategy_map); - static StrategyCheckpoint& GetInstance(); + static StrategyCheckpoint &GetInstance(); int32_t GetTrainTimes() const { return train_times_; } int32_t GetCurrentTrainTime() const { return current_train_time_; } bool CheckPointOn() const { return checkpoint_on_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc index b42ba30242..235ab00302 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Arrangement::Init(const std::vector& array) { +Status Arrangement::Init(const std::vector &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -45,7 +45,7 @@ bool Arrangement::IsValidArrangement() { void Arrangement::ComputeSize() { size_ = 1; - for (auto& value : array_) { + for (auto &value : array_) { size_ *= value; } } @@ -84,7 +84,7 @@ std::vector Arrangement::GetFrontElementByValue(int32_t value) const { } std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft( - const std::vector& expand_list) const { + const std::vector &expand_list) const { if (expand_list.size() != GetDimSize()) { return nullptr; } @@ -108,7 +108,7 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft * array_ = [8, 4], * arrangement_list = [[4, 2], [2, 2]] */ -std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement& expand_shape) const { +std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { int32_t size = 1; uint32_t ind = 0; std::vector arrangement_list; @@ -140,7 +140,7 @@ std::shared_ptr> Arrangement::GetExpandShapeList(const } std::shared_ptr, Arrangement>> Arrangement::GetExpandShapeListPair( - const Arrangement& expand_shape) const { + const Arrangement &expand_shape) const { std::shared_ptr> expand_shape_list_ptr = GetExpandShapeList(expand_shape); if (expand_shape_list_ptr == nullptr) { return nullptr; @@ -148,7 +148,7 @@ std::shared_ptr, Arrangement>> Arrangement::G std::vector expand_num_list_shape; (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), std::back_inserter(expand_num_list_shape), - [](const Arrangement& arr) { return SizeToInt(arr.GetDimSize()); }); + [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); Arrangement expand_num_list; Status status = expand_num_list.Init(expand_num_list_shape); if (status != Status::SUCCESS) { @@ -169,7 +169,7 @@ std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() co } std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLeft( - const std::vector& expand_list) const { + const std::vector &expand_list) const { if (expand_list.size() != GetDimSize()) { return nullptr; } @@ -191,7 +191,7 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLef return std::make_shared(arrangement_new); } -std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement& in2) const { +std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2) const { std::vector in1_accum; Status status = ShapeToAccumulateProduct(array_, &in1_accum); if (status != Status::SUCCESS) { diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h index 2dc13038c1..ca71b05c91 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h +++ b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h @@ -32,18 +32,18 @@ class Arrangement : public Array { public: Arrangement() : size_(1) {} ~Arrangement() override = default; - Status Init(const std::vector& array) override; + Status Init(const std::vector &array) override; int32_t size() const { return size_; } std::vector GetFrontElementByValue(int32_t value) const; - std::shared_ptr> GetExpandShapeList(const Arrangement& expand_shape) const; + std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; std::vector ComputeReverseAccumulateSumInReverseOrder() const; std::shared_ptr GetExpandedShapeByExpandListReserveLeft( - const std::vector& expand_list) const; + const std::vector &expand_list) const; std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( - const std::vector& expand_list) const; + const std::vector &expand_list) const; std::shared_ptr, Arrangement>> GetExpandShapeListPair( - const Arrangement& expand_shape) const; - std::shared_ptr GetUnifiedShape(const Arrangement& in2) const; + const Arrangement &expand_shape) const; + std::shared_ptr GetUnifiedShape(const Arrangement &in2) const; std::vector GetSqueezeIdx() const; Arrangement GetSqueezeArrangement() const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.cc b/mindspore/ccsrc/parallel/tensor_layout/array.cc index ba3858ae00..ef358e7cde 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/array.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/array.cc @@ -24,14 +24,14 @@ namespace parallel { std::string Array::ToString() const { std::ostringstream buffer; buffer << "[ "; - for (auto& element : array_) { + for (auto &element : array_) { buffer << std::to_string(element) + " "; } buffer << "]"; return buffer.str(); } -Status Array::Init(const std::vector& array) { +Status Array::Init(const std::vector &array) { array_ = array; return IsvalidArray() ? Status::SUCCESS : Status::FAILED; } @@ -54,7 +54,7 @@ int32_t Array::GetDimByReverseIdx(uint32_t idx) const { return array_[GetDimSize() - 1 - mod_idx]; } -bool Array::operator==(const Array& shape) const { +bool Array::operator==(const Array &shape) const { if (GetDimSize() != shape.GetDimSize()) { return false; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.h b/mindspore/ccsrc/parallel/tensor_layout/array.h index f7d9c3c673..5aa3bdb138 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/array.h +++ b/mindspore/ccsrc/parallel/tensor_layout/array.h @@ -31,13 +31,13 @@ class Array { Array() = default; virtual ~Array() = default; std::string ToString() const; - virtual Status Init(const std::vector& array); + virtual Status Init(const std::vector &array); bool IsvalidArray() const; std::vector array() const { return array_; } size_t GetDimSize() const { return array_.size(); } int32_t GetDimByIdx(uint32_t idx) const; int32_t GetDimByReverseIdx(uint32_t idx) const; - bool operator==(const Array& a1) const; + bool operator==(const Array &a1) const; protected: std::vector array_; diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc index 829c056fc2..b5ca5ed60a 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -Status ConstructOperator::Init(const RankList& dev_list, const Shape& dev_matrix_shape) { +Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) { dev_size_ = dev_matrix_shape.size(); dev_matrix_shape_ = dev_matrix_shape; dev_list_ = dev_list; @@ -46,7 +46,7 @@ Status ConstructOperator::ReshapeOP(Shape shape) { return Status::SUCCESS; } -Operator CreateStridedSliceOp(int32_t value, const Shape& begin, const Shape& end, const Shape& strides) { +Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &end, const Shape &strides) { ValuePtr attr_value = MakeValue(value); Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value); Attr attr_end_mask = std::make_pair(END_MASK, attr_value); @@ -230,7 +230,7 @@ Status ConstructOperator::AlltoAllOP(Args args) { return Status::SUCCESS; } -Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector* group) { +Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector *group) { MS_EXCEPTION_IF_NULL(group); CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h index cf6cff456a..1a69638fb6 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h +++ b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h @@ -34,7 +34,7 @@ class ConstructOperator { const int32_t DEFAULT = 0; ConstructOperator() : dev_size_(0) {} ~ConstructOperator() = default; - Status Init(const RankList& dev_list, const Shape& dev_matrix_shape); + Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); Status ReshapeOP(Shape shape); Status StridedSliceOP(Args args); Status AllGatherOP(int32_t dev_dim); @@ -42,7 +42,7 @@ class ConstructOperator { Status ConcatOP(int32_t concat_dim); Status AlltoAllOP(Args args); Operator GetOperator() const { return op_; } - void UpdateTensorShape(const Shape& tensor_shape) { tensor_shape_ = tensor_shape; } + void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } private: Operator op_; @@ -50,7 +50,7 @@ class ConstructOperator { Shape tensor_shape_; RankList dev_list_; Shape dev_matrix_shape_; - Status CreateGroupByDim(size_t axis, std::vector* group); + Status CreateGroupByDim(size_t axis, std::vector *group); }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc index 190a5846ba..84c0580ba8 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc @@ -29,7 +29,7 @@ std::string LayoutTransfer::ToString() const { LayoutTransfer::~LayoutTransfer() = default; -Status LayoutTransfer::Init(const TensorLayout& from_in, const TensorLayout& to_in) { +Status LayoutTransfer::Init(const TensorLayout &from_in, const TensorLayout &to_in) { from_in_ = from_in; to_in_ = to_in; MS_LOG(DEBUG) << "LayoutTransfer " << this->ToString(); diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h index b05128f5b8..c4da4b728f 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h @@ -28,7 +28,7 @@ class LayoutTransfer { LayoutTransfer() = default; virtual ~LayoutTransfer() = 0; std::string ToString() const; - Status Init(const TensorLayout& from_in, const TensorLayout& to_in); + Status Init(const TensorLayout &from_in, const TensorLayout &to_in); TensorLayout from_in() const { return from_in_; } TensorLayout to_in() const { return to_in_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.cc b/mindspore/ccsrc/parallel/tensor_layout/map.cc index 320dbe6ebd..669920fc44 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/map.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/map.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Map::Init(const std::vector& array) { +Status Map::Init(const std::vector &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -46,7 +46,7 @@ bool Map::IsValidMap() { std::vector sorted_array = array_; std::sort(sorted_array.begin(), sorted_array.end()); int32_t value = MAP_NONE; - for (auto& element : sorted_array) { + for (auto &element : sorted_array) { if (element == MAP_NONE) { continue; } @@ -78,7 +78,7 @@ int32_t Map::GetIndexByValue(int32_t value) const { /* * expand.size() should be equal to array_.size() */ -std::shared_ptr Map::ExpandMapByNone(const Arrangement& expand_num_list) const { +std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { if (expand_num_list.GetDimSize() != GetDimSize()) { return nullptr; } @@ -105,7 +105,7 @@ std::shared_ptr Map::ExpandMapByNone(const Arrangement& expand_num_list) co /* * expand.size() should be equal to array_.size() */ -std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement& expand_num_list) const { +std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { return nullptr; } @@ -126,7 +126,7 @@ std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement& expand_nu return map_new; } -std::shared_ptr> Map::ReMapVector(const std::vector& input_vector) const { +std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { if (GetMaxItem() >= static_cast(input_vector.size())) { return nullptr; } @@ -143,7 +143,7 @@ std::shared_ptr> Map::ReMapVector(const std::vector idx_list) const { - for (auto& value : idx_list) { + for (auto &value : idx_list) { if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { return false; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.h b/mindspore/ccsrc/parallel/tensor_layout/map.h index 3f839ef198..8c8bba2775 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/map.h +++ b/mindspore/ccsrc/parallel/tensor_layout/map.h @@ -34,12 +34,12 @@ class Map : public Array { public: Map() = default; ~Map() override = default; - Status Init(const std::vector& array) override; + Status Init(const std::vector &array) override; int32_t GetMaxItem() const; int32_t GetIndexByValue(int32_t value) const; - std::shared_ptr ExpandMapByNone(const Arrangement& expand_num_list) const; - std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement& expand_num_list) const; - std::shared_ptr> ReMapVector(const std::vector& input_vector) const; + std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; + std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; + std::shared_ptr> ReMapVector(const std::vector &input_vector) const; bool CheckNoneByIdxList(std::vector idx_list) const; Map SqueezeMapByIdxList(std::vector idx_list) const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc index b4ec6a016f..946620ec4c 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc @@ -22,8 +22,8 @@ namespace mindspore { namespace parallel { -Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, - RankList dev_list) { +Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, + RankList dev_list, bool is_cost_model) { in_tensor_map_ = tensor_layout.tensor_map(); dev_mat_ = tensor_layout.device_arrangement(); @@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons for (int32_t item : map) { map_[key++] = item; } + + is_cost_model_ = is_cost_model; return Status::SUCCESS; } @@ -103,7 +105,7 @@ Status RedistributionOperatorInfer::InferSplitByAxis() { } if (in_dim == NONE && !std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { Args args = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) { MS_LOG(ERROR) << "Insert SplitByAxis Error!"; @@ -128,17 +130,28 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { } if (in_dim == NONE && std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); - Args args_allconcat = {cat_dim, out_dim, dev_mat_.GetDimByReverseIdx(IntToUint(out_dim))}; - Args args_allsplit = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; - if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { - MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; - return Status::FAILED; - } - if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { - MS_LOG(ERROR) << "Insert SplitByAxis Error!"; - return Status::FAILED; + int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); + if (is_cost_model_) { + int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); + Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, + dev_num}; + if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { + MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; + return Status::FAILED; + } + } else { + Args args_allconcat = {cat_dim, out_dim, dev_num}; + Args args_allsplit = {dev_num, UintToInt(index), out_dim}; + if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { + MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; + return Status::FAILED; + } + if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { + MS_LOG(ERROR) << "Insert SplitByAxis Error!"; + return Status::FAILED; + } } (void)map_.erase(iter++); map_[IntToSize(cat_dim)] = NONE; diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h index b4ec0c4633..a96097a1d3 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h @@ -40,7 +40,8 @@ class RedistributionOperatorInfer { public: const int NONE = -1; explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} - Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list); + Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, + bool is_cost_model = false); ~RedistributionOperatorInfer() = default; OperatorList operator_list() const { return operator_list_; } OperatorVector operator_vector() const { return operator_vector_; } @@ -67,6 +68,7 @@ class RedistributionOperatorInfer { ConstructOperator constructor_; RankList dev_list_; bool construct_op_flag_; + bool is_cost_model_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc index 39a6bef92d..f6c90e9d46 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc @@ -104,7 +104,7 @@ std::shared_ptr ReshapeLayoutTransfer::ExchangeFromAndTo( } std::shared_ptr ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement& expand_shape) const { + const Arrangement &expand_shape) const { std::shared_ptr extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); if (extend_tensor_shape_from_ptr == nullptr) { return nullptr; diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h index 8aae71631d..ed62cb59da 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h @@ -33,7 +33,7 @@ class ReshapeLayoutTransfer : public LayoutTransfer { std::shared_ptr ExtendFromTensorShapeByExpandedTensorShape() const; std::shared_ptr ExtendToTensorShapeByExpandedTensorShape() const; std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement& expand_shape) const; + const Arrangement &expand_shape) const; std::shared_ptr ExchangeFromAndTo() const; private: diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc index a26627fb3c..e8f208708c 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc @@ -26,7 +26,7 @@ namespace parallel { * shape = [2, 8, 32] * shape_accum = [2, 2 * 8, 2 * 8 * 32] */ -Status ShapeToAccumulateProduct(const std::vector& shape, std::vector* shape_accum) { +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -47,7 +47,7 @@ Status ShapeToAccumulateProduct(const std::vector& shape, std::vector& shape, std::vector* shape_accum) { +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector& shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector& shape_accum, std::vector* shape) { +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -92,7 +92,7 @@ Status AccumulateProductToShape(const std::vector& shape_accum, std::ve * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] * shape = [2, 8, 32] */ -Status AccumulateProductReverseToShape(const std::vector& shape_accum_reverse, std::vector* shape) { +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -122,8 +122,8 @@ Status AccumulateProductReverseToShape(const std::vector& shape_accum_r * in2 = [8, 16] * *out = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector& in1_accum, const std::vector& in2_accum, - std::vector* out_accum) { +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum) { MS_EXCEPTION_IF_NULL(out_accum); out_accum->clear(); auto in1_iter = in1_accum.begin(); @@ -159,7 +159,7 @@ Status UnifyAccumulateProduct(const std::vector& in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector& in1, const std::vector& in2, std::vector* out) { +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { MS_EXCEPTION_IF_NULL(out); std::vector in1_accum; Status status = ShapeToAccumulateProduct(in1, &in1_accum); @@ -194,9 +194,9 @@ Status UnifyShape(const std::vector& in1, const std::vector& i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, - const std::vector& expand_accum_reverse, - std::vector* out_accum_reverse) { +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse) { MS_EXCEPTION_IF_NULL(out_accum_reverse); out_accum_reverse->clear(); auto in_riter = in_accum_reverse.rbegin(); @@ -236,7 +236,7 @@ Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector& in, const std::vector& expand, std::vector* out) { +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { MS_EXCEPTION_IF_NULL(out); std::vector in_accum_reverse; Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h index e83156500c..2ec21f3881 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h +++ b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h @@ -39,7 +39,7 @@ namespace parallel { * shape_accum = [2, 2 * 8, 2 * 8 * 32] * */ -Status ShapeToAccumulateProduct(const std::vector& shape, std::vector* shape_accum); +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); /* * compute the accumulating product of all the values in shape from right to left, @@ -53,7 +53,7 @@ Status ShapeToAccumulateProduct(const std::vector& shape, std::vector& shape, std::vector* shape_accum); +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum); /* * compute the original shape from the accumulating product shape_accum, @@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector& shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector& shape_accum, std::vector* shape); +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); /* * compute the original shape from the accumulating product shape_accum, @@ -83,7 +83,7 @@ Status AccumulateProductToShape(const std::vector& shape_accum, std::ve * shape = [2, 8, 32] * */ -Status AccumulateProductReverseToShape(const std::vector& shape_accum_reverse, std::vector* shape); +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); /* * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, @@ -101,8 +101,8 @@ Status AccumulateProductReverseToShape(const std::vector& shape_accum_r * in2_accum = [8, 16] * out_accum = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector& in1_accum, const std::vector& in2_accum, - std::vector* out_accum); +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum); /* * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] @@ -117,7 +117,7 @@ Status UnifyAccumulateProduct(const std::vector& in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector& in1, const std::vector& in2, std::vector* out); +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); /* * given two accumulate product in reverse order of in and expand, @@ -141,9 +141,9 @@ Status UnifyShape(const std::vector& in1, const std::vector& i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, - const std::vector& expand_accum_reverse, - std::vector* out_accum_reverse); +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse); /* * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], @@ -165,7 +165,7 @@ Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector& in, const std::vector& expand, std::vector* out); +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h index 4a64ab472c..43286317c5 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h @@ -32,9 +32,9 @@ using Shapes = std::vector; class TensorInfo { public: - TensorInfo(const TensorLayout& tensor_layout, Shape shape, Shape slice_shape) + TensorInfo(const TensorLayout &tensor_layout, Shape shape, Shape slice_shape) : tensor_layout_(tensor_layout), shape_(std::move(shape)), slice_shape_(std::move(slice_shape)) {} - explicit TensorInfo(const TensorLayout& tensor_layout) : tensor_layout_(tensor_layout) { + explicit TensorInfo(const TensorLayout &tensor_layout) : tensor_layout_(tensor_layout) { shape_ = tensor_layout.tensor_shape().array(); slice_shape_ = tensor_layout.slice_shape().array(); } @@ -44,7 +44,7 @@ class TensorInfo { TensorLayout tensor_layout() const { return tensor_layout_; } Shape slice_shape() const { return slice_shape_; } Shape shape() const { return shape_; } - void set_reduce_dim(const std::vector& dim) { reduce_dim_ = dim; } + void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } std::vector reduce_dim() const { return reduce_dim_; } private: diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc index 5fbd04431c..f3498065f2 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc @@ -45,8 +45,8 @@ std::string TensorLayout::OriginToString() const { return buffer.str(); } -Status TensorLayout::Init(const Arrangement& device_arrangement, const Map& tensor_map, - const Arrangement& tensor_shape) { +Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map, + const Arrangement &tensor_shape) { device_arrangement_origin_ = device_arrangement; tensor_map_origin_ = tensor_map; tensor_shape_origin_ = tensor_shape; @@ -64,8 +64,8 @@ Status TensorLayout::Init(const Arrangement& device_arrangement, const Map& tens } } -Status TensorLayout::InitFromVector(const std::vector& device_arrangement, - const std::vector& tensor_map, const std::vector& tensor_shape) { +Status TensorLayout::InitFromVector(const std::vector &device_arrangement, + const std::vector &tensor_map, const std::vector &tensor_shape) { if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { return FAILED; } @@ -124,7 +124,7 @@ void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { if (idx != -1) { tensor_map_shape[static_cast(idx)] = -1; } - for (auto& value : tensor_map_shape) { + for (auto &value : tensor_map_shape) { if (value >= dev_num_left - 1 - static_cast(i)) { value--; } @@ -153,7 +153,7 @@ int32_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint32_t idx) const { return device_arrangement_.GetDimByIdx(static_cast(GetSliceDeviceDimensionByTensorDimensionIndex(idx))); } -std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement& expanded_shape) const { +std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const { std::shared_ptr expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape); if (expanded_arrangement_ptr == nullptr) { return nullptr; @@ -174,7 +174,7 @@ std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement& * => * out_device_arrangement = [8, 2, 2] */ -std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement& tensor_shape) const { +std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const { std::shared_ptr> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape); if (expand_list_ptr == nullptr) { return nullptr; @@ -204,7 +204,7 @@ std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(con * out_tensor_map = [1, -1, 0, -1], */ std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement& expanded_shape) const { + const Arrangement &expanded_shape) const { std::shared_ptr, Arrangement>> expand_list_pair_ptr = tensor_shape_.GetExpandShapeListPair(expanded_shape); if (expand_list_pair_ptr == nullptr) { @@ -259,7 +259,7 @@ std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDevice * out_tensor_map = [0, 2, 1], * out_tensor_shape = [512, 4, 256] */ -std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement& expanded_arrangement) const { +std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const { std::shared_ptr, Arrangement>> expand_list_pair_ptr = device_arrangement_.GetExpandShapeListPair(expanded_arrangement); if (expand_list_pair_ptr == nullptr) { @@ -287,7 +287,7 @@ std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrang return std::make_shared(tensor_layout_new); } -bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement& expand_shape) const { +bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { std::vector in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { @@ -296,7 +296,7 @@ bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement& expand_shape) con return (in_expand_shape_shape == tensor_shape_.array()); } -std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement& expand_shape) const { +std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { std::vector in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { @@ -345,7 +345,7 @@ Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { return Status::SUCCESS; } -bool TensorLayout::operator==(const TensorLayout& t1) const { +bool TensorLayout::operator==(const TensorLayout &t1) const { return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); } diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h index e6ddc2a708..f51ed4e3e0 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h @@ -37,9 +37,9 @@ class TensorLayout { std::string ToString() const; std::string StandardToString() const; std::string OriginToString() const; - Status Init(const Arrangement& device_arrangement, const Map& tensor_map, const Arrangement& tensor_shape); - Status InitFromVector(const std::vector& device_arrangement, const std::vector& tensor_map, - const std::vector& tensor_shape); + Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); + Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, + const std::vector &tensor_shape); Arrangement device_arrangement() const { return device_arrangement_; } @@ -49,25 +49,25 @@ class TensorLayout { Map origin_tensor_map() const { return tensor_map_origin_; } - std::shared_ptr ExpandTensorShape(const Arrangement& expanded_shape) const; + std::shared_ptr ExpandTensorShape(const Arrangement &expanded_shape) const; - std::shared_ptr ExpandDeviceArrangement(const Arrangement& expanded_arrangement) const; + std::shared_ptr ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; - bool IsSameTensorShape(const TensorLayout& tensor_layout) const { + bool IsSameTensorShape(const TensorLayout &tensor_layout) const { return (tensor_shape_ == tensor_layout.tensor_shape()); } - bool IsSameDeviceArrangement(const TensorLayout& tensor_layout) const { + bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { return (device_arrangement_ == tensor_layout.device_arrangement()); } - bool IsSameTensorMap(const TensorLayout& tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } + bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } - bool operator==(const TensorLayout& t1) const; + bool operator==(const TensorLayout &t1) const; - bool TensorShapeCanBeExpanded(const Arrangement& expanded_shape) const; + bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; - std::shared_ptr ComputeExpandedTensorShape(const Arrangement& expand_shape) const; + std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; Arrangement slice_shape() const; @@ -77,8 +77,8 @@ class TensorLayout { private: std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement& expanded_shape) const; - std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement& tensor_shape) const; + const Arrangement &expanded_shape) const; + std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; bool IsValidTensorLayout() const; void RemoveElementEqualToOneInDeviceArrangement(); int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc index d8eef7e7a5..7824c21f3d 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace parallel { -Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list) { +Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) { from_origin_ = from; to_origin_ = to; if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) { @@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& return Status::SUCCESS; } -RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() { +RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { // Step 1: Match device arrangement between from_ and to_ RedistributionLayoutTransfer layout_transfer; Status status = layout_transfer.Init(from_, to_); @@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); // Step 2: Infer redistribution and insert operators RedistributionOperatorInfer operator_infer(construct_op_flag_); - if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) { + if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { MS_LOG(ERROR) << "Init operatorInfer failed!"; return nullptr; } @@ -87,9 +87,9 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL std::make_pair(operator_vector, output_info_vector)); } -Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, - OperatorVector* const operator_vector, - OutPutInfoVector* const output_info_vector) { +Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, + OutPutInfoVector *const output_info_vector) { MS_EXCEPTION_IF_NULL(operator_vector); MS_EXCEPTION_IF_NULL(output_info_vector); ConstructOperator constructor; @@ -138,27 +138,35 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const } Status TensorRedistribution::ComputeCost() { - RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(); + RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); if (redistribution_oplist_ptr == nullptr) { MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; return Status::FAILED; } // Compute redistribution communication cost and computation cost - for (auto& op_cost : operator_list_) { + for (auto &op_cost : operator_list_) { OperatorR op = op_cost.first; Shape slice_shape = op_cost.second; double prod = std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast(1.0), std::multiplies()); std::string str = op.first; if (str == PERMUTE_BY_AXIS) { - // The shape does not change after PermuteByAxis operation. - // communication cost = all_to_all + all_to_all = 2 * slice_shape - // computation cost = slice_shape - forward_comm_cost_ += prod; - backward_comm_cost_ += prod; - comm_cost_ += 2.0 * prod; - computation_cost_ += prod; - memory_cost_ += prod; + // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. + // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape + forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; + int32_t concat_dim = op.second[2]; + if (concat_dim == 0) { + // memory cost = all_gather + computation_cost_ += prod; + memory_cost_ += prod; + } else { + // memory cost = all_gather + split + concat + int32_t dev_num = op.second[4]; + computation_cost_ += (prod + prod * dev_num + prod * dev_num); + memory_cost_ += (prod * dev_num + prod * dev_num + prod); + } } else if (str == CONCAT_BY_AXIS) { // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape // computation cost = before_slice_shape @@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() { } double dev_num = op.second[2]; // here, communication cost = all_gather + reduce_scatter - forward_comm_cost_ += prod * dev_num; - backward_comm_cost_ += prod; - comm_cost_ += prod * (dev_num + 1.0); + forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; int32_t concat_dim = op.second[0]; if (concat_dim == 0) { // computation cost = all_gather diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h index ebaccadf53..e7800909c5 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h @@ -33,6 +33,8 @@ namespace mindspore { namespace parallel { +constexpr double ALLTOALL_SCALE_FACTOR = 2.0; +constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; class TensorRedistribution { public: explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) @@ -44,9 +46,9 @@ class TensorRedistribution { memory_cost_(0.0), construct_op_flag_(construct_op_flag), keep_reshape_(keep_reshape) {} - Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); + Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); ~TensorRedistribution() = default; - RedistributionOpListPtr InferTensorRedistributionOperatorList(); + RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); OperatorList operator_list() const { return operator_list_; } bool reshape_flag() const { return reshape_flag_; } Status ComputeCost(); @@ -57,8 +59,8 @@ class TensorRedistribution { double memory_cost() const { return memory_cost_; } private: - Status InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, - OperatorVector* const operator_vector, OutPutInfoVector* const output_info_vector); + Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); TensorLayout from_origin_; TensorLayout to_origin_; diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 3e0f8804e7..e8723e66a4 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -41,8 +41,8 @@ using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec, bool clear) { +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear) { MS_LOG(DEBUG) << "AbstractAnalyze start"; auto engine = res->engine(); MS_EXCEPTION_IF_NULL(engine); @@ -50,9 +50,9 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraph auto manager = res->manager(); MS_EXCEPTION_IF_NULL(manager); engine->Clear(); - for (auto& node : manager->all_nodes()) { + for (auto &node : manager->all_nodes()) { MS_EXCEPTION_IF_NULL(node); - const AbstractBasePtr& prev_inferred = node->abstract(); + const AbstractBasePtr &prev_inferred = node->abstract(); // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { node->set_abstract(nullptr); @@ -65,8 +65,8 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraph return ret; } -FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AnalysisContextPtr& context) { +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context) { MS_LOG(DEBUG) << "ProgramSpecialize start"; abstract::ProgramSpecializer spc(res->engine()); FuncGraphPtr result = spc.Run(func_graph, context); @@ -77,8 +77,8 @@ FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_ return result; } -FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec) { +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec) { MS_LOG(DEBUG) << "Renormalize start"; #ifdef ENABLE_PROFILE double t1 = GetTime(); @@ -98,7 +98,7 @@ FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, return ret; } -bool ParseAction(const ResourcePtr& res) { +bool ParseAction(const ResourcePtr &res) { if (!res->input()) { MS_LOG(EXCEPTION) << "Parse error"; } @@ -129,11 +129,11 @@ bool ParseAction(const ResourcePtr& res) { // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} // all obj_map's graph shared base_graph -bool CombineLikeGraphs(const ResourcePtr&) { - auto& obj_map = parse::data_converter::GetObjGraphs(); +bool CombineLikeGraphs(const ResourcePtr &) { + auto &obj_map = parse::data_converter::GetObjGraphs(); for (auto it : obj_map) { - auto& graphs = it.second; + auto &graphs = it.second; MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); auto fg = graphs[0]; FuncGraphPtrList func_graphs = {fg}; @@ -147,7 +147,7 @@ bool CombineLikeGraphs(const ResourcePtr&) { continue; } auto mng = Manage(base_graph, false); - for (auto& fv : fg->paramter_obj_nodes()) { + for (auto &fv : fg->paramter_obj_nodes()) { TraceManager::DebugTrace(std::make_shared(fv->debug_info())); auto param = base_graph->add_parameter(); TraceManager::EndTrace(); @@ -156,11 +156,11 @@ bool CombineLikeGraphs(const ResourcePtr&) { } MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); - for (auto& g : graphs) { + for (auto &g : graphs) { auto fvs = g->paramter_obj_nodes(); std::vector new_node_inputs; new_node_inputs.push_back(NewValueNode(base_graph)); - for (auto& p : g->parameters()) { + for (auto &p : g->parameters()) { AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p); new_node_inputs.push_back(para_after_cast); } @@ -174,7 +174,7 @@ bool CombineLikeGraphs(const ResourcePtr&) { return true; } -bool SymbolResolveAction(const ResourcePtr& res) { +bool SymbolResolveAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; } @@ -195,7 +195,7 @@ bool SymbolResolveAction(const ResourcePtr& res) { return succ; } -bool InferenceOptPrepareAction(const ResourcePtr& res) { +bool InferenceOptPrepareAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; } @@ -205,7 +205,7 @@ bool InferenceOptPrepareAction(const ResourcePtr& res) { return InferenceOptPreparePass(res); } -bool AbstractSpecializeAction(const ResourcePtr& res) { +bool AbstractSpecializeAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "AbstractSpecialize error"; } @@ -215,7 +215,7 @@ bool AbstractSpecializeAction(const ResourcePtr& res) { // suppose that there is not KeywordArgument for the top graph // get the hyper parameter - for (const auto& param : func_graph->parameters()) { + for (const auto ¶m : func_graph->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { AbstractBasePtr ptr = @@ -236,8 +236,8 @@ bool AbstractSpecializeAction(const ResourcePtr& res) { return true; } -bool OptimizeAction(const ResourcePtr& res, const std::vector& passes) { - for (auto& pass : passes) { +bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { + for (auto &pass : passes) { WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() { MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; auto result = pass.second(res); @@ -251,11 +251,11 @@ bool OptimizeAction(const ResourcePtr& res, const std::vector& passes) return true; } -bool GeOptimizeAction(const ResourcePtr& res) { return OptimizeAction(res, kGePasses); } +bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } -bool VmOptimizeAction(const ResourcePtr& res) { return OptimizeAction(res, kVmPasses); } +bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } -bool TaskEmitAction(const ResourcePtr& res) { +bool TaskEmitAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "TaskEmit args error"; } @@ -271,7 +271,7 @@ bool TaskEmitAction(const ResourcePtr& res) { return true; } -bool ExecuteAction(const ResourcePtr& res) { +bool ExecuteAction(const ResourcePtr &res) { if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; } @@ -291,11 +291,11 @@ bool ExecuteAction(const ResourcePtr& res) { // that will result in a syncronization error due to different executing order. // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, // the final solution will be proposed later as a parallel feature. -bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& res) { - auto& node_users = res->manager()->node_users(); - auto& users = node_users[value_node]; +bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { + auto &node_users = res->manager()->node_users(); + auto &users = node_users[value_node]; auto used_by_keep_value_prim = - std::any_of(users.begin(), users.end(), [](const std::pair& user) -> bool { + std::any_of(users.begin(), users.end(), [](const std::pair &user) -> bool { MS_EXCEPTION_IF_NULL(user.first); auto cnode = user.first->cast(); if (cnode == nullptr) { @@ -312,7 +312,7 @@ bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& r return used_by_keep_value_prim; } -bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { +bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "Remove value node duplications error."; } @@ -322,7 +322,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { auto value_nodes = manager->valuenodes()[func_graph]; HashCache hash_cache; HashValue hashes; - for (const auto& value_pair : value_nodes) { + for (const auto &value_pair : value_nodes) { if (KeepValueNodeDuplication(value_pair.first, res)) { continue; } @@ -331,7 +331,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { return true; } -bool ValidateAction(const ResourcePtr& res) { return ValidatePass(res); } +bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } static std::vector CommonPipeline() { std::vector actions; diff --git a/mindspore/ccsrc/pipeline/action.h b/mindspore/ccsrc/pipeline/action.h index 159e494a96..8a651c0038 100644 --- a/mindspore/ccsrc/pipeline/action.h +++ b/mindspore/ccsrc/pipeline/action.h @@ -30,22 +30,22 @@ extern const char kMsConvert[]; namespace pipeline { using ActionItem = std::pair>; -bool ParseAction(const ResourcePtr& res); -bool SymbolResolveAction(const ResourcePtr& res); -bool AbstractSpecializeAction(const ResourcePtr& res); -bool GeOptimizeAction(const ResourcePtr& res); -bool VmOptimizeAction(const ResourcePtr& res); -bool TaskEmitAction(const ResourcePtr& res); -bool ExecuteAction(const ResourcePtr& res); +bool ParseAction(const ResourcePtr &res); +bool SymbolResolveAction(const ResourcePtr &res); +bool AbstractSpecializeAction(const ResourcePtr &res); +bool GeOptimizeAction(const ResourcePtr &res); +bool VmOptimizeAction(const ResourcePtr &res); +bool TaskEmitAction(const ResourcePtr &res); +bool ExecuteAction(const ResourcePtr &res); std::vector GePipeline(); std::vector VmPipeline(); -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec, bool clear = false); -FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AnalysisContextPtr& context); -FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec); +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear = false); +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context); +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/base.h b/mindspore/ccsrc/pipeline/base.h index 30524e84f6..8ca153f45b 100644 --- a/mindspore/ccsrc/pipeline/base.h +++ b/mindspore/ccsrc/pipeline/base.h @@ -37,7 +37,7 @@ struct ExecutorInfo { using ExecutorInfoPtr = std::shared_ptr; -inline std::string GetPhasePrefix(const std::string& phase) { +inline std::string GetPhasePrefix(const std::string &phase) { auto pos = phase.find('.'); if (pos == std::string::npos) { MS_LOG(EXCEPTION) << "Phase has no . for prefix" << phase; @@ -45,7 +45,7 @@ inline std::string GetPhasePrefix(const std::string& phase) { return phase.substr(0, pos); } -inline std::string GetFilePathName(const std::string& file_name) { +inline std::string GetFilePathName(const std::string &file_name) { std::ostringstream oss; auto ms_context = MsContext::GetInstance(); if (ms_context == nullptr) { diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index b709199c87..86e6d436b7 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -53,10 +53,10 @@ PYBIND11_MODULE(_c_expression, m) { (void)py::class_>(*m, "MetaFuncGraph_") .def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_) - .def(py::init()); + .def(py::init()); auto fns = mindspore::PybindDefineRegister::AllFuncs(); - for (auto& item : fns) { + for (auto &item : fns) { item.second(&m); } @@ -288,7 +288,7 @@ PYBIND11_MODULE(_c_expression, m) { }}); (void)py::class_>(m, "EventWriter_") - .def(py::init()) + .def(py::init()) .def("GetFileName", &EventWriter::GetFileName, "Get the file name.") .def("Open", &EventWriter::Open, "Open the write file.") .def("Write", &EventWriter::Write, "Write the serialize event.") diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc index d25a202afc..861fc0eda8 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/parse/data_converter.cc @@ -38,7 +38,7 @@ using Tensor = mindspore::tensor::Tensor; using TensorPtr = mindspore::tensor::TensorPtr; namespace { -bool ConvertTuple(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python tuple"; py::tuple tuple = obj.cast(); std::vector value_list; @@ -55,7 +55,7 @@ bool ConvertTuple(const py::object& obj, ValuePtr* const data, bool use_signatur return true; } -bool ConvertList(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python list"; py::list list = obj.cast(); @@ -72,7 +72,7 @@ bool ConvertList(const py::object& obj, ValuePtr* const data, bool use_signature return true; } -bool ConvertCellList(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting cell list"; py::sequence list = obj; std::vector value_list; @@ -88,7 +88,7 @@ bool ConvertCellList(const py::object& obj, ValuePtr* const data, bool use_signa return true; } -bool ConvertDict(const py::object& obj, ValuePtr* data, bool use_signature) { +bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { MS_LOG(DEBUG) << "Converting python dict"; py::dict dict_values = obj.cast(); @@ -109,14 +109,14 @@ bool ConvertDict(const py::object& obj, ValuePtr* data, bool use_signature) { return true; } -void ConvertNameSpace(const py::object& obj, ValuePtr* const data) { +void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting python module"; py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); } -void ConvertDataClass(py::object obj, ValuePtr* const data) { +void ConvertDataClass(py::object obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting dataclass"; // Maybe the obj is dataclass define auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); @@ -124,7 +124,7 @@ void ConvertDataClass(py::object obj, ValuePtr* const data) { *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); } -bool ConvertPrimitive(py::object obj, ValuePtr* const data, bool use_signature = false) { +bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting primitive object"; // need check the primitive is class type or instance @@ -155,7 +155,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr* const data, bool use_signature = return true; } -bool ConvertMetaFuncGraph(const py::object& obj, ValuePtr* const data, bool use_signature = false) { +bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; auto meta = obj.cast(); if (meta == nullptr) { @@ -170,7 +170,7 @@ bool ConvertMetaFuncGraph(const py::object& obj, ValuePtr* const data, bool use_ return true; } -bool ConvertDataType(const py::object& obj, ValuePtr* const data) { +bool ConvertDataType(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting type object"; auto typeptr = obj.cast(); if (typeptr == nullptr) { @@ -181,7 +181,7 @@ bool ConvertDataType(const py::object& obj, ValuePtr* const data) { return true; } -bool ConvertTensor(const py::object& obj, ValuePtr* const data) { +bool ConvertTensor(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting tensor object"; auto m_tensor = obj.cast(); @@ -193,7 +193,7 @@ bool ConvertTensor(const py::object& obj, ValuePtr* const data) { return true; } -bool ConvertOtherObj(py::object obj, ValuePtr* const data) { +bool ConvertOtherObj(py::object obj, ValuePtr *const data) { auto obj_type = data_converter::GetObjType(obj); MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { @@ -244,7 +244,7 @@ bool ConvertOtherObj(py::object obj, ValuePtr* const data) { } } // namespace -bool ConvertData(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { // check parameter valid if (data == nullptr) { MS_LOG(ERROR) << "Data is null pointer"; @@ -295,7 +295,7 @@ bool ConvertData(const py::object& obj, ValuePtr* const data, bool use_signature } // convert data to graph -FuncGraphPtr ConvertToFuncGraph(const py::object& obj, const std::string& python_mod_get_parse_method) { +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { std::vector results = data_converter::GetObjKey(obj); std::string obj_id = results[0] + python_mod_get_parse_method; std::string obj_key = results[1]; @@ -331,25 +331,25 @@ static std::unordered_map object_map_ = std::unordered_map> object_graphs_map_ = std::unordered_map>(); -void SetObjGraphValue(const std::string& obj_key, const FuncGraphPtr& data) { +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { object_graphs_map_[obj_key].push_back(data); MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); } -const std::unordered_map>& GetObjGraphs() { +const std::unordered_map> &GetObjGraphs() { MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); return object_graphs_map_; } -void CacheObjectValue(const std::string& obj_key, const Any& data) { object_map_[obj_key] = data; } -bool GetObjectValue(const std::string& obj_key, Any* const data) { +void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } +bool GetObjectValue(const std::string &obj_key, Any *const data) { if (object_map_.count(obj_key)) { *data = object_map_[obj_key]; return true; } return false; } -std::vector GetObjKey(const py::object& obj) { +std::vector GetObjKey(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); if (obj_tuple.size() != 2) { @@ -359,7 +359,7 @@ std::vector GetObjKey(const py::object& obj) { } // get obj detail type -ResolveTypeDef GetObjType(const py::object& obj) { +ResolveTypeDef GetObjType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto obj_type = ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); @@ -367,7 +367,7 @@ ResolveTypeDef GetObjType(const py::object& obj) { } // get class instance detail type -ClassInstanceTypeDef GetClassInstanceType(const py::object& obj) { +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto class_type = ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); @@ -375,14 +375,14 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object& obj) { } // check the object is Cell Instance -bool IsCellInstance(const py::object& obj) { +bool IsCellInstance(const py::object &obj) { auto class_type = GetClassInstanceType(obj); bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); return isCell; } // create the python class instance -py::object CreatePythonObject(const py::object& type, const py::tuple& params) { +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object obj; if (params.size() == 0) { @@ -395,7 +395,7 @@ py::object CreatePythonObject(const py::object& type, const py::tuple& params) { // Generate an appropriate name and set to graph debuginfo // character <> can not used in the dot file, so change to another symbol -void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name) { +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph->debug_info()); // set detail name info of function @@ -412,7 +412,7 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name) func_graph->debug_info()->set_full_name(oss.str()); } -ValuePtr PyDataToValue(const py::object& obj) { +ValuePtr PyDataToValue(const py::object &obj) { py::object to_convert = obj; if (py::hasattr(obj, "__parameter__")) { to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); @@ -431,7 +431,7 @@ void ClearObjectCache() { static std::unordered_map g_dataClassToClass = {}; // parse dataclass to mindspore Class type -ClassPtr ParseDataClass(const py::object& cls_obj) { +ClassPtr ParseDataClass(const py::object &cls_obj) { std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); std::string cls = cls_module + "." + cls_name; @@ -443,7 +443,7 @@ ClassPtr ParseDataClass(const py::object& cls_obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); ClassAttrVector attributes; py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); - for (auto& item : names) { + for (auto &item : names) { TypePtr type_value = item.second.cast(); MS_EXCEPTION_IF_NULL(type_value); MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; @@ -452,7 +452,7 @@ ClassPtr ParseDataClass(const py::object& cls_obj) { std::unordered_map methods_map; py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); - for (auto& item : methods) { + for (auto &item : methods) { std::string fun_name = item.first.cast(); py::object obj = py::cast(item.second); std::shared_ptr method_obj = std::make_shared(obj, fun_name); diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.h b/mindspore/ccsrc/pipeline/parse/data_converter.h index 658360bcee..a8918fa60c 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.h +++ b/mindspore/ccsrc/pipeline/parse/data_converter.h @@ -32,25 +32,25 @@ namespace mindspore { namespace parse { // data convert for parse namespace data_converter { -void CacheObjectValue(const std::string& obj_key, const Any& data); -bool GetObjectValue(const std::string& obj_key, Any* const data); +void CacheObjectValue(const std::string &obj_key, const Any &data); +bool GetObjectValue(const std::string &obj_key, Any *const data); -void SetObjGraphValue(const std::string& obj_key, const FuncGraphPtr& data); +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); -const std::unordered_map>& GetObjGraphs(); +const std::unordered_map> &GetObjGraphs(); -std::vector GetObjKey(const py::object& obj); -ResolveTypeDef GetObjType(const py::object& obj); -ClassInstanceTypeDef GetClassInstanceType(const py::object& obj); +std::vector GetObjKey(const py::object &obj); +ResolveTypeDef GetObjType(const py::object &obj); +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); -bool IsCellInstance(const py::object& obj); -py::object CreatePythonObject(const py::object& type, const py::tuple& params); -void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name); -ValuePtr PyDataToValue(const py::object& obj); +bool IsCellInstance(const py::object &obj); +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); +ValuePtr PyDataToValue(const py::object &obj); void ClearObjectCache(); } // namespace data_converter -ClassPtr ParseDataClass(const py::object& cls_obj); +ClassPtr ParseDataClass(const py::object &cls_obj); void CleanDataClassToClassMap(); diff --git a/mindspore/ccsrc/pipeline/parse/function_block.cc b/mindspore/ccsrc/pipeline/parse/function_block.cc index 423e76c1d8..156f727b9e 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/parse/function_block.cc @@ -28,21 +28,21 @@ namespace mindspore { namespace parse { -FunctionBlock::FunctionBlock(const Parser& parser) : parser_(parser) { +FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { func_graph_ = std::make_shared(); matured_ = false; } -void FunctionBlock::AddPrevBlock(const FunctionBlockPtr& block) { prev_blocks_.push_back(block.get()); } +void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } // write variable records the variable name to corresponding node -void FunctionBlock::WriteVariable(const std::string& var_name, const AnfNodePtr& node) { +void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { MS_LOG(DEBUG) << "" << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); vars_[var_name] = node; } // read variable from predecessors -AnfNodePtr FunctionBlock::ReadVariable(const std::string& var) { +AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { // get var node if it is found if (vars_.count(var)) { AnfNodePtr node = vars_[var]; @@ -82,7 +82,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string& var) { } // Resolve Ast operator node -AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object& op) { +AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { auto ast = parser_.ast(); MS_EXCEPTION_IF_NULL(ast); TraceGuard trace_guard(parser_.GetLocation(op)); @@ -105,7 +105,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { } // Make a resolve node for symbol string -AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string& value) { +AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { if (value.compare(0, strlen("self."), "self.") == 0) { auto start = value.find_first_of('.') + 1; if (start >= value.size()) { @@ -122,14 +122,14 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string& value) { return MakeResolve(name_space, symbol); } -AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string& value) { +AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); return MakeResolve(name_space, symbol); } -AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr& name_space, const SymbolPtr& resolve_symbol) { +AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , " << ((std::string)resolve_symbol->symbol()); ValueNodePtr module_node = NewValueNode(name_space); @@ -139,10 +139,10 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr& name_space, const Symb } // add input for the block's phi parameter -void FunctionBlock::SetPhiArgument(const ParameterPtr& phi) { +void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; - for (auto& pred : prev_blocks_) { + for (auto &pred : prev_blocks_) { MS_EXCEPTION_IF_NULL(pred); MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); AnfNodePtr arg_node = pred->ReadVariable(var); @@ -161,9 +161,9 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr& phi) { } } -AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string& var, const ParameterPtr& phi) { +AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { AnfNodePtr arg_node = nullptr; - for (auto& prev : prev_blocks_) { + for (auto &prev : prev_blocks_) { MS_EXCEPTION_IF_NULL(prev); AnfNodePtr temp_node = prev->ReadVariable(var); MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var @@ -204,7 +204,7 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string& var, const Parame // 2. it's costly to iterate the graph to replace the phi for each phi. // Args : // phi : This parameter node is functioning as a phi node. -void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { +void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { MS_EXCEPTION_IF_NULL(phi); std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); @@ -221,15 +221,15 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { removable_phis_[phi] = arg_node; // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. - for (auto& prev : prev_blocks_) { + for (auto &prev : prev_blocks_) { MS_EXCEPTION_IF_NULL(prev); if (!prev->matured_) { continue; } - for (auto& phi_iter : prev->removable_phis_) { + for (auto &phi_iter : prev->removable_phis_) { MS_EXCEPTION_IF_NULL(phi_iter.second); if (phi_iter.second->isa()) { - const auto& param = phi_iter.second->cast(); + const auto ¶m = phi_iter.second->cast(); if (param == phi) { MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); @@ -243,8 +243,8 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { // A block should be marked matured if its predecessor blocks have been processed void FunctionBlock::Mature() { - const auto& graphParamVec = func_graph_->parameters(); - for (auto& paramItr : graphParamVec) { + const auto &graphParamVec = func_graph_->parameters(); + for (auto ¶mItr : graphParamVec) { MS_EXCEPTION_IF_NULL(paramItr); ParameterPtr param = paramItr->cast(); if (phi_nodes_.find(param) != phi_nodes_.cend()) { @@ -255,7 +255,7 @@ void FunctionBlock::Mature() { } // Force the conditIon node to bool using bool operation -CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr& cond) { +CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { TraceManager::DebugTrace(std::make_shared(cond->debug_info())); CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); TraceManager::EndTrace(); @@ -263,7 +263,7 @@ CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr& cond) { } // Perform a jump from this block to target block -void FunctionBlock::Jump(const FunctionBlockPtr& target_block, AnfNodePtr node) { +void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) { if (func_graph()->get_return() != nullptr) { MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); @@ -283,8 +283,8 @@ void FunctionBlock::Jump(const FunctionBlockPtr& target_block, AnfNodePtr node) // Perform a conditional jump using switch operation. // The first CNode select graph with condition, and than execute this graph -void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& true_block, - const FunctionBlockPtr& false_block) { +void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block) { if (func_graph()->get_return() != nullptr) { MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); @@ -297,15 +297,15 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& InsertDependItemsBeforeReturn(); } -void FunctionBlock::SetStateAssgin(const AnfNodePtr& target, const std::string& readid) { +void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { state_assign_[target] = readid; } -void FunctionBlock::AddAutoDepend(const AnfNodePtr& target) { auto_depends_.push_back(target); } +void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } void FunctionBlock::InsertDependItemsBeforeReturn() { if (!prev_blocks_.empty()) { - for (auto& prev_block : prev_blocks_) { + for (auto &prev_block : prev_blocks_) { MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); } } @@ -324,14 +324,14 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { AnfNodePtr state = nullptr; std::vector vec_states; vec_states.emplace_back(make_tuple_op); - for (auto& item : state_assign_) { + for (auto &item : state_assign_) { auto source = ReadVariable(item.second); auto refkey = func_graph()->NewCNode({get_refkey_op, item.first}); auto assign = func_graph()->NewCNode({assign_op, refkey, source}); MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; vec_states.emplace_back(assign); } - for (auto& item : auto_depends_) { + for (auto &item : auto_depends_) { MS_LOG(DEBUG) << "auto_depends " << item->ToString(); vec_states.emplace_back(item); } diff --git a/mindspore/ccsrc/pipeline/parse/function_block.h b/mindspore/ccsrc/pipeline/parse/function_block.h index 0be6e472f8..e7842903ee 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/parse/function_block.h @@ -43,47 +43,47 @@ using FunctionBlockPtr = std::shared_ptr; // the original source code. class FunctionBlock : public std::enable_shared_from_this { public: - explicit FunctionBlock(const Parser& parser); + explicit FunctionBlock(const Parser &parser); virtual ~FunctionBlock() {} FuncGraphPtr func_graph() { return func_graph_; } - void WriteVariable(const std::string& var_name, const AnfNodePtr& node); - AnfNodePtr ReadVariable(const std::string& var_name); - void AddPrevBlock(const FunctionBlockPtr& block); - void SetPhiArgument(const ParameterPtr& phi); - void CollectRemovablePhi(const ParameterPtr& phi); + void WriteVariable(const std::string &var_name, const AnfNodePtr &node); + AnfNodePtr ReadVariable(const std::string &var_name); + void AddPrevBlock(const FunctionBlockPtr &block); + void SetPhiArgument(const ParameterPtr &phi); + void CollectRemovablePhi(const ParameterPtr &phi); // A block is matured if all its predecessors is generated void Mature(); - CNodePtr ForceToBoolNode(const AnfNodePtr& cond); - void Jump(const FunctionBlockPtr& block, AnfNodePtr node); - AnfNodePtr SearchReplaceNode(const std::string& var, const ParameterPtr& phi); - void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& trueBlock, const FunctionBlockPtr& falseBlock); + CNodePtr ForceToBoolNode(const AnfNodePtr &cond); + void Jump(const FunctionBlockPtr &block, AnfNodePtr node); + AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); + void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock); // record the assign statement of self.xx weight parameter ,which will use state_setitem op - void SetStateAssgin(const AnfNodePtr& target, const std::string& readid); - void AddAutoDepend(const AnfNodePtr& target); + void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); + void AddAutoDepend(const AnfNodePtr &target); void InsertDependItemsBeforeReturn(); - void AddGlobalVar(const std::string& var_name) { (void)global_vars_.insert(var_name); } - bool IsGlobalVar(const std::string& var_name) { return global_vars_.find(var_name) != global_vars_.end(); } - AnfNodePtr MakeResolveAstOp(const py::object& op); + void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } + bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } + AnfNodePtr MakeResolveAstOp(const py::object &op); AnfNodePtr MakeResolveClassMember(std::string attr); - AnfNodePtr MakeResolveSymbol(const std::string& value); - AnfNodePtr MakeResolveOperation(const std::string& value); - AnfNodePtr MakeResolve(const std::shared_ptr& name_space, const std::shared_ptr& resolve_symbol); - const std::unordered_map& removable_phis() const { return removable_phis_; } + AnfNodePtr MakeResolveSymbol(const std::string &value); + AnfNodePtr MakeResolveOperation(const std::string &value); + AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); + const std::unordered_map &removable_phis() const { return removable_phis_; } private: // block graph FuncGraphPtr func_graph_; // the block's parser - const Parser& parser_; + const Parser &parser_; // A block is matured if all its prev_blocks is processed bool matured_; // store the nest-level block // refer to comments in Parser::func_block_list_; - std::vector prev_blocks_; + std::vector prev_blocks_; // store args and variable's node std::map vars_; @@ -93,7 +93,7 @@ class FunctionBlock : public std::enable_shared_from_this { // jumps map the successor block and the function call that perform jump // refer to comments in Parser::func_block_list_ that how to break the cyclic reference - std::map jumps_; + std::map jumps_; // keeps all removable phis which will be removed in one pass. std::unordered_map removable_phis_; diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index 51c4fc17ec..22d6fc9049 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -109,6 +109,7 @@ void Parser::BuildMethodMap() { expr_method_map_["Index"] = &Parser::ParseIndex; expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; expr_method_map_["Dict"] = &Parser::ParseDict; + expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; } void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } @@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, namelist_for_default_value.push_back(arg_name); if (py::isinstance(defaults[i])) { - default_values.push_back(NewValueNode(kNullObj)); + default_values.push_back(NewValueNode(kNull)); } else { default_values.push_back(ParseExprNode(block, defaults[i])); } @@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { return NewValueNode(kNone); } +AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { + MS_LOG(DEBUG) << "Process ast Ellipsis"; + return NewValueNode(kEllipsis); +} + AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast Num"; py::object obj = python_adapter::GetPyObjAttr(node, "n"); diff --git a/mindspore/ccsrc/pipeline/parse/parse.h b/mindspore/ccsrc/pipeline/parse/parse.h index 4dd1bc62aa..be6b09600c 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.h +++ b/mindspore/ccsrc/pipeline/parse/parse.h @@ -92,6 +92,8 @@ class Parser { AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); // process NoneType AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); + // process Ellipsis + AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); // process a integer or float number AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); // process a string variable diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h index df2d1968a5..aad8be0d6e 100644 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/parse/parse_base.h @@ -128,15 +128,15 @@ enum ClassInstanceTypeDef { }; // Convert python object to ValuePtr -bool ConvertData(const py::object& obj, ValuePtr* data, bool use_signature = false); +bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false); // Convert python obj to graph -FuncGraphPtr ConvertToFuncGraph(const py::object& obj, - const std::string& python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); // Parse the python object to graph -FuncGraphPtr ParsePythonCode(const py::object& obj, - const std::string& python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +FuncGraphPtr ParsePythonCode(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.cc b/mindspore/ccsrc/pipeline/parse/python_adapter.cc index e2c86164d4..df2f7d0d45 100644 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.cc +++ b/mindspore/ccsrc/pipeline/parse/python_adapter.cc @@ -32,7 +32,7 @@ void set_use_signature_in_resolve(bool use_signature) noexcept { use_signature_i bool UseSignatureInResolve() { return use_signature_in_resolve_; } void set_python_env_flag(bool python_env) noexcept { python_env_ = python_env; } bool IsPythonEnv() { return python_env_; } -void SetPythonPath(const std::string& path) { +void SetPythonPath(const std::string &path) { // load the python module path (void)python_adapter::set_python_scoped(); py::module sys = py::module::import("sys"); @@ -62,7 +62,7 @@ std::shared_ptr set_python_scoped() { } // return the module of python -py::module GetPyModule(const std::string& module) { +py::module GetPyModule(const std::string &module) { if (!module.empty()) { return py::module::import(module.c_str()); } else { @@ -71,7 +71,7 @@ py::module GetPyModule(const std::string& module) { } // Get the obj of attr -py::object GetPyObjAttr(const py::object& obj, const std::string& attr) { +py::object GetPyObjAttr(const py::object &obj, const std::string &attr) { if (!attr.empty() && !py::isinstance(obj)) { if (py::hasattr(obj, attr.c_str())) { return obj.attr(attr.c_str()); @@ -81,7 +81,7 @@ py::object GetPyObjAttr(const py::object& obj, const std::string& attr) { return py::none(); } -py::object GetPyFn(const std::string& module, const std::string& name) { +py::object GetPyFn(const std::string &module, const std::string &name) { (void)python_adapter::set_python_scoped(); if (!module.empty() && !name.empty()) { py::module mod = py::module::import(module.c_str()); diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.h b/mindspore/ccsrc/pipeline/parse/python_adapter.h index 12cfc27186..98adcd4f73 100644 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.h +++ b/mindspore/ccsrc/pipeline/parse/python_adapter.h @@ -31,10 +31,10 @@ namespace mindspore { namespace parse { // A utility to call python interface namespace python_adapter { -py::module GetPyModule(const std::string& module); -py::object GetPyObjAttr(const py::object& obj, const std::string& attr); +py::module GetPyModule(const std::string &module); +py::object GetPyObjAttr(const py::object &obj, const std::string &attr); template -py::object CallPyObjMethod(const py::object& obj, const std::string& method, T... args) { +py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) { if (!method.empty() && !py::isinstance(obj)) { return obj.attr(method.c_str())(args...); } @@ -43,7 +43,7 @@ py::object CallPyObjMethod(const py::object& obj, const std::string& method, T.. // call python function of module template -py::object CallPyModFn(const py::module& mod, const std::string& function, T... args) { +py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) { if (!function.empty() && !py::isinstance(mod)) { return mod.attr(function.c_str())(args...); } @@ -57,12 +57,12 @@ bool UseSignatureInResolve(); std::shared_ptr set_python_scoped(); void ResetPythonScope(); bool IsPythonEnv(); -void SetPythonPath(const std::string& path); +void SetPythonPath(const std::string &path); void set_python_env_flag(bool python_env) noexcept; -py::object GetPyFn(const std::string& module, const std::string& name); +py::object GetPyFn(const std::string &module, const std::string &name); // Call the python function template -py::object CallPyFn(const std::string& module, const std::string& name, T... args) { +py::object CallPyFn(const std::string &module, const std::string &name, T... args) { (void)set_python_scoped(); if (!module.empty() && !name.empty()) { py::module mod = py::module::import(module.c_str()); diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index f90fc5039c..284512c943 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -71,7 +71,7 @@ bool SymbolResolver::Resolve() { namespace { // argument obj should be python Parameter object // it will be converted to Parameter node here -AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& obj) { +AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { MS_EXCEPTION_IF_NULL(func_graph); // parameter object should not be none @@ -128,7 +128,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& } } -bool ResolveObjectToNode(const FuncGraphPtr& func_graph, const py::object& obj, AnfNodePtr* const node) { +bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { AnfNodePtr output = nullptr; if (py::hasattr(obj, "__parameter__")) { auto param = ResolveParameterObj(func_graph, obj); @@ -171,12 +171,12 @@ bool ResolveObjectToNode(const FuncGraphPtr& func_graph, const py::object& obj, } // transform the ValueTuple or ValueList of graph node to make tuple of const graph node -bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const AnfNodePtr& node, - const ValueNodePtr& value_node, AnfNodePtr* const transformed) { +bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, + const ValueNodePtr &value_node, AnfNodePtr *const transformed) { MS_EXCEPTION_IF_NULL(value_node); - const auto& value_vec = GetValue>(value_node->value()); + const auto &value_vec = GetValue>(value_node->value()); bool has_graph_in_list = false; - for (auto& elemv : value_vec) { + for (auto &elemv : value_vec) { MS_EXCEPTION_IF_NULL(elemv); if (elemv->isa()) { FuncGraphPtr new_fg = elemv->cast(); @@ -196,10 +196,10 @@ bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const Anf auto make_list_op = NewValueNode(prim::kPrimMakeTuple); list_vec.emplace_back(make_list_op); (void)std::transform(std::begin(value_vec), std::end(value_vec), std::back_inserter(list_vec), - [](const ValuePtr& value) { return NewValueNode(value); }); + [](const ValuePtr &value) { return NewValueNode(value); }); FuncGraphPtr cnode_graph = nullptr; auto users = manager->node_users()[node]; - for (auto& use : users) { + for (auto &use : users) { auto use_node = use.first; MS_EXCEPTION_IF_NULL(use_node); if (use_node->isa()) { @@ -220,8 +220,8 @@ bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const Anf } } // namespace -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& name_space, const SymbolPtr& symbol, - const AnfNodePtr& node) { +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node) { if (node->func_graph() == nullptr || manager == nullptr) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; } @@ -253,7 +253,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& } namespace { -opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib& irpass) { +opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap map({ {"resolve", { @@ -266,7 +266,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib& ir } } // namespace -bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBasePtr& res, bool use_profile) { +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { if (func_graph == nullptr || res == nullptr) { MS_LOG(ERROR) << "func_graph or resource is null"; return false; @@ -282,7 +282,7 @@ bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBa return true; } -bool ResolveAll(const FuncGraphManagerPtr& manager) { +bool ResolveAll(const FuncGraphManagerPtr &manager) { if (manager == nullptr) { MS_LOG(ERROR) << "func graph manager is null"; return false; @@ -301,7 +301,7 @@ bool ResolveAll(const FuncGraphManagerPtr& manager) { res->set_manager(manager); auto roots = manager->roots(); - for (auto& fg : roots) { + for (auto &fg : roots) { bool ret = ResolveFuncGraph(fg, res, false); if (!ret) { MS_EXCEPTION_IF_NULL(fg); diff --git a/mindspore/ccsrc/pipeline/parse/resolve.h b/mindspore/ccsrc/pipeline/parse/resolve.h index ccc22c72dc..acabfaf54b 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/parse/resolve.h @@ -39,7 +39,7 @@ namespace parse { // NameSpace class for resolving python code. class NameSpace : public Named { public: - NameSpace(const std::string& module, const py::object& obj) : Named(module), module_(module), obj_(obj) {} + NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {} ~NameSpace() override = default; MS_DECLARE_PARENT(NameSpace, Named); @@ -60,8 +60,8 @@ using NameSpacePtr = std::shared_ptr; // Symbol in NameSpace or Class which shall be resolved. class Symbol : public Named { public: - explicit Symbol(const std::string& symbol) : Named(symbol), symbol_(symbol) {} - explicit Symbol(const std::string& symbol, const std::string& name) : Named(name), symbol_(symbol) {} + explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} + explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} ~Symbol() override = default; MS_DECLARE_PARENT(Symbol, Named); @@ -79,7 +79,7 @@ using SymbolPtr = std::shared_ptr; // PyObjectWrapper class wrappers resolved python object for further processing. class PyObjectWrapper : public Named { public: - explicit PyObjectWrapper(const py::object& obj, const std::string name = "Python object") : Named(name), obj_(obj) {} + explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} ~PyObjectWrapper() override = default; MS_DECLARE_PARENT(PyObjectWrapper, Named); py::object obj() { return obj_; } @@ -92,7 +92,7 @@ class PyObjectWrapper : public Named { // ClassObject class wrappers dataclass class ClassObject : public PyObjectWrapper { public: - explicit ClassObject(const py::object& obj, const std::string name = "Python dataclass") + explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") : PyObjectWrapper(obj, name) {} ~ClassObject() override = default; MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); @@ -102,7 +102,7 @@ class ClassObject : public PyObjectWrapper { // ClassType class wrappers class name in python class ClassType : public PyObjectWrapper { public: - explicit ClassType(const py::object& obj, const std::string name = "Python class type") + explicit ClassType(const py::object &obj, const std::string name = "Python class type") : PyObjectWrapper(obj, name) {} ~ClassType() override = default; MS_DECLARE_PARENT(ClassType, PyObjectWrapper); @@ -112,7 +112,7 @@ class ClassType : public PyObjectWrapper { // SymbolResolver class for resolving symbol extracted from AnfNode. class SymbolResolver { public: - SymbolResolver(const NameSpacePtr& name_space, const SymbolPtr& symbol, const AnfNodePtr& node) + SymbolResolver(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) : namespace_(name_space), symbol_(symbol), resolved_node_(node) {} ~SymbolResolver() = default; @@ -124,7 +124,7 @@ class SymbolResolver { SymbolPtr symbol() { return symbol_; } - py::object& result() { return result_; } + py::object &result() { return result_; } AnfNodePtr resolved_node() { return resolved_node_; } @@ -141,15 +141,15 @@ class SymbolResolver { }; using SymbolResolverPtr = std::shared_ptr; // Resolve symbol in namespace. -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& name_space, const SymbolPtr& symbol, - const AnfNodePtr& node); +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node); // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). -bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBasePtr& res, bool use_profile = true); +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); // Resolve all graphs in manager which is defined outside of pipeline::Resource. // Mainly used for test cases or resolve graphs which will not be managed by manager. -bool ResolveAll(const FuncGraphManagerPtr& manager); +bool ResolveAll(const FuncGraphManagerPtr &manager); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index b3eda4c37b..6cdf641443 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -48,7 +48,7 @@ using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; using mindspore::validator::Validate; -bool SimplifyDataStructuresPass(const ResourcePtr& res) { +bool SimplifyDataStructuresPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); @@ -57,7 +57,7 @@ bool SimplifyDataStructuresPass(const ResourcePtr& res) { abstract::AbstractBasePtrList args_spec; auto parameters = func_graph->parameters(); (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), - [](const AnfNodePtr& p) -> AbstractBasePtr { return p->abstract(); }); + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); res->set_func_graph(new_fg); res->set_args_spec(args_spec); @@ -65,7 +65,7 @@ bool SimplifyDataStructuresPass(const ResourcePtr& res) { } namespace { -OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ irpass.switch_simplify_, @@ -133,7 +133,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { return map_a; } -OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_1 = opt::OptPassConfig({ irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, @@ -157,7 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib& irpass) { return map; } -OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}); OptPassGroupMap map({ {"control_group", control_group}, @@ -173,7 +173,7 @@ OptPassGroupMap GetInferenceOptPreparePhases() { return prepare_map; } -OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); OptPassGroupMap map({{"prepare_group", prepare_group}}); return map; @@ -181,7 +181,7 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { static std::unordered_map> g_pass_opts = {}; -void InitOpt(const ResourcePtr& res) { +void InitOpt(const ResourcePtr &res) { if (g_pass_opts.size() == 0) { opt::irpass::OptimizeIRPassLib irpass; g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); @@ -193,13 +193,13 @@ void InitOpt(const ResourcePtr& res) { } // namespace void ReclaimOptimizer() { - for (auto& opt : g_pass_opts) { + for (auto &opt : g_pass_opts) { opt.second = nullptr; } g_pass_opts.clear(); } -bool OptPassGroup(const ResourcePtr& res, const std::string& name) { +bool OptPassGroup(const ResourcePtr &res, const std::string &name) { if (res->func_graph() == nullptr) { MS_LOG(ERROR) << "Opt passes int error"; return false; @@ -216,12 +216,12 @@ bool OptPassGroup(const ResourcePtr& res, const std::string& name) { return true; } -bool OptPassAGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_a"); } -bool OptPassBGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_b"); } -bool ControlGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_control"); } -bool PrepareGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_prepare"); } +bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } +bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } +bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } +bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } -bool AddControlDependPass(const ResourcePtr& res) { +bool AddControlDependPass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -237,7 +237,7 @@ bool AddControlDependPass(const ResourcePtr& res) { return true; } -bool CconvPass(const ResourcePtr& res) { +bool CconvPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr new_fg = LiftingClone(func_graph); @@ -245,14 +245,14 @@ bool CconvPass(const ResourcePtr& res) { return true; } -bool ValidatePass(const ResourcePtr& res) { +bool ValidatePass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); Validate(func_graph); return true; } -bool InferenceOptPreparePass(const ResourcePtr& res) { +bool InferenceOptPreparePass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); abstract::AbstractBasePtrList args_spec = res->args_spec(); diff --git a/mindspore/ccsrc/pipeline/pass.h b/mindspore/ccsrc/pipeline/pass.h index 3731d7e524..2636879d01 100644 --- a/mindspore/ccsrc/pipeline/pass.h +++ b/mindspore/ccsrc/pipeline/pass.h @@ -30,11 +30,11 @@ using PassItem = std::pair>; extern std::vector kGePasses; extern std::vector kVmPasses; -bool CconvPass(const ResourcePtr& res); -bool ValidatePass(const ResourcePtr& res); -bool ConvertPrepareAdapt(const ResourcePtr& res); -bool AddControlDependPass(const ResourcePtr& res); -bool InferenceOptPreparePass(const ResourcePtr& res); +bool CconvPass(const ResourcePtr &res); +bool ValidatePass(const ResourcePtr &res); +bool ConvertPrepareAdapt(const ResourcePtr &res); +bool AddControlDependPass(const ResourcePtr &res); +bool InferenceOptPreparePass(const ResourcePtr &res); void ReclaimOptimizer(); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index cd4fe28db9..fca105d13c 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -67,7 +67,7 @@ std::unordered_map& defaults) { +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); abstract::AbstractBasePtrList args_spec; @@ -147,7 +147,7 @@ py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple i ExecutorPy::ExecutorPy() {} -ResourcePtr ExecutorPy::GetResource(const std::string& phase) { +ResourcePtr ExecutorPy::GetResource(const std::string &phase) { MS_LOG(DEBUG) << "Phase size:" << info_.size(); if (info_.count(phase) == 0) { return nullptr; @@ -155,21 +155,21 @@ ResourcePtr ExecutorPy::GetResource(const std::string& phase) { return info_[phase]->resource; } -FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string& phase) { +FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { if (info_.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } return info_[phase]->func_graph; } -std::size_t ExecutorPy::ArgListSize(const std::string& phase) { +std::size_t ExecutorPy::ArgListSize(const std::string &phase) { if (info_.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } return info_[phase]->arg_list_size; } -compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string& phase) { +compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { ResourcePtr res = GetResource(phase); MS_EXCEPTION_IF_NULL(res); if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is()) { @@ -179,17 +179,17 @@ compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string& phase) { return nullptr; } -bool ExecutorPy::HasCompiled(const std::string& phase) const { +bool ExecutorPy::HasCompiled(const std::string &phase) const { if (info_.count(phase) == 0) { return false; } return true; } -py::bytes ExecutorPy::GetFuncGraphProto(const std::string& phase, const std::string& ir_type) { +py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) { FuncGraphPtr fg_ptr = GetFuncGraph(phase); if (fg_ptr == nullptr) { - for (auto& item : info_) { + for (auto &item : info_) { MS_LOG(DEBUG) << "Phase key is: " << item.first; } MS_LOG(EXCEPTION) << "Can not find func graph " << phase; @@ -214,34 +214,34 @@ py::bytes ExecutorPy::GetFuncGraphProto(const std::string& phase, const std::str MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; } -py::dict ExecutorPy::GetParameterLayout(const std::string& phase) { +py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { MS_LOG(DEBUG) << "GetParameterLayout!"; std::string layout_graph = phase + kStepParallelGraph; auto graph = GetFuncGraph(layout_graph); return mindspore::parallel::GetParameterLayout(graph); } -py::dict ExecutorPy::GetCNodeStrategy(const std::string& phase) { +py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { MS_LOG(DEBUG) << "GetCNodeStrategy!"; std::string layout_graph = phase + kStepParallelGraph; auto graph = GetFuncGraph(layout_graph); return mindspore::parallel::GetCNodeStrategy(graph); } -py::dict ExecutorPy::GetAllreduceFusion(const std::string& phase) { +py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { MS_LOG(INFO) << "GetAllreduceFusion!"; auto graph = GetFuncGraph(phase); return mindspore::parallel::GetAllreduceFusion(graph); } -void ExecutorPy::DelNetRes(const std::string& id) { +void ExecutorPy::DelNetRes(const std::string &id) { #ifdef ENABLE_GE FinalizeGe(); #endif if (executor_ != nullptr) { bool flag = false; auto tmp_info = info_; - for (auto& item : tmp_info) { + for (auto &item : tmp_info) { if (item.first.find(id) != string::npos) { MS_LOG(INFO) << "Delete network res:" << item.first; (void)info_.erase(item.first); @@ -271,7 +271,7 @@ ExecutorPy::~ExecutorPy() { ConfigManager::GetInstance().ResetConfig(); } -void ExecutorPy::SaveCompiledGraph(const std::string& phase_s) { +void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { // save the graph to ExecutorPy FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -294,7 +294,7 @@ void ExecutorPy::SaveCompiledGraph(const std::string& phase_s) { MS_LOG(INFO) << "End save compiled func graph!"; } -bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string& phase_s) const { +bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { std::string phase_prefix = GetPhasePrefix(phase_s); if (use_vm && phase_prefix == "export") { @@ -313,7 +313,7 @@ void ExecutorPy::GetGeBackendPolicy() const { } } -bool ExecutorPy::CompileInner(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm) { +bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { MS_LOG(DEBUG) << "Start ExecutorPy compile!"; if ((!py::isinstance(phase))) { MS_LOG(ERROR) << "Arg phase must be string."; @@ -376,7 +376,7 @@ bool ExecutorPy::CompileInner(const py::object& obj, const py::tuple& args, cons return true; } -void ExecutorPy::ReleaseResource(const py::object& phase) { +void ExecutorPy::ReleaseResource(const py::object &phase) { ResourcePtr res = GetResource(py::cast(phase)); if (res != nullptr) { res->Clean(); @@ -385,18 +385,18 @@ void ExecutorPy::ReleaseResource(const py::object& phase) { ReclaimOptimizer(); } -static std::string PrintArgs(const py::tuple& args) { +static std::string PrintArgs(const py::tuple &args) { py::print(args); return ""; } -bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm) { +bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { bool ret_value = false; try { MS_LOG(DEBUG) << PrintArgs(args); ret_value = CompileInner(obj, args, phase, use_vm); - } catch (const py::error_already_set& ex) { + } catch (const py::error_already_set &ex) { // print function call stack info before release std::ostringstream oss; trace::TraceGraphInfer(); @@ -409,13 +409,13 @@ bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py: // re-throw this exception to Python interpreter to handle it throw(py::error_already_set(ex)); - } catch (const py::type_error& ex) { + } catch (const py::type_error &ex) { ReleaseResource(phase); throw py::type_error(ex); - } catch (const py::value_error& ex) { + } catch (const py::value_error &ex) { ReleaseResource(phase); throw py::value_error(ex); - } catch (const std::exception& ex) { + } catch (const std::exception &ex) { ReleaseResource(phase); // re-throw this exception to Python interpreter to handle it throw(std::runtime_error(ex.what())); @@ -432,7 +432,7 @@ bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py: // get MindSpore Intermediate Representation File std::string GetMsIrFile(void) { std::string file; - const char* path = getenv("MS_IR_FILE"); + const char *path = getenv("MS_IR_FILE"); if (path == nullptr) { return file; } @@ -446,7 +446,7 @@ std::string GetMsIrFile(void) { return file; } -void RunPipelineAction(const ActionItem& action, pipeline::ResourcePtr resource, bool* result) { +void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) { MS_EXCEPTION_IF_NULL(resource); MS_EXCEPTION_IF_NULL(result); @@ -472,7 +472,7 @@ void RunPipelineAction(const ActionItem& action, pipeline::ResourcePtr resource, } auto manager = resource->manager(); MS_EXCEPTION_IF_NULL(manager); - for (auto& graph : graphs) { + for (auto &graph : graphs) { manager->AddFuncGraph(graph); } resource->set_func_graph(graphs[0]); @@ -491,9 +491,9 @@ void Pipeline::Run() { WITH(MsProfile::GetProfile())[&user_graph, this]() { int i = 0; - for (auto& action : actions_) { + for (auto &action : actions_) { #ifdef ENABLE_TIMELINE - DumpTime& dump_time = DumpTime::GetInstance(); + DumpTime &dump_time = DumpTime::GetInstance(); dump_time.Record(action.first, GetTime(), true); #endif bool result = true; @@ -575,7 +575,7 @@ void Pipeline::Run() { MS_LOG(INFO) << "End"; } -void ExecutorPy::ProcessVmArg(const py::tuple& args, const std::string& phase, VectorRef* arg_list) { +void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) { std::size_t size = args.size(); for (std::size_t i = 0; i < size; i++) { @@ -584,7 +584,7 @@ void ExecutorPy::ProcessVmArg(const py::tuple& args, const std::string& phase, V if (ms_context->backend_policy() == kMsConvert && py::isinstance(arg)) { MS_LOG(EXCEPTION) << "Args[" << i << "] is numpy array, not tensor"; } - (*arg_list).push_back(arg); + arg_list->push_back(arg); } ResourcePtr res = GetResource(phase); @@ -604,7 +604,7 @@ void ExecutorPy::ProcessVmArg(const py::tuple& args, const std::string& phase, V } } -py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { +py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { std::size_t size = args.size(); if (!py::isinstance(phase)) { MS_LOG(EXCEPTION) << "Run failed, phase input is not a str"; @@ -649,8 +649,8 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { return BaseRefToPyData(value); } -FuncGraphPtr ExecutorPy::BuildGraph(const py::dict& init_params, const std::string& phase, - const py::object& broadcast_params) { +FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params) { #if (ENABLE_GE || ENABLE_D) return BuildDFGraph(info_, init_params, phase, broadcast_params); #else @@ -658,15 +658,15 @@ FuncGraphPtr ExecutorPy::BuildGraph(const py::dict& init_params, const std::stri #endif } -void ExecutorPy::RunInitGraph(const py::dict& init_params, const std::string& phase) { +void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { #if ENABLE_GE RunGEInitGraph(init_params, phase); #endif } -bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase) { +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { std::string name = MsContext::GetInstance()->backend_policy(); if (name == kMsConvert || name == kMsVm) { return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes); @@ -682,16 +682,16 @@ bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t ba return false; } -bool InitExecDatasetVm(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes) { +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes) { MS_LOG(INFO) << "Start InitDataSet Entry"; std::vector int_input_indexes; (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), [](int64_t item) { return static_cast(item); }); std::vector> int_shapes; (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), - [](const std::vector& item) { + [](const std::vector &item) { std::vector vector_item; (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), [](int64_t inner_item) { return static_cast(inner_item); }); @@ -774,7 +774,7 @@ void FinalizeHccl() { #endif } -void ExportGraph(const std::string& file_name, const std::string&, const std::string& phase) { +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { #if (ENABLE_GE || ENABLE_D) ExportDFGraph(file_name, phase); #endif diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index a0d7a19198..865c961ac1 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -43,7 +43,7 @@ namespace py = pybind11; class Pipeline { public: - Pipeline(const ResourcePtr& res, const std::vector& actions) : resource_(res), actions_(actions) {} + Pipeline(const ResourcePtr &res, const std::vector &actions) : resource_(res), actions_(actions) {} ~Pipeline() = default; @@ -69,35 +69,35 @@ class ExecutorPy : public std::enable_shared_from_this { ~ExecutorPy(); - void SaveCompiledGraph(const std::string& phase_s); - bool CompileInner(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm); - bool Compile(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm); + void SaveCompiledGraph(const std::string &phase_s); + bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); + bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); - void ProcessVmArg(const py::tuple& args, const std::string& phase, VectorRef* arg_list); + void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list); // for pynative mode when use_vm is on - py::object Run(const py::tuple& args, const py::object& phase); - ResourcePtr GetResource(const std::string& phase); - FuncGraphPtr GetFuncGraph(const std::string& phase); - py::bytes GetFuncGraphProto(const std::string& phase, const std::string& type); - std::size_t ArgListSize(const std::string& phase); - compile::VmEvalFuncPtr GetVmEvalFunc(const std::string& phase); - bool HasCompiled(const std::string& phase) const; - - FuncGraphPtr BuildGraph(const py::dict& init_params, const std::string& phase, - const py::object& broadcast_params = {}); - void RunInitGraph(const py::dict& init_params, const std::string& phase); - py::dict GetParameterLayout(const std::string& phase); - py::dict GetCNodeStrategy(const std::string& phase); - py::dict GetAllreduceFusion(const std::string& phase); - void DelNetRes(const std::string& id); - void ReleaseResource(const py::object& phase); + py::object Run(const py::tuple &args, const py::object &phase); + ResourcePtr GetResource(const std::string &phase); + FuncGraphPtr GetFuncGraph(const std::string &phase); + py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); + std::size_t ArgListSize(const std::string &phase); + compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); + bool HasCompiled(const std::string &phase) const; + + FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params = {}); + void RunInitGraph(const py::dict &init_params, const std::string &phase); + py::dict GetParameterLayout(const std::string &phase); + py::dict GetCNodeStrategy(const std::string &phase); + py::dict GetAllreduceFusion(const std::string &phase); + void DelNetRes(const std::string &id); + void ReleaseResource(const py::object &phase); static void ClearRes(); private: ExecutorPy(); - void ConvertObjectToTensors(const py::dict& dict, std::map* tensors); - bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string& phase_s) const; + void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); + bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; void GetGeBackendPolicy() const; std::map info_; @@ -107,10 +107,10 @@ class ExecutorPy : public std::enable_shared_from_this { using ExecutorPyPtr = std::shared_ptr; // Generate a key for mapping function graph -py::tuple GenerateKey(const std::string& name, const std::unordered_map& defaults); +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults); py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs); -bool InitDistribute(const std::map& options); +bool InitDistribute(const std::map &options); void ResetOpId(); void InitHccl(); @@ -121,17 +121,17 @@ void FinalizeGe(); void ClearResAtexit(); void ReleaseGeTsd(); -void ExportGraph(const std::string& file_name, const std::string&, const std::string& phase); +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); // init and exec dataset sub graph -bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase); +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); // Build and run dataset subgraph for ms backend -bool InitExecDatasetVm(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes); +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.cc b/mindspore/ccsrc/pipeline/pipeline_ge.cc index 6ce0ea5316..1da85b5699 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/pipeline_ge.cc @@ -46,7 +46,7 @@ using mindspore::transform::MeTensorPtr; using mindspore::transform::Status; using mindspore::transform::TransformUtil; -void DoExecNonInputGraph(const std::string& phase) { +void DoExecNonInputGraph(const std::string &phase) { std::vector ge_tensors; std::vector ge_outputs; transform::RunOptions run_options; @@ -68,7 +68,7 @@ void DoExecNonInputGraph(const std::string& phase) { } } -void SetGeOption(const std::map& options) { +void SetGeOption(const std::map &options) { ConfigManager::GetInstance().set_ge_initialize_options(options); } @@ -108,11 +108,11 @@ Status CreateSessionAndGraphRunner(bool is_training = true) { return Status::SUCCESS; } -bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase) { +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { std::vector ge_types; - (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr& i) -> int64_t { + (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr &i) -> int64_t { return transform::TransformUtil::ConvertDataType(i->type_id()); }); @@ -145,7 +145,7 @@ bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batc return true; } -void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) { +void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) { for (auto item : dict) { if ((!py::isinstance(item.first))) { MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; @@ -156,11 +156,11 @@ void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) if (py::isinstance(item.second.attr("default_input"))) { // convert float to tensor with shape([1]) tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); - *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); + *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); } else if (py::isinstance(item.second.attr("default_input"))) { // convert int to tensor with shape([1]) tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); - *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); + *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { // cast tensor tensor = py::cast>(item.second.attr("default_input")); @@ -173,8 +173,8 @@ void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) } } -bool AddDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params) { +bool AddDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { FuncGraphPtr anf_graph = info.at(phase)->func_graph; DfGraphConvertor convertor(anf_graph); @@ -237,8 +237,8 @@ bool AddDFGraph(const std::map& info, const py::di return true; } -FuncGraphPtr BuildDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params) { +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { if (info.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } @@ -268,13 +268,13 @@ FuncGraphPtr BuildDFGraph(const std::map& info, co return anf_graph; } -void RunGEInitGraph(const py::dict& init_params, const std::string& phase) { +void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { MS_LOG(DEBUG) << "ExecInitGraph start."; TensorOrderMap inputs_with_name{}; ConvertObjectToTensors(init_params, &inputs_with_name); std::vector inputs; (void)std::transform(inputs_with_name.begin(), inputs_with_name.end(), std::back_inserter(inputs), - [](const std::pair& item) { return item.second; }); + [](const std::pair &item) { return item.second; }); std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); if (ge_tensors.size() != inputs.size()) { @@ -317,7 +317,7 @@ void RunGEInitGraph(const py::dict& init_params, const std::string& phase) { } } -py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) { +py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { MS_EXCEPTION_IF_NULL(cnode_data); if (*count >= data.size()) { MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() @@ -350,7 +350,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::t return std::move(tp); } -py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) { +py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) { MS_EXCEPTION_IF_NULL(output_node); if (output_node->isa()) { @@ -387,8 +387,8 @@ py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, return ExtractGeneralCnodeRet(output_c->abstract(), data, count); } -std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::vector& inputs, - const std::string& phase) { +std::shared_ptr DoExecGraph(const FuncGraphPtr &graph, const std::vector &inputs, + const std::string &phase) { std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); if (ge_tensors.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; @@ -438,8 +438,8 @@ std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::ve return ret; } -void ProcessGeArg(const std::map& info, const py::tuple& args, const std::string& phase, - std::vector* inputs) { +void ProcessGeArg(const std::map &info, const py::tuple &args, const std::string &phase, + std::vector *inputs) { // check the arg and use the ExecutorPy args std::size_t size = args.size(); @@ -462,7 +462,7 @@ void ProcessGeArg(const std::map& info, const py:: MS_LOG(EXCEPTION) << "Args convert error"; } if (converted->isa()) { - (*inputs).push_back(converted->cast()); + inputs->push_back(converted->cast()); } else { MS_LOG(EXCEPTION) << "Args " << converted->ToString() << " is not tensor"; } @@ -470,8 +470,8 @@ void ProcessGeArg(const std::map& info, const py:: } } -py::object ExecDFGraph(const std::map& info, const py::tuple& args, - const std::string& phase) { +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase) { std::string phase_prefix = GetPhasePrefix(phase); if (phase_prefix == "save") { @@ -514,7 +514,7 @@ py::object ExecDFGraph(const std::map& info, const MS_LOG(EXCEPTION) << "Exec graph failed"; } } -void ExportDFGraph(const std::string& file_name, const std::string& phase) { +void ExportDFGraph(const std::string &file_name, const std::string &phase) { MS_LOG(DEBUG) << "ExportGraph Begin"; transform::DfGraphWrapperPtr wrap_ptr = DfGraphManager::GetInstance().GetGraphByName(phase); if (wrap_ptr == nullptr) { diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.h b/mindspore/ccsrc/pipeline/pipeline_ge.h index c3779fd982..9dc1524682 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.h +++ b/mindspore/ccsrc/pipeline/pipeline_ge.h @@ -34,22 +34,22 @@ namespace pipeline { namespace py = pybind11; -void SetGeOption(const std::map& options); +void SetGeOption(const std::map &options); -void RunGEInitGraph(const py::dict& init_params, const std::string& phase); +void RunGEInitGraph(const py::dict &init_params, const std::string &phase); -py::object ExecDFGraph(const std::map& info, const py::tuple& args, - const std::string& phase = "train"); +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase = "train"); -FuncGraphPtr BuildDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params = {}); +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params = {}); // init and exec dataset sub graph for GE backend -bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase); +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); -void ExportDFGraph(const std::string& file_name, const std::string& phase); +void ExportDFGraph(const std::string &file_name, const std::string &phase); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc index 7937c3e55f..0b7401345a 100644 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc +++ b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc @@ -24,9 +24,9 @@ namespace mindspore { namespace pipeline { -void TryToDoReplace(FuncGraphManager* const manager, const AnfNodePtr& node, HashCache* const hash_cache, - HashValue* const hash_value) { - const auto& to_check_value = GetValueNode(node); +void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, + HashValue *const hash_value) { + const auto &to_check_value = GetValueNode(node); MS_EXCEPTION_IF_NULL(to_check_value); // Calculate hash value. @@ -46,14 +46,14 @@ void TryToDoReplace(FuncGraphManager* const manager, const AnfNodePtr& node, Has return; } - auto& bucket = bucket_iter->second; + auto &bucket = bucket_iter->second; // Check if need to replace node with value node already met. - for (const auto& v : bucket) { + for (const auto &v : bucket) { // Already met and cached. if (v == node) { return; } - const auto& existed_value = GetValueNode(v); + const auto &existed_value = GetValueNode(v); MS_EXCEPTION_IF_NULL(existed_value); auto equal = [&]() -> bool { if (existed_value->isa() && to_check_value->isa()) { diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/remove_value_node_dup.h index 8fbb3f2755..8f670c7dcf 100644 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.h +++ b/mindspore/ccsrc/pipeline/remove_value_node_dup.h @@ -27,7 +27,7 @@ namespace pipeline { using HashCache = std::unordered_map>; using HashValue = std::unordered_map; -void TryToDoReplace(FuncGraphManager* manager, const AnfNodePtr& node, HashCache* hash_cache, HashValue* hash_value); +void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/resource.cc b/mindspore/ccsrc/pipeline/resource.cc index 18695518be..50ccef2f44 100644 --- a/mindspore/ccsrc/pipeline/resource.cc +++ b/mindspore/ccsrc/pipeline/resource.cc @@ -32,7 +32,7 @@ namespace mindspore { // namespace to support opmap definition namespace pipeline { -MethodMap& GetMethodMap() { +MethodMap &GetMethodMap() { static MethodMap method_map = {{kObjectTypeString, { {"__bool__", std::string("str_bool")} // C.str_bool @@ -178,7 +178,7 @@ MethodMap& GetMethodMap() { return method_map; } -Resource::Resource(const py::object& obj) +Resource::Resource(const py::object &obj) : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), input_(obj), is_cleaned_(false) {} @@ -197,7 +197,7 @@ Resource::~Resource() { if (!is_cleaned_) { try { Clean(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what(); } catch (...) { MS_LOG(ERROR) << "Exception when cleaning resource."; @@ -205,9 +205,9 @@ Resource::~Resource() { } } -bool Resource::IsTypeInMethodMap(const TypeId& type) { +bool Resource::IsTypeInMethodMap(const TypeId &type) { TypeId type_id = NormalizeTypeId(type); - const MethodMap& method_map = GetMethodMap(); + const MethodMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter != method_map.end()) { return true; @@ -215,9 +215,9 @@ bool Resource::IsTypeInMethodMap(const TypeId& type) { return false; } -Any Resource::GetMethodPtr(const TypeId& type, const std::string& name) { +Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { TypeId type_id = NormalizeTypeId(type); - const MethodMap& method_map = GetMethodMap(); + const MethodMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter == method_map.end()) { MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; diff --git a/mindspore/ccsrc/pipeline/resource.h b/mindspore/ccsrc/pipeline/resource.h index 15ab70db14..0c1348fd94 100644 --- a/mindspore/ccsrc/pipeline/resource.h +++ b/mindspore/ccsrc/pipeline/resource.h @@ -46,7 +46,7 @@ class InferenceResource; using MethodMap = std::unordered_map>; -MethodMap& GetMethodMap(); +MethodMap &GetMethodMap(); class ResourceBase { public: @@ -56,20 +56,20 @@ class ResourceBase { FuncGraphManagerPtr manager() { return manager_; } // set a manager defined outside which will not manage the graphs. - void set_manager(const FuncGraphManagerPtr& manager) { manager_ = manager; } + void set_manager(const FuncGraphManagerPtr &manager) { manager_ = manager; } - std::unordered_map& results() { return results_; } + std::unordered_map &results() { return results_; } - void SetResult(const std::string& key, const Any& value) { results_[key] = value; } + void SetResult(const std::string &key, const Any &value) { results_[key] = value; } - Any GetResult(const std::string& key) { + Any GetResult(const std::string &key) { if (results_.count(key) == 0) { MS_LOG(EXCEPTION) << "this key is not in resource list:" << key; } return results_[key]; } - bool HasResult(const std::string& key) const { return results_.count(key) != 0; } + bool HasResult(const std::string &key) const { return results_.count(key) != 0; } std::unordered_map results_; @@ -81,23 +81,23 @@ using ResourceBasePtr = std::shared_ptr; class Resource : public ResourceBase { public: - explicit Resource(const py::object& obj = py::none()); + explicit Resource(const py::object &obj = py::none()); ~Resource() override; abstract::AnalysisEnginePtr engine() { return engine_; } - static bool IsTypeInMethodMap(const TypeId& type); + static bool IsTypeInMethodMap(const TypeId &type); - static Any GetMethodPtr(const TypeId& type, const std::string& name); + static Any GetMethodPtr(const TypeId &type, const std::string &name); - const py::object& input() const { return input_; } + const py::object &input() const { return input_; } FuncGraphPtr func_graph() const { return func_graph_; } - void set_func_graph(const FuncGraphPtr& func_graph) { func_graph_ = func_graph; } + void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } - const abstract::AbstractBasePtrList& args_spec() const { return args_spec_; } - void set_args_spec(const abstract::AbstractBasePtrList& args_spec) { args_spec_ = args_spec; } + const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } + void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } // Reclaim resource and clear the cache. // ExecutorPy::Compile() can be called multiple times, so cache diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc index 555a6d87c0..210257ea53 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc @@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const { std::string AbstractNull::ToString() const { std::ostringstream buffer; - buffer << type_name() << "(" - << "Value: " - << "Null" - << ")"; + buffer << type_name() << "(Value: Null)"; + return buffer.str(); +} + +bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; } + +bool AbstractEllipsis::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + if (other.isa()) { + auto other_none = static_cast(&other); + return *this == *other_none; + } else { + return false; + } +} + +std::string AbstractEllipsis::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "(Value: Ellipsis)"; return buffer.str(); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h index 9e0dd82003..7608d0bec7 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h @@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr; // the un assigned state value for variable, which means the variable is not assigned class AbstractNull : public AbstractBase { public: - AbstractNull() : AbstractBase(kNullObj) { set_type(std::make_shared()); } + AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared()); } ~AbstractNull() override = default; MS_DECLARE_PARENT(AbstractNull, AbstractBase) @@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase { }; using AbstractNullPtr = std::shared_ptr; +class AbstractEllipsis : public AbstractBase { + public: + AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared()); } + ~AbstractEllipsis() override = default; + MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + bool operator==(const AbstractEllipsis &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override; +}; +using AbstractEllipsisPtr = std::shared_ptr; + class AbstractRefKey : public AbstractBase { public: AbstractRefKey() : AbstractBase() { set_type(std::make_shared()); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc b/mindspore/ccsrc/pipeline/static_analysis/dshape.cc index 15aa71ba1e..183ec772ff 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/dshape.cc @@ -26,31 +26,31 @@ namespace mindspore { namespace abstract { // used for print BaseShape content -std::ostream& operator<<(std::ostream& os, const BaseShape& bs) { +std::ostream &operator<<(std::ostream &os, const BaseShape &bs) { os << bs.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr bs) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr bs) { MS_EXCEPTION_IF_NULL(bs); os << bs->ToString(); return os; } -bool BaseShape::operator==(const BaseShape& other) const { +bool BaseShape::operator==(const BaseShape &other) const { if (tid() != other.tid()) { return false; } return true; } -bool BaseShape::operator!=(const BaseShape& other) const { return !(*this == other); } +bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == other); } std::string Shape::ToString() const { std::ostringstream buffer; bool f_begin = true; buffer << "("; - for (auto& x : shape_) { + for (auto &x : shape_) { if (!f_begin) { buffer << ", "; } else { @@ -72,11 +72,11 @@ std::string Shape::DumpText() const { return buffer.str(); } -bool Shape::operator==(const BaseShape& other) const { +bool Shape::operator==(const BaseShape &other) const { if (tid() != other.tid()) { return false; } - return shape_ == static_cast(other).shape_; + return shape_ == static_cast(other).shape_; } const int Shape::SHP_ANY; @@ -111,11 +111,11 @@ BaseShapePtrList SequeueShape::ElementsClone() const { } template -bool SequeueShape::SequeueEqual(const BaseShape& other) const { +bool SequeueShape::SequeueEqual(const BaseShape &other) const { if (tid() != other.tid()) { return false; } - auto other_shapes = static_cast(other).p_shapes_; + auto other_shapes = static_cast(other).p_shapes_; if (other_shapes.size() != p_shapes_.size()) { return false; } @@ -126,8 +126,8 @@ bool SequeueShape::SequeueEqual(const BaseShape& other) const { } return true; } -template bool SequeueShape::SequeueEqual(const BaseShape&) const; -template bool SequeueShape::SequeueEqual(const BaseShape&) const; +template bool SequeueShape::SequeueEqual(const BaseShape &) const; +template bool SequeueShape::SequeueEqual(const BaseShape &) const; const std::shared_ptr kNoShape = std::make_shared(); } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.h b/mindspore/ccsrc/pipeline/static_analysis/dshape.h index 6debe061c8..3e850e309b 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.h +++ b/mindspore/ccsrc/pipeline/static_analysis/dshape.h @@ -41,8 +41,8 @@ class BaseShape : public Base { ~BaseShape() override = default; MS_DECLARE_PARENT(BaseShape, Base) - virtual bool operator==(const BaseShape& other) const; - bool operator!=(const BaseShape& other) const; + virtual bool operator==(const BaseShape &other) const; + bool operator!=(const BaseShape &other) const; std::size_t hash() const override { return tid(); } // return a deep copy @@ -62,16 +62,16 @@ class Shape : public BaseShape { public: static const int SHP_ANY = -1; Shape() : shape_() {} - Shape(const std::initializer_list& list) : shape_(list) {} - explicit Shape(const std::vector& list) : shape_(list) {} + Shape(const std::initializer_list &list) : shape_(list) {} + explicit Shape(const std::vector &list) : shape_(list) {} ~Shape() override = default; MS_DECLARE_PARENT(Shape, BaseShape) std::string ToString() const override; std::string DumpText() const override; - bool operator==(const BaseShape& other) const override; + bool operator==(const BaseShape &other) const override; BaseShapePtr Clone() const override { return std::make_shared(shape_); } void Broaden() override; - std::vector& shape() { return shape_; } + std::vector &shape() { return shape_; } std::vector shape_; // use SHP_ANY to implement the any shape in python }; @@ -81,7 +81,7 @@ using ShapePtrList = std::vector; class SequeueShape : public BaseShape { public: SequeueShape() : p_shapes_() {} - explicit SequeueShape(const BaseShapePtrList& shapes) : p_shapes_(shapes) {} + explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {} ~SequeueShape() override = default; MS_DECLARE_PARENT(SequeueShape, BaseShape) @@ -89,9 +89,9 @@ class SequeueShape : public BaseShape { BaseShapePtrList ElementsClone() const; template - bool SequeueEqual(const BaseShape& other) const; + bool SequeueEqual(const BaseShape &other) const; - const BaseShapePtrList& shape() const { return p_shapes_; } + const BaseShapePtrList &shape() const { return p_shapes_; } size_t size() const { return p_shapes_.size(); } const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } @@ -103,7 +103,7 @@ using SequeueShapePtr = std::shared_ptr; class TupleShape : public SequeueShape { public: TupleShape() : SequeueShape() {} - explicit TupleShape(const BaseShapePtrList& shapes) : SequeueShape(shapes) {} + explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} ~TupleShape() override = default; MS_DECLARE_PARENT(TupleShape, SequeueShape) @@ -111,14 +111,14 @@ class TupleShape : public SequeueShape { BaseShapePtr Clone() const override { return std::make_shared(ElementsClone()); } - bool operator==(const BaseShape& other) const override { return SequeueEqual(other); } + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } }; using TupleShapePtr = std::shared_ptr; class ListShape : public SequeueShape { public: ListShape() : SequeueShape() {} - explicit ListShape(const BaseShapePtrList& shapes) : SequeueShape(shapes) {} + explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} ~ListShape() override = default; MS_DECLARE_PARENT(ListShape, SequeueShape) @@ -126,7 +126,7 @@ class ListShape : public SequeueShape { BaseShapePtr Clone() const override { return std::make_shared(SequeueShape::ElementsClone()); } - bool operator==(const BaseShape& other) const override { return SequeueEqual(other); } + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } }; using ListShapePtr = std::shared_ptr; } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/validator.cc b/mindspore/ccsrc/pipeline/validator.cc index 0fe3218813..73a54bb180 100644 --- a/mindspore/ccsrc/pipeline/validator.cc +++ b/mindspore/ccsrc/pipeline/validator.cc @@ -39,7 +39,7 @@ using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractType; -void ValidateOperation(const AnfNodePtr& node) { +void ValidateOperation(const AnfNodePtr &node) { if (!IsValueNode(node)) { return; } @@ -60,7 +60,7 @@ void ValidateOperation(const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); } -void ValidateAbstract(const AnfNodePtr& node) { +void ValidateAbstract(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(WARNING) << "Node to validate is invalid"; return; @@ -105,11 +105,11 @@ void ValidateAbstract(const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); } -void Validate(const FuncGraphPtr& fg) { +void Validate(const FuncGraphPtr &fg) { FuncGraphManagerPtr mgr = Manage(fg, false); MS_EXCEPTION_IF_NULL(mgr); - AnfNodeSet& all_nodes = mgr->all_nodes(); - for (const auto& anf_node : all_nodes) { + AnfNodeSet &all_nodes = mgr->all_nodes(); + for (const auto &anf_node : all_nodes) { ValidateOperation(anf_node); ValidateAbstract(anf_node); } diff --git a/mindspore/ccsrc/pipeline/validator.h b/mindspore/ccsrc/pipeline/validator.h index 9944078e6c..61f7470349 100644 --- a/mindspore/ccsrc/pipeline/validator.h +++ b/mindspore/ccsrc/pipeline/validator.h @@ -29,9 +29,9 @@ namespace mindspore { namespace validator { -void Validate(const FuncGraphPtr& func_graph); -void ValidateAbstract(const AnfNodePtr& node); -void ValidateOperation(const AnfNodePtr& node); +void Validate(const FuncGraphPtr &func_graph); +void ValidateAbstract(const AnfNodePtr &node); +void ValidateOperation(const AnfNodePtr &node); } // namespace validator } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 7a35627e25..a2d82525e9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -45,6 +45,7 @@ #include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" #include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/pass/getitem_tuple.h" #include "pre_activate/pass/optimize_dependence.h" @@ -61,6 +62,7 @@ #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" #include "pre_activate/ascend/ir_fission/addn_fission.h" +#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -68,6 +70,35 @@ namespace mindspore { namespace opt { +namespace { +void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { + MS_EXCEPTION_IF_NULL(ir_fusion_pm); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); +} +} // namespace + void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); auto optimizer = std::make_shared(); @@ -113,6 +144,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); @@ -161,29 +193,13 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->ir_fusion_flag()) { - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); + } + + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); @@ -213,6 +229,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc index f6eb6aca64..b9a86f7bcb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc @@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ new_addn->set_scope(origin_addn_cnode->scope()); new_addn->set_abstract(origin_addn_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); return new_addn; } } // namespace @@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN } CNodePtr new_cnode = cnode; while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; size_t cur_input_index = 1; - // Divide the inputs of addn by 63. - while (origin_input_size - cur_input_index + 1 > inputs_divisor_) { + // Divide the inputs of addn by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); cur_input_index += inputs_divisor_; } - base_addn_inputs.push_back( - CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1)); - + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_addn_inputs.push_back(new_cnode->input(i)); + } CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); MS_EXCEPTION_IF_NULL(base_addn); - MS_EXCEPTION_IF_NULL(new_cnode); base_addn->set_scope(new_cnode->scope()); base_addn->set_abstract(new_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); + std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); new_cnode = base_addn; origin_input_size = base_addn->inputs().size() - 1; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc index 83c58ab547..a5e4675c8f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc @@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const auto prim = std::make_shared(kFusedMulAddNOpName); std::vector inputs = {NewValueNode(prim)}; inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); - inputs.push_back(addn->input(1)); + inputs.push_back(addn->input(2)); // scalar input should be 3rd input inputs.push_back(mul->input(lossscale_input_index)); auto fusion_node = graph->NewCNode(inputs); @@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const { VarPtr Z = std::make_shared(); VectorRef mul({prim::kPrimMul, X, Z}); - VectorRef addn({prim::kPrimAddN, Y, mul}); + VectorRef addn({prim::kPrimAddN, mul, Y}); return addn; } @@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { return nullptr; } - auto mul_anf = addn->input(2); + auto mul_anf = addn->input(1); if (mul_anf == nullptr) { return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc index faa1308f8b..fe9b35a5e9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -26,6 +26,7 @@ namespace mindspore { namespace opt { +namespace { const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, std::vector *trans_road) { if (node == nullptr) { @@ -59,6 +60,24 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr return nullptr; } +kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, + TypeId output_type) { + MS_EXCEPTION_IF_NULL(cast); + auto kernel_info = cast->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto cast_build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(cast_build_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat({format}); + builder.SetInputsFormat({format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetKernelType(cast_build_info->kernel_type()); + builder.SetFusionType(cast_build_info->fusion_type()); + builder.SetProcessor(cast_build_info->processor()); + return builder.Build(); +} +} // namespace bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "Func graph is nullptr"; @@ -95,17 +114,7 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); auto cast = trans_road[1]; - auto cast_format = AnfAlgo::GetOutputFormat(cast, 0); - auto cast_build_info = cast->kernel_info()->select_kernel_build_info(); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsFormat({format}); - builder.SetInputsFormat({format}); - builder.SetInputsDeviceType({param_dtype}); - builder.SetOutputsDeviceType({dtype}); - builder.SetKernelType(cast_build_info->kernel_type()); - builder.SetFusionType(cast_build_info->fusion_type()); - builder.SetProcessor(cast_build_info->processor()); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); + AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); if (param_format == format && param_dtype != dtype) { manager->Replace(trans_road[2], final_node); manager->Replace(cur_transop, cast); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc new file mode 100644 index 0000000000..5e265f2cf1 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" +#include +#include "session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef RemoveReshapePair::DefinePattern() const { + const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); + VectorRef reshape({prim_reshape, input_varptr_}); + + return VectorRef({prim::kPrimReshape, reshape}); +} + +const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_1); + // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly + auto users = manager->node_users()[reshape_op_1]; + if (users.size() > 1) { + return nullptr; + } + auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_2); + users = manager->node_users()[reshape_op_2]; + if (users.size() > 1) { + return nullptr; + } + auto input_node = reshape_op_2->input(1); + return input_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h new file mode 100644 index 0000000000..a284f4eaa9 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ + +#include +#include +#include "ir/anf.h" +#include "pre_activate/common/pattern_engine.h" +#include "pre_activate/common/helper.h" +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveReshapePair : public PatternProcessPass { + public: + explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) { + input_varptr_ = std::make_shared(); + } + ~RemoveReshapePair() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc index f0077ef6cd..b7280f52ae 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc @@ -36,6 +36,37 @@ DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { return device_addr; } +std::vector DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, + std::vector size_list) { + // Pre-alloc the one whole piece memory. + auto device_addr = AllocTensorMem(total_size); + MS_EXCEPTION_IF_NULL(device_addr); + // Remove the pre-alloc memory. + auto mem_block = FindMemBlock(device_addr); + MS_EXCEPTION_IF_NULL(mem_block); + auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); + if (iter == mem_block->block_all_mem_buf_map_.end()) { + MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; + } + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + auto rest_size = mem_buf->size_ - total_size; + (void)mem_block->block_all_mem_buf_map_.erase(iter); + // Split the pre-alloc memory into continuous memory by the size list. + DynamicMemBufPtr continuous_mem_buf; + std::vector device_addr_list; + auto buf_addr = device_addr; + for (size_t i = 0; i < size_list.size(); i++) { + continuous_mem_buf = std::make_shared(buf_addr, kMemBufUsed, size_list[i]); + (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); + device_addr_list.emplace_back(buf_addr); + buf_addr = AddressOffset(buf_addr, size_list[i]); + } + // Update the size of the last memory buf. + continuous_mem_buf->size_ += rest_size; + return device_addr_list; +} + size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { if (size == 0) { return DYNAMIC_MEM_ALIGN_SIZE; @@ -121,7 +152,7 @@ bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) co return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; } -void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr& mem_buf) { +void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { MS_EXCEPTION_IF_NULL(mem_buf); auto mem_block = FindMemBlock(mem_buf->device_addr_); MS_EXCEPTION_IF_NULL(mem_block); @@ -160,7 +191,7 @@ void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { CombineMemBuf(mem_block, device_addr); } -void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr& mem_block, const DeviceMemPtr device_addr) { +void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr) { MS_EXCEPTION_IF_NULL(mem_block); MS_EXCEPTION_IF_NULL(device_addr); auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h index dcf735814c..07efa267aa 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h @@ -61,7 +61,7 @@ class DynamicMemBlock { DynamicMemBlock() = default; DynamicMemBlock(DeviceMemPtr addr_base, size_t size) : device_addr_base_(addr_base), mem_block_size_(size) {} ~DynamicMemBlock() { block_all_mem_buf_map_.clear(); } - const DeviceMemPtr& device_addr() const { return device_addr_base_; } + const DeviceMemPtr &device_addr() const { return device_addr_base_; } size_t size() const { return mem_block_size_; } // The map of all memory buf in this memory block by device address. DeviceAddrMapMemBuf block_all_mem_buf_map_; @@ -79,6 +79,8 @@ class DynamicMemPoolBestFit { virtual ~DynamicMemPoolBestFit(); // The main program entry of memory alloc. DeviceMemPtr AllocTensorMem(size_t size); + // The main program entry of continuous memory alloc. + std::vector AllocContinuousTensorMem(size_t total_size, std::vector size_list); // The main program entry of memory free. void FreeTensorMem(const DeviceMemPtr device_addr); // Release the real device memory. @@ -92,8 +94,8 @@ class DynamicMemPoolBestFit { size_t used_mem_peak_statistics() const { return used_mem_peak_statistics_; } // The related interface of device memory real operation, needs override by device type. - virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) = 0; - virtual bool FreeDeviceMem(const DeviceMemPtr& addr) = 0; + virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) = 0; + virtual bool FreeDeviceMem(const DeviceMemPtr &addr) = 0; virtual size_t free_mem_size() = 0; virtual size_t total_mem_size() = 0; @@ -113,14 +115,14 @@ class DynamicMemPoolBestFit { // Judge whether need divide the memory buf by alloc size and memory buf size. bool IsDivide(size_t tensor_size, size_t mem_buf_size) const; // Divide the memory buf by alloc size. - void DivideMemBuf(size_t size, const DynamicMemBufPtr& mem_buf); + void DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf); // Find the memory block by deivce address. DynamicMemBlockPtr FindMemBlock(const DeviceMemPtr device_addr); // The Comparator of memory block by device address, because memory blocks are arranged in order by device address. static bool CmpMemBlock(const DeviceMemPtr device_addr, const DynamicMemBlockPtr mem_block); // Combine the memory buf when memory free, to avoid the memory fragmentation. - void CombineMemBuf(const DynamicMemBlockPtr& mem_block, const DeviceMemPtr device_addr); + void CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr); // Erase the idle memory buf by size and device address when idle memory buf is combined. void EraseIdleMemBuf(size_t size, const DeviceMemPtr device_addr); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index d25b60003f..952dfe97e4 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -162,10 +162,6 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr if (iter == kernel_def_ptr->inputs_.end()) { kernel_def_ptr->inputs_[key].push_back(ref_ptr); } else { - if (std::any_of(iter->second.begin(), iter->second.end(), - [ref_ptr](const KernelRefCountPtr &it) { return (it.get() == ref_ptr.get()); })) { - break; - } iter->second.push_back(ref_ptr); } } @@ -185,10 +181,6 @@ void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_pt if (iter == kernel_def_ptr->outputs_.end()) { kernel_def_ptr->outputs_[key].push_back(kernel_ref); } else { - if (std::any_of(iter->second.begin(), iter->second.end(), - [kernel_ref](const KernelRefCountPtr &it) { return (it == kernel_ref); })) { - break; - } iter->second.push_back(kernel_ref); } } diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc index d1409cdedd..77f6f96cec 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc @@ -20,8 +20,8 @@ namespace mindspore { namespace memreuse { void StreamReuse::SetStreamReuseResource() { #ifdef ENABLE_D - auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().GetPhysicMap(); - auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().GetIndependentMap(); + auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_physic_map(); + auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_independent_map(); MS_LOG(INFO) << "stream mem reuse for Davici"; if (!logic_independent_map.empty() && !logic_physic_map.empty()) { set_logic_physic_map(logic_physic_map); diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc index c2f96e54c6..fb47c9fc2a 100644 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc @@ -53,6 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(kExpandDimsOpName, {1}); Register(kSplitOpName, {0}); Register(kTopKOpName, {1}); + Register(kErfOpName, {1}); Register(kSparseApplyAdagradOpName, {2}); Register(kResizeNearestNeighborGrad, {1}); } diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc index 3f283e5d24..93c1b73038 100644 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc @@ -68,9 +68,8 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimTupleGetItem->name()) { return nullptr; } - if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { - return AnfAlgo::IsTupleOutput(node) && AnfAlgo::GetCNodeName(node) != prim::kPrimMakeTuple->name(); - })) { + if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), + [](const AnfNodePtr &node) { return AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); })) { return ConvertTupleInputToMakeTuple(func_graph, cnode); } return nullptr; diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc index db32354abf..86a90a4dfe 100644 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc @@ -28,8 +28,7 @@ namespace mindspore { namespace opt { constexpr auto kSingleInputIndex = 1; namespace { -AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); +AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return nullptr; @@ -41,15 +40,6 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { return nullptr; } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - // Check whether the node has only one output node. - if (manager->node_users().find(cnode) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The node should be used by at least another node's input"; - } - if (manager->node_users()[cnode].size() > 1) { - return nullptr; - } CheckCNodeInputSize(cnode, kSingleInputIndex + 1); return cnode->input(kSingleInputIndex); } @@ -63,7 +53,7 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { std::vector new_make_tuple_inputs; bool need_update = false; for (const auto &input : cnode->inputs()) { - AnfNodePtr replace_input = GetReplaceNode(func_graph, input); + AnfNodePtr replace_input = GetReplaceNode(input); // If replace input is not null, it will be the input of the TransData or Cast. if (replace_input == nullptr) { new_make_tuple_inputs.push_back(input); @@ -119,7 +109,7 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con if (ReplaceMakeTuple(func_graph, replacing_cnode)) { return nullptr; } - AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode); + AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); if (replace_node == nullptr) { MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); return nullptr; diff --git a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc index f186758de5..e6fec3d540 100644 --- a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc +++ b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc @@ -48,8 +48,8 @@ OpAttrFactory::OpAttrFactory() { {"Softsign", ActivationPacker}, {"Softplus", ActivationPacker}, {"Tanh", ActivationPacker}, - {"Hswish", ActivationPacker}, - {"Hsigmoid", ActivationPacker}, + {"HSwish", ActivationPacker}, + {"HSigmoid", ActivationPacker}, {"MaxPool", PoolingPacker}, {"MaxPool2D", PoolingPacker}, {"MeanPool", PoolingPacker}, diff --git a/mindspore/ccsrc/predict/generator/ir/ir_model.h b/mindspore/ccsrc/predict/generator/ir/ir_model.h index bf1c057b5f..82bd2aad3f 100644 --- a/mindspore/ccsrc/predict/generator/ir/ir_model.h +++ b/mindspore/ccsrc/predict/generator/ir/ir_model.h @@ -23,7 +23,7 @@ namespace mindspore { namespace generator { class IRModel { public: - void SetIrTaskInfos(const std::vector& ir_tasks); + void SetIrTaskInfos(const std::vector &ir_tasks); IRModel() = default; ~IRModel(); diff --git a/mindspore/ccsrc/pybind_api/api_register.h b/mindspore/ccsrc/pybind_api/api_register.h index 2c1b622f31..8bab751267 100644 --- a/mindspore/ccsrc/pybind_api/api_register.h +++ b/mindspore/ccsrc/pybind_api/api_register.h @@ -29,19 +29,19 @@ namespace py = pybind11; namespace mindspore { -using PybindDefineFunc = std::function; +using PybindDefineFunc = std::function; class PybindDefineRegister { public: - static void Register(const std::string& name, const PybindDefineFunc& fn) { + static void Register(const std::string &name, const PybindDefineFunc &fn) { return GetSingleton().RegisterFn(name, fn); } - PybindDefineRegister(const PybindDefineRegister&) = delete; + PybindDefineRegister(const PybindDefineRegister &) = delete; - PybindDefineRegister& operator=(const PybindDefineRegister&) = delete; + PybindDefineRegister &operator=(const PybindDefineRegister &) = delete; - static std::map& AllFuncs() { return GetSingleton().fns_; } + static std::map &AllFuncs() { return GetSingleton().fns_; } std::map fns_; @@ -50,14 +50,14 @@ class PybindDefineRegister { virtual ~PybindDefineRegister() = default; - static PybindDefineRegister& GetSingleton(); + static PybindDefineRegister &GetSingleton(); - void RegisterFn(const std::string& name, const PybindDefineFunc& fn) { fns_[name] = fn; } + void RegisterFn(const std::string &name, const PybindDefineFunc &fn) { fns_[name] = fn; } }; class PybindDefineRegisterer { public: - PybindDefineRegisterer(const std::string& name, const PybindDefineFunc& fn) { + PybindDefineRegisterer(const std::string &name, const PybindDefineFunc &fn) { PybindDefineRegister::Register(name, fn); } ~PybindDefineRegisterer() = default; diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h index 7405f621cb..37ff000b04 100644 --- a/mindspore/ccsrc/pynative/base.h +++ b/mindspore/ccsrc/pynative/base.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -57,9 +58,9 @@ struct OpExecInfo { py::dict op_attrs; }; using OpExecInfoPtr = std::shared_ptr; -OpExecInfoPtr GenerateOpExecInfo(const py::args& args); +OpExecInfoPtr GenerateOpExecInfo(const py::args &args); -const std::unordered_set ignore_infer_prim = {"partial"}; +const std::set ignore_infer_prim = {"partial", "make_ref"}; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 6a1ddf6a7e..0d18dfb577 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -43,7 +43,7 @@ const std::unordered_set vm_operators = {"partial", "depend", "make namespace mindspore { namespace pynative { -inline ValuePtr PyAttrValue(const py::object& obj) { +inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { @@ -52,11 +52,11 @@ inline ValuePtr PyAttrValue(const py::object& obj) { return converted_ret; } -py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) { auto signature = prim->signatures(); std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature& sig) { return sig.dtype; }); + [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { return py_args; @@ -103,7 +103,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { return py_inputs; } -void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) { +void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecInfo *const op_exec_info) { size_t size = py_args.size(); AbstractBasePtrList args_spec_list; for (size_t i = 0; i < size; i++) { @@ -118,7 +118,7 @@ void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecI op_exec_info->abstract = infer_res; } -OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { +OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { if (args.size() != PY_ARGS_NUM) { MS_LOG(ERROR) << "Four args are needed by RunOp"; return nullptr; @@ -147,7 +147,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { return op_exec_info; } -std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) { +std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info) { MS_EXCEPTION_IF_NULL(op_exec_info); std::string graph_info; MS_EXCEPTION_IF_NULL(op_exec_info->abstract); @@ -167,7 +167,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) { return graph_info; } -py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_LOG(INFO) << "RunOpInVM start"; MS_EXCEPTION_IF_NULL(status); @@ -188,7 +188,7 @@ py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat return std::move(result); } -py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; auto ms_context = MsContext::GetInstance(); @@ -212,7 +212,7 @@ py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat } py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info, - PynativeStatusCode* const status) { + PynativeStatusCode *const status) { MS_EXCEPTION_IF_NULL(status); py::object result; switch (backend_policy) { @@ -248,7 +248,7 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn return result; } -py::tuple RunOp(const py::args& args) { +py::tuple RunOp(const py::args &args) { py::object result; // returns a null py::tuple on error py::tuple err_ret(0); diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h index 17b5610bfd..c64c6b4b25 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pynative/pynative_execute.h @@ -33,9 +33,9 @@ namespace pynative { namespace py = pybind11; -py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status); +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); -py::tuple RunOp(const py::args& args); +py::tuple RunOp(const py::args &args); } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pynative/pynative_execute_ge.cc index 180b0006ff..0bf2a391f9 100644 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.cc +++ b/mindspore/ccsrc/pynative/pynative_execute_ge.cc @@ -43,7 +43,7 @@ using transform::GraphRunner; using transform::GraphRunnerOptions; using transform::OperatorPtr; static std::shared_ptr session = nullptr; -inline ValuePtr PyAttrValue(const py::object& obj) { +inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { @@ -52,7 +52,7 @@ inline ValuePtr PyAttrValue(const py::object& obj) { return converted_ret; } -MeTensorPtr ConvertPyObjToTensor(const py::object& obj) { +MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { MeTensorPtr me_tensor_ptr = nullptr; if (py::isinstance(obj)) { me_tensor_ptr = py::cast(obj); @@ -72,8 +72,8 @@ MeTensorPtr ConvertPyObjToTensor(const py::object& obj) { return me_tensor_ptr; } -bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const OperatorPtr& op, std::vector* graph_input_nodes) { +bool SetInputsForSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const OperatorPtr &op, std::vector *graph_input_nodes) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(graph_input_nodes); auto op_inputs = op_exec_info->op_inputs; @@ -103,7 +103,7 @@ bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vec auto pointer_cast_const_op = std::static_pointer_cast(const_op); MS_EXCEPTION_IF_NULL(pointer_cast_const_op); (void)pointer_cast_const_op->update_output_desc_y(*const_op_desc); - auto& input_map = adapter->getInputMap(); + auto &input_map = adapter->getInputMap(); if (input_map.find(op_input_idx) == input_map.end()) { continue; } @@ -116,8 +116,8 @@ bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vec return true; } -bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const std::unordered_map& attrs, const GeGraphPtr& graph) { +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph) { MS_EXCEPTION_IF_NULL(op_exec_info); std::string op_name = op_exec_info->op_name; auto op_inputs = op_exec_info->op_inputs; @@ -145,8 +145,8 @@ bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vectorsetAttr(op, attr.first, attr.second); } // set input attributes - auto& input_attr_map = adapter->getInputAttrMap(); - for (auto& it : input_attr_map) { + auto &input_attr_map = adapter->getInputAttrMap(); + for (auto &it : input_attr_map) { if (op_inputs.size() < it.first) { continue; } @@ -165,7 +165,7 @@ bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector* const inputs) { +void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector *const inputs) { MS_EXCEPTION_IF_NULL(inputs); MS_EXCEPTION_IF_NULL(op_exec_info); auto op_inputs = op_exec_info->op_inputs; @@ -185,12 +185,12 @@ void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector* con } } -PynativeStatusCode ConvertAttributes(const OpExecInfoPtr& op_exec_info, const std::vector& inputs) { +PynativeStatusCode ConvertAttributes(const OpExecInfoPtr &op_exec_info, const std::vector &inputs) { MS_EXCEPTION_IF_NULL(op_exec_info); auto op_attrs = op_exec_info->op_attrs; std::unordered_map attrs{}; - for (auto& item : op_attrs) { + for (auto &item : op_attrs) { if (!py::isinstance(item.first)) { MS_LOG(ERROR) << "Type error in py dict convert"; return PYNATIVE_OP_ATTRS_ERR; @@ -218,8 +218,8 @@ PynativeStatusCode ConvertAttributes(const OpExecInfoPtr& op_exec_info, const st return PYNATIVE_SUCCESS; } -std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, - const std::vector& ge_tensors) { +std::vector ConvertOutputTensors(const OpExecInfoPtr &op_exec_info, + const std::vector &ge_tensors) { std::vector outputs; AbstractBasePtr abs_base = op_exec_info->abstract; std::vector> shapes; @@ -242,7 +242,7 @@ std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); return outputs; } - for (auto& it : ge_tensors) { + for (auto &it : ge_tensors) { auto tensor = transform::TransformUtil::ConvertGeTensor(it); if (tensor != nullptr) { outputs.emplace_back(tensor); @@ -251,7 +251,7 @@ std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, return outputs; } -py::object RunOpInGE(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_LOG(INFO) << "RunOpInGe start"; MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(status); diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pynative/pynative_execute_ge.h index af0efec3e3..2dca3df018 100644 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.h +++ b/mindspore/ccsrc/pynative/pynative_execute_ge.h @@ -36,10 +36,10 @@ using GeGraphPtr = std::shared_ptr; namespace mindspore { namespace pynative { -bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const std::unordered_map& attrs, const GeGraphPtr& graph); +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph); -py::object RunOpInGE(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status); +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 2591f763c5..525ff44dd8 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -111,12 +111,12 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr MS_EXCEPTION_IF_NULL(value_node); int item_idx = GetValue(value_node->value()); return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), - visit_nop_node); + visit_nop_node, return_types); } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node); + return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); } else if (opt::IsNopNode(cnode) && visit_nop_node) { if (cnode->inputs().size() == 2) { - return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node); + return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); } else { MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; } diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index ad6c58bc93..11ae3da6f7 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -610,7 +610,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { // insert active in true graph, another active will be inserted in kernel adjust - InsertStreamActiveToGraph(true_last_id, kInvalidDistincLabel - 1); + InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); } break; } diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 2270e6719b..162ef8f0c1 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -237,15 +237,15 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameSquare), ADPT_DESC(Square)}, {prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, - {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborD)}, - {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborGrad)}, + {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)}, + {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)}, {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, {string(kNameReLU6), ADPT_DESC(Relu6)}, {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, {string(kNameElu), ADPT_DESC(Elu)}, {string(kNameEluGrad), ADPT_DESC(EluGrad)}, - {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearGrad)}, - {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearD)}, + {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)}, + {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, @@ -264,7 +264,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameArgMinWithValue), ADPT_DESC(ArgMinWithValue)}, {prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSumD)}, {prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMeanD)}, - {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAll)}, + {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD)}, {prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD)}, {prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD)}, {string(kNameLARSUpdate), ADPT_DESC(LarsV2Update)}, diff --git a/mindspore/ccsrc/transform/convert.h b/mindspore/ccsrc/transform/convert.h index 556db5acee..5596e20f19 100644 --- a/mindspore/ccsrc/transform/convert.h +++ b/mindspore/ccsrc/transform/convert.h @@ -51,16 +51,16 @@ class OpAdapterDesc { public: OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} - OpAdapterDesc(const OpAdapterPtr& train, const OpAdapterPtr& infer) : train_(train), infer_(infer) {} + OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} - explicit OpAdapterDesc(const OpAdapterPtr& common) : train_(common), infer_(common) {} + explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} - OpAdapterDesc(const OpAdapterDesc& desc) { + OpAdapterDesc(const OpAdapterDesc &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; } - OpAdapterDesc(OpAdapterDesc&& desc) { + OpAdapterDesc(OpAdapterDesc &&desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; desc.train_ = nullptr; @@ -71,7 +71,7 @@ class OpAdapterDesc { OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } - OpAdapterDesc& operator=(const OpAdapterDesc& desc) { + OpAdapterDesc &operator=(const OpAdapterDesc &desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; @@ -79,7 +79,7 @@ class OpAdapterDesc { return *this; } - OpAdapterDesc& operator=(OpAdapterDesc&& desc) { + OpAdapterDesc &operator=(OpAdapterDesc &&desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; @@ -99,7 +99,7 @@ using TensorOrderMap = std::map>; class DfGraphConvertor { public: - explicit DfGraphConvertor(const AnfGraphPtr& anf_graph) + explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { #if (!defined ENABLE_GE) || (defined ENABLE_INFER) auto it_training = anf_graph->flags().find("training"); @@ -125,14 +125,14 @@ class DfGraphConvertor { ~DfGraphConvertor() {} - static void RegisterAdapter(const std::string& name, OpAdapterPtr adpt) { + static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { get_adpt_map()[name] = std::make_shared(adpt); } - static void RegisterAdapter(const std::string& name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { + static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); } - void DrawComputeGraph(const std::string& name) { + void DrawComputeGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -141,7 +141,7 @@ class DfGraphConvertor { fout << compute_sout_.str(); fout.close(); } - void DrawInitGraph(const std::string& name) { + void DrawInitGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -150,7 +150,7 @@ class DfGraphConvertor { fout << init_sout_.str(); fout.close(); } - void DrawSaveCheckpointGraph(const std::string& name) { + void DrawSaveCheckpointGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -160,74 +160,74 @@ class DfGraphConvertor { fout.close(); } - DfGraphConvertor& ConvertAllNode(); - DfGraphConvertor& BuildGraph(); - DfGraphConvertor& InitParam(const TensorOrderMap& tensors); - DfGraphConvertor& GenerateCheckpointGraph(); - DfGraphConvertor& GenerateBroadcastGraph(const TensorOrderMap& tensors); - void InitParamWithData(const TensorOrderMap& tensors); - void SetOpInput(const OpAdapterPtr& adpt, const CNodePtr& node); - void SetupBroadcast(const std::shared_ptr& broadcast, const std::vector& broadcast_desc, - const DfGraphPtr& broadcast_graph, std::vector broadcast_input); - void MakeDatasetHandler(const std::string& name, const size_t& input_idx, const AnfNodePtr& it); - void SetupParamInitSubGraph(const TensorOrderMap& tensors, std::vector* init_input); - void DrawParamInitSubGraph(const std::string& name, const AnfNodePtr& it); + DfGraphConvertor &ConvertAllNode(); + DfGraphConvertor &BuildGraph(); + DfGraphConvertor &InitParam(const TensorOrderMap &tensors); + DfGraphConvertor &GenerateCheckpointGraph(); + DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); + void InitParamWithData(const TensorOrderMap &tensors); + void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); + void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, + const DfGraphPtr &broadcast_graph, std::vector broadcast_input); + void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); + void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); + void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); DfGraphPtr GetComputeGraph(); DfGraphPtr GetInitGraph(); DfGraphPtr GetSaveCheckpointGraph(); DfGraphPtr GetBroadcastGraph(); - static OpAdapterPtr FindAdapter(const std::string& op_name, bool train = false); + static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); int ErrCode() const { return static_cast(error_); } - static std::unordered_map& get_adpt_map(); + static std::unordered_map &get_adpt_map(); bool is_training() const { return training_; } void set_training(bool is_training) { training_ = is_training; } protected: - void InitLoopVar(std::vector* init_input); + void InitLoopVar(std::vector *init_input); private: std::ostringstream compute_sout_; std::ostringstream init_sout_; std::ostringstream checkpoint_sout_; std::ostringstream restore_checkpoint_sout_; - std::unordered_map op_draw_name_; + std::unordered_map op_draw_name_; - AnfNodePtr TraceTupleGetItem(const CNodePtr& node, unsigned int* index); - AnfNodePtr TraceMakeTuple(const CNodePtr& node, unsigned int index); - AnfNodePtr TraceDepend(const CNodePtr& node); + AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); + AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); + AnfNodePtr TraceDepend(const CNodePtr &node); OutHandler TraceRealOp(AnfNodePtr node); - OutHandler GetHandler(const AnfNodePtr& node, const std::stack& index_stack, AnfNode* const draw_index); + OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); OperatorPtr Convert(AnfNodePtr node); OperatorPtr ConvertCNode(CNodePtr node); std::vector ConvertDependNode(AnfNodePtr node); AnfNodePtr GetRealOpNode(AnfNodePtr node); - std::vector GetDependNodes(const AnfNodePtr& node); + std::vector GetDependNodes(const AnfNodePtr &node); OperatorPtr ConvertParameter(AnfNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node); void ConvertTupleGetItem(const CNodePtr node); - void GetDependOnParameterUse(const CNodePtr& node, const AnfNodePtr& src_node, const AnfNodePtr& dest_node, - const std::shared_ptr>& src_ops_list, - const std::shared_ptr>& dst_ops_list); - bool GetControlDependList(const CNodePtr& node, const std::shared_ptr>& src_ops_list, - const std::shared_ptr>& dst_ops_list); - void DrawControlDepend(const AnfNodePtr& src_node, const AnfNodePtr& dest_node); + void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); void ConvertControlDependNode(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node); - bool CheckCNode(const std::string& name, const CNodePtr node); + bool CheckCNode(const std::string &name, const CNodePtr node); void TraceOutput(AnfNodePtr node); - void TraceOutputFromParameter(const AnfNodePtr& anf_out); - void TraceOutputFromTupleGetItem(const AnfNodePtr& anf_out); + void TraceOutputFromParameter(const AnfNodePtr &anf_out); + void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); void SetNodeInput(AnfNodePtr node); void SetOpControlInput(const AnfNodePtr node); void UpdateOpDesc(AnfNodePtr node); void BuildSaveCheckpointGraph(); void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); - void UpdateDataOpDesc(const AnfNodePtr& it, const OperatorPtr& op) const; - void AddGraphConstInput(const OperatorPtr& op); + void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; + void AddGraphConstInput(const OperatorPtr &op); std::shared_ptr anf_graph_{nullptr}; std::shared_ptr df_graph_{nullptr}; @@ -235,12 +235,12 @@ class DfGraphConvertor { std::shared_ptr save_ckp_graph_{nullptr}; std::shared_ptr restore_ckp_graph_{nullptr}; std::shared_ptr broadcast_graph_{nullptr}; - std::unordered_map op_cache_; - std::unordered_map> control_depend_cache_; + std::unordered_map op_cache_; + std::unordered_map> control_depend_cache_; /* record "tuple_getitem"<->"out_handler" mapping */ - std::unordered_map out_handle_cache_; + std::unordered_map out_handle_cache_; /* record "make_tuple"<->"out_handler vector" mapping */ - std::unordered_map>> tuple_out_handle_cache_; + std::unordered_map>> tuple_out_handle_cache_; std::unordered_map params_; std::unordered_map vars_; std::vector> graph_outputs_; diff --git a/mindspore/ccsrc/transform/df_graph_manager.cc b/mindspore/ccsrc/transform/df_graph_manager.cc index bfe4d9f5d2..f62c386587 100644 --- a/mindspore/ccsrc/transform/df_graph_manager.cc +++ b/mindspore/ccsrc/transform/df_graph_manager.cc @@ -31,8 +31,8 @@ namespace mindspore { namespace transform { -DfGraphWrapper::DfGraphWrapper(const std::string& name, const int& id, const DfGraphPtr& graph_ptr, - const OptionMap& options) +DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, + const OptionMap &options) : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {} DfGraphManager::DfGraphManager() { @@ -49,7 +49,7 @@ DfGraphManager::~DfGraphManager() { parse::python_adapter::set_python_env_flag(false); } -DfGraphManager& DfGraphManager::GetInstance() { +DfGraphManager &DfGraphManager::GetInstance() { static DfGraphManager instance; return instance; } @@ -63,7 +63,7 @@ int DfGraphManager::GenerateId() { return graph_id_; } -Status DfGraphManager::AddGraph(const std::string& name, const DfGraphPtr& graph_ptr, const OptionMap& options) { +Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options) { std::lock_guard lg(lock_); if (name.empty()) { MS_LOG(ERROR) << "The graph name is null, add graph failed"; @@ -101,9 +101,9 @@ std::vector DfGraphManager::GetAllGraphs() { } std::set DfGraphManager::GetSavedGraphs() { return saved_graphs_; } -void DfGraphManager::AddSavedGraphs(const std::string& id) { saved_graphs_.insert(id); } +void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); } -DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string& name) { +DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) { std::lock_guard lg(lock_); if (name.empty()) { MS_LOG(ERROR) << "The graph name is null"; @@ -126,7 +126,7 @@ void DfGraphManager::ClearGraph() noexcept { MS_LOG(INFO) << "Remove all graphs in GraphManager"; } -void DfGraphManager::SetAnfGraph(const std::string& name, const AnfGraphPtr& anf_graph_ptr) { +void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) { DfGraphWrapperPtr df_graph = GetGraphByName(name); if (df_graph == nullptr) { MS_LOG(ERROR) << "Can't found graph name: " << name; @@ -152,7 +152,7 @@ void DfGraphManager::EraseAnfGraph() { anf_graphs_.clear(); } -void DfGraphManager::SetGeSession(const std::shared_ptr& sess_ptr) { +void DfGraphManager::SetGeSession(const std::shared_ptr &sess_ptr) { std::lock_guard lg(lock_); if (sess_ptr == nullptr) { MS_LOG(WARNING) << "You are adding a empty Ge Session"; @@ -182,7 +182,7 @@ void DfGraphManager::DeleteGeSession() noexcept { } } -void DfGraphManager::SetGraphRunner(const std::shared_ptr& graph_runner_ptr) noexcept { +void DfGraphManager::SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept { std::lock_guard lg(lock_); if (graph_runner_ptr == nullptr) { MS_LOG(WARNING) << "You are adding a empty GraphRunner"; diff --git a/mindspore/ccsrc/transform/df_graph_manager.h b/mindspore/ccsrc/transform/df_graph_manager.h index 97137ae94b..2ca43d1f07 100644 --- a/mindspore/ccsrc/transform/df_graph_manager.h +++ b/mindspore/ccsrc/transform/df_graph_manager.h @@ -35,7 +35,7 @@ using OptionMap = std::map; struct DfGraphWrapper { public: - DfGraphWrapper(const std::string& name, const int& id, const DfGraphPtr& graph_ptr, const OptionMap& options); + DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); ~DfGraphWrapper() {} std::string name_; @@ -51,19 +51,19 @@ class DfGraphManager { ~DfGraphManager(); void ClearGraph() noexcept; - static DfGraphManager& GetInstance(); - Status AddGraph(const std::string& name, const DfGraphPtr& graph, const OptionMap& options = {}); + static DfGraphManager &GetInstance(); + Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}); std::vector GetAllGraphs(); std::set GetSavedGraphs(); - void AddSavedGraphs(const std::string& id); - DfGraphWrapperPtr GetGraphByName(const std::string& name); - DfGraphManager(const DfGraphManager&) = delete; - void SetAnfGraph(const std::string& name, const AnfGraphPtr& anf_graph_ptr); + void AddSavedGraphs(const std::string &id); + DfGraphWrapperPtr GetGraphByName(const std::string &name); + DfGraphManager(const DfGraphManager &) = delete; + void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); AnfGraphPtr GetAnfGraph(uint32_t graph_id); std::shared_ptr GetGraphRunner(); - void SetGraphRunner(const std::shared_ptr& graph_runner_ptr) noexcept; + void SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept; void DeleteGraphRunner() noexcept; - void SetGeSession(const std::shared_ptr& sess_ptr); + void SetGeSession(const std::shared_ptr &sess_ptr); std::shared_ptr GetGeSession(); void DeleteGeSession() noexcept; void EraseAnfGraph(); diff --git a/mindspore/ccsrc/transform/graph_builder.cc b/mindspore/ccsrc/transform/graph_builder.cc index 9c05969fb0..785c5c7f3a 100644 --- a/mindspore/ccsrc/transform/graph_builder.cc +++ b/mindspore/ccsrc/transform/graph_builder.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace transform { -DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam& param) { +DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam ¶m) { MS_LOG(INFO) << "BuildMDDatasetGraph."; // InitData @@ -37,7 +37,7 @@ DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam& param) { return dataset_graph; } -Status BuildDatasetGraph(const DatasetGraphParam& param, const std::string& phase) { +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) { Status ret; std::string graph_name = phase; diff --git a/mindspore/ccsrc/transform/graph_builder.h b/mindspore/ccsrc/transform/graph_builder.h index 30b891460b..3d959f5a85 100644 --- a/mindspore/ccsrc/transform/graph_builder.h +++ b/mindspore/ccsrc/transform/graph_builder.h @@ -27,7 +27,7 @@ namespace mindspore { namespace transform { -Status BuildDatasetGraph(const DatasetGraphParam& param, const std::string& phase = "dataset"); +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_runner.cc b/mindspore/ccsrc/transform/graph_runner.cc index 8b0ddfd18d..52d0d8e17f 100644 --- a/mindspore/ccsrc/transform/graph_runner.cc +++ b/mindspore/ccsrc/transform/graph_runner.cc @@ -30,7 +30,7 @@ #ifdef NO_GE_CLIENT namespace ge { -Session::Session(const std::map& options) { +Session::Session(const std::map &options) { if (options.empty()) { MS_LOG(ERROR) << "session input options is empty"; } @@ -42,7 +42,7 @@ Session::~Session() {} namespace mindspore { namespace transform { -std::shared_ptr GraphRunner::NewSession(const SessionOptions& sess_options) { +std::shared_ptr GraphRunner::NewSession(const SessionOptions &sess_options) { std::shared_ptr ret = std::make_shared(sess_options); if (ret == nullptr) { MS_LOG(ERROR) << "Create GE session failed"; @@ -52,7 +52,7 @@ std::shared_ptr GraphRunner::NewSession(const SessionOptions& sess_ return ret; } -GraphRunner::GraphRunner(const GraphRunnerOptions& options) +GraphRunner::GraphRunner(const GraphRunnerOptions &options) : options_(options), graph_manager_(DfGraphManager::GetInstance()) { if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; @@ -88,7 +88,7 @@ GraphRunner::GraphRunner(const GraphRunnerOptions& options) } #ifdef ENABLE_GE - for (auto& it : wrappers) { + for (auto &it : wrappers) { std::set saved_graph = graph_manager_.GetSavedGraphs(); auto iter_find = saved_graph.find(std::to_string(it->id_)); if (iter_find != saved_graph.end()) { @@ -101,8 +101,8 @@ GraphRunner::GraphRunner(const GraphRunnerOptions& options) #endif } -Status GraphRunner::RunGraph(const RunOptions& options, const std::vector& inputs, - std::vector* outputs) { +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *outputs) { std::string name = options.name; if (name.empty()) { MS_LOG(ERROR) << "The graph name is null"; @@ -125,7 +125,7 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vector ge_outputs; (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs), - [](const GeTensorPtr& i) { return *i; }); + [](const GeTensorPtr &i) { return *i; }); MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs"; @@ -161,19 +161,19 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vector(ge_tensor); }); + [](const GeTensor &ge_tensor) { return std::make_shared(ge_tensor); }); return Status::SUCCESS; } -Status GraphRunner::RunGraph(const RunOptions& options, const std::vector& inputs, - std::vector* const outputs) { +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *const outputs) { std::vector ge_inputs; for (auto it : inputs) { MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize(); auto shape = (*it).shape(); std::string shape_str; - for (const auto& elem : shape) { + for (const auto &elem : shape) { shape_str += std::to_string(elem); shape_str += " "; } @@ -199,7 +199,7 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vectoremplace_back(tensor); diff --git a/mindspore/ccsrc/transform/graph_runner.h b/mindspore/ccsrc/transform/graph_runner.h index a9aa9fbc59..728a1a25a2 100644 --- a/mindspore/ccsrc/transform/graph_runner.h +++ b/mindspore/ccsrc/transform/graph_runner.h @@ -46,16 +46,16 @@ struct RunOptions { class GraphRunner { public: - explicit GraphRunner(const GraphRunnerOptions& options); + explicit GraphRunner(const GraphRunnerOptions &options); ~GraphRunner() { sess_ = nullptr; } - Status RunGraph(const RunOptions& options, const std::vector& inputs, std::vector* outputs); - Status RunGraph(const RunOptions& options, const std::vector& inputs, std::vector* outputs); - static std::shared_ptr NewSession(const SessionOptions& sess_options); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + static std::shared_ptr NewSession(const SessionOptions &sess_options); private: std::shared_ptr sess_; transform::GraphRunnerOptions options_; - DfGraphManager& graph_manager_; + DfGraphManager &graph_manager_; }; } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index 421e4c4569..ae678606a4 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -26,17 +26,17 @@ #include "utils/utils.h" namespace mindspore { namespace transform { -static uint32_t CustomInferFunc(const Operator&) { return 0; } +static uint32_t CustomInferFunc(const Operator &) { return 0; } template class OpAdapter : public BaseOpAdapter { public: using OpType = T; OpAdapter() {} - explicit OpAdapter(const ExtraAttr& extra_attr) : extra_attr_(extra_attr) {} + explicit OpAdapter(const ExtraAttr &extra_attr) : extra_attr_(extra_attr) {} ~OpAdapter() override {} - bool IsCustomOp(const OperatorPtr& op) { + bool IsCustomOp(const OperatorPtr &op) { MS_EXCEPTION_IF_NULL(op); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { @@ -45,7 +45,7 @@ class OpAdapter : public BaseOpAdapter { return true; } - Status GenerateCustomOpInputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) { + Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(prim); // Create the map of custom op from input index to input name. @@ -69,7 +69,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - Status GenerateCustomOpOutputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) { + Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(prim); // Create the map of custom op from output index to output name. @@ -122,7 +122,7 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr GenerateNormalOp(const AnfNodePtr& anf) { + OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { OperatorPtr op = nullptr; // There are duplicate names in ANF graph, do not assign ANF node name to GE // GE will generate unique name automatically @@ -148,7 +148,7 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr generate(const AnfNodePtr& anf) override { + OperatorPtr generate(const AnfNodePtr &anf) override { OperatorPtr op = nullptr; if (IsCustomCNode(anf)) { op = GenerateCustomOp(anf); @@ -158,21 +158,21 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr generate(const std::string& op_name) override { return std::make_shared(op_name); } + OperatorPtr generate(const std::string &op_name) override { return std::make_shared(op_name); } - const std::unordered_map& getInputMap() override { return input_map_; } - const std::unordered_map& getInputAttrMap() override { return input_attr_map_; } - const std::unordered_map& getDynInputMap() override { return dyn_input_map_; } - const std::unordered_map& getOutputMap() override { return output_map_; } + const std::unordered_map &getInputMap() override { return input_map_; } + const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } + const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } + const std::unordered_map &getOutputMap() override { return output_map_; } - Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OperatorPtr& input) { + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(input); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { return NOT_FOUND; } - std::unordered_map& input_map = it->second; + std::unordered_map &input_map = it->second; if ((input_map.find(index) != input_map.end())) { MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index]; @@ -182,7 +182,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - Status SetNormalOpInput(const OperatorPtr& op, int index, const OperatorPtr& input) { + Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); auto it = input_map_.find(index); if (it != input_map_.end()) { @@ -194,7 +194,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - int setInput(const OperatorPtr& op, int index, const OperatorPtr& input) override { + int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { if (IsCustomOp(op)) { auto cus_op = std::dynamic_pointer_cast(op); return static_cast(SetCustomOpInput(cus_op, index, input)); @@ -203,14 +203,14 @@ class OpAdapter : public BaseOpAdapter { } } - Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OutHandler& handle) { + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { MS_EXCEPTION_IF_NULL(op); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { return NOT_FOUND; } - std::unordered_map& input_map = it->second; + std::unordered_map &input_map = it->second; if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) { if (handle.out.empty()) { MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index]; @@ -225,7 +225,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - Status SetNormalOpInput(const OperatorPtr& op, int index, const OutHandler& handle) { + Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { MS_EXCEPTION_IF_NULL(op); auto it = input_map_.find(index); if ((handle.op != nullptr) && (it != input_map_.end())) { @@ -242,7 +242,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - int setInput(const OperatorPtr& op, int index, const OutHandler& handle) override { + int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { if (IsCustomOp(op)) { auto cus_op = std::dynamic_pointer_cast(op); return static_cast(SetCustomOpInput(cus_op, index, handle)); @@ -251,7 +251,7 @@ class OpAdapter : public BaseOpAdapter { } } - int setInput(const OperatorPtr& op, int index, const std::shared_ptr>& handler_vec) override { + int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) override { MS_EXCEPTION_IF_NULL(handler_vec); if (IsCustomOp(op)) { MS_LOG(ERROR) << "Custom Op do not support dynamic input"; @@ -278,7 +278,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(NOT_FOUND); } - OutHandler getOutput(const OperatorPtr& op, int index) override { + OutHandler getOutput(const OperatorPtr &op, int index) override { MS_EXCEPTION_IF_NULL(op); if (IsCustomOp(op)) { return getCustomOutput(op, index); @@ -286,7 +286,7 @@ class OpAdapter : public BaseOpAdapter { return getNormalOutput(op, index); } - OutHandler getCustomOutput(const OperatorPtr& op, int index) { + OutHandler getCustomOutput(const OperatorPtr &op, int index) { MS_EXCEPTION_IF_NULL(op); auto it = cus_output_map_.find(op->GetOpType()); if (it == cus_output_map_.end()) { @@ -294,7 +294,7 @@ class OpAdapter : public BaseOpAdapter { return OutHandler(); } - std::unordered_map& output_map = it->second; + std::unordered_map &output_map = it->second; if ((output_map.find(index) != output_map.end())) { return OutHandler(op, output_map[index]); @@ -303,7 +303,7 @@ class OpAdapter : public BaseOpAdapter { return OutHandler(); } - OutHandler getNormalOutput(const OperatorPtr& op, int index) { + OutHandler getNormalOutput(const OperatorPtr &op, int index) { MS_EXCEPTION_IF_NULL(op); if (!dyn_output_map_.empty() && !output_map_.empty()) { MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; @@ -320,7 +320,7 @@ class OpAdapter : public BaseOpAdapter { } } - Status UpdateSingleOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) { + Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { MS_EXCEPTION_IF_NULL(type); std::string format = "NCHW"; if (op->GetOpType() == kExtractImagePatchesOpName) { @@ -353,7 +353,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - size_t GetCustomOpOutputSize(const CusOperatorPtr& cus_op) { + size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { MS_EXCEPTION_IF_NULL(cus_op); if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) { MS_LOG(ERROR) << "This op does not create custom output map"; @@ -363,8 +363,8 @@ class OpAdapter : public BaseOpAdapter { return output_size; } - std::shared_ptr CreateOutputDesc(const abstract::ShapePtr& shape_ptr, const TypePtr& type, - const std::string& format) { + std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, + const std::string &format) { if (shape_ptr == nullptr) { MS_LOG(ERROR) << "Shape ptr is nullptr"; return nullptr; @@ -383,7 +383,7 @@ class OpAdapter : public BaseOpAdapter { return desc; } - Status UpdateMultiOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) { + Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { auto tuple_shp = dyn_cast(shp); MS_EXCEPTION_IF_NULL(tuple_shp); @@ -432,7 +432,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - std::shared_ptr CreateNodeDesc(const AnfNodePtr& node) { + std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TypeId me_type = node->Type()->type_id(); if (kObjectTypeTensorType == me_type) { @@ -456,7 +456,7 @@ class OpAdapter : public BaseOpAdapter { return desc; } - void UpdateNormalOpInputDesc(const OperatorPtr& op, const AnfNodePtr node) { + void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -479,7 +479,7 @@ class OpAdapter : public BaseOpAdapter { } } - void UpdateCustomOpInputDesc(const CusOperatorPtr& op, const AnfNodePtr& node) { + void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -491,7 +491,7 @@ class OpAdapter : public BaseOpAdapter { return; } - std::unordered_map& input_map = cus_input_map_[op->GetOpType()]; + std::unordered_map &input_map = cus_input_map_[op->GetOpType()]; auto inputs = node->cast()->inputs(); for (size_t i = 1; i < inputs.size(); ++i) { if (input_map.find(i) != input_map.end()) { @@ -504,7 +504,7 @@ class OpAdapter : public BaseOpAdapter { } } - void updateInputDesc(const OperatorPtr& op, const AnfNodePtr& node) { + void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(node); if (IsCustomOp(op)) { @@ -515,8 +515,8 @@ class OpAdapter : public BaseOpAdapter { } } - void updateOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type, - const AnfNodePtr& node) override { + void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) override { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -548,7 +548,7 @@ class OpAdapter : public BaseOpAdapter { updateInputDesc(op, node); } - int setAttr(const OperatorPtr& op, const std::string& attrKey, const ValuePtr& attrValue) override { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { auto it = attr_map_.find(attrKey); if (it != attr_map_.end()) { // switch case for each avalilable attribute type @@ -560,7 +560,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(NOT_FOUND); } - int SetCustomOpAttr(const CusOperatorPtr& op, const PrimitivePtr& prim) { + int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { enum ValueType { SINGLE_VALUE = 0, SEQUEUE_VALUE, @@ -611,11 +611,11 @@ class OpAdapter : public BaseOpAdapter { return 0; } - int SetNormalOpAttr(const OperatorPtr& op, const PrimitivePtr& prim) { + int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { int ret = 0; MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(op); - for (auto& it : attr_map_) { + for (auto &it : attr_map_) { auto value = prim->GetAttr(it.first); if (value != nullptr) { // set attr from primitive @@ -637,7 +637,7 @@ class OpAdapter : public BaseOpAdapter { return 0; } - int setAttr(const OperatorPtr& op, const PrimitivePtr& prim) override { + int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { int ret = 0; if (IsCustomPrim(prim)) { auto cus_op = std::dynamic_pointer_cast(op); @@ -648,7 +648,7 @@ class OpAdapter : public BaseOpAdapter { return ret; } - int setAttr(const OperatorPtr& op, const AnfNodePtr& node) override { + int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { // no attribute for lonely node MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { @@ -660,7 +660,7 @@ class OpAdapter : public BaseOpAdapter { return 0; } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); if (inputs.empty()) { return 0; } @@ -691,7 +691,7 @@ class OpAdapter : public BaseOpAdapter { } // set attr from const input - for (auto& it : input_attr_map_) { + for (auto &it : input_attr_map_) { if (inputs.size() <= it.first || !inputs[it.first]->isa()) { continue; } @@ -711,38 +711,38 @@ class OpAdapter : public BaseOpAdapter { private: template - static S ConvertAny(const ValuePtr& value, const AnyTraits&) { + static S ConvertAny(const ValuePtr &value, const AnyTraits &) { return GetValue(value); } // specialization for reverse bool - static bool ConvertAny(const ValuePtr& value, const AnyTraits&, bool reverse) { + static bool ConvertAny(const ValuePtr &value, const AnyTraits &, bool reverse) { return reverse != GetValue(value); } template - static Q ConvertAny(const ValuePtr& value, const AnyTraits

& traits_from, const AnyTraits& traits_to) { + static Q ConvertAny(const ValuePtr &value, const AnyTraits

&traits_from, const AnyTraits &traits_to) { return ConvertAnyUtil(value, traits_from, traits_to); } // specialization for tensor - static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits& traits) { + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits &traits) { // To-DO the format may read from ME tensor return ConvertAnyUtil(value, traits); } // specialization for int - static int64_t ConvertAny(const ValuePtr& value, const AnyTraits) { + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { return static_cast(GetValue(value)); } - // specialization for int to Vector - static std::vector ConvertAny(const ValuePtr& value, const std::string& name, + // specialization for int or tuple broadcast to Vector + static std::vector ConvertAny(const ValuePtr &value, const std::string &name, const AnyTraits> anyTraitsInt) { return ConvertAnyUtil(value, name, anyTraitsInt); } - static std::vector> ConvertAny(const ValuePtr& value, + static std::vector> ConvertAny(const ValuePtr &value, const AnyTraits>>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(INFO) << "Value: " << value->type_name(); @@ -752,14 +752,14 @@ class OpAdapter : public BaseOpAdapter { } auto vec = value->cast(); MS_EXCEPTION_IF_NULL(vec); - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { MS_EXCEPTION_IF_NULL(it); if (!it->isa()) { MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); } auto sub_vector = it->cast(); std::vector sublist; - for (auto& item : sub_vector->value()) { + for (auto &item : sub_vector->value()) { sublist.push_back(static_cast(GetValue(item))); } list.push_back(sublist); @@ -767,7 +767,7 @@ class OpAdapter : public BaseOpAdapter { return list; } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits>>, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>>, const AnyTraits>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(DEBUG) << "Value: " << value->type_name(); @@ -776,20 +776,20 @@ class OpAdapter : public BaseOpAdapter { } auto vec = value->cast(); std::vector list; - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { MS_EXCEPTION_IF_NULL(it); if (!it->isa()) { MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); } auto sub_vector = it->cast(); - for (auto& item : sub_vector->value()) { + for (auto &item : sub_vector->value()) { list.push_back(static_cast(GetValue(item))); } } return list; } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits>, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>, const AnyTraits>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(INFO) << "Value: " << value->type_name(); @@ -797,7 +797,7 @@ class OpAdapter : public BaseOpAdapter { if (value->isa()) { auto vec = value->cast(); MS_EXCEPTION_IF_NULL(vec); - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { list.push_back(static_cast(GetValue(it))); } return list; @@ -809,17 +809,17 @@ class OpAdapter : public BaseOpAdapter { MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); } - static std::string ConvertAny(const ValuePtr& value, const AnyTraits> anyTraitsVec, + static std::string ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsStr) { return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits> anyTraitsVec, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsFlo) { return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); } - static std::vector ConvertAny(const ValuePtr& value, const std::string& format, + static std::vector ConvertAny(const ValuePtr &value, const std::string &format, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsInt) { return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); @@ -827,12 +827,12 @@ class OpAdapter : public BaseOpAdapter { // convert value list for value tuple to vector template - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits

& anyTraitsP, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits

&anyTraitsP, const AnyTraits> anyTraitsQ) { return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); } - static int64_t ConvertAny(const ValuePtr& value, const AnyTraits) { + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { auto name = GetValue(value); auto it = enum_map_.find(name); int v = 0; @@ -842,12 +842,12 @@ class OpAdapter : public BaseOpAdapter { return v; } - static GeDataType ConvertAny(const ValuePtr& value, const AnyTraits anyTraitsGE) { + static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsGE) { return ConvertAnyUtil(value, anyTraitsGE); } // convert any value to tensor - static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits anyTraitsValue) { + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsValue) { return ConvertAnyUtil(value, anyTraitsValue); } diff --git a/mindspore/ccsrc/transform/op_adapter_base.h b/mindspore/ccsrc/transform/op_adapter_base.h index 99106b8761..01f96e251d 100644 --- a/mindspore/ccsrc/transform/op_adapter_base.h +++ b/mindspore/ccsrc/transform/op_adapter_base.h @@ -48,15 +48,17 @@ namespace ge { class CustomOperator : public Operator { public: - CustomOperator(const string& name, const string& type) : Operator(name, type) {} + CustomOperator(const string &name, const string &type) : Operator(name, type) {} ~CustomOperator() override{}; - void CustomInputRegister(const string& name) { Operator::InputRegister(name); } + void CustomInputRegister(const string &name) { Operator::InputRegister(name); } - void CustomOutputRegister(const string& name) { Operator::OutputRegister(name); } + void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } - void CustomInferFuncRegister(const std::function& func) { Operator::InferFuncRegister(func); } + void CustomInferFuncRegister(const std::function &func) { + Operator::InferFuncRegister(func); + } }; } // namespace ge @@ -69,7 +71,7 @@ struct OutHandler { OperatorPtr op; std::string out; OutHandler() : op(nullptr), out("") {} - OutHandler(const OperatorPtr& op, const std::string out) : op(op), out(out) {} + OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} }; struct ControlEdge { @@ -119,33 +121,33 @@ struct DynOutputDesc { class BaseOpAdapter { public: virtual ~BaseOpAdapter() {} - virtual OperatorPtr generate(const AnfNodePtr& anf) = 0; - virtual OperatorPtr generate(const std::string& type) { return std::make_shared(type); } - virtual int setInput(const OperatorPtr& op, int index, const OperatorPtr& input) = 0; - virtual int setInput(const OperatorPtr& op, int index, const OutHandler& handle) = 0; - virtual int setInput(const OperatorPtr& op, int index, - const std::shared_ptr>& handler_vec) = 0; - virtual int setAttr(const OperatorPtr& op, const std::string& attrKey, const ValuePtr& attrValue) = 0; - virtual int setAttr(const OperatorPtr& op, const PrimitivePtr& prim) = 0; - virtual int setAttr(const OperatorPtr& op, const AnfNodePtr& node) = 0; + virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; + virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } + virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; + virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; + virtual int setInput(const OperatorPtr &op, int index, + const std::shared_ptr> &handler_vec) = 0; + virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; + virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; + virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; virtual std::unordered_map GetExtraAttr() = 0; template ::value>::type> - int setAttr(const OperatorPtr& op, const std::string& attrKey, const std::shared_ptr& attrValue) { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } template ::value>::type> - int setAttr(const OperatorPtr& op, const std::string& attrKey, const T& attrValue) { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } - virtual OutHandler getOutput(const OperatorPtr& op, int index) = 0; - virtual void updateOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type, - const AnfNodePtr& node) = 0; - virtual const std::unordered_map& getInputMap() = 0; - virtual const std::unordered_map& getInputAttrMap() = 0; - virtual const std::unordered_map& getDynInputMap() = 0; - virtual const std::unordered_map& getOutputMap() = 0; - void AddAttrToDrawGraph(const std::string& attr_str) { attrs_vec_.push_back(attr_str); } - const std::vector& GetAttrsFromDrawGraph() const { return attrs_vec_; } + virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; + virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) = 0; + virtual const std::unordered_map &getInputMap() = 0; + virtual const std::unordered_map &getInputAttrMap() = 0; + virtual const std::unordered_map &getDynInputMap() = 0; + virtual const std::unordered_map &getOutputMap() = 0; + void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } + const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } void clearAttrVect() { attrs_vec_.clear(); } private: diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index d52699fa8f..203acac10f 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace transform { -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits&) { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &) { // To-DO the format may read from ME tensor MS_EXCEPTION_IF_NULL(value); auto me_tensor = value->cast(); @@ -33,24 +33,30 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits ConvertAnyUtil(const ValuePtr& value, const std::string& name, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>) { - int64_t data = GetValue(value); + MS_EXCEPTION_IF_NULL(value); std::vector list; - int size = 2; // 2 int in list if (name == "pad") { - size = 4; // 4 int in list - list = TransformUtil::ConvertIntToList(data, size); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); + } + auto vec = value->cast(); + list.resize(vec->value().size()+2); list[0] = 1; list[1] = 1; + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin()+2, + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); } else { + int64_t data = GetValue(value); + int size = 2; // 2 int in list list = TransformUtil::ConvertIntToList(data, size); } return list; } -std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits) { +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); if (nullptr == vec) { @@ -58,7 +64,7 @@ std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraitsvalue()) { + for (auto &it : vec->value()) { if (i != 0) { buffer << ","; } @@ -68,7 +74,7 @@ std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits) { +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); if (nullptr == vec) { @@ -77,11 +83,11 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const AnyTraits list; list.resize(vec->value().size()); (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr& val) { return static_cast(GetValue(val)); }); + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); return list; } -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& format, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); @@ -91,7 +97,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& fo std::vector list; list.resize(vec->value().size()); (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr& val) { return static_cast(GetValue(val)); }); + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); if (format == kOpFormat_NHWC) { if (list.size() < 4) { MS_LOG(EXCEPTION) << "The size of list is less than 4"; @@ -105,7 +111,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& fo return list; } -GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); if (!value->isa()) { MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString() @@ -120,7 +126,7 @@ GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { return TransformUtil::ConvertDataType(me_type); } -GeTensor VectorToTensorUtil(const ValuePtr& value) { +GeTensor VectorToTensorUtil(const ValuePtr &value) { // convert tuple or list to ge tensor, only supported one dim for now MS_EXCEPTION_IF_NULL(value); auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); @@ -136,7 +142,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); } else if (vec[0]->isa()) { MS_LOG(INFO) << "convert value to tensor with data type = Float32"; auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); @@ -144,7 +150,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); } else if (vec[0]->isa()) { MS_LOG(INFO) << "convert value to tensor with data type = Bool"; // We use uint8_t to save bool type data @@ -153,7 +159,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); + return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); } else { MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name(); } @@ -161,7 +167,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { return GeTensor(); } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); if (value->isa()) { // convert me tensor to ge tensor @@ -174,28 +180,28 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); } else if (value->isa()) { // convert scalar Int64 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = Int64"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); } else if (value->isa()) { // convert scalar FP32 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = FP32"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); } else if (value->isa()) { // convert scalar FP32 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = Bool"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); } else if (value->isa()) { // convert String to GeTensor MS_LOG(INFO) << "convert string to tensor with data type = String"; @@ -213,7 +219,7 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { return GeTensor(); } -bool IsCustomPrim(const PrimitivePtr& prim) { +bool IsCustomPrim(const PrimitivePtr &prim) { if (prim == nullptr) { return false; } @@ -232,7 +238,7 @@ bool IsCustomPrim(const PrimitivePtr& prim) { return is_custom_op; } -bool IsCustomCNode(const AnfNodePtr& anf) { +bool IsCustomCNode(const AnfNodePtr &anf) { if (anf == nullptr) { return false; } diff --git a/mindspore/ccsrc/transform/op_adapter_util.h b/mindspore/ccsrc/transform/op_adapter_util.h index 0cb6c763b2..fcabc732d5 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.h +++ b/mindspore/ccsrc/transform/op_adapter_util.h @@ -25,42 +25,42 @@ namespace mindspore { namespace transform { template -static Q ConvertAnyUtil(const ValuePtr& value, const AnyTraits

&, const AnyTraits&) { +static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits

&, const AnyTraits &) { return static_cast(GetValue

(value)); } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits& traits); +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &traits); -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& name, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>); -std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits); +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); -std::vector ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits); +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& format, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, const AnyTraits>, const AnyTraits); -GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits); +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits); template -std::vector ConvertAnyUtil(const ValuePtr& value, AnyTraits

, const AnyTraits>) { +std::vector ConvertAnyUtil(const ValuePtr &value, AnyTraits

, const AnyTraits>) { if (!value->isa() && !value->isa()) { MS_LOG(EXCEPTION) << "error convert Value to vector for value: " << value->ToString() << ", type: " << value->type_name() << ", value should be a tuple or list"; } auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); std::vector data; - for (auto& it : vec) { + for (auto &it : vec) { data.push_back(ConvertAnyUtil(it, AnyTraits

(), AnyTraits())); } return data; } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits); +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); -bool IsCustomPrim(const PrimitivePtr& prim); -bool IsCustomCNode(const AnfNodePtr& node); +bool IsCustomPrim(const PrimitivePtr &prim); +bool IsCustomCNode(const AnfNodePtr &node); } // namespace transform } // namespace mindspore #endif // TRANSFORM_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 858b9b6b39..5ec54b2037 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -155,13 +155,14 @@ OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(batch_mean)}, {2, OUTPUT_DESC(batch_variance)}, {3, OUTPUT_DESC(reserve_space_1)}, - {4, OUTPUT_DESC(reserve_space_2)}, - {5, OUTPUT_DESC(reserve_space_3)}}; + {4, OUTPUT_DESC(reserve_space_2)}}; // BatchNormGrad -INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, {2, INPUT_DESC(x)}, - {3, INPUT_DESC(scale)}, {4, INPUT_DESC(reserve_space_1)}, - {5, INPUT_DESC(reserve_space_2)}, {6, INPUT_DESC(reserve_space_3)}}; +INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, + {2, INPUT_DESC(x)}, + {3, INPUT_DESC(scale)}, + {4, INPUT_DESC(reserve_space_1)}, + {5, INPUT_DESC(reserve_space_2)}}; ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, {"is_training", ATTR_DESC(is_training, AnyTraits())}}; @@ -266,11 +267,6 @@ INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_D ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; -// ReduceSum -INPUT_MAP(ReduceSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}}; -ATTR_MAP(ReduceSum) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}}; - // ReduceSumD INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; INPUT_ATTR_MAP(ReduceSumD) = { @@ -451,17 +447,17 @@ INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}}; ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; -// ResizeNearestNeighborD -INPUT_MAP(ResizeNearestNeighborD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeNearestNeighborD) = { +// ResizeNearestNeighborV2D +INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeNearestNeighborV2D) = { {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborD) = {{0, OUTPUT_DESC(y)}}; +OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}}; -// ResizeNearestNeighborGrad -INPUT_MAP(ResizeNearestNeighborGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; -ATTR_MAP(ResizeNearestNeighborGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborGrad) = {{0, OUTPUT_DESC(y)}}; +// ResizeNearestNeighborV2Grad +INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; +ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}}; // ApplyAdam INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, @@ -486,17 +482,17 @@ INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; -// ResizeBilinearGrad -INPUT_MAP(ResizeBilinearGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; -ATTR_MAP(ResizeBilinearGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearGrad) = {{0, OUTPUT_DESC(y)}}; +// ResizeBilinearV2Grad +INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; +ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}}; -// ResizeBilinearD -INPUT_MAP(ResizeBilinearD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeBilinearD) = { +// ResizeBilinearV2D +INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeBilinearV2D) = { {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearD) = {{0, OUTPUT_DESC(y)}}; +OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}}; // ZerosLike INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; @@ -609,10 +605,12 @@ ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; -// ReduceAll -INPUT_MAP(ReduceAll) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}}; -ATTR_MAP(ReduceAll) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceAll) = {{0, OUTPUT_DESC(y)}}; +// ReduceAllD +INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceAllD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; // ReduceMeanD INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; @@ -720,11 +718,13 @@ OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; // Conv2D INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; -ATTR_MAP(Conv2D) = {{"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, - {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"group", ATTR_DESC(groups, AnyTraits())}}; +ATTR_MAP(Conv2D) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, +}; OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; // Conv2DBackpropInputD @@ -734,9 +734,10 @@ INPUT_ATTR_MAP(Conv2DBackpropInputD) = { ATTR_MAP(Conv2DBackpropInputD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"group", ATTR_DESC(groups, AnyTraits())}}; + {"group", ATTR_DESC(groups, AnyTraits())}, +}; OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; // Conv2DBackpropFilterD @@ -746,17 +747,18 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { ATTR_MAP(Conv2DBackpropFilterD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"group", ATTR_DESC(groups, AnyTraits())}}; + {"group", ATTR_DESC(groups, AnyTraits())}, +}; OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; // DepthwiseConv2D INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; ATTR_MAP(DepthwiseConv2D) = { - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, }; OUTPUT_MAP(DepthwiseConv2D) = {{0, OUTPUT_DESC(y)}}; @@ -766,9 +768,9 @@ INPUT_MAP(DepthwiseConv2DBackpropInputD) = {{2, INPUT_DESC(filter)}, {3, INPUT_D INPUT_ATTR_MAP(DepthwiseConv2DBackpropInputD) = { {1, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(DepthwiseConv2DBackpropInputD) = { - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, }; OUTPUT_MAP(DepthwiseConv2DBackpropInputD) = {{0, OUTPUT_DESC(input_grad)}}; @@ -777,9 +779,9 @@ INPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{1, INPUT_DESC(input)}, {3, INPUT_D INPUT_ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { {2, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, }; OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 9fbc97f3c9..59c95df8e1 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -116,20 +116,20 @@ DECLARE_OP_ADAPTER(Reshape) DECLARE_OP_USE_OUTPUT(Reshape) DECLARE_OP_ADAPTER(Iou) DECLARE_OP_USE_OUTPUT(Iou) -DECLARE_OP_ADAPTER(ResizeNearestNeighborD) -DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborD) -DECLARE_OP_ADAPTER(ResizeNearestNeighborGrad) -DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborGrad) +DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) +DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D) +DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) +DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) DECLARE_OP_ADAPTER(ApplyAdam) DECLARE_OP_USE_OUTPUT(ApplyAdam) DECLARE_OP_ADAPTER(Relu6) DECLARE_OP_USE_OUTPUT(Relu6) DECLARE_OP_ADAPTER(Relu6Grad) DECLARE_OP_USE_OUTPUT(Relu6Grad) -DECLARE_OP_ADAPTER(ResizeBilinearD) -DECLARE_OP_USE_OUTPUT(ResizeBilinearD) -DECLARE_OP_ADAPTER(ResizeBilinearGrad) -DECLARE_OP_USE_OUTPUT(ResizeBilinearGrad) +DECLARE_OP_ADAPTER(ResizeBilinearV2D) +DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D) +DECLARE_OP_ADAPTER(ResizeBilinearV2Grad) +DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad) DECLARE_OP_ADAPTER(ZerosLike) DECLARE_OP_USE_OUTPUT(ZerosLike) DECLARE_OP_ADAPTER(OnesLike) @@ -340,10 +340,9 @@ DECLARE_OP_USE_OUTPUT(Sin) DECLARE_OP_ADAPTER(Exp) DECLARE_OP_USE_OUTPUT(Exp) -DECLARE_OP_ADAPTER(ReduceAll) -DECLARE_OP_USE_OUTPUT(ReduceAll) -DECLARE_OP_ADAPTER(ReduceSum) -DECLARE_OP_USE_OUTPUT(ReduceSum) +DECLARE_OP_ADAPTER(ReduceAllD) +DECLARE_OP_USE_INPUT_ATTR(ReduceAllD) +DECLARE_OP_USE_OUTPUT(ReduceAllD) DECLARE_OP_ADAPTER(ReduceSumD) DECLARE_OP_USE_INPUT_ATTR(ReduceSumD) DECLARE_OP_USE_OUTPUT(ReduceSumD) diff --git a/mindspore/ccsrc/transform/util.cc b/mindspore/ccsrc/transform/util.cc index 0a18763d12..b1120ade6d 100644 --- a/mindspore/ccsrc/transform/util.cc +++ b/mindspore/ccsrc/transform/util.cc @@ -53,7 +53,7 @@ static std::map datatype_trans_map = { {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32}, {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}}; -GeDataType TransformUtil::ConvertDataType(const MeDataType& type) { +GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type"; if (datatype_trans_map.find(type) != datatype_trans_map.end()) { return datatype_trans_map[type]; @@ -70,7 +70,7 @@ static std::map datatype_size_map = { {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)}, {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}}; -size_t TransformUtil::GetDataTypeSize(const MeDataType& type) { +size_t TransformUtil::GetDataTypeSize(const MeDataType &type) { if (datatype_size_map.find(type) != datatype_size_map.end()) { return datatype_size_map[type]; } else { @@ -79,7 +79,7 @@ size_t TransformUtil::GetDataTypeSize(const MeDataType& type) { } } -GeFormat TransformUtil::ConvertFormat(const string& format) { +GeFormat TransformUtil::ConvertFormat(const string &format) { if (format == kOpFormat_NCHW) { return GeFormat::FORMAT_NCHW; } else if (format == kOpFormat_NC1HWC0) { @@ -95,8 +95,8 @@ GeFormat TransformUtil::ConvertFormat(const string& format) { static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } -std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector& me_shape, - const MeDataType& me_type, const std::string& format) { +std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, + const MeDataType &me_type, const std::string &format) { // convert me shape to ge shape std::vector ge_shape; @@ -135,8 +135,8 @@ std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector TransformUtil::ConvertInputTensors(const std::vector& me_tensors, - const std::string& format) { +std::vector TransformUtil::ConvertInputTensors(const std::vector &me_tensors, + const std::string &format) { std::vector ge_tensors; for (size_t index = 0; index < me_tensors.size(); index++) { @@ -163,7 +163,7 @@ std::vector TransformUtil::ConvertInputTensors(const std::vectordata_type()); @@ -192,15 +192,15 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr& tensor, const std::s MS_LOG(ERROR) << "Failed to get Tensor Desc"; return nullptr; } - GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); + GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); if (tensor_ptr != nullptr) { MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!"; } return tensor_ptr; } -std::vector TransformUtil::ConvertGeTensors(const std::vector& ge_tensors, - const std::vector>& request_dims) { +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims) { std::vector outputs; for (size_t index = 0; index < ge_tensors.size(); index++) { @@ -222,7 +222,7 @@ std::vector TransformUtil::ConvertGeTensors(const std::vector TransformUtil::ConvertGeTensors(const std::vector& ge_tensors) { +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors) { std::vector outputs; for (size_t index = 0; index < ge_tensors.size(); index++) { @@ -237,7 +237,7 @@ std::vector TransformUtil::ConvertGeTensors(const std::vector& request_dims) { +bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); @@ -311,20 +311,20 @@ bool IsGeShapeCompatible(const GeShape& ge_shape, const std::vector& reques } } // namespace -GeShape TransformUtil::ConvertMeShape(const std::vector& me_dims) { +GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { std::vector ge_dims; (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); return GeShape(ge_dims); } -std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape) { +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { std::vector me_dims; std::vector ge_dims = ge_shape.GetDims(); (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); return me_dims; } -std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape, const std::vector& request_dims) { +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { vector ret; if (ge_shape.GetDimNum() == 0) { MS_LOG(DEBUG) << "GeTensor's shape is scalar"; @@ -340,12 +340,12 @@ std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape, const st return ret; } -MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr& ge_tensor, const std::vector& me_dims, - const TypeId& me_type) { +MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type) { MeTensor me_tensor(me_type, me_dims); // Get the writable data pointer of the tensor and cast it to its data type - auto me_data_ptr = reinterpret_cast(me_tensor.data_c(true)); + auto me_data_ptr = reinterpret_cast(me_tensor.data_c(true)); size_t me_data_size = static_cast(me_tensor.data().nbytes()); MS_EXCEPTION_IF_NULL(me_data_ptr); MS_EXCEPTION_IF_NULL(ge_tensor); @@ -369,7 +369,7 @@ MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr& ge_tensor, const return make_shared(me_tensor); } -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr& ge_tensor) { +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { MS_EXCEPTION_IF_NULL(ge_tensor); GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); vector me_dims = ConvertGeShape(ge_shape); @@ -384,7 +384,7 @@ MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr& ge_tensor) { } // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector& request_dims) { +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { MS_EXCEPTION_IF_NULL(ge_tensor); GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); vector me_dims = ConvertGeShape(ge_shape, request_dims); diff --git a/mindspore/ccsrc/transform/util.h b/mindspore/ccsrc/transform/util.h index 9bcd8dc115..0f5d79f6a1 100644 --- a/mindspore/ccsrc/transform/util.h +++ b/mindspore/ccsrc/transform/util.h @@ -47,7 +47,7 @@ class TransformUtil { * Return: * [GeDataType] the data type for ge tensor * */ - static GeDataType ConvertDataType(const MeDataType& type); + static GeDataType ConvertDataType(const MeDataType &type); /* * Parameters: @@ -55,7 +55,7 @@ class TransformUtil { * Return: * [GeFormat] the data format for ge tensor * */ - static GeFormat ConvertFormat(const std::string& format); + static GeFormat ConvertFormat(const std::string &format); /* * Parameters: @@ -63,7 +63,7 @@ class TransformUtil { * Return: * [size_t] the buff size for the type in ME * */ - static size_t GetDataTypeSize(const MeDataType& type); + static size_t GetDataTypeSize(const MeDataType &type); /* * Parameters: @@ -73,8 +73,8 @@ class TransformUtil { * Return: * [shared_ptr] the shared pointer of ge tensor description * */ - static std::shared_ptr GetGeTensorDesc(const std::vector& shape, const MeDataType& me_type, - const std::string& format); + static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, + const std::string &format); /* * Parameters: @@ -84,7 +84,7 @@ class TransformUtil { * Return: * [GeTensor] the data tensor in GE * */ - static GeTensorPtr ConvertTensor(const MeTensorPtr& tensor, const std::string& format); + static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format); /* * Parameters: @@ -93,8 +93,8 @@ class TransformUtil { * Return: * [std::vector] the data tensors in GE * */ - static std::vector ConvertInputTensors(const std::vector& me_tensors, - const std::string& format); + static std::vector ConvertInputTensors(const std::vector &me_tensors, + const std::string &format); /* * Parameters: @@ -102,7 +102,7 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr ConvertGeTensor(const GeTensorPtr& tensor); + static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); /* * Parameters: @@ -111,7 +111,7 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector& request_dims); + static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE @@ -119,15 +119,15 @@ class TransformUtil { * Return: * [std::vector] the data tensor in ME * */ - static std::vector ConvertGeTensors(const std::vector& ge_tensors, - const std::vector>& request_dims); + static std::vector ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE * Return: * [std::vector] the data tensor in ME * */ - static std::vector ConvertGeTensors(const std::vector& ge_tensors); + static std::vector ConvertGeTensors(const std::vector &ge_tensors); /* * Parameters: * ge_tensor: [GeTensor] the data tensor in GE @@ -136,15 +136,15 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr GenerateMeTensor(const GeTensorPtr& ge_tensor, const std::vector& me_dims, - const TypeId& me_type); + static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type); /* * Parameters: * type: [GeDataType] the ge tensor data type * Return: * [MeDataType] the me tensor data type * */ - static MeDataType ConvertGeDataType(const GeDataType& type); + static MeDataType ConvertGeDataType(const GeDataType &type); /* * Parameters: @@ -152,7 +152,7 @@ class TransformUtil { * Return: * [GeShape] the ge shape * */ - static GeShape ConvertMeShape(const std::vector& me_dims); + static GeShape ConvertMeShape(const std::vector &me_dims); /* * Parameters: @@ -160,7 +160,7 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape& ge_shape); + static std::vector ConvertGeShape(const GeShape &ge_shape); /* Function: * Convert GeShape to Me request shape, Support pattern: @@ -176,7 +176,7 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape& ge_shape, const std::vector& request_dims); + static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); /* * Parameters: @@ -185,7 +185,7 @@ class TransformUtil { * [string] value string * */ template ::value>::type> - static std::string PrintVector(const std::vector& vec) { + static std::string PrintVector(const std::vector &vec) { const int MAX_PRINT_NUM = 100; std::stringstream ss; ss << "{ "; @@ -222,7 +222,7 @@ class TransformUtil { * [shared_ptr] vector pointer * */ template ::value>::type> - static std::vector MakeVector(const uint8_t* const data, size_t size) { + static std::vector MakeVector(const uint8_t *const data, size_t size) { auto dest = std::vector(size / sizeof(T)); if (data == nullptr) { return dest; diff --git a/mindspore/ccsrc/utils/any.cc b/mindspore/ccsrc/utils/any.cc index 31ee1fd302..3cb89f5dd7 100644 --- a/mindspore/ccsrc/utils/any.cc +++ b/mindspore/ccsrc/utils/any.cc @@ -21,7 +21,7 @@ namespace mindspore { // only support (int, float, bool) as Literal -bool AnyIsLiteral(const Any& any) { +bool AnyIsLiteral(const Any &any) { static const std::type_index typeid_int = std::type_index(typeid(int)); static const std::type_index typeid_float = std::type_index(typeid(float)); static const std::type_index typeid_bool = std::type_index(typeid(bool)); @@ -30,12 +30,12 @@ bool AnyIsLiteral(const Any& any) { return typeid_int == typeid_any || typeid_float == typeid_any || typeid_bool == typeid_any; } -std::ostream& operator<<(std::ostream& os, const pybind11::object& obj) { +std::ostream &operator<<(std::ostream &os, const pybind11::object &obj) { os << "[py::object]"; return os; } -Any& Any::operator=(const Any& other) { +Any &Any::operator=(const Any &other) { if (m_ptr == other.m_ptr || &other == this) { return *this; } @@ -44,9 +44,9 @@ Any& Any::operator=(const Any& other) { return *this; } -bool Any::operator<(const Any& other) const { return this < &other; } +bool Any::operator<(const Any &other) const { return this < &other; } -Any& Any::operator=(Any&& other) { +Any &Any::operator=(Any &&other) { if (this != &other) { if (m_ptr == other.m_ptr || &other == this) { return *this; diff --git a/mindspore/ccsrc/utils/any.h b/mindspore/ccsrc/utils/any.h index ce691f1c12..b4edf602ac 100644 --- a/mindspore/ccsrc/utils/any.h +++ b/mindspore/ccsrc/utils/any.h @@ -35,23 +35,23 @@ namespace mindspore { // usage:AnyPtr sp = std::make_shared(aname); template -std::string type(const T& t) { +std::string type(const T &t) { return demangle(typeid(t).name()); } -std::ostream& operator<<(std::ostream& os, const pybind11::object& obj); +std::ostream &operator<<(std::ostream &os, const pybind11::object &obj); class Any { public: // constructors Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {} - Any(const Any& other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} - Any(Any&& other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} + Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} + Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} - Any& operator=(Any&& other); + Any &operator=(Any &&other); // right reference constructor template ::type, Any>::value, T>::type> - Any(T&& t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT + Any(T &&t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT BasePtr new_val(new Derived::type>(std::forward(t))); std::swap(m_ptr, new_val); } @@ -67,7 +67,7 @@ class Any { return m_tpIndex == std::type_index(typeid(T)); } - const std::type_info& type() const { return m_ptr ? m_ptr->type() : typeid(void); } + const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); } std::size_t Hash() const { std::stringstream buffer; @@ -79,7 +79,7 @@ class Any { } template - bool Apply(const std::function& fn) { + bool Apply(const std::function &fn) { if (type() == typeid(T)) { T x = cast(); fn(x); @@ -96,23 +96,23 @@ class Any { } } - friend std::ostream& operator<<(std::ostream& os, const Any& any) { + friend std::ostream &operator<<(std::ostream &os, const Any &any) { os << any.GetString(); return os; } // type cast template - T& cast() const { + T &cast() const { if (!is() || !m_ptr) { // Use MS_LOGFATAL replace throw std::bad_cast() MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name(); } - auto ptr = static_cast*>(m_ptr.get()); + auto ptr = static_cast *>(m_ptr.get()); return ptr->m_value; } - bool operator==(const Any& other) const { + bool operator==(const Any &other) const { if (m_tpIndex != other.m_tpIndex) { return false; } @@ -125,11 +125,11 @@ class Any { return *m_ptr == *other.m_ptr; } - bool operator!=(const Any& other) const { return !(operator==(other)); } + bool operator!=(const Any &other) const { return !(operator==(other)); } - Any& operator=(const Any& other); + Any &operator=(const Any &other); - bool operator<(const Any& other) const; + bool operator<(const Any &other) const; std::string ToString() const { std::ostringstream buffer; @@ -154,26 +154,26 @@ class Any { // type base definition struct Base { - virtual const std::type_info& type() const = 0; + virtual const std::type_info &type() const = 0; virtual BasePtr clone() const = 0; virtual ~Base() = default; - virtual bool operator==(const Base& other) const = 0; + virtual bool operator==(const Base &other) const = 0; virtual std::string GetString() = 0; }; template struct Derived : public Base { template - explicit Derived(Args&&... args) : m_value(std::forward(args)...), serialize_cache_("") {} + explicit Derived(Args &&... args) : m_value(std::forward(args)...), serialize_cache_("") {} - bool operator==(const Base& other) const override { + bool operator==(const Base &other) const override { if (typeid(*this) != typeid(other)) { return false; } - return m_value == static_cast&>(other).m_value; + return m_value == static_cast &>(other).m_value; } - const std::type_info& type() const override { return typeid(T); } + const std::type_info &type() const override { return typeid(T); } BasePtr clone() const override { return BasePtr(new Derived(m_value)); } @@ -204,14 +204,14 @@ class Any { using AnyPtr = std::shared_ptr; struct AnyHash { - std::size_t operator()(const Any& c) const { return c.Hash(); } + std::size_t operator()(const Any &c) const { return c.Hash(); } }; struct AnyLess { - bool operator()(const Any& a, const Any& b) const { return a.Hash() < b.Hash(); } + bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); } }; -bool AnyIsLiteral(const Any& any); +bool AnyIsLiteral(const Any &any); } // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref.cc b/mindspore/ccsrc/utils/base_ref.cc index e50f0003b8..aa38c8a6a0 100644 --- a/mindspore/ccsrc/utils/base_ref.cc +++ b/mindspore/ccsrc/utils/base_ref.cc @@ -17,17 +17,17 @@ #include "utils/base_ref.h" namespace mindspore { -iterator ConstIteratorCast(std::vector* v, const const_iterator iter) { +iterator ConstIteratorCast(std::vector *v, const const_iterator iter) { return std::next(v->begin(), std::distance(v->cbegin(), iter)); } -BaseRef::BaseRef(const BaseRef& other) : Base(other), m_ptr(other.m_ptr) { +BaseRef::BaseRef(const BaseRef &other) : Base(other), m_ptr(other.m_ptr) { if (!m_ptr) { m_ptr = other.copy(); } } -bool BaseRef::operator==(const BaseRef& other) const { +bool BaseRef::operator==(const BaseRef &other) const { if (m_ptr == other.m_ptr) { return true; } @@ -55,7 +55,7 @@ bool BaseRef::operator==(const BaseRef& other) const { } // left reference -BaseRef& BaseRef::operator=(const BaseRef& other) { +BaseRef &BaseRef::operator=(const BaseRef &other) { if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { return *this; } @@ -64,7 +64,7 @@ BaseRef& BaseRef::operator=(const BaseRef& other) { } // right reference -BaseRef& BaseRef::operator=(BaseRef&& other) { +BaseRef &BaseRef::operator=(BaseRef &&other) { if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { return *this; } @@ -88,7 +88,7 @@ uint32_t BaseRef::type() const { } // left reference -SetRef& SetRef::operator=(const SetRef& other) { +SetRef &SetRef::operator=(const SetRef &other) { if (elements_ == other.elements_ || this == &other) { return *this; } @@ -100,7 +100,7 @@ std::string SetRef::ToString() const { std::ostringstream buffer; bool begin = true; buffer << "set["; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -113,7 +113,7 @@ std::string SetRef::ToString() const { } // left reference -VectorRef& VectorRef::operator=(const VectorRef& other) { +VectorRef &VectorRef::operator=(const VectorRef &other) { if (elements_ == other.elements_ || this == &other) { return *this; } @@ -125,7 +125,7 @@ std::string VectorRef::ToString() const { std::ostringstream buffer; bool begin = true; buffer << "vector["; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -137,14 +137,14 @@ std::string VectorRef::ToString() const { return buffer.str(); } -bool VectorRef::operator==(const BaseRef& other) const { +bool VectorRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool VectorRef::operator==(const VectorRef& other) const { +bool VectorRef::operator==(const VectorRef &other) const { if (elements_.size() != other.elements_.size()) { return false; } @@ -156,14 +156,14 @@ bool VectorRef::operator==(const VectorRef& other) const { return true; } -bool SetRef::operator==(const BaseRef& other) const { +bool SetRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool SetRef::operator==(const SetRef& other) const { +bool SetRef::operator==(const SetRef &other) const { if (elements_.size() != other.elements_.size()) { return false; } @@ -177,21 +177,21 @@ bool SetRef::operator==(const SetRef& other) const { return true; } -bool RunFunctionRef::operator==(const BaseRef& other) const { +bool RunFunctionRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool RunFunctionRef::operator==(const RunFunctionRef& other) const { return func_ == other.func_; } +bool RunFunctionRef::operator==(const RunFunctionRef &other) const { return func_ == other.func_; } -bool PyObjectRef::operator==(const BaseRef& other) const { +bool PyObjectRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool PyObjectRef::operator==(const PyObjectRef& other) const { return object_ == other.object_; } +bool PyObjectRef::operator==(const PyObjectRef &other) const { return object_ == other.object_; } } // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref.h b/mindspore/ccsrc/utils/base_ref.h index ed00d8280c..6e7911d0d9 100644 --- a/mindspore/ccsrc/utils/base_ref.h +++ b/mindspore/ccsrc/utils/base_ref.h @@ -40,7 +40,7 @@ using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; using const_reverse_iterator = std::vector::const_reverse_iterator; -using RunFunc = std::function; +using RunFunc = std::function; using RunFuncPtr = std::shared_ptr; template @@ -54,9 +54,9 @@ using is_value = std::is_base_of>; template using is_base_ref = std::is_base_of>; -iterator ConstIteratorCast(std::vector* v, const_iterator iter); +iterator ConstIteratorCast(std::vector *v, const_iterator iter); -inline std::shared_ptr MakeNode(const std::vector& elements) { +inline std::shared_ptr MakeNode(const std::vector &elements) { return std::make_shared(elements); } @@ -68,34 +68,34 @@ inline std::shared_ptr MakeNode(std::initializer_list elemen template >::value && is_base::value, int>::type = 0> -inline BasePtr MakeNode(const T& v) { +inline BasePtr MakeNode(const T &v) { return v; } template >::value && !is_base_ref::value, int>::type = 0> -inline BasePtr MakeNode(const T& v) { +inline BasePtr MakeNode(const T &v) { return MakeValue(v); } -inline std::shared_ptr MakeNode(const VectorRef& a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const AnfNodePtrList& a) { +inline std::shared_ptr MakeNode(const VectorRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const AnfNodePtrList &a) { std::vector ret; - (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr& v) { return v; }); + (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; }); return std::make_shared(ret); } -inline std::shared_ptr MakeNode(const SetRef& a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const RunFuncPtr& a) { return std::make_shared(a); } -inline std::shared_ptr MakeNode(const py::object& a) { return std::make_shared(a); } -inline std::shared_ptr MakeNode(const py::tuple& a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const SetRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const RunFuncPtr &a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const py::object &a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const py::tuple &a) { return std::make_shared(a); } class BaseRef : public Base { public: BaseRef() : m_ptr(nullptr) {} - BaseRef(const BaseRef& other); + BaseRef(const BaseRef &other); virtual std::shared_ptr copy() const { return m_ptr; } - BaseRef(BaseRef&& other) : Base(other) { + BaseRef(BaseRef &&other) : Base(other) { m_ptr = other.m_ptr; other.m_ptr = nullptr; } @@ -103,7 +103,7 @@ class BaseRef : public Base { // right reference constructor template ::type, BaseRef>::value, T>::type> - BaseRef(T&& t) { // NOLINT + BaseRef(T &&t) { // NOLINT m_ptr = MakeNode(t); } @@ -111,14 +111,14 @@ class BaseRef : public Base { MS_DECLARE_PARENT(BaseRef, Base) - bool operator!=(const BaseRef& other) const { return !(operator==(other)); } + bool operator!=(const BaseRef &other) const { return !(operator==(other)); } - virtual bool operator==(const BaseRef& other) const; + virtual bool operator==(const BaseRef &other) const; // left reference - virtual BaseRef& operator=(const BaseRef& other); + virtual BaseRef &operator=(const BaseRef &other); // right reference - virtual BaseRef& operator=(BaseRef&& other); + virtual BaseRef &operator=(BaseRef &&other); std::size_t hash() const override { if (m_ptr == nullptr) { @@ -139,18 +139,18 @@ class BaseRef : public Base { using BaseRefPtr = std::shared_ptr; struct BaseRefHash { - std::size_t operator()(const BaseRef& c) const { return c.hash(); } + std::size_t operator()(const BaseRef &c) const { return c.hash(); } }; struct BaseRefLess { - bool operator()(const BaseRef& a, const BaseRef& b) const { return a.hash() < b.hash(); } + bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); } }; namespace utils { // judge isa relation // examples: isa(handle), isa(handle) template ::value && !is_base_ref::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (!handle.m_ptr) { return false; } @@ -160,7 +160,7 @@ bool isa(const BaseRef& handle) { // noderef isa ptr isa(x) or isa() template ::value, typename T::element_type>::type, typename std::enable_if::value || is_base_ref::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return typeid(handle.m_ptr) == typeid(T); } @@ -175,7 +175,7 @@ bool isa(const BaseRef& handle) { // isa(handle) template ::type::element_type> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return false; } @@ -184,7 +184,7 @@ bool isa(const BaseRef& handle) { // isa(handle), judge reference or ptr template ::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { static const uint32_t tid = Base::GetTypeId(typeid(T).name()); return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa()); } @@ -192,7 +192,7 @@ bool isa(const BaseRef& handle) { // valueref -> C++ type // cast(handle) template ::value && !is_shared_ptr::value, int>::type = 0> -T cast(const BaseRef& handle) { +T cast(const BaseRef &handle) { T ret = GetValue(std::static_pointer_cast(handle.m_ptr)); return std::move(ret); } @@ -200,12 +200,12 @@ T cast(const BaseRef& handle) { // valueref -> valueref type // cast(handle) template ::value, int>::type = 0> -const T& cast(const BaseRef& handle) { +const T &cast(const BaseRef &handle) { if (handle.m_ptr) { - return static_cast(*handle.m_ptr); + return static_cast(*handle.m_ptr); } - return std::move(static_cast(handle)); + return std::move(static_cast(handle)); } // valueref -> nodeptr type @@ -213,7 +213,7 @@ const T& cast(const BaseRef& handle) { template ::value, typename T::element_type>::type, typename std::enable_if::value && std::is_base_of::value, int>::type = 0> -T cast(const BaseRef& handle) { +T cast(const BaseRef &handle) { if (!handle.m_ptr) { MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null"; } @@ -229,11 +229,11 @@ T cast(const BaseRef& handle) { class VectorRef : public BaseRef { public: VectorRef() {} - explicit VectorRef(const std::vector& elements) : elements_(elements) {} - VectorRef(const const_iterator& begin, const const_iterator& end) : elements_(begin, end) {} + explicit VectorRef(const std::vector &elements) : elements_(elements) {} + VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} // left reference - virtual VectorRef& operator=(const VectorRef& other); + virtual VectorRef &operator=(const VectorRef &other); ~VectorRef() override = default; @@ -244,7 +244,7 @@ class VectorRef : public BaseRef { std::size_t size() const { return elements_.size(); } MS_DECLARE_PARENT(VectorRef, BaseRef) - const BaseRef& operator[](const std::size_t& dim) const { + const BaseRef &operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "Out of the size of the tuple."; } @@ -253,17 +253,17 @@ class VectorRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override; - std::vector& elements() { return elements_; } + std::vector &elements() { return elements_; } void clear() { elements_.clear(); } - bool operator==(const BaseRef& other) const override; - bool operator==(const VectorRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const VectorRef &other) const; - void push_back(const BaseRef& value) { elements_.push_back(value); } - void push_back(BaseRef&& value) { elements_.push_back(value); } + void push_back(const BaseRef &value) { elements_.push_back(value); } + void push_back(BaseRef &&value) { elements_.push_back(value); } - void emplace_back(const BaseRef& value) { elements_.emplace_back(value); } - void emplace_back(BaseRef&& value) { elements_.emplace_back(value); } + void emplace_back(const BaseRef &value) { elements_.emplace_back(value); } + void emplace_back(BaseRef &&value) { elements_.emplace_back(value); } template void insert(const iterator pos, const InputIt first, const InputIt last) { @@ -308,21 +308,21 @@ using set_iterator = std::set::iterator; using const_set_iterator = std::set::const_iterator; struct VectorRefHash { - std::size_t operator()(const VectorRef& c) const { return c.hash(); } + std::size_t operator()(const VectorRef &c) const { return c.hash(); } }; class SetRef : public BaseRef { public: SetRef() {} - explicit SetRef(const std::set& elements) : elements_(elements) {} + explicit SetRef(const std::set &elements) : elements_(elements) {} SetRef(const std::initializer_list elements) : elements_(elements.begin(), elements.end()) {} - SetRef(const const_set_iterator& begin, const const_set_iterator& end) : elements_(begin, end) {} + SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {} // left reference - virtual SetRef& operator=(const SetRef& other); + virtual SetRef &operator=(const SetRef &other); - bool operator==(const BaseRef& other) const override; - bool operator==(const SetRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const SetRef &other) const; ~SetRef() override = default; @@ -335,10 +335,10 @@ class SetRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override; - std::set& elements() { return elements_; } + std::set &elements() { return elements_; } void clear() { elements_.clear(); } - void insert(const BaseRef& elem) { (void)elements_.insert(elem); } + void insert(const BaseRef &elem) { (void)elements_.insert(elem); } const_set_iterator begin() const { return elements_.begin(); } const_set_iterator end() const { return elements_.end(); } @@ -348,8 +348,8 @@ class SetRef : public BaseRef { (void)elements_.insert(first, last); } - std::size_t count(const BaseRef& elem) const { return elements_.count(elem); } - const_set_iterator find(const BaseRef& elem) const { return elements_.find(elem); } + std::size_t count(const BaseRef &elem) const { return elements_.count(elem); } + const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); } std::set elements_; }; @@ -358,8 +358,8 @@ using SetRefPtr = std::shared_ptr; class PyObjectRef : public BaseRef { public: - explicit PyObjectRef(const py::object& py_object) : object_(py_object) {} - explicit PyObjectRef(const py::tuple& tuple_obj) : object_(tuple_obj) {} + explicit PyObjectRef(const py::object &py_object) : object_(py_object) {} + explicit PyObjectRef(const py::tuple &tuple_obj) : object_(tuple_obj) {} ~PyObjectRef() override = default; @@ -368,8 +368,8 @@ class PyObjectRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override { return py::str(object_); } - bool operator==(const BaseRef& other) const override; - bool operator==(const PyObjectRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const PyObjectRef &other) const; py::object object_; }; @@ -377,15 +377,15 @@ class PyObjectRef : public BaseRef { class RunFunctionRef : public BaseRef { public: RunFunctionRef() {} - explicit RunFunctionRef(const RunFuncPtr& ref_func) : func_(ref_func) {} + explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {} ~RunFunctionRef() override = default; MS_DECLARE_PARENT(RunFunctionRef, BaseRef) uint32_t type() const override { return tid(); } std::string ToString() const override { return std::string("RunFunctionRef"); } - bool operator==(const BaseRef& other) const override; - bool operator==(const RunFunctionRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const RunFunctionRef &other) const; RunFuncPtr func_; }; diff --git a/mindspore/ccsrc/utils/callbacks.cc b/mindspore/ccsrc/utils/callbacks.cc index 03c6322afe..06bf1c73ab 100644 --- a/mindspore/ccsrc/utils/callbacks.cc +++ b/mindspore/ccsrc/utils/callbacks.cc @@ -37,14 +37,14 @@ const int ONE_SHAPE = 1; // Cache the summary callback data from ME session // Remove the GE module on new architecture // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...] -uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; py::list summary_list = py::list(); MS_LOG(INFO) << "The Summary save callback function for graph " << graph_id << ", Param list size = " << params_list.size() << "."; - for (auto& item : params_list) { + for (auto &item : params_list) { std::string tag_name = item.first; auto tensor_ptr = item.second; if (tensor_ptr == nullptr) { diff --git a/mindspore/ccsrc/utils/callbacks.h b/mindspore/ccsrc/utils/callbacks.h index a1e4e75d5b..9f46df0414 100644 --- a/mindspore/ccsrc/utils/callbacks.h +++ b/mindspore/ccsrc/utils/callbacks.h @@ -39,9 +39,9 @@ extern const std::string kPythonCheckpointFuncName; const int kCallbackOk = 0; const int kCallbackFalied = 1; -bool GetParameterShape(const FuncGraphPtr& anf_graph, const std::string& param_name, - const std::shared_ptr>& shape); -uint32_t SummarySaveCallback(uint32_t, const std::map&); +bool GetParameterShape(const FuncGraphPtr &anf_graph, const std::string ¶m_name, + const std::shared_ptr> &shape); +uint32_t SummarySaveCallback(uint32_t, const std::map &); } // namespace callbacks } // namespace mindspore diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 36bbcbf297..b4c9fda634 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -35,15 +35,15 @@ const int ONE_SHAPE = 1; using mindspore::transform::Status; using mindspore::transform::TransformUtil; -bool GetParameterShape(const FuncGraphPtr& graph, const std::string& param_name, - const std::shared_ptr>& shape) { +bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, + const std::shared_ptr> &shape) { if (graph == nullptr) { MS_LOG(ERROR) << "Graph is null, can not get graph parameter"; return false; } auto parameter_nodes = graph->parameters(); - for (auto& node : parameter_nodes) { + for (auto &node : parameter_nodes) { ParameterPtr param_node = std::static_pointer_cast(node); if (param_node == nullptr) { MS_LOG(ERROR) << "Parameter node is null, can not get graph parameter"; @@ -65,8 +65,8 @@ bool GetParameterShape(const FuncGraphPtr& graph, const std::string& param_name, return false; } -static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string& parameter_name, - const std::shared_ptr& ge_tensor_ptr) { +static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string ¶meter_name, + const std::shared_ptr &ge_tensor_ptr) { FuncGraphPtr anf_graph = transform::DfGraphManager::GetInstance().GetAnfGraph(graph_id); if (anf_graph == nullptr) { MS_LOG(ERROR) << "Get anf graph failed during callback"; @@ -82,13 +82,13 @@ static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string& pa return TransformUtil::ConvertGeTensor(ge_tensor_ptr, *parameter_shape_ptr); } -uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; MS_LOG(DEBUG) << "Start the checkpoint save callback function in checkpoint save process."; py::list parameter_list = py::list(); - for (auto& item : params_list) { + for (auto &item : params_list) { std::string name = item.first; std::shared_ptr ge_tensor_ptr = std::make_shared(item.second); TensorPtr tensor_ptr = GetMeTensorTransformed(graph_id, name, ge_tensor_ptr); @@ -112,7 +112,7 @@ uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map& ge_tensor_ptr) { +static TensorPtr GetMeTensorForSummary(const std::string &name, const std::shared_ptr &ge_tensor_ptr) { // confirm the type by name // Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor] if (name.empty()) { @@ -149,14 +149,14 @@ static TensorPtr GetMeTensorForSummary(const std::string& name, const std::share // Cache the summary callback data // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...] -uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; MS_LOG(DEBUG) << "Start the summary save callback function for graph " << graph_id << "."; py::list summary_list = py::list(); MS_LOG(DEBUG) << "Param list size = " << params_list.size(); - for (auto& item : params_list) { + for (auto &item : params_list) { std::string tag_name = item.first; std::shared_ptr ge_tensor_ptr = std::make_shared(item.second); TensorPtr tensor_ptr = GetMeTensorForSummary(tag_name, ge_tensor_ptr); diff --git a/mindspore/ccsrc/utils/callbacks_ge.h b/mindspore/ccsrc/utils/callbacks_ge.h index 750ec74666..08f5bb59db 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.h +++ b/mindspore/ccsrc/utils/callbacks_ge.h @@ -29,8 +29,8 @@ namespace callbacks { using mindspore::tensor::TensorPtr; -uint32_t CheckpointSaveCallback(uint32_t, const std::map&); -uint32_t SummarySaveCallback(uint32_t, const std::map&); +uint32_t CheckpointSaveCallback(uint32_t, const std::map &); +uint32_t SummarySaveCallback(uint32_t, const std::map &); } // namespace callbacks } // namespace mindspore diff --git a/mindspore/ccsrc/utils/config_manager.cc b/mindspore/ccsrc/utils/config_manager.cc index 6d66b37436..7dc559b20e 100644 --- a/mindspore/ccsrc/utils/config_manager.cc +++ b/mindspore/ccsrc/utils/config_manager.cc @@ -22,12 +22,12 @@ namespace mindspore { -ConfigManager& ConfigManager::GetInstance() noexcept { +ConfigManager &ConfigManager::GetInstance() noexcept { static ConfigManager instance; return instance; } -void ConfigManager::SetDatasetModeConfig(const std::string& mode) { +void ConfigManager::SetDatasetModeConfig(const std::string &mode) { static const std::map mode_map = {{"normal", DS_NORMAL_MODE}, {"sink", DS_SINK_MODE}}; if (mode_map.find(mode) == mode_map.end()) { MS_LOG(ERROR) << "Invalid dataset mode:" << mode; diff --git a/mindspore/ccsrc/utils/config_manager.h b/mindspore/ccsrc/utils/config_manager.h index db7d7d0c14..635f24792a 100644 --- a/mindspore/ccsrc/utils/config_manager.h +++ b/mindspore/ccsrc/utils/config_manager.h @@ -37,8 +37,8 @@ enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE }; class DatasetGraphParam { public: - DatasetGraphParam(const std::string& name, int64_t size, int64_t batch_size, const std::vector& ge_types, - const std::vector>& shapes, const std::vector& input_indexes) + DatasetGraphParam(const std::string &name, int64_t size, int64_t batch_size, const std::vector &ge_types, + const std::vector> &shapes, const std::vector &input_indexes) : queue_name_(name), loop_size_(size), batch_size_(batch_size), @@ -72,15 +72,15 @@ class DatasetGraphParam { class ConfigManager { public: - ConfigManager(const ConfigManager&) = delete; - ConfigManager& operator=(const ConfigManager&) = delete; - static ConfigManager& GetInstance() noexcept; + ConfigManager(const ConfigManager &) = delete; + ConfigManager &operator=(const ConfigManager &) = delete; + static ConfigManager &GetInstance() noexcept; ParallelStrategy parallel_strategy() const { return parallel_strategy_; } void set_parallel_strategy(ParallelStrategy strategy) { parallel_strategy_ = strategy; } - const std::map& ge_initialize_options() const { return ge_initialize_options_; } - void set_ge_initialize_options(const std::map& options) { + const std::map &ge_initialize_options() const { return ge_initialize_options_; } + void set_ge_initialize_options(const std::map &options) { ge_initialize_options_ = options; } @@ -90,12 +90,12 @@ class ConfigManager { void set_iter_num(const int64_t num) { iter_num_ = num; } std::string dataset_phase() const { return dataset_phase_; } - void set_dataset_phase(const std::string& phase) { dataset_phase_ = phase; } + void set_dataset_phase(const std::string &phase) { dataset_phase_ = phase; } DatasetGraphParam dataset_param() const { return dataset_param_; } - void set_dataset_param(const DatasetGraphParam& param) { dataset_param_ = param; } + void set_dataset_param(const DatasetGraphParam ¶m) { dataset_param_ = param; } - static void SetDatasetModeConfig(const std::string& mode); + static void SetDatasetModeConfig(const std::string &mode); void ResetConfig() noexcept; diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index bee5875f60..0a2f065140 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -45,7 +45,7 @@ std::map MsContext::policy_map_ = {{"ge", kMsBacke {"ge_only", kMsBackendGeOnly}, {"vm_prior", kMsBackendVmPrior}}; -MsContext::MsContext(const std::string& policy, const std::string& target) { +MsContext::MsContext(const std::string &policy, const std::string &target) { save_graphs_flag_ = false; save_graphs_path_ = "."; save_ms_model_flag_ = false; @@ -97,7 +97,7 @@ std::shared_ptr MsContext::GetInstance() { return inst_context_; } -bool MsContext::set_backend_policy(const std::string& policy) { +bool MsContext::set_backend_policy(const std::string &policy) { if (policy_map_.find(policy) == policy_map_.end()) { MS_LOG(ERROR) << "invalid backend policy name: " << policy; return false; @@ -110,7 +110,7 @@ bool MsContext::set_backend_policy(const std::string& policy) { std::string MsContext::backend_policy() const { auto res = std::find_if( policy_map_.begin(), policy_map_.end(), - [&, this](const std::pair& item) { return item.second == backend_policy_; }); + [&, this](const std::pair &item) { return item.second == backend_policy_; }); if (res != policy_map_.end()) { return res->first; } @@ -124,7 +124,7 @@ void MsContext::set_execution_mode(int execution_mode) { execution_mode_ = execution_mode; } -bool MsContext::set_device_target(const std::string& target) { +bool MsContext::set_device_target(const std::string &target) { if (kTargetSet.find(target) == kTargetSet.end()) { MS_LOG(ERROR) << "invalid device target name: " << target; return false; @@ -218,7 +218,7 @@ bool MsContext::CloseTsd(bool force) { MS_LOG(INFO) << "join tdt host receive process"; tdt_print_.join(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); } #endif @@ -241,7 +241,7 @@ bool MsContext::OpenTsd() { return true; } bool MsContext::CloseTsd(bool) { return true; } #endif -void MsContext::SetHcclOptions(std::map* ge_options) const { +void MsContext::SetHcclOptions(std::map *ge_options) const { auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); auto env_rank_id = common::GetEnv("RANK_ID"); auto env_device_id = std::to_string(device_id_); @@ -274,7 +274,7 @@ void MsContext::SetHcclOptions(std::map* ge_options) c } } -void MsContext::GetGeOptions(std::map* ge_options) const { +void MsContext::GetGeOptions(std::map *ge_options) const { #ifdef ENABLE_GE (*ge_options)["device_id"] = "0"; (*ge_options)["ge.exec.enableDump"] = std::to_string(enable_dump_); @@ -365,7 +365,7 @@ void MsContext::GetGeOptions(std::map* ge_options) con #endif } -void MsContext::SetDisableReuseMemoryFlag(std::map* ge_options) const { +void MsContext::SetDisableReuseMemoryFlag(std::map *ge_options) const { auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY"); if (!env_disable_reuse_memory.empty()) { (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory; @@ -412,7 +412,7 @@ bool MsContext::FinalizeGe(bool force) { try { DfGraphManager::GetInstance().DeleteGraphRunner(); DfGraphManager::GetInstance().DeleteGeSession(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what(); } catch (...) { std::string exName(abi::__cxa_current_exception_type()->name()); diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 06704ff9c6..1d84061a8a 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -48,13 +48,13 @@ const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, class MsContext { public: ~MsContext() = default; - MsContext(const MsContext&) = delete; - MsContext& operator=(const MsContext&) = delete; + MsContext(const MsContext &) = delete; + MsContext &operator=(const MsContext &) = delete; static std::shared_ptr GetInstance(); std::string backend_policy() const; - bool set_backend_policy(const std::string& policy); + bool set_backend_policy(const std::string &policy); int execution_mode() const { return execution_mode_; } void set_execution_mode(int execution_mode); @@ -69,7 +69,7 @@ class MsContext { bool precompile_only() const { return precompile_only_; } std::string device_target() const { return device_target_; } - bool set_device_target(const std::string& target); + bool set_device_target(const std::string &target); uint32_t device_id() const { return device_id_; } bool set_device_id(uint32_t device_id); @@ -78,7 +78,7 @@ class MsContext { void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } std::string save_graphs_path() const { return save_graphs_path_; } - void set_save_graphs_path(const std::string& save_paths) { save_graphs_path_ = save_paths; } + void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; } bool OpenTsd(); bool CloseTsd(bool force = false); @@ -101,7 +101,7 @@ class MsContext { void set_save_ms_model_flag(bool save_ms_model_flag) { save_ms_model_flag_ = save_ms_model_flag; } std::string save_ms_model_path() const { return save_ms_model_path_; } - void set_save_ms_model_path(const std::string& save_ms_model_path) { save_ms_model_path_ = save_ms_model_path; } + void set_save_ms_model_path(const std::string &save_ms_model_path) { save_ms_model_path_ = save_ms_model_path; } void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } bool enable_gpu_summary() const { return enable_gpu_summary_; } @@ -117,7 +117,7 @@ class MsContext { void set_enable_dump(bool flag) { enable_dump_ = flag; } bool enable_dump() const { return enable_dump_; } - void set_save_dump_path(const std::string& path) { save_dump_path_ = path; } + void set_save_dump_path(const std::string &path) { save_dump_path_ = path; } std::string save_dump_path() const { return save_dump_path_; } bool IsTsdOpened() const { return tsd_ref_ > 0; } @@ -128,19 +128,19 @@ class MsContext { void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; } bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; } - void set_graph_memory_max_size(const std::string& graph_memory_max_size) { + void set_graph_memory_max_size(const std::string &graph_memory_max_size) { graph_memory_max_size_ = graph_memory_max_size; } - void set_variable_memory_max_size(const std::string& variable_memory_max_size) { + void set_variable_memory_max_size(const std::string &variable_memory_max_size) { variable_memory_max_size_ = variable_memory_max_size; } private: - MsContext(const std::string& backend_policy, const std::string& target); - void GetGeOptions(std::map* ge_options) const; - void SetDisableReuseMemoryFlag(std::map* ge_options) const; - void SetHcclOptions(std::map* ge_options) const; + MsContext(const std::string &backend_policy, const std::string &target); + void GetGeOptions(std::map *ge_options) const; + void SetDisableReuseMemoryFlag(std::map *ge_options) const; + void SetHcclOptions(std::map *ge_options) const; static std::shared_ptr inst_context_; static std::map policy_map_; diff --git a/mindspore/ccsrc/utils/contract.h b/mindspore/ccsrc/utils/contract.h index fc257b3e24..6ef9928241 100644 --- a/mindspore/ccsrc/utils/contract.h +++ b/mindspore/ccsrc/utils/contract.h @@ -28,6 +28,7 @@ class ContractError : public std::logic_error { public: explicit ContractError(const std::string &msg) : std::logic_error(msg) {} explicit ContractError(const char *msg) : std::logic_error(msg) {} + ~ContractError() override = default; }; struct Signatory { @@ -60,6 +61,7 @@ class Ensures : public EnsuresAccess { } template >> Ensures(const Ensures &other) : value_(other.get()) {} + ~Ensures() = default; T get() const { return value_; } T &get() { return value_; } diff --git a/mindspore/ccsrc/utils/counter.h b/mindspore/ccsrc/utils/counter.h index 891f9c7a35..ead0ad84f2 100644 --- a/mindspore/ccsrc/utils/counter.h +++ b/mindspore/ccsrc/utils/counter.h @@ -29,17 +29,17 @@ class Counter { Counter() = default; ~Counter() = default; - Counter(const Counter& other) { data = other.data; } - Counter& operator=(const Counter& other) { + Counter(const Counter &other) { data = other.data; } + Counter &operator=(const Counter &other) { if (this != &other) { data = other.data; } return *this; } - int& operator[](const T& t) { return data[t]; } + int &operator[](const T &t) { return data[t]; } - counter_type operator-(const counter_type& other) { + counter_type operator-(const counter_type &other) { counter_type new_counter; for (auto iter = begin(); iter != end(); ++iter) { auto key = iter->first; @@ -58,7 +58,7 @@ class Counter { return new_counter; } - counter_type operator+(const counter_type& other) { + counter_type operator+(const counter_type &other) { counter_type new_counter; for (auto iter = begin(); iter != end(); ++iter) { auto key = iter->first; @@ -84,7 +84,7 @@ class Counter { std::size_t size() const { return data.size(); } - bool contains(const T& t) const { return data.find(t) != data.end(); } + bool contains(const T &t) const { return data.find(t) != data.end(); } typename OrderedMap::iterator begin() { return data.begin(); } diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc index 55ef8dc3d5..0801622549 100644 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ b/mindspore/ccsrc/utils/graph_utils.cc @@ -39,10 +39,10 @@ using SymbolicKeyTypePtr = std::shared_ptr; namespace { class DeepFirstSearcher : public AnfVisitor { public: - explicit DeepFirstSearcher(const IncludeFunc& include) : include_(include) {} + explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {} ~DeepFirstSearcher() override = default; - std::vector Search(const AnfNodePtr& root) { + std::vector Search(const AnfNodePtr &root) { if (root == nullptr) { return res_; } @@ -50,7 +50,7 @@ class DeepFirstSearcher : public AnfVisitor { return res_; } - void Visit(const AnfNodePtr& node) override { + void Visit(const AnfNodePtr &node) override { MS_EXCEPTION_IF_NULL(node); if (seen_.count(node) != 0) { return; @@ -77,10 +77,10 @@ class DeepFirstSearcher : public AnfVisitor { class DeepScopedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepScopedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepScopedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { + void Visit(const CNodePtr &cnode) override { if (cnode->func_graph() == nullptr) { return; } @@ -90,13 +90,13 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { DeepFirstSearcher::Visit(ret); } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr& vnode) override { + void Visit(const ValueNodePtr &vnode) override { if (!IsValueNode(vnode)) { return; } @@ -108,7 +108,7 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { } } - void Visit(const ParameterPtr& param) override { + void Visit(const ParameterPtr ¶m) override { if (param->func_graph() == nullptr) { return; } @@ -122,17 +122,17 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { class DeepUsedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepUsedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepUsedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { - auto& inputs = cnode->inputs(); + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr& vnode) override { + void Visit(const ValueNodePtr &vnode) override { if (!IsValueNode(vnode)) { return; } @@ -147,33 +147,33 @@ class DeepUsedGraphSearcher : public DeepFirstSearcher { class DeepLinkedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepLinkedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepLinkedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { - auto& inputs = cnode->inputs(); + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr&) override {} + void Visit(const ValueNodePtr &) override {} }; } // namespace -std::vector DeepScopedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepScopedGraphSearcher(include).Search(root); } -std::vector DeepUsedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepUsedGraphSearcher(include).Search(root); } -std::vector DeepLinkedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepLinkedGraphSearcher(include).Search(root); } -std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ, const IncludeFunc& include) { +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { std::unordered_set done; std::list todo(1, root); std::unordered_map rank; @@ -222,7 +222,7 @@ std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ, c return res; } -std::vector SuccDeeper(const AnfNodePtr& node) { +std::vector SuccDeeper(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; @@ -237,7 +237,7 @@ std::vector SuccDeeper(const AnfNodePtr& node) { return vecs; } else if (node->func_graph() != nullptr) { if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } auto graph = node->func_graph(); @@ -250,7 +250,7 @@ std::vector SuccDeeper(const AnfNodePtr& node) { return vecs; } -std::vector SuccDeeperSimple(const AnfNodePtr& node) { +std::vector SuccDeeperSimple(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; @@ -265,39 +265,39 @@ std::vector SuccDeeperSimple(const AnfNodePtr& node) { return vecs; } else { if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } return vecs; } } -std::vector SuccIncoming(const AnfNodePtr& node) { +std::vector SuccIncoming(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; } if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } return vecs; } -std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& node) { +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; } if (node->isa()) { auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); // Check if free variables used. - for (const auto& input : inputs) { + for (const auto &input : inputs) { auto input_fg = GetValueNode(input); if (input_fg) { - for (auto& fv : input_fg->free_variables_nodes()) { + for (auto &fv : input_fg->free_variables_nodes()) { if (fv->func_graph() == fg && fg->nodes().contains(fv)) { vecs.push_back(fv); } @@ -309,9 +309,9 @@ std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& return vecs; } -IncludeType AlwaysInclude(const AnfNodePtr&) { return FOLLOW; } +IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } -IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node) { +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { if (node->func_graph() == fg) { return FOLLOW; } else { @@ -319,12 +319,12 @@ IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node) { } } -FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search, const IncludeFunc& include) { +FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) { MS_EXCEPTION_IF_NULL(fg); Acquire(fg); auto vec = search(fg->get_return(), include); - for (auto& node : vec) { + for (auto &node : vec) { MS_EXCEPTION_IF_NULL(node); Acquire(node); if (node->func_graph() != nullptr) { @@ -333,7 +333,7 @@ FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search, } } -std::set FuncGraphIndex::GetFuncGraphs(const std::string& key) { +std::set FuncGraphIndex::GetFuncGraphs(const std::string &key) { std::set func_graphs; if (index_func_graph_.find(key) != index_func_graph_.end()) { func_graphs = index_func_graph_[key]; @@ -341,7 +341,7 @@ std::set FuncGraphIndex::GetFuncGraphs(const std::string& key) { return func_graphs; } -std::set FuncGraphIndex::GetNodes(const std::string& key) { +std::set FuncGraphIndex::GetNodes(const std::string &key) { if (index_node_.find(key) != index_node_.end()) { return index_node_[key]; } @@ -349,7 +349,7 @@ std::set FuncGraphIndex::GetNodes(const std::string& key) { return std::set(); } -FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string& key) { +FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) { if (GetFuncGraphs(key).empty()) { return nullptr; } @@ -358,7 +358,7 @@ FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string& key) { return fg; } -AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string& key) { +AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) { if (GetNodes(key).empty()) { return nullptr; } @@ -367,14 +367,14 @@ AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string& key) { return node; } -void FuncGraphIndex::Acquire(const FuncGraphPtr& key) { +void FuncGraphIndex::Acquire(const FuncGraphPtr &key) { std::string name = label_manage::Label(key->debug_info()); if (!name.empty()) { (void)index_func_graph_[name].insert(key); } } -void FuncGraphIndex::Acquire(const AnfNodePtr& key) { +void FuncGraphIndex::Acquire(const AnfNodePtr &key) { std::string name = label_manage::Label(key->debug_info()); if (!name.empty()) { (void)index_node_[name].insert(key); @@ -382,8 +382,8 @@ void FuncGraphIndex::Acquire(const AnfNodePtr& key) { } // Isomorphism -static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { if (equiv_node == nullptr) { MS_LOG(ERROR) << "Invalid equiv_node"; return false; @@ -419,13 +419,13 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu return false; } -static bool SameNode(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { MS_EXCEPTION_IF_NULL(node1); MS_EXCEPTION_IF_NULL(node2); if (node1->isa() && node2->isa()) { - auto& inputs1 = node1->cast()->inputs(); - auto& inputs2 = node2->cast()->inputs(); + auto &inputs1 = node1->cast()->inputs(); + auto &inputs2 = node2->cast()->inputs(); for (std::size_t i = 0; i < inputs1.size(); ++i) { if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) { return false; @@ -436,8 +436,8 @@ static bool SameNode(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraph return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node); } -static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { std::unordered_set done; std::stack> todo; @@ -479,8 +479,8 @@ static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEqu return true; } -bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { auto fg1_fg2 = std::make_pair(fg1, fg2); if (equiv_func_graph == nullptr) { MS_LOG(ERROR) << "equiv_func_graph not init"; @@ -511,7 +511,7 @@ bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv* equiv return false; } -tensor::TensorPtr ScalarToTensor(const ScalarPtr& scalar) { +tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { if (scalar == nullptr) { MS_EXCEPTION(ArgumentError) << "Nullptr Error!"; } diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index 57bc0e42fc..d01335af82 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -38,42 +38,42 @@ namespace mindspore { enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; -using IncludeFunc = std::function; +using IncludeFunc = std::function; using SuccFunc = std::function(AnfNodePtr)>; -using SearchFunc = std::function(const AnfNodePtr&, const IncludeFunc&)>; +using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; -std::vector SuccDeeper(const AnfNodePtr& node); -std::vector SuccDeeperSimple(const AnfNodePtr& node); -std::vector SuccIncoming(const AnfNodePtr& node); -std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& node); +std::vector SuccDeeper(const AnfNodePtr &node); +std::vector SuccDeeperSimple(const AnfNodePtr &node); +std::vector SuccIncoming(const AnfNodePtr &node); +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); -IncludeType AlwaysInclude(const AnfNodePtr& node); -IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node); +IncludeType AlwaysInclude(const AnfNodePtr &node); +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); -std::vector DeepScopedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); -std::vector DeepUsedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); -std::vector DeepLinkedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); -std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ = SuccIncoming, - const IncludeFunc& include = AlwaysInclude); +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, + const IncludeFunc &include = AlwaysInclude); class FuncGraphIndex { public: - explicit FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search = DeepScopedGraphSearch, - const IncludeFunc& include = AlwaysInclude); - FuncGraphIndex(const FuncGraphIndex&) = delete; - FuncGraphIndex& operator=(const FuncGraphIndex&) = delete; + explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, + const IncludeFunc &include = AlwaysInclude); + FuncGraphIndex(const FuncGraphIndex &) = delete; + FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; virtual ~FuncGraphIndex() {} - std::set GetFuncGraphs(const std::string& key); - std::set GetNodes(const std::string& key); - FuncGraphPtr GetFirstFuncGraph(const std::string& key); - AnfNodePtr GetFirstNode(const std::string& key); + std::set GetFuncGraphs(const std::string &key); + std::set GetNodes(const std::string &key); + FuncGraphPtr GetFirstFuncGraph(const std::string &key); + AnfNodePtr GetFirstNode(const std::string &key); private: - void Acquire(const FuncGraphPtr& key); - void Acquire(const AnfNodePtr& key); + void Acquire(const FuncGraphPtr &key); + void Acquire(const AnfNodePtr &key); std::map> index_func_graph_; std::map> index_node_; @@ -83,7 +83,7 @@ class FuncGraphIndex { struct PairHasher { template - std::size_t operator()(const std::pair& p) const { + std::size_t operator()(const std::pair &p) const { auto h1 = std::hash{}(p.first); auto h2 = std::hash{}(p.second); return h1 ^ h2; @@ -95,9 +95,9 @@ enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 }; using FuncGraphPairMapEquiv = std::unordered_map, EquivState, PairHasher>; using NodeMapEquiv = std::unordered_map; -bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv* equiv_func_graph, NodeMapEquiv* equiv_node); +bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); -tensor::TensorPtr ScalarToTensor(const ScalarPtr& scalar); +tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ diff --git a/mindspore/ccsrc/utils/hashing.h b/mindspore/ccsrc/utils/hashing.h index 730657ce7a..cc8cc5b991 100644 --- a/mindspore/ccsrc/utils/hashing.h +++ b/mindspore/ccsrc/utils/hashing.h @@ -25,7 +25,7 @@ inline std::size_t hash_combine(std::size_t hash_sum, std::size_t hash_val) { return ((hash_sum << 6) + (hash_sum >> 2) + 0x9e3779b9 + hash_val) ^ hash_sum; } -inline std::size_t hash_combine(const std::initializer_list& hash_vals) { +inline std::size_t hash_combine(const std::initializer_list &hash_vals) { std::size_t hash_sum = 0; for (auto hash_val : hash_vals) { hash_sum = hash_combine(hash_sum, hash_val); diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index 704ab24d52..0cd9b64a9b 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -179,7 +179,7 @@ void LogWriter::operator^(const LogStream &stream) const { std::ostringstream oss; oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; - if (exception_type_ != NoExceptionType) { + if (exception_type_ != NoExceptionType && exception_type_ != TypeError && exception_type_ != ValueError) { oss << ExceptionTypeToString(exception_type_) << " "; } oss << msg.str(); @@ -229,19 +229,29 @@ static void InitMsLogLevel() { extern "C" { // shared lib init hook +#if defined(_WIN32) || defined(_WIN64) +__attribute__((constructor)) void mindspore_log_init(void) { +#else void mindspore_log_init(void) { +#endif #ifdef USE_GLOG // do not use glog predefined log prefix FLAGS_log_prefix = false; static bool is_glog_initialzed = false; if (!is_glog_initialzed) { +#if !defined(_WIN32) && !defined(_WIN64) google::InitGoogleLogging("mindspore"); +#endif is_glog_initialzed = true; } // set default log level to WARNING if (mindspore::GetEnv("GLOG_v").empty()) { FLAGS_v = mindspore::WARNING; } + // set default log file mode to 0640 + if (mindspore::GetEnv("GLOG_logfile_mode").empty()) { + FLAGS_logfile_mode = 0640; + } // default print log to screen if (mindspore::GetEnv("GLOG_logtostderr").empty()) { FLAGS_logtostderr = true; diff --git a/mindspore/ccsrc/utils/misc.cc b/mindspore/ccsrc/utils/misc.cc index 47e675a341..a9eb8071ef 100644 --- a/mindspore/ccsrc/utils/misc.cc +++ b/mindspore/ccsrc/utils/misc.cc @@ -23,9 +23,9 @@ const int RET_FAILED = 1; const int RET_CONTINUE = 2; const int RET_BREAK = 3; -std::string demangle(const char* name) { +std::string demangle(const char *name) { int status = -1; - std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; + std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; return (status == 0) ? res.get() : name; } } // namespace mindspore diff --git a/mindspore/ccsrc/utils/misc.h b/mindspore/ccsrc/utils/misc.h index 66e8937f9c..e2cdebe98a 100644 --- a/mindspore/ccsrc/utils/misc.h +++ b/mindspore/ccsrc/utils/misc.h @@ -33,7 +33,7 @@ extern const int RET_CONTINUE; extern const int RET_BREAK; // demangle the name to make it human reablable. -extern std::string demangle(const char* name); +extern std::string demangle(const char *name); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_MISC_H_ diff --git a/mindspore/ccsrc/utils/ordered_set.h b/mindspore/ccsrc/utils/ordered_set.h index b22053f196..f393ce74f2 100644 --- a/mindspore/ccsrc/utils/ordered_set.h +++ b/mindspore/ccsrc/utils/ordered_set.h @@ -53,28 +53,28 @@ class OrderedSet { // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion, // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use // traversal to build elements. - OrderedSet(const OrderedSet& os) { - for (auto& item : os.ordered_data_) { + OrderedSet(const OrderedSet &os) { + for (auto &item : os.ordered_data_) { add(item); } } - explicit OrderedSet(const sequential_type& other) { - for (auto& item : other) { + explicit OrderedSet(const sequential_type &other) { + for (auto &item : other) { add(item); } } // Explicitly construct an OrderedSet use vector - explicit OrderedSet(const vector_type& other) { - for (auto& item : other) { + explicit OrderedSet(const vector_type &other) { + for (auto &item : other) { add(item); } } - OrderedSet& operator=(const OrderedSet& os) { + OrderedSet &operator=(const OrderedSet &os) { if (this != &os) { - for (auto& item : os.ordered_data_) { + for (auto &item : os.ordered_data_) { add(item); } } @@ -82,14 +82,14 @@ class OrderedSet { } // Add an element to the OrderedSet, without judging return value - void add(const element_type& e) { (void)insert(e); } + void add(const element_type &e) { (void)insert(e); } // insert an element to the OrderedSet - std::pair insert(const element_type& e) { + std::pair insert(const element_type &e) { iterator empty_itr; std::pair map_pair = std::make_pair(e, empty_itr); auto result = mapped_data_.insert(map_pair); - auto& seq_idx = result.first->second; + auto &seq_idx = result.first->second; // if insert success; if (result.second) { auto it = ordered_data_.insert(ordered_data_.end(), e); @@ -99,7 +99,7 @@ class OrderedSet { } // Remove an element, if removed return true, otherwise return false - bool erase(const element_type& e) { + bool erase(const element_type &e) { auto pos = mapped_data_.find(e); if (pos == mapped_data_.end()) { return false; @@ -119,7 +119,7 @@ class OrderedSet { std::string toString() { std::ostringstream res; res << "orderset content:\n"; - for (auto& item : ordered_data_) { + for (auto &item : ordered_data_) { res << std::to_string(reinterpret_cast(item.get())) << " "; } return res.str(); @@ -132,7 +132,7 @@ class OrderedSet { } // Compare two orderedset, if the order is not equal shall return false - bool operator==(const OrderedSet& other) const { return ordered_data_ == other.ordered_data_; } + bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } // Remove and return the first element in the OrderedSet T pop() { @@ -153,8 +153,8 @@ class OrderedSet { } // Return true if there are no common elements - bool is_disjoint(const OrderedSet& other) { - for (auto& item : other.ordered_data_) { + bool is_disjoint(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { if (mapped_data_.find(item) != mapped_data_.end()) { return false; } @@ -163,8 +163,8 @@ class OrderedSet { } // Test whether this is subset of other - bool is_subset(const OrderedSet& other) { - for (auto& item : ordered_data_) { + bool is_subset(const OrderedSet &other) { + for (auto &item : ordered_data_) { if (other.mapped_data_.find(item) == other.mapped_data_.end()) { return false; } @@ -173,51 +173,51 @@ class OrderedSet { } // Add elements in other to this orderedset - void update(const OrderedSet& other) { - for (auto& item : other.ordered_data_) { + void update(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { add(item); } } - void update(const std::shared_ptr& other) { update(*other); } + void update(const std::shared_ptr &other) { update(*other); } - void update(const sequential_type& other) { - for (auto& item : other) { + void update(const sequential_type &other) { + for (auto &item : other) { add(item); } } - void update(const vector_type& other) { - for (auto& item : other) { + void update(const vector_type &other) { + for (auto &item : other) { add(item); } } - ordered_set_type get_union(const OrderedSet& other) { + ordered_set_type get_union(const OrderedSet &other) { ordered_set_type res(ordered_data_); res.update(other); return res; } // Get the union with other set, this operator may cost time because of copy - ordered_set_type operator|(const OrderedSet& other) { return get_union(other); } + ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } // Return the intersection of two sets - ordered_set_type intersection(const OrderedSet& other) { + ordered_set_type intersection(const OrderedSet &other) { ordered_set_type res(ordered_data_); - for (auto& item : ordered_data_) { + for (auto &item : ordered_data_) { if (other.mapped_data_.find(item) == other.mapped_data_.end()) { (void)res.erase(item); } } return res; } - ordered_set_type operator&(const OrderedSet& other) { return intersection(other); } + ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } // Return the symmetric difference of two sets - ordered_set_type symmetric_difference(const OrderedSet& other) { + ordered_set_type symmetric_difference(const OrderedSet &other) { ordered_set_type res(ordered_data_); - for (auto& item : other.ordered_data_) { + for (auto &item : other.ordered_data_) { if (mapped_data_.find(item) != mapped_data_.end()) { (void)res.erase(item); } else { @@ -227,40 +227,40 @@ class OrderedSet { return res; } - ordered_set_type operator^(const OrderedSet& other) { return symmetric_difference(other); } + ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } // Remove elements which is also in others. - void difference_update(const OrderedSet& other) { + void difference_update(const OrderedSet &other) { // use vector traversal, to keep ordrer - for (auto& item : other.ordered_data_) { + for (auto &item : other.ordered_data_) { (void)erase(item); } } - void difference_update(const sequential_type& other) { - for (auto& item : other) { + void difference_update(const sequential_type &other) { + for (auto &item : other) { (void)erase(item); } } - void difference_update(const vector_type& other) { - for (auto& item : other) { + void difference_update(const vector_type &other) { + for (auto &item : other) { (void)erase(item); } } // Return the set with elements that are not in the others - ordered_set_type difference(const OrderedSet& other) { + ordered_set_type difference(const OrderedSet &other) { ordered_set_type res(ordered_data_); res.difference_update(other); return res; } - ordered_set_type operator-(const OrderedSet& other) { return difference(other); } + ordered_set_type operator-(const OrderedSet &other) { return difference(other); } - bool contains(const element_type& e) const { return (mapped_data_.find(e) != mapped_data_.end()); } + bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } // Return the count of an element in set - std::size_t count(const element_type& e) const { return mapped_data_.count(e); } + std::size_t count(const element_type &e) const { return mapped_data_.count(e); } iterator begin() { return ordered_data_.begin(); } iterator end() { return ordered_data_.end(); } diff --git a/mindspore/ccsrc/utils/profile.cc b/mindspore/ccsrc/utils/profile.cc index ba490549f8..e9e7920e0c 100644 --- a/mindspore/ccsrc/utils/profile.cc +++ b/mindspore/ccsrc/utils/profile.cc @@ -33,37 +33,43 @@ namespace { constexpr size_t TIME_INFO_PREFIX_NUM_LEN = 4; const char KEY_PROF_TOTAL[] = "__total__"; -void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent = 0, - std::map* sums = nullptr, const std::string& prefix = ""); - -void PrintTimeInfoMap(std::ostringstream& oss, const TimeInfoMap& dict, int indent = 0, - std::map* sums = nullptr, const std::string& prefix = "") { - for (auto iter = dict.begin(); iter != dict.end(); ++iter) { - if (iter->second == nullptr) { +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = ""); + +void PrintTimeInfoMap(std::ostringstream &oss, const TimeInfoMap &dict, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = "") { + size_t count = 0; + for (const auto &iter : dict) { + count++; + if (iter.second == nullptr) { continue; } // indent by multiples of 4 spaces. - auto name = iter->first.substr(TIME_INFO_PREFIX_NUM_LEN); + if (iter.first.size() < TIME_INFO_PREFIX_NUM_LEN) { + MS_LOG(EXCEPTION) << "In TimeInfoMap, the " << count << "th string key is " << iter.first + << ", but the length is less than " << TIME_INFO_PREFIX_NUM_LEN; + } + auto name = iter.first.substr(TIME_INFO_PREFIX_NUM_LEN); oss << std::setw(indent * 4) << "" - << "[" << name << "]: " << iter->second->time_; - if (iter->second->dict_ != nullptr) { - oss << ", [" << iter->second->dict_->size() << "]"; + << "[" << name << "]: " << iter.second->time_; + if (iter.second->dict_ != nullptr) { + oss << ", [" << iter.second->dict_->size() << "]"; } oss << "\n"; std::string newPrefix = prefix; - if (iter->first.find("Cycle ") == std::string::npos) { - newPrefix = prefix.empty() ? iter->first : prefix + "." + iter->first; + if (iter.first.find("Cycle ") == std::string::npos) { + newPrefix = prefix.empty() ? iter.first : prefix + "." + iter.first; } - PrintProfile(oss, *iter->second, indent + 1, sums, newPrefix); - if (iter->second->dict_ == nullptr) { - (*sums)[newPrefix] += iter->second->time_; + PrintProfile(oss, *iter.second, indent + 1, sums, newPrefix); + if (iter.second->dict_ == nullptr) { + (*sums)[newPrefix] += iter.second->time_; } } } -void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent, std::map* sums, - const std::string& prefix) { +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent, std::map *sums, + const std::string &prefix) { bool need_free = false; if (sums == nullptr) { sums = new (std::nothrow) std::map(); @@ -95,7 +101,7 @@ void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent } oss << "Sums\n"; if (total >= 0.0 + DBL_EPSILON) { - for (auto& iter : *sums) { + for (auto &iter : *sums) { std::string name = iter.first; name.erase(0, TIME_INFO_PREFIX_NUM_LEN); std::size_t pos = 0; @@ -159,7 +165,7 @@ void Profile::Print(void) { // Start a step in the current context with the given name. // Nomes must be unique otherwise the previous record will be overwritten. -ProfContext* Profile::Step(const std::string& name) { +ProfContext *Profile::Step(const std::string &name) { ctx_ptr_ = new (std::nothrow) ProfContext(name, this); if (ctx_ptr_ == nullptr) { MS_LOG(ERROR) << "memory allocation failed"; @@ -170,7 +176,7 @@ ProfContext* Profile::Step(const std::string& name) { // Creates subcontext for a repeated action. // Count should be monotonically increasing. -ProfContext* Profile::Lap(int count) { +ProfContext *Profile::Lap(int count) { std::ostringstream oss; oss << "Cycle " << count; ctx_ptr_ = new (std::nothrow) ProfContext(oss.str(), this); @@ -188,7 +194,7 @@ void Profile::Pop(void) noexcept { ctx_ptr_ = ctx_ptr_->parent_; } -ProfContext::ProfContext(const std::string& name, ProfileBase* const prof) : name_(name), prof_(prof) { +ProfContext::ProfContext(const std::string &name, ProfileBase *const prof) : name_(name), prof_(prof) { // Initialize a subcontext. time_info_ = nullptr; if (prof == nullptr || IsTopContext()) { @@ -227,7 +233,7 @@ void ProfContext::SetTime(double time) noexcept { time_info_->time_ = time; } -void ProfContext::Insert(const std::string& name, const TimeInfo* time) noexcept { +void ProfContext::Insert(const std::string &name, const TimeInfo *time) noexcept { if (time_info_ == nullptr) { time_info_ = new (std::nothrow) TimeInfo(); if (time_info_ == nullptr) { @@ -266,7 +272,7 @@ void ProfContext::Insert(const std::string& name, const TimeInfo* time) noexcept bool ProfContext::IsTopContext() const noexcept { return (prof_ != nullptr) && (this == &prof_->context_); } -ProfTransaction::ProfTransaction(const ProfileBase* prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } +ProfTransaction::ProfTransaction(const ProfileBase *prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } ProfTransaction::~ProfTransaction() { if (ctx_ != nullptr && !ctx_->IsTopContext()) { @@ -275,7 +281,7 @@ ProfTransaction::~ProfTransaction() { ctx_ = nullptr; } -void DumpTime::Record(const std::string& step_name, const double time, const bool is_start) { +void DumpTime::Record(const std::string &step_name, const double time, const bool is_start) { file_ss_ << " {" << std::endl; file_ss_ << " \"name\": " << "\"" << step_name << "\"," << std::endl; @@ -298,7 +304,7 @@ void DumpTime::Record(const std::string& step_name, const double time, const boo void DumpTime::Save() { try { file_out_.open(file_path_, std::ios::trunc | std::ios::out); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "Cannot open file in " << (file_path_); } file_out_ << "{\n"; @@ -317,10 +323,10 @@ struct TimeInfoGroup { std::list::const_iterator> items; }; -static void PrintTimeStat(std::ostringstream& oss, const TimeInfoGroup& group, const std::string& prefix) { +static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, const std::string &prefix) { oss << "------[" << prefix << "] " << std::setw(10) << std::fixed << std::setprecision(6) << group.total_time << std::setw(6) << group.total_count << "\n"; - for (const auto& iter : group.items) { + for (const auto &iter : group.items) { oss << std::setw(5) << std::fixed << std::setprecision(2) << iter->second.time_ / group.total_time * 100 << "% : " << std::setw(12) << std::fixed << std::setprecision(6) << iter->second.time_ << "s : " << std::setw(6) << iter->second.count_ << ": " << iter->first << "\n"; @@ -332,7 +338,7 @@ void MsProfile::Print() { std::vector items = {"substitution.", "renormalize.", "replace.", "match.", "func_graph_cloner_run.", "meta_graph.", "manager."}; std::vector groups(items.size() + 1); - const auto& stat = GetSingleton().time_stat_; + const auto &stat = GetSingleton().time_stat_; // group all time infos for (auto iter = stat.cbegin(); iter != stat.cend(); ++iter) { auto matched_idx = items.size(); diff --git a/mindspore/ccsrc/utils/profile.h b/mindspore/ccsrc/utils/profile.h index 6892b0b4f6..bd3723d5bb 100644 --- a/mindspore/ccsrc/utils/profile.h +++ b/mindspore/ccsrc/utils/profile.h @@ -27,7 +27,7 @@ namespace mindspore { struct TimeInfo; -using TimeInfoMap = std::map; +using TimeInfoMap = std::map; extern double GetTime(); @@ -35,11 +35,11 @@ class ProfileBase; struct TimeInfo { explicit TimeInfo(double time = -1.0) : time_(time), dict_(nullptr), actionNum_(0) {} - TimeInfo(const TimeInfo&) = delete; + TimeInfo(const TimeInfo &) = delete; ~TimeInfo(); double time_; - TimeInfoMap* dict_; + TimeInfoMap *dict_; size_t actionNum_; }; @@ -50,21 +50,21 @@ class ProfContext { friend class ProfTransaction; public: - ProfContext(const std::string& name, ProfileBase* prof); + ProfContext(const std::string &name, ProfileBase *prof); ~ProfContext(); - ProfContext(const ProfContext&) = delete; - ProfContext& operator=(const ProfContext&) = delete; + ProfContext(const ProfContext &) = delete; + ProfContext &operator=(const ProfContext &) = delete; void SetTime(double time) noexcept; - void Insert(const std::string& name, const TimeInfo* time) noexcept; + void Insert(const std::string &name, const TimeInfo *time) noexcept; bool IsTopContext() const noexcept; private: std::string name_; - ProfileBase* prof_; - ProfContext* parent_; - TimeInfo* time_info_; + ProfileBase *prof_; + ProfContext *parent_; + TimeInfo *time_info_; }; class ProfileBase { @@ -76,38 +76,38 @@ class ProfileBase { virtual ~ProfileBase(); virtual void Print(void) {} - virtual ProfContext* Step(const std::string&) { return nullptr; } - virtual ProfContext* Lap(int) { return nullptr; } + virtual ProfContext *Step(const std::string &) { return nullptr; } + virtual ProfContext *Lap(int) { return nullptr; } virtual void Pop(void) {} // top level profile context ProfContext context_; // profile context pointer, act as a stack pointer - ProfContext* ctx_ptr_ = nullptr; + ProfContext *ctx_ptr_ = nullptr; }; class Profile : public ProfileBase { public: Profile() = default; ~Profile() override = default; - Profile(const Profile&) = delete; - Profile& operator=(const Profile&) = delete; + Profile(const Profile &) = delete; + Profile &operator=(const Profile &) = delete; void Print(void) override; - ProfContext* Step(const std::string& name) override; - ProfContext* Lap(int count) override; + ProfContext *Step(const std::string &name) override; + ProfContext *Lap(int count) override; void Pop(void) noexcept override; }; class ProfTransaction { public: - explicit ProfTransaction(const ProfileBase* prof); - explicit ProfTransaction(ProfContext* const ctx) : ctx_(ctx) {} - ProfTransaction(const ProfTransaction&) = delete; + explicit ProfTransaction(const ProfileBase *prof); + explicit ProfTransaction(ProfContext *const ctx) : ctx_(ctx) {} + ProfTransaction(const ProfTransaction &) = delete; ~ProfTransaction(); template - void operator-(const Function& func) { + void operator-(const Function &func) { double start_time = GetTime(); func(); double end_time = GetTime(); @@ -117,17 +117,17 @@ class ProfTransaction { } private: - ProfContext* ctx_ = nullptr; + ProfContext *ctx_ = nullptr; }; class NoProfTransaction { public: - explicit NoProfTransaction(ProfileBase* prof) {} - explicit NoProfTransaction(ProfContext* ctx) {} + explicit NoProfTransaction(ProfileBase *prof) {} + explicit NoProfTransaction(ProfContext *ctx) {} ~NoProfTransaction() = default; template - void operator-(const Function& func) { + void operator-(const Function &func) { func(); } }; @@ -137,20 +137,20 @@ class DumpTime { ~DumpTime() { try { Save(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Cannot save file by profile::DumpTime::save"; } catch (...) { MS_LOG(ERROR) << "Uncaught exception"; } } - DumpTime(const DumpTime&) = delete; - DumpTime& operator=(const DumpTime&) = delete; - static DumpTime& GetInstance() { + DumpTime(const DumpTime &) = delete; + DumpTime &operator=(const DumpTime &) = delete; + static DumpTime &GetInstance() { static DumpTime instance; return instance; } - void set_file_path(const std::string& save_path) { file_path_ = save_path; } - void Record(const std::string& name, const double time, const bool is_start); + void set_file_path(const std::string &save_path) { file_path_ = save_path; } + void Record(const std::string &name, const double time, const bool is_start); void Save(); private: @@ -188,8 +188,8 @@ class MsProfile { static void Reset() { GetSingleton().Clear(); } - static ProfileBase* GetProfile() { - MsProfile& ms_prof = GetSingleton(); + static ProfileBase *GetProfile() { + MsProfile &ms_prof = GetSingleton(); if (ms_prof.profile_ == nullptr) { #ifdef ENABLE_PROFILE ms_prof.profile_ = new Profile(); @@ -199,14 +199,14 @@ class MsProfile { } return ms_prof.profile_; } - static void StatTime(const std::string& id, double time) { GetSingleton().time_stat_[id] += time; } + static void StatTime(const std::string &id, double time) { GetSingleton().time_stat_[id] += time; } static void Print(); private: MsProfile() = default; - static MsProfile& GetSingleton() { + static MsProfile &GetSingleton() { static MsProfile profile; return profile; } @@ -220,7 +220,7 @@ class MsProfile { } std::map time_stat_; // record time and count info from some activity - ProfileBase* profile_ = nullptr; // record hierarchical profile info + ProfileBase *profile_ = nullptr; // record hierarchical profile info }; } // namespace mindspore diff --git a/mindspore/ccsrc/utils/signal.h b/mindspore/ccsrc/utils/signal.h index af7b36a8b5..9a43e23814 100644 --- a/mindspore/ccsrc/utils/signal.h +++ b/mindspore/ccsrc/utils/signal.h @@ -24,14 +24,14 @@ namespace mindspore { template -std::function bind_member(Type* instance, Return (Type::*method)(Args...)) { - return [=](Args&&... args) -> Return { return (instance->*method)(std::forward(args)...); }; +std::function bind_member(Type *instance, Return (Type::*method)(Args...)) { + return [=](Args &&... args) -> Return { return (instance->*method)(std::forward(args)...); }; } template class Slot { public: - explicit Slot(const std::function& callback) : callback(callback) {} + explicit Slot(const std::function &callback) : callback(callback) {} ~Slot() {} @@ -42,15 +42,15 @@ template class Signal { public: template - void operator()(Args&&... args) { - for (auto& slot : slots_) { + void operator()(Args &&... args) { + for (auto &slot : slots_) { if (slot->callback != nullptr) { slot->callback(std::forward(args)...); } } } - void add_slot(const std::function& func) { + void add_slot(const std::function &func) { auto slot = std::make_shared>(func); slots_.push_back(slot); } diff --git a/mindspore/ccsrc/utils/symbolic.cc b/mindspore/ccsrc/utils/symbolic.cc index 8764678288..8ad16e50c8 100644 --- a/mindspore/ccsrc/utils/symbolic.cc +++ b/mindspore/ccsrc/utils/symbolic.cc @@ -22,29 +22,29 @@ namespace mindspore { -std::ostream& operator<<(std::ostream& out, const std::shared_ptr& objPtr) { +std::ostream &operator<<(std::ostream &out, const std::shared_ptr &objPtr) { out << "("; MS_EXCEPTION_IF_NULL(objPtr); - for (auto& iter : objPtr->contents_) { + for (auto &iter : objPtr->contents_) { out << iter.first << ":" << iter.second << ";"; } out << ")"; return out; } -bool EnvInstance::operator==(const EnvInstance& other) const { +bool EnvInstance::operator==(const EnvInstance &other) const { if (Len() != other.Len()) { return false; } bool equal = std::all_of(contents_.begin(), contents_.end(), - [&other](const std::pair& item) -> bool { + [&other](const std::pair &item) -> bool { return other.contents_.find(item.first) != other.contents_.end(); }); return equal; } -bool EnvInstance::operator==(const Value& other) const { +bool EnvInstance::operator==(const Value &other) const { if (other.isa()) { - auto other_env_inst = static_cast(&other); + auto other_env_inst = static_cast(&other); return *this == *other_env_inst; } return false; diff --git a/mindspore/ccsrc/utils/symbolic.h b/mindspore/ccsrc/utils/symbolic.h index 3c712483ee..a373c23573 100644 --- a/mindspore/ccsrc/utils/symbolic.h +++ b/mindspore/ccsrc/utils/symbolic.h @@ -32,18 +32,18 @@ namespace mindspore { class SymbolicKeyInstance : public Value { public: - SymbolicKeyInstance(const AnfNodePtr& node, const abstract::AbstractBasePtr& abstract) + SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) : node_(node), abstract_(abstract) {} ~SymbolicKeyInstance() override = default; MS_DECLARE_PARENT(SymbolicKeyInstance, Value); AnfNodePtr node() const { return node_; } abstract::AbstractBasePtr abstract() const { return abstract_; } - bool operator==(const SymbolicKeyInstance& other) const { + bool operator==(const SymbolicKeyInstance &other) const { return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); } std::size_t hash() const override { return std::hash{}(node_); } - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr& inst) { + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &inst) { if (inst == nullptr) { os << "[Key][" << "Invalid symbolic key instance" @@ -56,9 +56,9 @@ class SymbolicKeyInstance : public Value { std::string ToString() const override { return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); } - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -106,19 +106,19 @@ using EnvInstanceContentsMap = // with inferred properties. class EnvInstance : public Value { public: - friend std::ostream& operator<<(std::ostream& out, const std::shared_ptr& env); + friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr &env); - explicit EnvInstance(const EnvInstanceContentsMap& contents = {}) : contents_(contents) {} + explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} ~EnvInstance() override = default; MS_DECLARE_PARENT(EnvInstance, Value); abstract::AbstractBasePtr ToAbstract() override { return std::make_shared(shared_from_base(), std::make_shared()); } - bool operator==(const EnvInstance& other) const; - bool operator==(const Value& other) const override; - EnvInstance(const EnvInstance& v) : Value(v), contents_(v.contents_) {} - EnvInstance(EnvInstance&& v) = default; - EnvInstance& operator=(EnvInstance&& src) noexcept { + bool operator==(const EnvInstance &other) const; + bool operator==(const Value &other) const override; + EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} + EnvInstance(EnvInstance &&v) = default; + EnvInstance &operator=(EnvInstance &&src) noexcept { if (&src != this) { contents_ = src.contents_; } @@ -126,7 +126,7 @@ class EnvInstance : public Value { }; // Get the sensitivity list for the given key - const Any& Get(const SymbolicKeyInstancePtr& key, const Any& def) const { + const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { auto iterator = contents_.find(key); if (iterator != contents_.end()) { return iterator->second; @@ -135,14 +135,14 @@ class EnvInstance : public Value { } // Set a value for the given key. - EnvInstance Set(const SymbolicKeyInstancePtr& key, const Any& value) const { + EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { EnvInstance rval(contents_); rval.contents_[key] = value; return rval; } // Add two EnvInstances. - EnvInstance Add(const EnvInstance& other) const { + EnvInstance Add(const EnvInstance &other) const { EnvInstance rval(contents_); for (auto iter_other : other.contents_) { auto item_self = contents_.find(iter_other.first); diff --git a/mindspore/ccsrc/utils/system/base.h b/mindspore/ccsrc/utils/system/base.h index dace2e7178..4cfb5b312d 100644 --- a/mindspore/ccsrc/utils/system/base.h +++ b/mindspore/ccsrc/utils/system/base.h @@ -108,7 +108,7 @@ constexpr bool kLittleEndian = true; // implement common define function // Get the 32 bits align value -inline uint32 DecodeFixed32(const char* ptr) { +inline uint32 DecodeFixed32(const char *ptr) { uint32 result; if (EOK != memcpy_s(&result, sizeof(result), ptr, sizeof(result))) { MS_LOG(EXCEPTION) << "Call DecodeFixed32 memcpy value failure."; @@ -116,14 +116,14 @@ inline uint32 DecodeFixed32(const char* ptr) { return result; } // Used to fetch a naturally-aligned 32-bit word in little endian byte-order -inline uint32 LE_LOAD32(const uint8_t* p) { return DecodeFixed32(reinterpret_cast(p)); } +inline uint32 LE_LOAD32(const uint8_t *p) { return DecodeFixed32(reinterpret_cast(p)); } // Encode the data to buffer -inline void EncodeFixed32(char* buf, uint32 value) { +inline void EncodeFixed32(char *buf, uint32 value) { if (EOK != memcpy_s(buf, sizeof(value), &value, sizeof(value))) { MS_LOG(EXCEPTION) << "Call EncodeFixed32 memcpy value failure."; } } -inline void EncodeFixed64(char* buf, const unsigned int array_len, int64 value) { +inline void EncodeFixed64(char *buf, const unsigned int array_len, int64 value) { if (sizeof(value) > array_len) { MS_LOG(EXCEPTION) << "Buffer overflow, real size is " << array_len << ", but required " << sizeof(value) << "."; } diff --git a/mindspore/ccsrc/utils/system/crc32c.h b/mindspore/ccsrc/utils/system/crc32c.h index 4411423bab..d23b9ad463 100644 --- a/mindspore/ccsrc/utils/system/crc32c.h +++ b/mindspore/ccsrc/utils/system/crc32c.h @@ -40,10 +40,10 @@ class Crc32c { ~Crc32c() = default; // Calculate the crc32c value, use the 8 table method - static uint32 MakeCrc32c(uint32 init_crc, const char* data, size_t size); + static uint32 MakeCrc32c(uint32 init_crc, const char *data, size_t size); // retrun the crc32c value(need mask) - static uint32 GetMaskCrc32cValue(const char* data, size_t n) { + static uint32 GetMaskCrc32cValue(const char *data, size_t n) { auto crc = MakeCrc32c(0, data, n); // Rotate right by kRightShift bits and add kMaskDelta(a constant). return ((crc >> kRightShift) | (crc << kLeftShift)) + kMaskDelta; diff --git a/mindspore/ccsrc/utils/system/file_system.cc b/mindspore/ccsrc/utils/system/file_system.cc index aee89d4b7b..ce27108a39 100644 --- a/mindspore/ccsrc/utils/system/file_system.cc +++ b/mindspore/ccsrc/utils/system/file_system.cc @@ -25,7 +25,7 @@ namespace system { #if defined(SYSTEM_ENV_POSIX) // Implement the Posix file systen -WriteFilePtr PosixFileSystem::CreateWriteFile(const string& file_name) { +WriteFilePtr PosixFileSystem::CreateWriteFile(const string &file_name) { if (file_name.empty()) { MS_LOG(ERROR) << "Create write file failed because the file name is null."; return nullptr; @@ -43,7 +43,7 @@ WriteFilePtr PosixFileSystem::CreateWriteFile(const string& file_name) { return fp; } -bool PosixFileSystem::FileExist(const string& file_name) { +bool PosixFileSystem::FileExist(const string &file_name) { if (file_name.empty()) { MS_LOG(WARNING) << "The file name is null."; return false; @@ -56,7 +56,7 @@ bool PosixFileSystem::FileExist(const string& file_name) { return true; } -bool PosixFileSystem::DeleteFile(const string& file_name) { +bool PosixFileSystem::DeleteFile(const string &file_name) { if (file_name.empty()) { MS_LOG(WARNING) << "The file name is null."; return false; @@ -70,7 +70,7 @@ bool PosixFileSystem::DeleteFile(const string& file_name) { } static const int DEFAULT_MKDIR_MODE = 0700; -bool PosixFileSystem::CreateDir(const string& dir_name) { +bool PosixFileSystem::CreateDir(const string &dir_name) { if (dir_name.empty()) { MS_LOG(WARNING) << "The directory name is null."; return false; @@ -83,7 +83,7 @@ bool PosixFileSystem::CreateDir(const string& dir_name) { return true; } -bool PosixFileSystem::DeleteDir(const string& dir_name) { +bool PosixFileSystem::DeleteDir(const string &dir_name) { if (dir_name.empty()) { MS_LOG(WARNING) << "The directory name is null."; return false; diff --git a/mindspore/ccsrc/utils/system/file_system.h b/mindspore/ccsrc/utils/system/file_system.h index ef0cf885be..ed9db874c8 100644 --- a/mindspore/ccsrc/utils/system/file_system.h +++ b/mindspore/ccsrc/utils/system/file_system.h @@ -45,25 +45,25 @@ class FileSystem { virtual ~FileSystem() = default; // Create a new read/write file - virtual WriteFilePtr CreateWriteFile(const string& file_name) = 0; + virtual WriteFilePtr CreateWriteFile(const string &file_name) = 0; // Check the file is exist? - virtual bool FileExist(const string& file_name) = 0; + virtual bool FileExist(const string &file_name) = 0; // Delete the file - virtual bool DeleteFile(const string& file_name) = 0; + virtual bool DeleteFile(const string &file_name) = 0; // Create a directory - virtual bool CreateDir(const string& dir_name) = 0; + virtual bool CreateDir(const string &dir_name) = 0; // Delete the specified directory - virtual bool DeleteDir(const string& dir_name) = 0; + virtual bool DeleteDir(const string &dir_name) = 0; }; // A file that can be read and write class WriteFile { public: - explicit WriteFile(const string& file_name) : file_name_(file_name) {} + explicit WriteFile(const string &file_name) : file_name_(file_name) {} virtual ~WriteFile() = default; @@ -71,7 +71,7 @@ class WriteFile { virtual bool Open() = 0; // append the content to file - virtual bool Write(const std::string& data) { + virtual bool Write(const std::string &data) { MS_LOG(WARNING) << "Attention: Maybe not call the function."; return true; } @@ -101,27 +101,27 @@ class PosixFileSystem : public FileSystem { ~PosixFileSystem() override = default; // create a new write file - WriteFilePtr CreateWriteFile(const string& file_name) override; + WriteFilePtr CreateWriteFile(const string &file_name) override; // check the file is exist? - bool FileExist(const string& file_name) override; + bool FileExist(const string &file_name) override; // delete the file - bool DeleteFile(const string& file_name) override; + bool DeleteFile(const string &file_name) override; // Create a Directory - bool CreateDir(const string& dir_name) override; + bool CreateDir(const string &dir_name) override; // Delete the specified directory. - bool DeleteDir(const string& dir_name) override; + bool DeleteDir(const string &dir_name) override; }; // A file that can be read and write for posix class PosixWriteFile : public WriteFile { public: - explicit PosixWriteFile(const string& file_name) : WriteFile(file_name), file_(nullptr) {} - PosixWriteFile(const PosixWriteFile&); - PosixWriteFile& operator=(const PosixWriteFile&); + explicit PosixWriteFile(const string &file_name) : WriteFile(file_name), file_(nullptr) {} + PosixWriteFile(const PosixWriteFile &); + PosixWriteFile &operator=(const PosixWriteFile &); ~PosixWriteFile() override { try { @@ -129,7 +129,7 @@ class PosixWriteFile : public WriteFile { (void)fclose(file_); file_ = nullptr; } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file."; } catch (...) { MS_LOG(ERROR) << "Non standard exception when closing file."; @@ -159,7 +159,7 @@ class PosixWriteFile : public WriteFile { return true; } - bool Write(const std::string& data) override { + bool Write(const std::string &data) override { MS_LOG(DEBUG) << "Write data(" << data.size() << ") to file(" << this->file_name_ << ")."; size_t r = fwrite(data.data(), 1, data.size(), file_); if (r != data.size()) { @@ -194,7 +194,7 @@ class PosixWriteFile : public WriteFile { bool Sync() override { return Flush(); } private: - FILE* file_; + FILE *file_; }; #endif diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 10ef4abf62..6829a7e888 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -92,6 +92,7 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum"; constexpr auto kGreaterOpName = "Greater"; constexpr auto kSqrtOpName = "Sqrt"; constexpr auto kRsqrtOpName = "Rsqrt"; +constexpr auto kErfOpName = "Erf"; constexpr auto kRealDivOpName = "RealDiv"; constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR"; constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay"; @@ -114,6 +115,9 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; constexpr auto kBiasAddOpName = "BiasAdd"; constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; +constexpr auto kStreamSwitchOpName = "StreamSwitch"; +constexpr auto kStreamActiveOpName = "StreamActive"; +constexpr auto kAssignAddOpName = "AssignAdd"; constexpr auto kSendOpName = "Send"; constexpr auto kRecvOpName = "Recv"; constexpr auto kReluV2OpName = "ReluV2"; @@ -149,7 +153,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; constexpr auto kAttrSrcFormat = "src_format"; constexpr auto kAttrOutputUsedNum = "output_used_num"; constexpr auto kAttrHasBias = "has_bias"; -constexpr auto kAttrN = "N"; +constexpr auto kAttrN = "n"; constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; // attr value @@ -210,7 +214,7 @@ const std::set kOptOperatorSet = { const std::set kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; -static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { +static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { if (access(file_name.c_str(), F_OK) != 0) { MS_LOG(DEBUG) << "File `" << file_name << "` does not exist."; return; diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc old mode 100755 new mode 100644 index e69d25d2dc..d754667cce --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -189,6 +189,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { } else if (utils::isa(arg)) { auto value = utils::cast(arg).object_; inputs.push_back(py::cast(value)); + } else if (utils::isa(arg)) { + auto args_new = utils::cast(arg); + (void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs), + [](const BaseRef &v) { return utils::cast(v); }); + } else { + MS_LOG(WARNING) << "Invalid input type."; } } diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index d7d5a4c096..ae052770ff 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -47,7 +47,7 @@ void ClearConvertCache() { g_ConvertCache.clear(); } // lst: list of nodes (the segment) // users: dict mapping each node to its users (globally) // seen: set of nodes that are part of the segment -AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, const std::vector& seen) { +AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector &seen) { AnfNodePtrList output; if (users.size() == 0) { return output; @@ -57,7 +57,7 @@ AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, c std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr { auto usersn = users.find(n); bool is_referred_out_of_segment = std::any_of( - std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair& u) -> bool { + std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair &u) -> bool { return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen); }); if (n->isa() && is_referred_out_of_segment) { @@ -78,7 +78,7 @@ AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, c return output; } -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList& lst) { +std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { auto fg = std::make_shared(); AnfNodePtrList inputs; AnfNodePtrToAnfNodePtrMap eqv; @@ -86,7 +86,7 @@ std::tuple TransformSegmentToAnfGr MS_LOG(EXCEPTION) << "Input anf node list is empty"; } - auto ref = [&eqv, &inputs, &fg](const AnfNodePtr& a) -> AnfNodePtr { + auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { if (a->isa() && !IsValueNode(a)) { eqv[a] = a; } else if (eqv.find(a) == eqv.end()) { @@ -102,7 +102,7 @@ std::tuple TransformSegmentToAnfGr if (!n->isa()) { MS_LOG(EXCEPTION) << "Inst is not CNode"; } - auto& inps = n->cast()->inputs(); + auto &inps = n->cast()->inputs(); if (inps.empty()) { MS_LOG(EXCEPTION) << "Input is empty"; @@ -120,13 +120,13 @@ std::tuple TransformSegmentToAnfGr std::vector eqv_keys; (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), - [](const std::pair& elem) -> AnfNodePtr { return elem.first; }); + [](const std::pair &elem) -> AnfNodePtr { return elem.first; }); auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); std::vector output_args; output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr& o) -> AnfNodePtr { return eqv[o]; }); + [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); // Set output for AnfGraph auto fg_output = fg->NewCNode(output_args); @@ -148,7 +148,7 @@ std::tuple TransformSegmentToAnfGr // This implementation will convert the nodes into a subgraph // that will run using the MsVM. template -LinConvertResult Convert(const AnfNodePtrList& lst) { +LinConvertResult Convert(const AnfNodePtrList &lst) { auto cached = g_ConvertCache.find(lst); if (cached != g_ConvertCache.end()) { return cached->second; @@ -168,7 +168,7 @@ LinConvertResult Convert(const AnfNodePtrList& lst) { std::shared_ptr vm = std::make_shared(); result.run = - std::make_shared([fg, vm](const VectorRef& args) -> VectorRef { return vm->RunGraph(fg, args); }); + std::make_shared([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); }); result.inputs = inputs; result.outputs = outputs; result.graph_id = UINT32_MAX; diff --git a/mindspore/ccsrc/vm/segment_runner.h b/mindspore/ccsrc/vm/segment_runner.h index 112a770de8..8ea87da50c 100644 --- a/mindspore/ccsrc/vm/segment_runner.h +++ b/mindspore/ccsrc/vm/segment_runner.h @@ -43,7 +43,7 @@ struct LinConvertResult { uint32_t graph_id; }; -using LinkFuncType = std::function; +using LinkFuncType = std::function; using ConvertCache = std::unordered_map; extern LinkFuncType MsVmConvert; extern LinkFuncType GeVmConvert; @@ -53,7 +53,7 @@ extern std::set backend_list; void ClearConvertCache(); -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList& lst); +std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst); } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 92976e0ddb..1c3c917dae 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -41,12 +41,12 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple}; -const std::vector& GetMsNonlinearOps() { +const std::vector &GetMsNonlinearOps() { static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch}; return ms_nonlinear_ops; } -CompileGraph::CompileGraph(const BackendPtr& backend, const std::vector& cut_list) +CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend), cut_list_(cut_list) { MS_EXCEPTION_IF_NULL(backend_); lin_convert_ = backend_->convert_fn(); @@ -61,11 +61,11 @@ CompileGraph::CompileGraph(const BackendPtr& backend, const std::vectorisa()) { auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); if (inputs.empty()) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } @@ -76,7 +76,7 @@ bool CompileGraph::IsCut(const AnfNodePtr& node) { } PrimitivePtr node_prim = GetValueNode(fn); - for (auto& prim : cut_list_) { + for (auto &prim : cut_list_) { MS_EXCEPTION_IF_NULL(prim); if (prim->name() == node_prim->name()) { return true; @@ -97,14 +97,14 @@ bool CompileGraph::IsCut(const AnfNodePtr& node) { return false; } -VectorRef CompileGraph::SplitNodes(const FuncGraphPtr& graph) { +VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); VectorRef splits; VectorRef split; std::vector nodes = TopoSort(graph->get_return()); MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); - for (auto& node : nodes) { + for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (IsCut(node)) { MS_LOG(DEBUG) << "Cut node:" << node->DebugString(10) << ", size:" << split.size(); @@ -123,7 +123,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr& graph) { } // Push the value node on the stack. -void CompileGraph::Push(const AnfNodePtr& node) { +void CompileGraph::Push(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (slots_.count(node) > 0) { MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString() @@ -135,25 +135,25 @@ void CompileGraph::Push(const AnfNodePtr& node) { set_height(height_ + 1); } -void CompileGraph::AddInst(const Instruction& inst, const int& arg) { +void CompileGraph::AddInst(const Instruction &inst, const int &arg) { VectorRef args; args.push_back(arg); AddInst(inst, args); } -void CompileGraph::AddInst(const Instruction& inst, const ValuePtr& arg) { +void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) { VectorRef args; args.push_back(arg); AddInst(inst, args); } -void CompileGraph::AddInst(const Instruction& inst, const VectorRef& args) { +void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) { inst_.push_back(std::make_pair(inst, args)); } // Gets the stack reference for the node value. If the node is a constant, // it may actually cause the push in to not be mentioned before. -int CompileGraph::Ref(const AnfNodePtr& node) { +int CompileGraph::Ref(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_; if (slots_.count(node) == 0 && node->isa()) { @@ -176,7 +176,7 @@ int CompileGraph::Ref(const AnfNodePtr& node) { } // Make sure the value of node is at the top of the stack. -void CompileGraph::AddInput(const AnfNodePtr& node) { +void CompileGraph::AddInput(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (slots_.count(node) == 0) { MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true); @@ -190,7 +190,7 @@ void CompileGraph::AddInput(const AnfNodePtr& node) { // Call back effect in stack void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); } -void CompileGraph::PushParameters(const FuncGraphPtr& graph) { +void CompileGraph::PushParameters(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector parameters = graph->parameters(); for (size_t i = parameters.size(); i != 0; i--) { @@ -199,7 +199,7 @@ void CompileGraph::PushParameters(const FuncGraphPtr& graph) { } } -int CompileGraph::LinConvert(const FuncGraphPtr& graph, const AnfNodePtrList& node_list) { +int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) { MS_LOG(DEBUG) << "LinConvert start"; LinConvertResult result; @@ -227,14 +227,14 @@ int CompileGraph::LinConvert(const FuncGraphPtr& graph, const AnfNodePtrList& no } } AddExternal(result); - for (auto& o : result.outputs) { + for (auto &o : result.outputs) { Push(o); } return RET_SUCCESS; } -void CompileGraph::AddSinkSwitch(const CNodePtr& node) { +void CompileGraph::AddSinkSwitch(const CNodePtr &node) { MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString(); if (backend_->is_multi_graph_sink()) { VectorRef args; @@ -255,7 +255,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr& node) { } } -int CompileGraph::InterpretNode(const FuncGraphPtr& graph, const CNodePtr& node) { +int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true); std::vector node_inputs = node->inputs(); @@ -293,7 +293,7 @@ int CompileGraph::InterpretNode(const FuncGraphPtr& graph, const CNodePtr& node) return RET_SUCCESS; } -void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr& graph) { +void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) { auto ret = LinConvert(graph, {}); if (ret == RET_FAILED) { MS_LOG(EXCEPTION) << "MultiGraphRun failed."; @@ -301,20 +301,20 @@ void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr& graph) { AddReturn(nullptr); } -bool CompileGraph::SplitGraph(const FuncGraphPtr& graph) { +bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start split graph"; MS_EXCEPTION_IF_NULL(graph); VectorRef splits = SplitNodes(graph); MS_LOG(DEBUG) << "Split nodes size:" << splits.size(); - for (auto& split : splits) { + for (auto &split : splits) { int ret = RET_SUCCESS; if (utils::isa(split)) { MS_LOG(DEBUG) << "Start a extern LinConvert"; std::vector args; auto vec_ref = utils::cast(split); (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), - [](const BaseRef& v) { return utils::cast(v); }); + [](const BaseRef &v) { return utils::cast(v); }); ret = LinConvert(graph, args); MS_LOG(DEBUG) << "End a extern LinConvert"; if (ret == RET_FAILED) { @@ -340,12 +340,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr& graph) { return true; } -InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr& graph) { +InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) { InstSet inst = Run(graph); return inst; } -InstSet CompileGraph::Run(const FuncGraphPtr& graph) { +InstSet CompileGraph::Run(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(DEBUG) << "Compile start graph: " << graph->ToString(); @@ -378,7 +378,7 @@ void CompileGraph::AddPadStack(int param_height) { } } -void CompileGraph::AddTailCall(const AnfNodePtr& fn, size_t size) { +void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) { VectorRef args; args.emplace_back(Ref(fn)); args.emplace_back(height_); @@ -387,7 +387,7 @@ void CompileGraph::AddTailCall(const AnfNodePtr& fn, size_t size) { AddInst(Instruction::kTailCall, args); } -void CompileGraph::AddPartial(const CNodePtr& node) { +void CompileGraph::AddPartial(const CNodePtr &node) { auto inputs = node->inputs(); VectorRef args; for (size_t i = 1; i < inputs.size(); i++) { @@ -396,7 +396,7 @@ void CompileGraph::AddPartial(const CNodePtr& node) { AddInst(Instruction::kPartial, args); } -void CompileGraph::AddMakeTuple(const CNodePtr& node) { +void CompileGraph::AddMakeTuple(const CNodePtr &node) { auto inputs = node->inputs(); VectorRef args; for (size_t i = 1; i < inputs.size(); i++) { @@ -405,7 +405,7 @@ void CompileGraph::AddMakeTuple(const CNodePtr& node) { AddInst(Instruction::kTuple, args); } -void CompileGraph::AddSwitch(const CNodePtr& node) { +void CompileGraph::AddSwitch(const CNodePtr &node) { auto inputs = node->inputs(); if (inputs.size() < 4) { MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4"; @@ -420,7 +420,7 @@ void CompileGraph::AddSwitch(const CNodePtr& node) { AddInst(Instruction::kSwitch, args); } -void CompileGraph::AddReturn(const CNodePtr& node) { +void CompileGraph::AddReturn(const CNodePtr &node) { VectorRef args; if (backend_->simu_flag()) { args.emplace_back(Ref(backend_->final_output())); @@ -431,7 +431,7 @@ void CompileGraph::AddReturn(const CNodePtr& node) { AddInst(Instruction::kReturn, args); } -void CompileGraph::AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim) { +void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) { auto inputs = node->inputs(); VectorRef args; args.push_back(prim); @@ -441,7 +441,7 @@ void CompileGraph::AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim) AddInst(Instruction::kPrim, args); } -int CompileGraph::AddCall(const FuncGraphPtr& graph, const CNodePtr& node) { +int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { auto node_inputs = node->inputs(); AnfNodePtr fn = node_inputs[0]; (void)Ref(fn); @@ -459,7 +459,7 @@ int CompileGraph::AddCall(const FuncGraphPtr& graph, const CNodePtr& node) { return RET_SUCCESS; } -void CompileGraph::AddExternal(const LinConvertResult& result) { +void CompileGraph::AddExternal(const LinConvertResult &result) { VectorRef args; args.push_back(result.run); args.push_back(result.simu_run); @@ -471,16 +471,16 @@ void CompileGraph::AddExternal(const LinConvertResult& result) { } void TraverseGraphMap( - const FuncGraphManagerPtr& manager_ptr, FuncGraphTransaction* const tr, const FuncGraphToAnfNodeCounterMap& cts, - const std::function(const PrimitivePtr, const AbstractFunctionPtr)>& get_prim_graph) { + const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, + const std::function(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(tr); - for (const auto& ct_graphs : cts) { - for (const auto& ct_any : ct_graphs.second) { + for (const auto &ct_graphs : cts) { + for (const auto &ct_any : ct_graphs.second) { AnfNodePtr const_primitive_node = ct_any.first; if (const_primitive_node != nullptr && IsValueNode(const_primitive_node)) { auto users = manager_ptr->node_users()[const_primitive_node]; - for (auto& use : users) { + for (auto &use : users) { CNodePtr node = use.first->cast(); MS_EXCEPTION_IF_NULL(node); int key = use.second; @@ -503,12 +503,12 @@ void TraverseGraphMap( } } -FuncGraphPtr WrapPrimitives(const FuncGraphPtr& graph) { +FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); FuncGraphManagerPtr manager_ptr = graph->manager(); MS_EXCEPTION_IF_NULL(manager_ptr); MapPrimTypeFuncGraph prim_graphs; - auto get_prim_graph = [&](const PrimitivePtr& prim, const AbstractFunctionPtr& type) { + auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) { PrimTypePair prim_type = std::make_pair(prim, type); if (prim_graphs.end() == prim_graphs.find(prim_type)) { FuncGraphPtr g = std::make_shared(); @@ -536,13 +536,13 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr& graph) { }; FuncGraphTransaction tr = manager_ptr->Transact(); - auto& cts = manager_ptr->valuenodes(); + auto &cts = manager_ptr->valuenodes(); TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); return graph; } -CompileGraphs::CompileGraphs(const BackendPtr& backend, const std::vector& cut_list) : backend_(backend) { +CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend) { MS_EXCEPTION_IF_NULL(backend); MS_LOG(DEBUG) << "Start vm: " << backend->name(); transform_ = std::make_shared(backend, cut_list); @@ -550,12 +550,12 @@ CompileGraphs::CompileGraphs(const BackendPtr& backend, const std::vectormanager(); MS_EXCEPTION_IF_NULL(graph_manager); FuncGraphSet graphs = graph_manager->func_graphs(); - for (auto& g : graphs) { + for (auto &g : graphs) { mapping_[g] = static_cast(insts_.size()); if (transform_ != nullptr) { InstSet insts = transform_->Run(g); @@ -568,7 +568,7 @@ void CompileGraphs::Compile(const FuncGraphPtr& graph) { } // Link instructions from multiple function graphs together. -FinalVMPtr CompileGraphs::Link(const FuncGraphPtr& graph) { +FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start"; for (std::size_t i = 0; i < insts_.size(); i++) { InstType inst = insts_[i]; @@ -600,7 +600,7 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr& graph) { } // Convert all graphs to unlinked instructions and link them. -FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr& graph) { +FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(DEBUG) << "Start"; Reset(); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 290af10049..711c1777ab 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -42,26 +42,26 @@ extern const char kGeVm[]; // A sub namespace in ME to support compile related definition. namespace compile { extern std::vector nonlinear_ops; -const std::vector& GetMsNonlinearOps(); +const std::vector &GetMsNonlinearOps(); -using VmEvalFunc = std::function; -using VmEvalFuncPtr = std::shared_ptr>; +using VmEvalFunc = std::function; +using VmEvalFuncPtr = std::shared_ptr>; class CompileGraph { public: - explicit CompileGraph(const BackendPtr& backend, const std::vector& cut_list = nonlinear_ops); + explicit CompileGraph(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); ~CompileGraph() = default; - InstSet Run(const FuncGraphPtr& func_graph); - InstSet GenMultiGraphsSinkInst(const FuncGraphPtr& graph); - bool IsCut(const AnfNodePtr& node); - void Push(const AnfNodePtr& node); - void Tie(const AnfNodePtr& n1, const AnfNodePtr& n2) { slots_[n2] = slots_[n1]; } + InstSet Run(const FuncGraphPtr &func_graph); + InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph); + bool IsCut(const AnfNodePtr &node); + void Push(const AnfNodePtr &node); + void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } void Ret(int nargs); - void GenMultiGraphsRun(const FuncGraphPtr& graph); - int Ref(const AnfNodePtr& node); - VectorRef SplitNodes(const FuncGraphPtr& func_graph); + void GenMultiGraphsRun(const FuncGraphPtr &graph); + int Ref(const AnfNodePtr &node); + VectorRef SplitNodes(const FuncGraphPtr &func_graph); void set_height(int h) { height_ = h; @@ -78,24 +78,24 @@ class CompileGraph { } private: - void PushParameters(const FuncGraphPtr& func_graph); - bool SplitGraph(const FuncGraphPtr& func_graph); - int LinConvert(const FuncGraphPtr& func_graph, const AnfNodePtrList& node_list); - int InterpretNode(const FuncGraphPtr& func_graph, const CNodePtr& node); - int AddCall(const FuncGraphPtr& graph, const CNodePtr& node); - void AddSinkSwitch(const CNodePtr& node); + void PushParameters(const FuncGraphPtr &func_graph); + bool SplitGraph(const FuncGraphPtr &func_graph); + int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list); + int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); + int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); + void AddSinkSwitch(const CNodePtr &node); void AddPadStack(int param_height); - void AddTailCall(const AnfNodePtr& fn, size_t size); - void AddPartial(const CNodePtr& node); - void AddMakeTuple(const CNodePtr& node); - void AddSwitch(const CNodePtr& node); - void AddReturn(const CNodePtr& node); - void AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim); - void AddInput(const AnfNodePtr& node); - void AddExternal(const LinConvertResult& result); - void AddInst(const Instruction& inst, const int& arg); - void AddInst(const Instruction& inst, const ValuePtr& arg); - void AddInst(const Instruction& inst, const VectorRef& args); + void AddTailCall(const AnfNodePtr &fn, size_t size); + void AddPartial(const CNodePtr &node); + void AddMakeTuple(const CNodePtr &node); + void AddSwitch(const CNodePtr &node); + void AddReturn(const CNodePtr &node); + void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); + void AddInput(const AnfNodePtr &node); + void AddExternal(const LinConvertResult &result); + void AddInst(const Instruction &inst, const int &arg); + void AddInst(const Instruction &inst, const ValuePtr &arg); + void AddInst(const Instruction &inst, const VectorRef &args); BackendPtr backend_; LinkFuncType lin_convert_; @@ -112,7 +112,7 @@ using CompileGraphPtr = std::shared_ptr; // CompileGraphs is used to Convert a graph cluster into instruction lists. class CompileGraphs { public: - explicit CompileGraphs(const BackendPtr& backend, const std::vector& cut_list = nonlinear_ops); + explicit CompileGraphs(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); ~CompileGraphs() = default; @@ -121,9 +121,9 @@ class CompileGraphs { mapping_.clear(); } - void Compile(const FuncGraphPtr& func_graph); - FinalVMPtr Link(const FuncGraphPtr& func_graph); - FinalVMPtr CompileAndLink(const FuncGraphPtr& func_graph); + void Compile(const FuncGraphPtr &func_graph); + FinalVMPtr Link(const FuncGraphPtr &func_graph); + FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); private: InstSet insts_; diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index 493873b0bc..95ceceb67f 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -32,29 +32,29 @@ namespace compile { // Arguments: // fn_: Callable function. // args_: Sequence of function args. -StructPartial::StructPartial(int fn, const VectorRef& args) : fn_(fn), args_(args) {} +StructPartial::StructPartial(int fn, const VectorRef &args) : fn_(fn), args_(args) {} -std::ostream& operator<<(std::ostream& os, const StructPartial& other) { +std::ostream &operator<<(std::ostream &os, const StructPartial &other) { os << "partial(" << other.fn_ << ", " << other.args_.ToString() << ")"; return os; } -bool operator==(const StructPartial& lhs, const StructPartial& rhs) { +bool operator==(const StructPartial &lhs, const StructPartial &rhs) { return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_); } -StructSimuSwitch::StructSimuSwitch(const BaseRef& fn, const BaseRef& value) : fn_(fn), value_(value) {} +StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {} -std::ostream& operator<<(std::ostream& os, const StructSimuSwitch& other) { +std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) { os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")"; return os; } -bool operator==(const StructSimuSwitch& lhs, const StructSimuSwitch& rhs) { +bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) { return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_); } -std::ostream& operator<<(std::ostream& os, const SwitchCondStatus& other) { +std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) { os << "SwitchCondStatus(" << static_cast(other) << ")"; return os; } @@ -66,13 +66,13 @@ std::ostream& operator<<(std::ostream& os, const SwitchCondStatus& other) { // retp_: The call stack. // pc_: program counter (next instruction) // sp_: stack pointer (for the value stack) -FinalVM::FinalVM(const InstSet& insts, const BackendPtr& backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) { +FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) { MS_LOG(DEBUG) << "InstSet size:" << insts_.size(); insts_stack_.emplace_back(BaseRef()); retp_.push(-1); } -void FinalVM::Push(const BaseRef& v) { +void FinalVM::Push(const BaseRef &v) { MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_; insts_stack_[IntToSize(sp_++)] = v; } @@ -140,7 +140,7 @@ void FinalVM::Popsp() { } } -void FinalVM::DoJmp(const BaseRef& jmp_orig) { +void FinalVM::DoJmp(const BaseRef &jmp_orig) { MS_LOG(DEBUG) << "Start"; BaseRef jmp = jmp_orig; @@ -173,7 +173,7 @@ void FinalVM::DoJmp(const BaseRef& jmp_orig) { MS_LOG(DEBUG) << "End do jump pc_:" << pc_; } -BaseRef FinalVM::Eval(const VectorRef& args) { +BaseRef FinalVM::Eval(const VectorRef &args) { MS_LOG(DEBUG) << "Start: " << args.size(); insts_stack_.clear(); insts_stack_.resize(args.size()); @@ -212,7 +212,7 @@ BaseRef FinalVM::Eval(const VectorRef& args) { return insts_stack_[0]; } -void FinalVM::InstCall(const VectorRef& args) { +void FinalVM::InstCall(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -228,7 +228,7 @@ void FinalVM::InstCall(const VectorRef& args) { MS_LOG(DEBUG) << "Instcall end sp :" << sp_; } -void FinalVM::InstTailCall(const VectorRef& args) { +void FinalVM::InstTailCall(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 3; if (args.size() != args_size) { @@ -258,7 +258,7 @@ void FinalVM::InstTailCall(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSwitchReturn(const VectorRef& args) { +void FinalVM::InstSwitchReturn(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; if (args.size() != 1) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; @@ -268,7 +268,7 @@ void FinalVM::InstSwitchReturn(const VectorRef& args) { Popsp(); } -void FinalVM::InstReturn(const VectorRef& args) { +void FinalVM::InstReturn(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 2; if (args.size() != args_size) { @@ -291,7 +291,7 @@ void FinalVM::InstReturn(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPartial(const VectorRef& args) { +void FinalVM::InstPartial(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() < args_size) { @@ -306,12 +306,12 @@ void FinalVM::InstPartial(const VectorRef& args) { std::vector outs(args.size() - 1); (void)std::transform(args.begin() + 1, args.end(), outs.begin(), - [&, this](const BaseRef& a) { return Ref(utils::cast(a)); }); + [&, this](const BaseRef &a) { return Ref(utils::cast(a)); }); Push(std::make_shared(fn, VectorRef(outs))); MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSimuSwitch(const VectorRef& args) { +void FinalVM::InstSimuSwitch(const VectorRef &args) { const size_t args_size = 4; if (args.size() != args_size) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " @@ -365,7 +365,7 @@ void FinalVM::InstSimuSwitch(const VectorRef& args) { } } -void FinalVM::InstRealSwitch(const VectorRef& args) { +void FinalVM::InstRealSwitch(const VectorRef &args) { const size_t args_size = 3; if (args.size() != args_size) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " @@ -392,7 +392,7 @@ void FinalVM::InstRealSwitch(const VectorRef& args) { } } -void FinalVM::InstSwitch(const VectorRef& args) { +void FinalVM::InstSwitch(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; if (backend_->is_multi_graph_sink()) { InstSimuSwitch(args); @@ -401,7 +401,7 @@ void FinalVM::InstSwitch(const VectorRef& args) { } } -void FinalVM::InstTuple(const VectorRef& args) { +void FinalVM::InstTuple(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; VectorRef tuple; auto iter = args.begin(); @@ -413,7 +413,7 @@ void FinalVM::InstTuple(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPush(const VectorRef& args) { +void FinalVM::InstPush(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -427,7 +427,7 @@ void FinalVM::InstPush(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstInput(const VectorRef& args) { +void FinalVM::InstInput(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -441,7 +441,7 @@ void FinalVM::InstInput(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPadStack(const VectorRef& args) { +void FinalVM::InstPadStack(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -461,7 +461,7 @@ void FinalVM::InstPadStack(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstExternal(const VectorRef& args) { +void FinalVM::InstExternal(const VectorRef &args) { MS_LOG(DEBUG) << "Start:" << args.size(); if (args.empty()) { @@ -490,14 +490,14 @@ void FinalVM::InstExternal(const VectorRef& args) { auto outs = (*fn)(tuple); MS_LOG(DEBUG) << "'fn' out size:" << outs.size(); - for (auto& o : outs) { + for (auto &o : outs) { MS_LOG(DEBUG) << "InstExternal value:" << o.ToString(); Push(o); } MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPushPrim(const VectorRef& args) { +void FinalVM::InstPushPrim(const VectorRef &args) { MS_LOG(DEBUG) << "Start: " << args.size(); const size_t args_size = 2; if (args.size() < args_size) { diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index 3e1e5b5c08..eab726a9b7 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -53,14 +53,14 @@ enum Instruction { using InstType = std::pair; using InstSet = std::vector; -using InstFunctionMap = std::map>; +using InstFunctionMap = std::map>; const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", "input", "external", "push", "primitive", "graph", "pad_stack"}; class StructPartial : public Base { public: // Initialize StructPartial. - StructPartial(int fn, const VectorRef& args); + StructPartial(int fn, const VectorRef &args); virtual ~StructPartial() = default; MS_DECLARE_PARENT(StructPartial, Base) @@ -69,12 +69,12 @@ class StructPartial : public Base { VectorRef args_; }; -std::ostream& operator<<(std::ostream& os, const StructPartial& other); -bool operator==(const StructPartial& lhs, const StructPartial& rhs); +std::ostream &operator<<(std::ostream &os, const StructPartial &other); +bool operator==(const StructPartial &lhs, const StructPartial &rhs); class StructSimuSwitch : public Base { public: - StructSimuSwitch(const BaseRef& fn, const BaseRef& value); + StructSimuSwitch(const BaseRef &fn, const BaseRef &value); virtual ~StructSimuSwitch() = default; MS_DECLARE_PARENT(StructSimuSwitch, Base) @@ -83,43 +83,43 @@ class StructSimuSwitch : public Base { BaseRef value_; }; -std::ostream& operator<<(std::ostream& os, const StructSimuSwitch& other); -bool operator==(const StructSimuSwitch& lhs, const StructSimuSwitch& rhs); +std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other); +bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs); class FinalVM { public: // Create a VM with the specified instructions and backend. - explicit FinalVM(const InstSet& insts, const BackendPtr& backend); + explicit FinalVM(const InstSet &insts, const BackendPtr &backend); virtual ~FinalVM() = default; - BaseRef Eval(const VectorRef& args); - void InstCall(const VectorRef& args); - void InstTailCall(const VectorRef& args); - void InstReturn(const VectorRef& args); - void InstPartial(const VectorRef& args); - void InstSwitch(const VectorRef& args); - void InstSimuSwitch(const VectorRef& args); - void InstRealSwitch(const VectorRef& args); - void InstTuple(const VectorRef& args); - void InstPush(const VectorRef& args); - void InstInput(const VectorRef& args); - void InstPadStack(const VectorRef& args); - void InstExternal(const VectorRef& args); - void InstPushPrim(const VectorRef& args); - void InstSwitchReturn(const VectorRef& args); - void set_insts(const InstSet& value) { insts_ = value; } + BaseRef Eval(const VectorRef &args); + void InstCall(const VectorRef &args); + void InstTailCall(const VectorRef &args); + void InstReturn(const VectorRef &args); + void InstPartial(const VectorRef &args); + void InstSwitch(const VectorRef &args); + void InstSimuSwitch(const VectorRef &args); + void InstRealSwitch(const VectorRef &args); + void InstTuple(const VectorRef &args); + void InstPush(const VectorRef &args); + void InstInput(const VectorRef &args); + void InstPadStack(const VectorRef &args); + void InstExternal(const VectorRef &args); + void InstPushPrim(const VectorRef &args); + void InstSwitchReturn(const VectorRef &args); + void set_insts(const InstSet &value) { insts_ = value; } protected: BaseRef Ref(int i); - void Push(const BaseRef& v); + void Push(const BaseRef &v); void Pop(int n = 1); void MoveStack(int nitems, int height); void Pushp(); void Popp(); void Pushsp(); void Popsp(); - void DoJmp(const BaseRef& jmp); + void DoJmp(const BaseRef &jmp); private: InstSet insts_; @@ -130,18 +130,18 @@ class FinalVM { int sp_; BackendPtr backend_; const InstFunctionMap inst_function_map = { - {Instruction::kCall, [this](const VectorRef& args) { InstCall(args); }}, - {Instruction::kTailCall, [this](const VectorRef& args) { InstTailCall(args); }}, - {Instruction::kReturn, [this](const VectorRef& args) { InstReturn(args); }}, - {Instruction::kPartial, [this](const VectorRef& args) { InstPartial(args); }}, - {Instruction::kSwitch, [this](const VectorRef& args) { InstSwitch(args); }}, - {Instruction::kTuple, [this](const VectorRef& args) { InstTuple(args); }}, - {Instruction::kPush, [this](const VectorRef& args) { InstPush(args); }}, - {Instruction::kInput, [this](const VectorRef& args) { InstInput(args); }}, - {Instruction::kPadStack, [this](const VectorRef& args) { InstPadStack(args); }}, - {Instruction::kExternal, [this](const VectorRef& args) { InstExternal(args); }}, - {Instruction::kPrim, [this](const VectorRef& args) { InstPushPrim(args); }}, - {Instruction::kSwitchReturn, [this](const VectorRef& args) { InstSwitchReturn(args); }}, + {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, + {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }}, + {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }}, + {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }}, + {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }}, + {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }}, + {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }}, + {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }}, + {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }}, + {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, + {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, + {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, }; }; diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index ee9a817dd8..017121f334 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -40,25 +40,25 @@ using PrimitivePyPtr = std::shared_ptr; // Indicate a call to a new frame. struct CallWrap : public Base { - explicit CallWrap(const VMFramePtr& vm_frame) : frame(vm_frame) {} + explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {} VMFramePtr frame{nullptr}; }; using CallWrapPtr = std::shared_ptr; // Indicates a return with its value. struct ReturnWrap : public Base { - explicit ReturnWrap(const BaseRef& r_value) : value(r_value) {} + explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {} BaseRef value{BaseRef()}; }; using ReturnWrapPtr = std::shared_ptr; -VMFrame::VMFrame(const AnfNodePtrList& nodes, const AnfNodePtrToBaseRefMap& values, - const AnfNodePtrToBaseRefMap& closure) +VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, + const AnfNodePtrToBaseRefMap &closure) : values_(values), todo_(nodes), closure_(closure) { std::reverse(std::begin(todo_), std::end(todo_)); } -const BaseRef VMFrame::operator[](const AnfNodePtr& node) { +const BaseRef VMFrame::operator[](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto ret = values_.find(node); if (ret != values_.end()) { @@ -77,31 +77,31 @@ const BaseRef VMFrame::operator[](const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "ValueError " << node->type_name(); } -Closure::Closure(const FuncGraphPtr& graph, const AnfNodePtrToBaseRefMap& values) +Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values) : func_graph_(graph), values_(values) {} -BaseRef Closure::operator()(const VectorRef& args) { +BaseRef Closure::operator()(const VectorRef &args) { MS_LOG(DEBUG) << "start closure"; return vm_->Evaluate(func_graph_, args, values_); } -Partial::Partial(const BaseRef& fn, const VectorRef& args, const VMPtr& vm) : fn_(fn), args_(args), vm_(vm) {} +Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {} -BaseRef Partial::operator()(const VectorRef& nodes) { +BaseRef Partial::operator()(const VectorRef &nodes) { VectorRef arglist; (void)arglist.insert(arglist.end(), args_.begin(), args_.end()); (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end()); return vm_->Call(fn_, arglist); } -SetRef VM::ComputeFvs(const FuncGraphPtr& graph) { +SetRef VM::ComputeFvs(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); SetRef rval; - for (auto& fkv : graph->free_variables_total()) { + for (auto &fkv : graph->free_variables_total()) { if (utils::isa(fkv.first)) { // Add all value_nodes of g that refer to a fv graph auto g = utils::cast(fkv.first); - for (auto& ctkv : g->value_nodes()) { + for (auto &ctkv : g->value_nodes()) { auto ct = ctkv.first; if (GetValueNode(ct) == g) { (void)rval.insert(ct); @@ -116,7 +116,7 @@ SetRef VM::ComputeFvs(const FuncGraphPtr& graph) { return rval; } -void VM::AcquireGraph(const FuncGraphPtr& graph) { +void VM::AcquireGraph(const FuncGraphPtr &graph) { // Already acquired if (vars_.find(graph) != vars_.end()) { return; @@ -130,30 +130,30 @@ void VM::AcquireGraph(const FuncGraphPtr& graph) { } } -VectorRef VM::ExportSequence(const VectorRef& seq) { +VectorRef VM::ExportSequence(const VectorRef &seq) { std::vector ret; (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret), - [&, this](const BaseRef& x) -> BaseRef { return Export(x); }); + [&, this](const BaseRef &x) -> BaseRef { return Export(x); }); return VectorRef(ret); } -ClosurePtr VM::ExportClosure(const ClosurePtr& clos) { +ClosurePtr VM::ExportClosure(const ClosurePtr &clos) { MS_EXCEPTION_IF_NULL(clos); clos->set_vm(shared_from_this()); return clos; } // transform graph to executable closure -ClosurePtr VM::ExportGraph(const FuncGraphPtr& g) { +ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) { auto c = std::make_shared(g, AnfNodePtrToBaseRefMap()); MS_EXCEPTION_IF_NULL(c); c->set_vm(shared_from_this()); return c; } -BaseRef VM::ExportObj(const BaseRef& obj) const { return obj; } +BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; } -BaseRef VM::Export(const BaseRef& value) { +BaseRef VM::Export(const BaseRef &value) { if (utils::isa(value) && utils::cast(value)->isa()) { return ExportGraph(utils::cast(value)->cast()); } @@ -183,7 +183,7 @@ BaseRef VM::Export(const BaseRef& value) { // Run a graph. // This will evaluate the passed-in graph and return the resulting value. -BaseRef VM::Evaluate(const FuncGraphPtr& graph, const VectorRef& args, const AnfNodePtrToBaseRefMap& closure) { +BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) { AcquireGraph(graph); MS_LOG(DEBUG) << "evalue arg size: " << args.size(); if (args.size() != graph->parameters().size()) { @@ -237,15 +237,15 @@ BaseRef VM::Evaluate(const FuncGraphPtr& graph, const VectorRef& args, const Anf MS_LOG(EXCEPTION) << "VM Evaluate error"; } -SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { - auto fn = [&, this](const AnfNodePtr& node) -> AnfNodePtrList { +SuccFunc VM::SuccVm(const FuncGraphPtr &graph) { + auto fn = [&, this](const AnfNodePtr &node) -> AnfNodePtrList { MS_EXCEPTION_IF_NULL(node); AnfNodePtrList ret; // Follow node.incoming if (node->isa()) { - auto& inputs = node->cast()->inputs(); - for (auto& i : inputs) { + auto &inputs = node->cast()->inputs(); + for (auto &i : inputs) { if (i->func_graph() == node->func_graph() || (IsValueNode(i) && GetValueNode(i)->parent() == graph)) { ret.push_back(i); @@ -257,7 +257,7 @@ SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { if (IsValueNode(node) && GetValueNode(node)->parent() == graph) { auto fvs = utils::cast(vars_[GetValueNode(node)]); (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(ret), - [](const BaseRef& value) -> AnfNodePtr { return utils::cast(value); }); + [](const BaseRef &value) -> AnfNodePtr { return utils::cast(value); }); } return ret; @@ -265,7 +265,7 @@ SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { return fn; } -BaseRef VM::Call(const BaseRef& fn, const VectorRef& args) { +BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn)) { return RunOperation(utils::cast(fn), args); } @@ -283,7 +283,7 @@ BaseRef VM::Call(const BaseRef& fn, const VectorRef& args) { } // make call frame for graph -BaseRef VM::_Call(const BaseRef& graph, const VectorRef& args) { +BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) { AnfNodePtrToBaseRefMap clos; auto func_graph = graph; if (utils::isa(func_graph)) { @@ -319,11 +319,11 @@ BaseRef VM::_Call(const BaseRef& graph, const VectorRef& args) { } // make closure out of graph with fv values from frame -ClosurePtr VM::MakeClosure(const FuncGraphPtr& graph, const VMFramePtr& frame) { +ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(frame); AnfNodePtrToBaseRefMap clos; - for (auto& v : utils::cast(vars_[graph])) { + for (auto &v : utils::cast(vars_[graph])) { auto anf = utils::cast(v); clos[anf] = (*frame)[anf]; } @@ -331,7 +331,7 @@ ClosurePtr VM::MakeClosure(const FuncGraphPtr& graph, const VMFramePtr& frame) { return std::make_shared(graph, clos); } -BaseRef VM::DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const BaseRef& fn, const VectorRef& args) { +BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn) && utils::cast(fn)->isa()) { auto fnval = utils::cast(fn)->cast(); MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true); @@ -384,7 +384,7 @@ BaseRef VM::DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const MS_LOG(EXCEPTION) << "Invalid fn to call"; } -BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { +BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { // pass @@ -409,10 +409,10 @@ BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { if (node->isa()) { std::vector fnArgs; - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); // set args' values in frame (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs), - [&](const AnfNodePtr& inp) -> BaseRef { return (*frame)[inp]; }); + [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; }); if (fnArgs.empty()) { MS_LOG(EXCEPTION) << "function arguments is empty"; } else { @@ -425,7 +425,7 @@ BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { MS_LOG(EXCEPTION) << "Unknown node type"; } -VectorRef VM::RunGraph(const FuncGraphPtr& g, const VectorRef& args) { +VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) { this->manager_ = Manage(g); auto fn = utils::cast(Export(g)); @@ -439,7 +439,7 @@ VectorRef VM::RunGraph(const FuncGraphPtr& g, const VectorRef& args) { } } -BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args) { +BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { PrimitivePyPtr operation = dyn_cast(prim); MS_LOG(DEBUG) << "operation start " << prim->name(); @@ -451,7 +451,7 @@ BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args) { py::tuple py_args = py::tuple(args.size()); MS_LOG(DEBUG) << "input for operation:"; size_t i = 0; - for (auto& arg : args) { + for (auto &arg : args) { py_args[i] = BaseRefToPyData(arg); MS_LOG(DEBUG) << "arg: " << i << ":"; i++; diff --git a/mindspore/ccsrc/vm/vmimpl.h b/mindspore/ccsrc/vm/vmimpl.h index 4ef507af82..11d026fe72 100644 --- a/mindspore/ccsrc/vm/vmimpl.h +++ b/mindspore/ccsrc/vm/vmimpl.h @@ -53,14 +53,14 @@ using VMPtr = std::shared_ptr; class Partial; using PartialPtr = std::shared_ptr; -using RunFunc = std::function; +using RunFunc = std::function; using RunFuncPtr = std::shared_ptr; using SuccFunc = std::function; class VMImpl { public: - virtual VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) = 0; + virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0; virtual ~VMImpl() = default; }; @@ -76,11 +76,11 @@ class VMImpl { // closure: values for the closure if the current application is a closure class VMFrame { public: - VMFrame(const AnfNodePtrList& nodes, const AnfNodePtrToBaseRefMap& values, const AnfNodePtrToBaseRefMap& closure); - const BaseRef operator[](const AnfNodePtr& node); - const AnfNodePtrList& todo() const { return todo_; } + VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure); + const BaseRef operator[](const AnfNodePtr &node); + const AnfNodePtrList &todo() const { return todo_; } - AnfNodePtrToBaseRefMap& values() { return values_; } + AnfNodePtrToBaseRefMap &values() { return values_; } virtual ~VMFrame() = default; @@ -94,16 +94,16 @@ class VMFrame { // Representation of a closure. class Closure : public Base { public: - Closure(const FuncGraphPtr& func_graph, const AnfNodePtrToBaseRefMap& values); - BaseRef operator()(const VectorRef& args); + Closure(const FuncGraphPtr &func_graph, const AnfNodePtrToBaseRefMap &values); + BaseRef operator()(const VectorRef &args); - const VMPtr& vm() const { return vm_; } + const VMPtr &vm() const { return vm_; } - void set_vm(const VMPtr& vm) { vm_ = vm; } + void set_vm(const VMPtr &vm) { vm_ = vm; } - const FuncGraphPtr& func_graph() const { return func_graph_; } + const FuncGraphPtr &func_graph() const { return func_graph_; } - const AnfNodePtrToBaseRefMap& values() const { return values_; } + const AnfNodePtrToBaseRefMap &values() const { return values_; } virtual ~Closure() = default; @@ -118,11 +118,11 @@ class Closure : public Base { // Representation of a partial application. class Partial : public Base { public: - Partial(const BaseRef& fn, const VectorRef& args, const VMPtr& vm); - BaseRef operator()(const VectorRef& nodes); - const BaseRef& fn() const { return fn_; } + Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm); + BaseRef operator()(const VectorRef &nodes); + const BaseRef &fn() const { return fn_; } - const VectorRef& args() const { return args_; } + const VectorRef &args() const { return args_; } virtual ~Partial() = default; MS_DECLARE_PARENT(Partial, Base) @@ -136,52 +136,52 @@ class Partial : public Base { // Virtual Machine interface. class VM : public std::enable_shared_from_this, public VMImpl { public: - SetRef ComputeFvs(const FuncGraphPtr& func_graph); + SetRef ComputeFvs(const FuncGraphPtr &func_graph); - void AcquireGraph(const FuncGraphPtr& func_graph); + void AcquireGraph(const FuncGraphPtr &func_graph); - VectorRef ExportSequence(const VectorRef& seq); + VectorRef ExportSequence(const VectorRef &seq); - BaseRef ExportPrimitive(const PrimitivePtr&) const { return kAnyValue; } + BaseRef ExportPrimitive(const PrimitivePtr &) const { return kAnyValue; } - ClosurePtr ExportClosure(const ClosurePtr& clos); + ClosurePtr ExportClosure(const ClosurePtr &clos); // Return an object that executes `fg` when called on arguments. - ClosurePtr ExportGraph(const FuncGraphPtr& fg); + ClosurePtr ExportGraph(const FuncGraphPtr &fg); - BaseRef ExportObj(const BaseRef& obj) const; + BaseRef ExportObj(const BaseRef &obj) const; - BaseRef Export(const BaseRef& value); + BaseRef Export(const BaseRef &value); // Run a graph. // This will evaluate the passed-in graph and return the // resulting value. - BaseRef Evaluate(const FuncGraphPtr& func_graph, const VectorRef& args, - const AnfNodePtrToBaseRefMap& closure = AnfNodePtrToBaseRefMap()); + BaseRef Evaluate(const FuncGraphPtr &func_graph, const VectorRef &args, + const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap()); // Return a visitor for the graph. - SuccFunc SuccVm(const FuncGraphPtr& func_graph); + SuccFunc SuccVm(const FuncGraphPtr &func_graph); // Call the `fn` object. // `fn` can be anything that would be valid as the first element of an apply. - BaseRef Call(const BaseRef& fn, const VectorRef& args); + BaseRef Call(const BaseRef &fn, const VectorRef &args); - BaseRef _Call(const BaseRef& graph, const VectorRef& args); + BaseRef _Call(const BaseRef &graph, const VectorRef &args); - ClosurePtr MakeClosure(const FuncGraphPtr& func_graph, const VMFramePtr& frame); + ClosurePtr MakeClosure(const FuncGraphPtr &func_graph, const VMFramePtr &frame); - BaseRef DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const BaseRef& fn, const VectorRef& args); + BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args); - BaseRef HandleNode(const AnfNodePtr& node, const VMFramePtr& frame); + BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame); - VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) override; + VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) override; private: FuncGraphManagerPtr manager_; FuncGraphPtrToBaseRefMap vars_; }; -extern BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args); +extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args); } // namespace compile } // namespace mindspore diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index c354bcd235..5f56d23956 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -15,7 +15,6 @@ """Parameter for cell.""" from copy import copy, deepcopy -import numpy as np from .initializer import initializer from .tensor import Tensor from .._checkparam import _check_str_by_regular @@ -176,14 +175,15 @@ class Parameter: return res def set_parameter_data(self, data): - if isinstance(data, (Tensor, list, int, float, - np.float16, np.float32, np.int32, np.int16, np.ndarray)) and not isinstance(data, bool): - if isinstance(data, Tensor): - # make a copy of Tensor to init the parameter - data = Tensor(data.asnumpy().copy()) - self.default_input = data + """Set `default_input` of current `Parameter`.""" + if isinstance(data, bool): + raise ValueError('Parameter data can not be `bool`') + if isinstance(data, Tensor): + # make a copy of Tensor to init the parameter + data = Tensor(data.asnumpy().copy()) else: - raise ValueError("Parameter data must be tensor or number.") + data = Tensor(data) + self.default_input = data class ParameterTuple(tuple): diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 099c8cfc2d..508aa2e7a9 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -334,8 +334,8 @@ def _create_group_helper(group, rank_ids, backend): if not isinstance(rank_ids, list): raise TypeError("Rank_ids {} should be list".format(rank_ids)) rank_size = len(rank_ids) - if rank_size < 2: - raise ValueError("Rank_ids size {} should be large than 1".format(rank_size)) + if rank_size < 1: + raise ValueError("Rank_ids size {} should be large than 0".format(rank_size)) if len(rank_ids) - len(list(set(rank_ids))) > 0: raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) hccl.create_group(group, rank_size, rank_ids) diff --git a/mindspore/context.py b/mindspore/context.py index ba0ac36b66..159522a87a 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -14,14 +14,15 @@ # ============================================================================ """ The context of mindspore, used to configure the current execution environment, -including execution mode, execution backend and other feature switchs. +including execution mode, execution backend and other feature switches. """ +import os import threading from collections import namedtuple from types import FunctionType from mindspore import log as logger from mindspore._c_expression import MSContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context @@ -33,10 +34,36 @@ GRAPH_MODE = 0 PYNATIVE_MODE = 1 +def _make_directory(path): + """Make directory.""" + real_path = None + if path is None or not isinstance(path, str) or path.strip() == "": + raise ValueError(f"Input path `{path}` is invalid type") + + # convert the relative paths + path = os.path.realpath(path) + logger.debug("The absolute path is %r", path) + + # check whether the path is already existed and has written permissions + if os.path.exists(path): + real_path = path + else: + # All exceptions need to be caught because create directory maybe have some limit(permissions) + logger.debug("The directory(%s) doesn't exist, will create it", path) + try: + os.makedirs(path) + real_path = path + except PermissionError as e: + logger.error(f"No write permission on the directory `{path}, error = {e}") + raise ValueError(f"No write permission on the directory `{path}`.") + return real_path + + class _ThreadLocalInfo(threading.local): """ Thread local Info used for store thread local attributes. """ + def __init__(self): super(_ThreadLocalInfo, self).__init__() self._reserve_class_name_in_scope = True @@ -64,6 +91,7 @@ class _ContextSwitchInfo(threading.local): Args: is_pynative (bool): Whether to adopt the PyNative mode. """ + def __init__(self, is_pynative): super(_ContextSwitchInfo, self).__init__() self.context_stack = [] @@ -173,7 +201,7 @@ class _Context: @save_graphs_path.setter def save_graphs_path(self, save_graphs_path): - self._context_handle.set_save_graphs_path(save_graphs_path) + self._context_handle.set_save_graphs_path(_make_directory(save_graphs_path)) @property def device_target(self): @@ -183,7 +211,7 @@ class _Context: def device_target(self, target): success = self._context_handle.set_device_target(target) if not success: - raise ValueError("target device name is invalid!!!") + raise ValueError("Target device name is invalid!!!") @property def device_id(self): @@ -309,7 +337,7 @@ class _Context: @graph_memory_max_size.setter def graph_memory_max_size(self, graph_memory_max_size): - if check_input_fotmat(graph_memory_max_size): + if check_input_format(graph_memory_max_size): graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024" self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) else: @@ -321,7 +349,7 @@ class _Context: @variable_memory_max_size.setter def variable_memory_max_size(self, variable_memory_max_size): - if check_input_fotmat(variable_memory_max_size): + if check_input_format(variable_memory_max_size): variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" self._context_handle.set_variable_memory_max_size(variable_memory_max_size_) else: @@ -341,12 +369,13 @@ class _Context: thread_info.debug_runtime = enable -def check_input_fotmat(x): +def check_input_format(x): import re pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' result = re.match(pattern, x) return result is not None + _k_context = None diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 479c66045f..1b0397ae26 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -20,14 +20,14 @@ can also create samplers with this module to sample data. from .core.configuration import config from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \ - GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \ - Shuffle, zip + GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ + Schema, Shuffle, zip from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ - WeightedRandomSampler + WeightedRandomSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", + "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"] diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 720b56b96d..86d2971332 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", - "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] + "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", + "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8de56a6dff..1648734704 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -24,21 +24,24 @@ import math import os import random import uuid +import multiprocessing +import queue from enum import Enum from importlib import import_module +import threading import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ - MindRecordOp, CBatchInfo + MindRecordOp, TextFileOp, CBatchInfo from mindspore._c_expression import typing from mindspore import log as logger from . import samplers from .iterators import DictIterator, TupleIterator -from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ +from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_zip_dataset, check_add_column + check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -139,6 +142,7 @@ class Dataset: self._batch_size = None self._num_classes = None self._repeat_count = None + self._sync = False def get_args(self): """ @@ -196,6 +200,30 @@ class Dataset: """ return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) + @check_sync_wait + def sync_wait(self, condition_name, num_batch=1, callback=None): + ''' + Add a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + + Examples: + >>> import mindspore.dataset as ds + >>> # data is an instance of Dataset object. + >>> data = data.sync_wait("callback1") + >>> data = data.batch(batch_size) + >>> for batch_data in data.create_dict_iterator(): + >>> data = data.sync_update("callback1") + ''' + return SyncWaitDataset(self, condition_name, num_batch, callback) + @check_shuffle def shuffle(self, buffer_size): """ @@ -218,6 +246,9 @@ class Dataset: Returns: ShuffleDataset, dataset shuffled. + Raises: + RuntimeError: If exist sync operators before shuffle. + Examples: >>> import mindspore.dataset as ds >>> # data is an instance of Dataset object @@ -231,7 +262,7 @@ class Dataset: @check_map def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None): + num_parallel_workers=None, python_multiprocessing=False): """ Applies each operation in operations to this dataset. @@ -270,6 +301,8 @@ class Dataset: same). num_parallel_workers (int, optional): Number of threads used to process the dataset in parallel (default=None, the value from the config will be used). + python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This + option could be beneficial if the python operation is computational heavy (default=False). Returns: MapDataset, dataset after mapping operation. @@ -383,7 +416,34 @@ class Dataset: >>> columns_order = ["mod7", "mod3", "col1"] >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) """ - return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) + return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, + python_multiprocessing) + + @check_filter + def filter(self, predicate, input_columns=None, num_parallel_workers=1): + """ + Filter dataset by predicate. + + Note: + If input_columns not provided or empty, all columns will be used. + + Args: + predicate: python callable which returns a boolean value. + input_columns: (list[str]): List of names of the input columns, when + default=None, the predicate will be applied on all columns in the dataset. + num_parallel_workers (int, optional): Number of workers to process the Dataset + in parallel (default=None). + + Returns: + FilterDataset, dataset filter. + + Examples: + >>> import mindspore.dataset as ds + >>> # generator data(0 ~ 63) + >>> # filter the data that greater than or equal to 11 + >>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"]) + """ + return FilterDataset(self, predicate, input_columns, num_parallel_workers) @check_repeat def repeat(self, count=None): @@ -790,6 +850,9 @@ class Dataset: self._input_indexs = value def _get_pipeline_info(self): + """ + Gets pipeline information. + """ device_iter = TupleIterator(self) self._output_shapes = device_iter.get_output_shapes() self._output_types = device_iter.get_output_types() @@ -844,6 +907,30 @@ class Dataset: return self.input[0].num_classes() return None + def get_sync_notifiers(self): + if self.input: + return self.input[0].get_sync_notifiers() + return {} + + def is_sync(self): + if self.input: + return self.input[0].is_sync() + return False + + def sync_update(self, condition_name, num_batch=None, data=None): + """ + condition_name (str): The condition name that is used to toggle sending next row + step_size (int or None): The number of steps(rows) that are released + when pass_rows is None, will update the same number as sync_wait specified + data (dict or None): The data passed to the callback + """ + notifiers_dict = self.get_sync_notifiers() + if condition_name not in notifiers_dict: + raise RuntimeError("Condition name not found") + if num_batch is not None: + num_batch *= self.get_batch_size() + notifiers_dict[condition_name](num_batch, data) + def get_batch_size(self): """ Get the size of a batch. @@ -888,6 +975,38 @@ class SourceDataset(Dataset): # No need for __init__ since it is the same as the super's init + @staticmethod + def _find_files(patterns): + """ + Utility function to search for files with the given glob patterns. + + Args: + patterns (str or list[str]): string or list of patterns to be searched. + + Returns: + List, files. + """ + + if not isinstance(patterns, list): + patterns = [patterns] + + file_list = [] + unmatched_patterns = [] + for pattern in patterns: + matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)] + + if matches: + file_list.extend(matches) + else: + unmatched_patterns.append(pattern) + + if unmatched_patterns: + raise ValueError("The following patterns did not match any files: ", unmatched_patterns) + + if file_list: # not empty + return file_list + raise ValueError("The list of path names matching the patterns is empty.") + class DatasetOp(Dataset): """ @@ -915,6 +1034,8 @@ class BatchDataset(DatasetOp): if BatchDataset._is_ancestor_of_repeat(input_dataset): logger.warning("Repeat is located before batch, data from two epochs can be batched together.") + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + self.batch_size = batch_size self.drop_remainder = drop_remainder self.per_batch_map = per_batch_map @@ -971,6 +1092,20 @@ class BatchDataset(DatasetOp): flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) return flag + @staticmethod + def _update_batch_size_for_syncwait(dataset, batch_size): + """ + Utility function to notify batch size to sync_wait. + + Args: + dataset (Dataset): dataset to be checked + batchsize (int): batch size to notify + """ + if isinstance(dataset, SyncWaitDataset): + dataset.update_sync_batch_size(batch_size) + for input_dataset in dataset.input: + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + class BatchInfo(CBatchInfo): """ @@ -995,6 +1130,108 @@ class BatchInfo(CBatchInfo): """ return +class BlockReleasePair: + """ + The blocking condition class used by SyncWaitDataset + + Args: + init_release_rows (int): Number of lines to allow through the pipeline + callback (function): The callback funciton that will be called when release is called + """ + def __init__(self, init_release_rows, callback=None): + self.row_count = -init_release_rows + self.cv = threading.Condition() + self.callback = callback + self.default_rows = init_release_rows + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + memodict[id(self)] = self + # condition variable and callback are the same, but reset the counter + self.reset() + return self + + def reset(self): + with self.cv: + self.row_count = -self.default_rows + self.cv.notify_all() + + def update_batched_size(self, batch_size): + # should only use before the pipeline creates + self.row_count *= batch_size + self.default_rows *= batch_size + + def block_func(self): + with self.cv: + self.cv.wait_for(lambda: self.row_count < 0) + self.row_count += 1 + return True + + def release_func(self, pass_rows=None, data=None): + with self.cv: + if pass_rows is None: + pass_rows = self.default_rows + self.row_count -= pass_rows + if self.callback is not None: + self.callback(data) + self.cv.notify_all() + +class SyncWaitDataset(DatasetOp): + """ + The result of adding a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + """ + + def __init__(self, input_dataset, condition_name, num_batch, callback=None): + super().__init__() + self.input.append(input_dataset) + input_dataset.output.append(self) + # set to the default value, waiting for the batch to update it + self._condition_name = condition_name + self._pair = BlockReleasePair(num_batch, callback) + if self._condition_name in self.input[0].get_sync_notifiers(): + raise RuntimeError("Condition name is already in use") + + def get_sync_notifiers(self): + return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} + + def is_sync(self): + return True + + def get_args(self): + args = super().get_args() + args["condition_name"] = self._condition_name + args["condition_func"] = self._pair.block_func + return args + + def update_sync_batch_size(self, batch_size): + self._pair.update_batched_size(batch_size) + + @staticmethod + def _is_ancestor_of_batch(dataset): + """ + Utility function to find the case where sync_wait is used before batch. + + Args: + dataset (Dataset): dataset to be checked + Return: + True or False + """ + if isinstance(dataset, BatchDataset): + return True + flag = False + for input_dataset in dataset.input: + flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) + return flag class ShuffleDataset(DatasetOp): """ @@ -1003,6 +1240,9 @@ class ShuffleDataset(DatasetOp): Args: input_dataset (Dataset): Input Dataset to be shuffled. buffer_size (int): The size of the buffer. + + Raises: + RuntimeError: If exist sync operators before shuffle. """ def __init__(self, input_dataset, buffer_size): @@ -1011,6 +1251,8 @@ class ShuffleDataset(DatasetOp): self.input.append(input_dataset) input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs + if self.is_sync(): + raise RuntimeError("No shuffle after sync operators") def get_args(self): args = super().get_args() @@ -1018,6 +1260,55 @@ class ShuffleDataset(DatasetOp): return args +# Pyfunc collection for multiprocess pyfunc +# This global variable will only be used within subprocesses +_GLOBAL_PYFUNC_LIST = [] + + +# Pyfunc worker init function +# Python multiprocessing library forbid sending lambda function through pipe. +# This init function allow us to add all python function to a global collection and then fork afterwards. +def _pyfunc_worker_init(pyfunc_list): + global _GLOBAL_PYFUNC_LIST + _GLOBAL_PYFUNC_LIST = pyfunc_list + + +# Pyfunc worker execution function +# All exceptions will be raised to main processes +def _pyfunc_worker_exec(index, *args): + try: + return _GLOBAL_PYFUNC_LIST[index](*args) + except KeyboardInterrupt: + raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + + +# PythonCallable wrapper for multiprocess pyfunc +class _PythonCallable: + """ + Internal python function wrapper for multiprocessing pyfunc + """ + def __init__(self, py_callable, idx, pool=None): + # Original python callable from user. + self.py_callable = py_callable + # Process pool created for current iterator. + self.pool = pool + # Python callable index for subprocess _GLOBAL_PYFUNC_LIST + self.idx = idx + + def __call__(self, *args): + if self.pool is not None: + try: + # This call will send the tensors along with Python callable index to the process pool. + # Block, yield GIL. Current thread will reacquire GIL once result is returned. + return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args]) + except KeyboardInterrupt: + self.pool.terminate() + self.pool.join() + raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + # Invoke original python callable in master process in case the pool is gone. + return self.py_callable(*args) + + class MapDataset(DatasetOp): """ The result of applying Map operator to the input Dataset. @@ -1037,13 +1328,15 @@ class MapDataset(DatasetOp): The argument is mandatory if len(input_columns) != len(output_columns). num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). + python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This + option could be beneficial if the python operation is computational heavy (default=False). Raises: ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. """ def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None): + num_parallel_workers=None, python_multiprocessing=False): super().__init__(num_parallel_workers) self.input.append(input_dataset) if input_columns is not None and not isinstance(input_columns, list): @@ -1064,6 +1357,8 @@ class MapDataset(DatasetOp): input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs + self.python_multiprocessing = python_multiprocessing + self.process_pool = None def get_args(self): args = super().get_args() @@ -1081,6 +1376,78 @@ class MapDataset(DatasetOp): """ return self.input[0].get_dataset_size() + # Iterator bootstrap will be called on iterator construction. + # A deep copy of Dataset object is created prior of iterator_bootstrap. + # This method will create per iterator process pool and bind pyfunc execution to the pool. + def iterator_bootstrap(self): + """ + Per iterator bootstrap callback. + """ + if self.python_multiprocessing: + iter_specific_operations = [] + callable_list = [] + + # Pass #1, look for python callables and build list + for op in self.operations: + if callable(op): + callable_list.append(op) + + if callable_list: + # Construct pool with the callable list + # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses + self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, + initializer=_pyfunc_worker_init, + initargs=(callable_list,)) + # Pass #2 + idx = 0 + for op in self.operations: + if callable(op): + # Wrap python callable into _PythonCallable + iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) + idx += 1 + else: + # CPP ops remain the same + iter_specific_operations.append(op) + self.operations = iter_specific_operations + + +class FilterDataset(DatasetOp): + """ + The result of applying filter predicate to the input Dataset. + + Args: + input_dataset: Input Dataset to be mapped. + predicate: python callable which returns a boolean value. + input_columns: (list[str]): List of names of the input columns, when + default=None, the predicate will be applied all columns in the dataset. + num_parallel_workers (int, optional): Number of workers to process the Dataset + in parallel (default=None). + """ + + def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): + super().__init__(num_parallel_workers) + self.predicate = lambda *args: bool(predicate(*args)) + self.input.append(input_dataset) + input_dataset.output.append(self) + if input_columns is not None and not isinstance(input_columns, list): + input_columns = [input_columns] + self.input_columns = input_columns + + def get_args(self): + args = super().get_args() + args["predicate"] = self.predicate + args["input_columns"] = self.input_columns + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + the size cannot be determined before we run the pipeline + Return: + 0 + """ + return 0 + class RepeatDataset(DatasetOp): """ @@ -1239,6 +1606,9 @@ class ZipDataset(DatasetOp): """ return None + def is_sync(self): + return any([c.is_sync() for c in self.input]) + def get_args(self): args = super().get_args() return args @@ -1359,7 +1729,7 @@ class StorageDataset(SourceDataset): Args: dataset_files (list[str]): List of files to be read. - schema (str): Path to the json schema file. + schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset. distribution (str, optional): Path of distribution config file (default=""). columns_list (list[str], optional): List of columns to be read (default=None, read all columns). num_parallel_workers (int, optional): Number of parallel working threads (default=None). @@ -1786,7 +2156,8 @@ class MindDataset(SourceDataset): block_reader (bool, optional): Whether read data by block mode (default=False). sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, sampler is exclusive - with shuffle and block_reader). Support list: SubsetRandomSampler. + with shuffle and block_reader). Support list: SubsetRandomSampler, + PkSampler Raises: ValueError: If num_shards is specified but shard_id is None. @@ -1819,8 +2190,10 @@ class MindDataset(SourceDataset): if block_reader is True: logger.warning("WARN: global shuffle is not used.") - if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False: - raise ValueError("the sampler is not supported yet.") + if sampler is not None: + if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ + isinstance(sampler, samplers.PKSampler) is False: + raise ValueError("the sampler is not supported yet.") # sampler exclusive if block_reader is True and sampler is not None: @@ -1856,7 +2229,7 @@ class MindDataset(SourceDataset): Number, number of batches. """ - num_rows = MindRecordOp.get_num_rows(self.dataset_file) + num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler) if self.partitions is not None and self.partitions[0] > 0: if num_rows % self.partitions[0] == 0: num_rows = num_rows // self.partitions[0] @@ -1934,6 +2307,142 @@ def _cpp_sampler_fn(sampler, dataset): yield tuple([np.array(x) for x in val]) +def _cpp_sampler_fn_mp(sampler, dataset, num_worker): + """ + Multiprocessing generator function wrapper for mappable dataset with cpp sampler + """ + indices = sampler.get_indices() + return _sampler_fn_mp(indices, dataset, num_worker) + + +def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): + """ + Multiprocessing generator function wrapper for mappable dataset with python sampler + """ + indices = _fetch_py_sampler_indices(sampler, num_samples) + return _sampler_fn_mp(indices, dataset, num_worker) + + +def _fetch_py_sampler_indices(sampler, num_samples): + """ + Indices fetcher for python sampler + """ + if num_samples is not None: + sampler_iter = iter(sampler) + ret = [] + for _ in range(num_samples): + try: + val = next(sampler_iter) + ret.append(val) + except StopIteration: + break + return ret + return [i for i in sampler] + + +def _fill_worker_indices(workers, indices, idx): + """ + Worker index queue filler, fill worker index queue in round robin order + """ + num_worker = len(workers) + while idx < len(indices): + try: + workers[idx % num_worker].put(indices[idx]) + idx += 1 + except queue.Full: + break + return idx + + +def _sampler_fn_mp(indices, dataset, num_worker): + """ + Multiprocessing generator function wrapper master process + """ + workers = [] + # Event for end of epoch + eoe = multiprocessing.Event() + + # Create workers + for _ in range(num_worker): + worker = _GeneratorWorker(dataset, eoe) + worker.daemon = True + workers.append(worker) + + # Fill initial index queues + idx_cursor = 0 + idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) + + # Start all workers + for w in workers: + w.start() + + # Fetch results + for i in range(len(indices)): + # Fetch result and put index + try: + result = workers[i % num_worker].get() + except queue.Empty: + raise Exception("Generator worker process timeout") + except KeyboardInterrupt: + for w in workers: + w.terminate() + w.join() + raise Exception("Generator worker receives KeyboardInterrupt") + if idx_cursor < len(indices): + idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) + # Set eoe event once all indices are sent + if idx_cursor == len(indices) and not eoe.is_set(): + eoe.set() + yield tuple([np.array(x) for x in result]) + + +def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): + """ + Multiprocessing generator worker process loop + """ + while True: + # Fetch index, block + try: + idx = idx_queue.get() + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + if idx is None: + # When the queue is out of scope from master process, a None item can be fetched from the queue. + # Upon receiving None, worker process should check if EOE is set. + assert eoe.is_set(), "" + return + # Fetch data, any exception from __getitem__ will terminate worker and timeout master process + result = dataset[idx] + # Send data, block + try: + result_queue.put(result) + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + del result, idx + + +class _GeneratorWorker(multiprocessing.Process): + """ + Worker process for multiprocess Generator + """ + def __init__(self, dataset, eoe): + self.idx_queue = multiprocessing.Queue(16) + self.res_queue = multiprocessing.Queue(16) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) + + def put(self, item): + """ + Put function for worker index queue. Never block. Raise queue.Full on failure. + """ + self.idx_queue.put_nowait(item) + + def get(self): + """ + Get function for worker result queue. Block with timeout. + """ + return self.res_queue.get(timeout=5) + + class GeneratorDataset(SourceDataset): """ A source dataset that generate data from python by invoking python data source each epoch. @@ -1981,6 +2490,7 @@ class GeneratorDataset(SourceDataset): If the schema is not provided, the meta data from column_names and column_types is considered the schema. num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). + num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. (default=None, expected order behavior shown in the table). sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is @@ -2032,16 +2542,22 @@ class GeneratorDataset(SourceDataset): if self.sampler is not None and hasattr(source, "__getitem__"): if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler)): + samplers.WeightedRandomSampler, samplers.Sampler)): if num_samples is None: num_samples = len(source) sampler_instance = self.sampler.create() sampler_instance.set_num_rows(len(source)) sampler_instance.set_num_samples(num_samples) sampler_instance.initialize() - self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) + if num_parallel_workers > 1: + self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) + else: + self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) else: - self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) + if num_parallel_workers > 1: + self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers)) + else: + self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) else: try: iter(source) @@ -2094,7 +2610,10 @@ class TFRecordDataset(SourceDataset): schema (str or Schema, optional): Path to the json schema file or schema object (default=None). If the schema is not provided, the meta data from the TFData file is considered the schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) - num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_samples (int, optional): number of samples(rows) to read (default=None). + If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset; + If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; + If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). @@ -2126,30 +2645,6 @@ class TFRecordDataset(SourceDataset): >>> # 3) get all rows from dataset_files with schema file "./schema.json": >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json") """ - - @staticmethod - def _find_files(patterns): - """ - Utility function to search for files with the given glob patterns. - - Args: - patterns (str or list[str]): string or list of patterns to be searched. - - Returns: - List, files. - """ - - def flat(lists): - return list(np.array(lists).flatten()) - - if not isinstance(patterns, list): - patterns = [patterns] - - file_list = flat([glob.glob(file, recursive=True) for file in patterns]) - if file_list: # not empty - return file_list - raise ValueError("The list of path names matching the patterns is empty.") - @check_tfrecorddataset def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): @@ -2636,10 +3131,10 @@ class Schema: """ def __init__(self, schema_file=None): + self.num_rows = None if schema_file is None: self.columns = [] self.dataset_type = '' - self.num_rows = 0 else: if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK): raise ValueError("The file %s does not exist or permission denied!" % schema_file) @@ -2784,6 +3279,9 @@ class Schema: raise RuntimeError("DatasetType field is missing.") if self.columns is None: raise RuntimeError("Columns are missing.") + if self.num_rows is not None: + if not isinstance(self.num_rows, int) or self.num_rows <= 0: + raise ValueError("numRows must be greater than 0") def __str__(self): return self.to_json() @@ -2837,14 +3335,17 @@ class VOCDataset(SourceDataset): decode (bool, optional): Decode the images after reading (default=False). sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, expected order behavior shown in the table). - distribution (str, optional): Path to the json distribution file to configure - dataset sharding (default=None). This argument should be specified - only when no 'sampler' is used. + num_shards (int, optional): Number of shards that the dataset should be divided + into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. Raises: - RuntimeError: If distribution and sampler are specified at the same time. - RuntimeError: If distribution is failed to read. - RuntimeError: If shuffle and sampler are specified at the same time. + RuntimeError: If sampler and shuffle are specified at the same time. + RuntimeError: If sampler and sharding are specified at the same time. + RuntimeError: If num_shards is specified but shard_id is None. + RuntimeError: If shard_id is specified but num_shards is None. + ValueError: If shard_id is invalid (< 0 or >= num_shards). Examples: >>> import mindspore.dataset as ds @@ -2858,27 +3359,15 @@ class VOCDataset(SourceDataset): @check_vocdataset def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, - shuffle=None, decode=False, sampler=None, distribution=None): + shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir - self.sampler = sampler - if distribution is not None: - if sampler is not None: - raise RuntimeError("Cannot specify distribution and sampler at the same time.") - try: - with open(distribution, 'r') as load_d: - json.load(load_d) - except json.decoder.JSONDecodeError: - raise RuntimeError("Json decode error when load distribution file") - except Exception: - raise RuntimeError("Distribution file has failed to load.") - elif shuffle is not None: - if sampler is not None: - raise RuntimeError("Cannot specify shuffle and sampler at the same time.") + self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.decode = decode - self.distribution = distribution self.shuffle_level = shuffle + self.num_shards = num_shards + self.shard_id = shard_id def get_args(self): args = super().get_args() @@ -2887,7 +3376,8 @@ class VOCDataset(SourceDataset): args["sampler"] = self.sampler args["decode"] = self.decode args["shuffle"] = self.shuffle_level - args["distribution"] = self.distribution + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id return args def get_dataset_size(self): @@ -2952,3 +3442,82 @@ class CelebADataset(SourceDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id return args + +class TextFileDataset(SourceDataset): + """ + A source dataset that reads and parses datasets stored on disk in text format. + The generated dataset has one columns ['text']. + + Args: + dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of + files. The list will be sorted in a lexicographical order. + num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_parallel_workers (int, optional): number of workers to read the data + (default=None, number set in the config). + shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. + Examples: + >>> import mindspore.dataset as ds + >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files + >>> dataset = ds.TextFileDataset(dataset_files=dataset_files) + """ + + @check_textfiledataset + def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, + shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + super().__init__(num_parallel_workers) + self.dataset_files = self._find_files(dataset_files) + self.dataset_files.sort() + self.num_samples = num_samples + + if not isinstance(shuffle, (bool, Shuffle)): + raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") + if not isinstance(shuffle, Shuffle): + if shuffle: + self.shuffle_level = Shuffle.GLOBAL + self.shuffle_files = True + else: + self.shuffle_level = None + self.shuffle_files = False + else: + self.shuffle_level = shuffle + self.shuffle_files = True + + self.num_shards = num_shards + self.shard_id = shard_id + + def get_args(self): + args = super().get_args() + args["dataset_files"] = self.dataset_files + args["num_samples"] = self.num_samples + if self.shuffle_files is not None: + args["shuffle_files"] = self.shuffle_files + args["shuffle"] = self.shuffle_level + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + if self._dataset_size is None: + num_rows = TextFileOp.get_num_rows(self.dataset_files) + num_rows = get_num_rows(num_rows, self.num_shards) + if self.num_samples is None: + return num_rows + return min(self.num_samples, num_rows) + return self._dataset_size diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 2bb130f303..81bad14810 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -48,17 +48,25 @@ def alter_tree(node): def _alter_node(node): """Performing some alteration to a dataset node. A common alteration is to insert a node.""" - if isinstance(node, de.TFRecordDataset) and node.shuffle_level == de.Shuffle.GLOBAL: + if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL: # Remove the connection between the parent's node to the current node because we are inserting a node. if node.output: node.output.pop() # Perform a fast scan for average rows per file - avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) + if isinstance(node, de.TFRecordDataset): + avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) + else: + avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files) + # Shuffle between 4 files with a minimum size of 10000 rows new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000)) return new_shuffle if isinstance(node, de.MapDataset): + if node.python_multiprocessing: + # Bootstrap can only be performed on a copy of the original dataset node. + # Bootstrap on original dataset node will make all iterators share the same process pool + node.iterator_bootstrap() if node.columns_order is not None: # Remove the connection between the parent's node to the current node because we are inserting a node. if node.output: @@ -121,10 +129,14 @@ class Iterator: op_type = OpName.MINDRECORD elif isinstance(dataset, de.BatchDataset): op_type = OpName.BATCH + elif isinstance(dataset, de.SyncWaitDataset): + op_type = OpName.BARRIER elif isinstance(dataset, de.ZipDataset): op_type = OpName.ZIP elif isinstance(dataset, de.MapDataset): op_type = OpName.MAP + elif isinstance(dataset, de.FilterDataset): + op_type = OpName.FILTER elif isinstance(dataset, de.RepeatDataset): op_type = OpName.REPEAT elif isinstance(dataset, de.SkipDataset): @@ -157,6 +169,8 @@ class Iterator: op_type = OpName.CIFAR100 elif isinstance(dataset, de.CelebADataset): op_type = OpName.CELEBA + elif isinstance(dataset, de.TextFileDataset): + op_type = OpName.TEXTFILE else: raise ValueError("Unsupported DatasetOp") diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 0bba559210..82759989cb 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -16,11 +16,90 @@ Sampler module provides several samplers to generate sampling data from dataset. There are following samplers: DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. +User can also define custom sampler by extending from Sampler class. """ import mindspore._c_dataengine as cde +import numpy as np -class DistributedSampler(): + +class Sampler: + """ + Base class for user defined sampler. + User defined sampler can be used with any existing dataset with sampler support. + + An required _iter_() method should by overridden by user for sample index generation. + An optional reset() method can be overridden for per repeat reset, + + dataset_size and num_samples will be set by dataset once a dataset iterator is created. + + Examples: + >>> import mindspore.dataset as ds + >>> + >>> class ReverseSampler(ds,Sampler): + >>> def __iter__(self): + >>> for i in range(self.dataset_size - 1, -1, -1): + >>> yield i + >>> + >>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler()) + """ + + def __init__(self): + self.dataset_size = 0 + self.num_samples = 0 + + def __iter__(self): + """ + User defined iterator, must be overridden. + _handshake is guaranteed to be called prior to iterator construction + + """ + raise NotImplementedError + + def reset(self): + """ + Per repeat reset callback, override this method if necessary + """ + + # Initialization handshake callback + # Do not override this method! + def _handshake(self, ds_size, num_samples): + self.dataset_size = ds_size + self.num_samples = num_samples + + # Indices fetcher + # Do not override this method! + def _get_indices(self): + sampler_iter = iter(self) + ret = [] + for _ in range(self.num_samples): + try: + idx = next(sampler_iter) + ret.append(idx) + except StopIteration: + break + return np.array(ret) + + # Instance fetcher + # Do not override this method! + def create(self): + return cde.PythonSampler(self) + + +class BuiltinSampler: + """ + Base class for BuiltinSampler. + + User should not extend this class. + """ + def __init__(self): + pass + + def create(self): + pass + + +class DistributedSampler(BuiltinSampler): """ Sampler that access a shard of the dataset. @@ -65,7 +144,7 @@ class DistributedSampler(): return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) -class PKSampler(): +class PKSampler(BuiltinSampler): """ Samples K elements for each P class in the dataset. @@ -105,8 +184,10 @@ class PKSampler(): def create(self): return cde.PKSampler(self.num_val, self.shuffle) + def _create_for_minddataset(self): + return cde.MindrecordPkSampler(self.num_val, self.shuffle) -class RandomSampler(): +class RandomSampler(BuiltinSampler): """ Samples the elements randomly. @@ -147,7 +228,7 @@ class RandomSampler(): return cde.RandomSampler(self.replacement, self.num_samples) -class SequentialSampler(): +class SequentialSampler(BuiltinSampler): """ Samples the dataset elements sequentially, same as not having a sampler. @@ -165,7 +246,7 @@ class SequentialSampler(): return cde.SequentialSampler() -class SubsetRandomSampler(): +class SubsetRandomSampler(BuiltinSampler): """ Samples the elements randomly from a sequence of indices. @@ -196,7 +277,8 @@ class SubsetRandomSampler(): def _create_for_minddataset(self): return cde.MindrecordSubsetRandomSampler(self.indices) -class WeightedRandomSampler(): + +class WeightedRandomSampler(BuiltinSampler): """ Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 61417e4d52..f588d572bb 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -286,7 +286,8 @@ def create_node(node): elif dataset_op == 'VOCDataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), - node.get('shuffle'), node.get('decode'), sampler, node.get('distribution')) + node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'), + node.get('shard_id')) elif dataset_op == 'CelebADataset': sampler = construct_sampler(node.get('sampler')) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b74e913202..dabeb2d424 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -233,8 +233,13 @@ def make_param_dict(method, args, kwargs): params = sig.parameters keys = list(params.keys()) param_dict = dict() - for name, value in enumerate(args): - param_dict[keys[name]] = value + try: + for name, value in enumerate(args): + param_dict[keys[name]] = value + except IndexError: + raise TypeError("{0}() expected {1} arguments, but {2} were given".format( + method.__name__, len(keys) - 1, len(args) - 1)) + param_dict.update(zip(params.keys(), args)) param_dict.update(kwargs) @@ -297,9 +302,7 @@ def check_sampler_shuffle_shard_options(param_dict): shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') - if sampler is not None and not isinstance(sampler, ( - samplers.DistributedSampler, samplers.PKSampler, samplers.RandomSampler, samplers.SequentialSampler, - samplers.SubsetRandomSampler, samplers.WeightedRandomSampler)): + if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): raise ValueError("sampler is not a valid Sampler type.") if sampler is not None: @@ -445,9 +448,8 @@ def check_vocdataset(method): def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) - nreq_param_int = ['num_samples', 'num_parallel_workers'] + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] - nreq_param_str = ['distribution'] # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') @@ -459,7 +461,7 @@ def check_vocdataset(method): check_param_type(nreq_param_bool, param_dict, bool) - check_param_type(nreq_param_str, param_dict, str) + check_sampler_shuffle_shard_options(param_dict) return method(*args, **kwargs) @@ -579,11 +581,11 @@ def check_generatordataset(method): raise ValueError("PKSampler is not supported by GeneratorDataset") if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler)): + samplers.WeightedRandomSampler, samplers.Sampler)): try: iter(sampler) except TypeError: - raise TypeError("sampler should be either iterable or from dataset.samplers.py") + raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers") return method(*args, **kwargs) @@ -654,6 +656,22 @@ def check_batch(method): return new_method +def check_sync_wait(method): + """check the input arguments of sync_wait.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_str = ['condition_name'] + nreq_param_int = ['step_size'] + + check_param_type(nreq_param_int, param_dict, int) + + check_param_type(nreq_param_str, param_dict, str) + + return method(*args, **kwargs) + + return new_method def check_shuffle(method): """check the input arguments of shuffle.""" @@ -695,6 +713,26 @@ def check_map(method): return new_method +def check_filter(method): + """"check the input arguments of filter.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + predicate = param_dict.get("predicate") + if not callable(predicate): + raise ValueError("Predicate should be a python function or a callable python object.") + + nreq_param_int = ['num_parallel_workers'] + check_param_type(nreq_param_int, param_dict, int) + param_name = "input_columns" + param = param_dict.get(param_name) + if param is not None: + check_columns(param, param_name) + return method(*args, **kwargs) + + return new_method + + def check_repeat(method): """check the input arguments of repeat.""" @wraps(method) @@ -849,3 +887,25 @@ def check_add_column(method): return method(*args, **kwargs) return new_method + + +def check_textfiledataset(method): + """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + + # check dataset_files; required argument + dataset_files = param_dict.get('dataset_files') + if dataset_files is None: + raise ValueError("dataset_files is not provided.") + if not isinstance(dataset_files, (str, list)): + raise TypeError("dataset_files should be of type str or a list of strings.") + + check_param_type(nreq_param_int, param_dict, int) + + return method(*args, **kwargs) + + return new_method diff --git a/mindspore/dataset/transforms/nlp/__init__.py b/mindspore/dataset/transforms/nlp/__init__.py new file mode 100644 index 0000000000..01d425e2eb --- /dev/null +++ b/mindspore/dataset/transforms/nlp/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module is to support nlp augmentations. It includes two parts: +c_transforms and py_transforms. C_transforms is a high performance +image augmentation module which is developed with c++ opencv. Py_transforms +provide more kinds of image augmentations which is developed with python PIL. +""" +from .utils import as_text diff --git a/mindspore/dataset/transforms/nlp/utils.py b/mindspore/dataset/transforms/nlp/utils.py new file mode 100644 index 0000000000..adcc7cc71d --- /dev/null +++ b/mindspore/dataset/transforms/nlp/utils.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Some basic function for nlp +""" +import numpy as np + +def as_text(array, encoding='utf8'): + """ + Convert data of array to unicode. + + Args: + array (numpy array): Data of array should be ASCII values of each character after converted. + encoding (string): Indicating the charset for decoding. + Returns: + A 'str' object. + + """ + + if not isinstance(array, np.ndarray): + raise ValueError('input should be a numpy array') + + byte_array = bytearray(list(array)) + return byte_array.decode(encoding) diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 171eb846a8..07011b1d53 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -45,7 +45,7 @@ import mindspore._c_dataengine as cde from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ - check_resize, check_rescale, check_pad, check_cutout + check_resize, check_rescale, check_pad, check_cutout, check_uniform_augmentation DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -447,3 +447,19 @@ class Pad(cde.PadOp): fill_value = tuple([fill_value] * 3) padding_mode = DE_C_BORDER_TYPE[padding_mode] super().__init__(*padding, padding_mode, *fill_value) + + +class UniformAugment(cde.UniformAugOp): + """ + Tensor operation to perform randomly selected augmentation + + Args: + operations: list of python operations. + NumOps (int): number of OPs to be selected and applied. + """ + + @check_uniform_augmentation + def __init__(self, operations, num_ops=2): + self.operations = operations + self.num_ops = num_ops + super().__init__(operations, num_ops) diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index f5ab5d873b..51bea80b21 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -220,7 +220,7 @@ class Decode: class Normalize: """ - Normalize the input Numpy image array of shape (H, W, C) with the given mean and standard deviation. + Normalize the input Numpy image array of shape (C, H, W) with the given mean and standard deviation. The values of the array need to be in range [0.0, 1.0]. @@ -1312,3 +1312,177 @@ class HsvToRgb: rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs. """ return util.hsv_to_rgbs(hsv_imgs, self.is_hwc) + + +class RandomColor: + """ + Adjust the color of the input PIL image by a random degree. + + Args: + degrees (sequence): Range of random color adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.RandomColor(0.5,1.5), + >>> py_transforms.ToTensor()]) + """ + + def __init__(self, degrees=(0.1, 1.9)): + self.degrees = degrees + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be color adjusted. + + Returns: + img (PIL Image), Color adjusted image. + """ + + return util.random_color(img, self.degrees) + +class RandomSharpness: + """ + Adjust the sharpness of the input PIL image by a random degree. + + Args: + degrees (sequence): Range of random sharpness adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.RandomColor(0.5,1.5), + >>> py_transforms.ToTensor()]) + + """ + + def __init__(self, degrees=(0.1, 1.9)): + self.degrees = degrees + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be sharpness adjusted. + + Returns: + img (PIL Image), Color adjusted image. + """ + + return util.random_sharpness(img, self.degrees) + + +class AutoContrast: + """ + Automatically maximize the contrast of the input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.AutoContrast(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be augmented with AutoContrast. + + Returns: + img (PIL Image), Augmented image. + """ + + return util.auto_contrast(img) + + +class Invert: + """ + Invert colors of input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.Invert(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be color Inverted. + + Returns: + img (PIL Image), Color inverted image. + """ + + return util.invert_color(img) + + +class Equalize: + """ + Equalize the histogram of input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.Equalize(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be equalized. + + Returns: + img (PIL Image), Equalized image. + """ + + return util.equalize(img) + + +class UniformAugment: + """ + Uniformly select and apply a number of transforms sequentially from + a list of transforms. Randomly assigns a probability to each transform for + each image to decide whether apply it or not. + + Args: + transforms (list): List of transformations to be chosen from to apply. + num_ops (int, optional): number of transforms to sequentially apply (default=2). + + Examples: + >>> transforms_list = [py_transforms.CenterCrop(64), + >>> py_transforms.RandomColor(), + >>> py_transforms.RandomSharpness(), + >>> py_transforms.RandomRotation(30)] + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.UniformAugment(transforms_list), + >>> py_transforms.ToTensor()]) + """ + + def __init__(self, transforms, num_ops=2): + self.transforms = transforms + self.num_ops = num_ops + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be applied transformation. + + Returns: + img (PIL Image), Transformed image. + """ + return util.uniform_augment(img, self.transforms, self.num_ops) diff --git a/mindspore/dataset/transforms/vision/py_transforms_util.py b/mindspore/dataset/transforms/vision/py_transforms_util.py index 10c71bbe38..54fb4c8274 100644 --- a/mindspore/dataset/transforms/vision/py_transforms_util.py +++ b/mindspore/dataset/transforms/vision/py_transforms_util.py @@ -1408,3 +1408,160 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc): if batch_size == 0: return hsv_to_rgb(np_hsv_imgs, is_hwc) return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs]) + + +def random_color(img, degrees): + + """ + Adjust the color of the input PIL image by a random degree. + + Args: + img (PIL Image): Image to be color adjusted. + degrees (sequence): Range of random color adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Returns: + img (PIL Image), Color adjusted image. + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence length 2.") + if degrees[0] < 0: + raise ValueError("Degree value must be non-negative.") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") + + else: + raise TypeError("Degrees must be a sequence in (min,max) format.") + + v = (degrees[1] - degrees[0]) * random.random() + degrees[0] + return ImageEnhance.Color(img).enhance(v) + + +def random_sharpness(img, degrees): + + """ + Adjust the sharpness of the input PIL image by a random degree. + + Args: + img (PIL Image): Image to be sharpness adjusted. + degrees (sequence): Range of random sharpness adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Returns: + img (PIL Image), Sharpness adjusted image. + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence length 2.") + if degrees[0] < 0: + raise ValueError("Degree value must be non-negative.") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") + + else: + raise TypeError("Degrees must be a sequence in (min,max) format.") + + v = (degrees[1] - degrees[0]) * random.random() + degrees[0] + return ImageEnhance.Sharpness(img).enhance(v) + + +def auto_contrast(img): + + """ + Automatically maximize the contrast of the input PIL image. + + Args: + img (PIL Image): Image to be augmented with AutoContrast. + + Returns: + img (PIL Image), Augmented image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.autocontrast(img) + + +def invert_color(img): + + """ + Invert colors of input PIL image. + + Args: + img (PIL Image): Image to be color inverted. + + Returns: + img (PIL Image), Color inverted image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.invert(img) + + +def equalize(img): + + """ + Equalize the histogram of input PIL image. + + Args: + img (PIL Image): Image to be equalized + + Returns: + img (PIL Image), Equalized image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.equalize(img) + + +def uniform_augment(img, transforms, num_ops): + + """ + Uniformly select and apply a number of transforms sequentially from + a list of transforms. Randomly assigns a probability to each transform for + each image to decide whether apply it or not. + + Args: + img: Image to be applied transformation. + transforms (list): List of transformations to be chosen from to apply. + num_ops (int): number of transforms to sequentially aaply. + + Returns: + img, Transformed image. + """ + + if transforms is None: + raise ValueError("transforms is not provided.") + if not isinstance(transforms, list): + raise ValueError("The transforms needs to be a list.") + + if not isinstance(num_ops, int): + raise ValueError("Number of operations should be a positive integer.") + if num_ops < 1: + raise ValueError("Number of operators should equal or greater than one.") + + for _ in range(num_ops): + AugmentOp = random.choice(transforms) + pr = random.random() + if random.random() < pr: + img = AugmentOp(img.copy()) + transforms.remove(AugmentOp) + + return img diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index ef4b879f8c..713d9c5714 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -812,3 +812,36 @@ def check_rescale(method): return method(self, **kwargs) return new_method + + +def check_uniform_augmentation(method): + """Wrapper method to check the parameters of UniformAugmentation.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + operations, num_ops = (list(args) + 2 * [None])[:2] + if "operations" in kwargs: + operations = kwargs.get("operations") + else: + raise ValueError("operations list required") + if "num_ops" in kwargs: + num_ops = kwargs.get("num_ops") + else: + num_ops = 2 + + if num_ops <= 0: + raise ValueError("num_ops should be greater than zero") + if num_ops > len(operations): + raise ValueError("num_ops is greater than operations list size") + if not isinstance(operations, list): + raise ValueError("operations is not a python list") + for op in operations: + if not callable(op): + raise ValueError("non-callable op in operations list") + + kwargs["num_ops"] = num_ops + kwargs["operations"] = operations + + return method(self, **kwargs) + + return new_method diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 90bca48038..62bcc2df79 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -200,13 +200,24 @@ class FileWriter: raw_data.pop(i) logger.warning(v) - def write_raw_data(self, raw_data): + def open_and_set_header(self): + """ + Open writer and set header + + """ + if not self._writer.is_open: + self._writer.open(self._paths) + if not self._writer.get_shard_header(): + self._writer.set_shard_header(self._header) + + def write_raw_data(self, raw_data, parallel_writer=False): """ Write raw data and generate sequential pair of MindRecord File and \ validate data based on predefined schema by default. Args: raw_data (list[dict]): List of raw data. + parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). Raises: ParamTypeError: If index field is invalid. @@ -225,7 +236,7 @@ class FileWriter: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') self._verify_based_on_schema(raw_data) - return self._writer.write_raw_data(raw_data, True) + return self._writer.write_raw_data(raw_data, True, parallel_writer) def set_header_size(self, header_size): """ diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0ef23d4ce6..0913201861 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -135,7 +135,7 @@ class ShardWriter: def get_shard_header(self): return self._header - def write_raw_data(self, data, validate=True): + def write_raw_data(self, data, validate=True, parallel_writer=False): """ Write raw data of cv dataset. @@ -145,6 +145,7 @@ class ShardWriter: Args: data (list[dict]): List of raw data. validate (bool, optional): verify data according schema if it equals to True. + parallel_writer (bool, optional): Load data parallel if it equals to True. Returns: MSRStatus, SUCCESS or FAILED. @@ -165,7 +166,7 @@ class ShardWriter: if row_raw: raw_data.append(row_raw) raw_data = {0: raw_data} if raw_data else {} - ret = self._writer.write_raw_data(raw_data, blob_data, validate) + ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer) if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to write dataset.") raise MRMWriteDatasetError diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 5507d12af8..a694489f5a 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -103,6 +103,10 @@ class Cell: def parameter_layout_dict(self): return self._parameter_layout_dict + @property + def cls_name(self): + return self.__class__.__name__ + @parameter_layout_dict.setter def parameter_layout_dict(self, value): if not isinstance(value, dict): diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index cf25f1f50e..00e6a45901 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -15,7 +15,7 @@ """dynamic learning rate""" import math -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel @@ -28,11 +28,11 @@ def piecewise_constant_lr(milestone, learning_rates): `milestone`. Let the output learning rate be `y`. .. math:: - y[i] = x_t for i \in [M_{t-1}, M_t) + y[i] = x_t,\ for\ i \in [M_{t-1}, M_t) Args: - milestone (list[int]): A list of milestone. This list is a monotone increasing list. - learning_rates (list[float]): A list of learning rates. + milestone (Union[list[int], tuple[int]]): A list of milestone. This list is a monotone increasing list. + learning_rates (Union[list[float], tuple[float]]): A list of learning rates. Returns: list[float]. The size of list is :math:`M_N`. @@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates): >>> lr = piecewise_constant_lr(milestone, learning_rates) [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] """ - validator.check_type('milestone', milestone, (tuple, list)) - validator.check_type('learning_rates', learning_rates, (tuple, list)) + validator.check_value_type('milestone', milestone, (tuple, list), None) + validator.check_value_type('learning_rates', learning_rates, (tuple, list), None) if len(milestone) != len(learning_rates): raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') lr = [] last_item = 0 for i, item in enumerate(milestone): - validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT) - validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float]) + validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) + validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None) if item < last_item: raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') lr += [learning_rates[i]] * (item - last_item) @@ -62,12 +62,14 @@ def piecewise_constant_lr(milestone, learning_rates): def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) - validator.check_float_positive('learning_rate', learning_rate) - validator.check_float_positive('decay_rate', decay_rate) - validator.check_type('is_stair', is_stair, [bool]) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) + validator.check_float_positive('decay_rate', decay_rate, None) + validator.check_float_legal_value('decay_rate', decay_rate, None) + validator.check_value_type('is_stair', is_stair, [bool], None) def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): @@ -228,11 +230,15 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] """ - validator.check_float_positive('min_lr', min_lr) - validator.check_float_positive('max_lr', max_lr) - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) + validator.check_float_positive('min_lr', min_lr, None) + validator.check_float_legal_value('min_lr', min_lr, None) + validator.check_float_positive('max_lr', max_lr, None) + validator.check_float_legal_value('max_lr', max_lr, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + if min_lr >= max_lr: + raise ValueError('`max_lr` should be greater than `min_lr`.') delta = 0.5 * (max_lr - min_lr) lr = [] @@ -251,11 +257,11 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e .. math:: decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) * - (1 - tmp\_epoch / decay\_epoch)^{power} + end\_learning\_rate + (1 - tmp\_epoch / tmp\_decay\_epoch)^{power} + end\_learning\_rate - Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch), current\_epoch=floor(\frac{i}{step\_per\_epoch})`. - If `update_decay_epoch` is true, update the value of `decay_epoch` every epoch. The formula is - :math:`decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)` + Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})`, and + :math:`tmp\_decay\_epoch = decay\_epoch`. If `update_decay_epoch` is true, update the value of `tmp_decay_epoch` + every epoch. The formula is :math:`tmp\_decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)` Args: learning_rate (float): The initial value of learning rate. @@ -263,7 +269,7 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e total_step (int): The total number of steps. step_per_epoch (int): The number of steps in per epoch. decay_epoch (int): A value used to calculate decayed learning rate. - power (float): A value used to calculate decayed learning rate. + power (float): A value used to calculate decayed learning rate. This parameter should be greater than 0. update_decay_epoch (bool): If true, update `decay_epoch`. Default: False. Returns: @@ -279,17 +285,21 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] """ - validator.check_float_positive('learning_rate', learning_rate) - validator.check_float_positive('end_learning_rate', end_learning_rate) - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) - validator.check_type('power', power, [float]) - validator.check_type('update_decay_epoch', update_decay_epoch, [bool]) - + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) + validator.check_float_positive('end_learning_rate', end_learning_rate, None) + validator.check_float_legal_value('end_learning_rate', end_learning_rate, None) + validator.check_float_positive('power', power, None) + validator.check_float_legal_value('power', power, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) + + origin_decay_epoch = decay_epoch function = lambda x, y: (x, min(x, y)) if update_decay_epoch: - function = lambda x, y: (x * max(math.ceil(y / x), 1), y) + function = lambda x, y: (origin_decay_epoch * max(math.ceil(y / origin_decay_epoch), 1), y) lr = [] delta = learning_rate - end_learning_rate @@ -298,3 +308,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e decay_epoch, tmp_epoch = function(decay_epoch, current_epoch) lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate) return lr + + +__all__ = [ + 'piecewise_constant_lr', + 'exponential_decay_lr', + 'natural_exp_decay_lr', + 'inverse_decay_lr', + 'cosine_decay_lr', + 'polynomial_decay_lr' +] diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 098489a91d..b9f79b6cf7 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -24,7 +24,7 @@ from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold from .embedding import Embedding -from .pooling import AvgPool2d, MaxPool2d +from .pooling import AvgPool2d, MaxPool2d, AvgPool1d from .image import ImageGradients, SSIM, PSNR __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', @@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'LSTM', 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Embedding', - 'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold', + 'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'Pad', 'Unfold', 'ImageGradients', 'SSIM', 'PSNR', ] diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 6485e27228..8845247a65 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -346,7 +346,7 @@ class HSwish(Cell): where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. Inputs: - - **input_data** (Tensor) - The input of Hswish. + - **input_data** (Tensor) - The input of HSwish. Outputs: Tensor, with the same type and shape as the `input_data`. diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 5ac52acac7..2449eea9b4 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from ..cell import Cell from .activation import get_activation -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator class Dropout(Cell): @@ -73,7 +73,7 @@ class Dropout(Cell): super(Dropout, self).__init__() if keep_prob <= 0 or keep_prob > 1: raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) - validator.check_subclass("dtype", dtype, mstype.number_type) + validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.keep_prob = Tensor(keep_prob) self.seed0 = seed0 self.seed1 = seed1 @@ -421,7 +421,7 @@ class Pad(Cell): super(Pad, self).__init__() self.mode = mode self.paddings = paddings - validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) + validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name) if not isinstance(paddings, tuple): raise TypeError('Paddings must be tuple type.') for item in paddings: diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index dfa8e66469..24b94f2f3c 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -19,7 +19,7 @@ from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from ..cell import Cell -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator class Embedding(Cell): @@ -59,7 +59,7 @@ class Embedding(Cell): """ def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): super(Embedding, self).__init__() - validator.check_subclass("dtype", dtype, mstype.number_type) + validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.vocab_size = vocab_size self.embedding_size = embedding_size self.use_one_hot = use_one_hot diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 72c4c6d8e2..b46ac4cd6e 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops.primitive import constexpr -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ..cell import Cell @@ -134,15 +134,15 @@ class SSIM(Cell): """ def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(SSIM, self).__init__() - validator.check_type('max_val', max_val, [int, float]) - validator.check('max_val', max_val, '', 0.0, Rel.GT) + validator.check_value_type('max_val', max_val, [int, float], self.cls_name) + validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val - self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE) - self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma) - validator.check_type('k1', k1, [float]) - self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_type('k2', k2, [float]) - self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER) + self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) + self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) + validator.check_value_type('k1', k1, [float], self.cls_name) + self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) + validator.check_value_type('k2', k2, [float], self.cls_name) + self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) def construct(self, img1, img2): @@ -231,8 +231,8 @@ class PSNR(Cell): """ def __init__(self, max_val=1.0): super(PSNR, self).__init__() - validator.check_type('max_val', max_val, [int, float]) - validator.check('max_val', max_val, '', 0.0, Rel.GT) + validator.check_value_type('max_val', max_val, [int, float], self.cls_name) + validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val def construct(self, img1, img2): diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index cef926d365..84c156a1c2 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -17,7 +17,7 @@ from mindspore.ops import operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator class LSTM(Cell): @@ -114,7 +114,7 @@ class LSTM(Cell): self.hidden_size = hidden_size self.num_layers = num_layers self.has_bias = has_bias - self.batch_first = validator.check_type("batch_first", batch_first, [bool]) + self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) self.dropout = float(dropout) self.bidirectional = bidirectional diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 6456a3603d..3ef2381ba1 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -33,7 +33,6 @@ class _BatchNorm(Cell): @cell_attr_register def __init__(self, num_features, - group=1, eps=1e-5, momentum=0.9, affine=True, @@ -41,7 +40,8 @@ class _BatchNorm(Cell): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=True): + use_batch_statistics=True, + group=1): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") @@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) >>> net(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True): + super(BatchNorm1d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics) def _check_data_dim(self, x): if x.dim() != 2: pass @@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> net(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True): + super(BatchNorm2d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics) def _check_data_dim(self, x): if x.dim() != 4: pass @@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> global_bn_op(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True, + group=1): + super(GlobalBatchNorm, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics, + group) + self.group = check_int_positive(group) + if self.group <= 1: + raise ValueError("the number of group must be greater than 1.") def _check_data_dim(self, x): if x.dim == 0: pass diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 746b6d240f..6cf06de029 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -14,45 +14,39 @@ # ============================================================================ """pooling""" from mindspore.ops import operations as P -from mindspore._checkparam import ParamValidator as validator -from mindspore._checkparam import Rel +from mindspore.ops import functional as F +from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell +from ..._checkparam import Rel +from ..._checkparam import ParamValidator class _PoolNd(Cell): """N-D AvgPool""" def __init__(self, kernel_size, stride, pad_mode): - name = self.__class__.__name__ super(_PoolNd, self).__init__() - validator.check_type('kernel_size', kernel_size, [int, tuple]) - validator.check_type('stride', stride, [int, tuple]) - self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) - - if isinstance(kernel_size, int): - validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) - else: - if (len(kernel_size) != 2 or - (not isinstance(kernel_size[0], int)) or - (not isinstance(kernel_size[1], int)) or - kernel_size[0] <= 0 or - kernel_size[1] <= 0): - raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or' - f'a tuple of two positive int numbers, but got {kernel_size}') - self.kernel_size = kernel_size - - if isinstance(stride, int): - validator.check_integer("stride", stride, 1, Rel.GE) - else: - if (len(stride) != 2 or - (not isinstance(stride[0], int)) or - (not isinstance(stride[1], int)) or - stride[0] <= 0 or - stride[1] <= 0): - raise ValueError(f'The stride passed to cell {name} should be an positive int number or' - f'a tuple of two positive int numbers, but got {stride}') - self.stride = stride + self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name) + + def _check_int_or_tuple(arg_name, arg_value): + validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name) + error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \ + f'a tuple of two positive int numbers, but got {arg_value}' + if isinstance(arg_value, int): + if arg_value <= 0: + raise ValueError(error_msg) + elif len(arg_value) == 2: + for item in arg_value: + if isinstance(item, int) and item > 0: + continue + raise ValueError(error_msg) + else: + raise ValueError(error_msg) + return arg_value + + self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size) + self.stride = _check_int_or_tuple('stride', stride) def construct(self, *inputs): pass @@ -217,3 +211,81 @@ class AvgPool2d(_PoolNd): def construct(self, x): return self.avg_pool(x) + + +class AvgPool1d(_PoolNd): + r""" + Average pooling for temporal data. + + Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes. + + Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs + regional average in the :math:`(W_{in})`-dimension. Given kernel size + :math:`ks = w_{ker}` and stride :math:`s = s_0`, the operation is as follows. + + .. math:: + \text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1} + \text{input}(N_i, C_j, h_k, s_0 \times w + n) + + Note: + pad_mode for training only supports "same" and "valid". + + Args: + kernel_size (int): The size of kernel window used to take the average value, Default: 1. + stride (int): The distance of kernel moving, an int number that represents + the width of movement is strides, Default: 1. + pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". + + - same: Adopts the way of completion. Output height and width will be the same as + the input. Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. + + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. + + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> pool = nn.AvgPool1d(kernel_size=3, strides=1) + >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) + >>> output = pool(x) + >>> output.shape() + (1, 2, 4, 2) + """ + + def __init__(self, + kernel_size=1, + stride=1, + pad_mode="valid"): + super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) + ParamValidator.check_type('kernel_size', kernel_size, [int,]) + ParamValidator.check_type('stride', stride, [int,]) + self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) + ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) + ParamValidator.check_integer("stride", stride, 1, Rel.GE) + self.kernel_size = (1, kernel_size) + self.stride = (1, stride) + self.avg_pool = P.AvgPool(ksize=self.kernel_size, + strides=self.stride, + padding=self.pad_mode) + self.shape = F.shape + self.reduce_mean = P.ReduceMean(keep_dims=True) + self.slice = P.Slice() + + def construct(self, x): + batch, channel, high, width = self.shape(x) + if width == self.kernel_size[1]: + x = self.reduce_mean(x, 3) + elif width - self.kernel_size[1] < self.stride[1]: + x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1])) + x = self.reduce_mean(x, 3) + else: + x = self.avg_pool(x) + return x diff --git a/mindspore/nn/metrics/fbeta.py b/mindspore/nn/metrics/fbeta.py index 68df4318b0..3ae5c44bc2 100755 --- a/mindspore/nn/metrics/fbeta.py +++ b/mindspore/nn/metrics/fbeta.py @@ -15,7 +15,7 @@ """Fbeta.""" import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .metric import Metric @@ -104,7 +104,7 @@ class Fbeta(Metric): Returns: Float, computed result. """ - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') diff --git a/mindspore/nn/metrics/precision.py b/mindspore/nn/metrics/precision.py index ad7b6c576f..633b9f8e2c 100644 --- a/mindspore/nn/metrics/precision.py +++ b/mindspore/nn/metrics/precision.py @@ -17,7 +17,7 @@ import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .evaluation import EvaluationBase @@ -136,7 +136,7 @@ class Precision(EvaluationBase): if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) result = self._true_positives / (self._positives + self.eps) if average: diff --git a/mindspore/nn/metrics/recall.py b/mindspore/nn/metrics/recall.py index 45ebf0d7db..da06321aa3 100644 --- a/mindspore/nn/metrics/recall.py +++ b/mindspore/nn/metrics/recall.py @@ -17,7 +17,7 @@ import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .evaluation import EvaluationBase @@ -136,7 +136,7 @@ class Recall(EvaluationBase): if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) result = self._true_positives / (self._actual_positives + self.eps) if average: diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index eb4e33751f..4e88c3ef93 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -22,7 +22,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer @@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad return next_v -def _check_param_value(beta1, beta2, eps, weight_decay): +def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - validator.check_type("beta1", beta1, [float]) - validator.check_type("beta2", beta2, [float]) - validator.check_type("eps", eps, [float]) - validator.check_type("weight_dacay", weight_decay, [float]) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", @@ -101,17 +101,6 @@ def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, ep return success -@adam_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", - "Tensor") -def _run_opt_with_two_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1, - moment2): - """Apply adam optimizer to the weight parameter using Tensor.""" - success = True - success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, - eps, gradient)) - return success - - class Adam(Optimizer): r""" Updates gradients by Adaptive Moment Estimation (Adam) algorithm. @@ -168,11 +157,11 @@ class Adam(Optimizer): use_nesterov=False, weight_decay=0.0, loss_scale=1.0, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) - _check_param_value(beta1, beta2, eps, weight_decay) - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("use_nesterov", use_nesterov, [bool]) - validator.check_type("loss_scale", loss_scale, [float]) - validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) + validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) + validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) self.beta1 = Tensor(beta1, mstype.float32) self.beta2 = Tensor(beta2, mstype.float32) @@ -183,7 +172,6 @@ class Adam(Optimizer): self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') - self.decay_tf = tuple(decay_filter(x) for x in self.parameters) self.hyper_map = C.HyperMap() self.opt = P.Adam(use_locking, use_nesterov) @@ -241,7 +229,7 @@ class AdamWeightDecay(Optimizer): """ def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): super(AdamWeightDecay, self).__init__(learning_rate, params) - _check_param_value(beta1, beta2, eps, weight_decay) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) self.lr = Tensor(np.array([learning_rate]).astype(np.float32)) self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) @@ -304,7 +292,7 @@ class AdamWeightDecayDynamicLR(Optimizer): eps=1e-6, weight_decay=0.0): super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) - _check_param_value(beta1, beta2, eps, weight_decay) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations self.global_step = Parameter(initializer(0, [1]), name="global_step") diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index ee8fc9355f..2bc329f42d 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -18,41 +18,42 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common import Tensor import mindspore.common.dtype as mstype -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer, apply_decay, grad_scale ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") -@ftrl_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") +@ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): """Apply ftrl optimizer to the weight parameter.""" success = True success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) return success -def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0): - validator.check_type("initial_accum", initial_accum, [float]) - validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE) +def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0, + prim_name=None): + validator.check_value_type("initial_accum", initial_accum, [float], prim_name) + validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) - validator.check_type("learning_rate", learning_rate, [float]) - validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT) + validator.check_value_type("learning_rate", learning_rate, [float], prim_name) + validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name) - validator.check_type("lr_power", lr_power, [float]) - validator.check("lr_power", lr_power, "", 0.0, Rel.LE) + validator.check_value_type("lr_power", lr_power, [float], prim_name) + validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name) - validator.check_type("l1", l1, [float]) - validator.check("l1", l1, "", 0.0, Rel.GE) + validator.check_value_type("l1", l1, [float], prim_name) + validator.check_number("l1", l1, 0.0, Rel.GE, prim_name) - validator.check_type("l2", l2, [float]) - validator.check("l2", l2, "", 0.0, Rel.GE) + validator.check_value_type("l2", l2, [float], prim_name) + validator.check_number("l2", l2, 0.0, Rel.GE, prim_name) - validator.check_type("use_locking", use_locking, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], prim_name) - validator.check_type("loss_scale", loss_scale, [float]) - validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE) + validator.check_value_type("loss_scale", loss_scale, [float], prim_name) + validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name) - validator.check_type("weight_decay", weight_decay, [float]) - validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE) + validator.check_value_type("weight_decay", weight_decay, [float], prim_name) + validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name) class FTRL(Optimizer): @@ -94,7 +95,8 @@ class FTRL(Optimizer): use_locking=False, loss_scale=1.0, weight_decay=0.0): super(FTRL, self).__init__(learning_rate, params) - _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay) + _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, + self.cls_name) self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.linear = self.parameters.clone(prefix="linear", init='zeros') self.l1 = l1 diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index e74d6fc6a8..afcbf8cda4 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -21,7 +21,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer from .. import layer @@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para def _check_param_value(decay_steps, warmup_steps, start_learning_rate, - end_learning_rate, power, beta1, beta2, eps, weight_decay): + end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - validator.check_type("decay_steps", decay_steps, [int]) - validator.check_type("warmup_steps", warmup_steps, [int]) - validator.check_type("start_learning_rate", start_learning_rate, [float]) - validator.check_type("end_learning_rate", end_learning_rate, [float]) - validator.check_type("power", power, [float]) - validator.check_type("beta1", beta1, [float]) - validator.check_type("beta2", beta2, [float]) - validator.check_type("eps", eps, [float]) - validator.check_type("weight_dacay", weight_decay, [float]) - validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_value_type("decay_steps", decay_steps, [int], prim_name) + validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name) + validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) + validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) + validator.check_value_type("power", power, [float], prim_name) + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) class Lamb(Optimizer): @@ -182,7 +182,7 @@ class Lamb(Optimizer): super(Lamb, self).__init__(start_learning_rate, params) _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, - power, beta1, beta2, eps, weight_decay) + power, beta1, beta2, eps, weight_decay, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations self.global_step = Parameter(initializer(0, [1]), name="global_step") diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 02538aa61a..73451f3bf5 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -43,23 +43,6 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f return gradient -@lars_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Bool", "Bool") -def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): - """Apply lars optimizer to the weight parameter.""" - if lars_flag: - op_reduce = P.ReduceSum() - w_square_sum = op_reduce(F.square(weight)) - grad_square_sum = op_reduce(F.square(gradient)) - if decay_flag: - grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate) - else: - num_zero = 0.0 - grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, num_zero, learning_rate) - return grad_t - - return gradient - - class LARS(Optimizer): """ Implements the LARS algorithm with LARSUpdate Operator. diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index bac8e74a42..c69e226df9 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -15,19 +15,13 @@ """momentum""" from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype from .optimizer import Optimizer momentum_opt = C.MultitypeFuncGraph("momentum_opt") -@momentum_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, moment): - """Apply momentum optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) - return success - - @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): """Apply momentum optimizer to the weight parameter using Tensor.""" @@ -36,14 +30,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): return success -@momentum_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, moment): - """Apply momentum optimizer to the weight parameter using dynamic learning rate.""" - success = True - success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) - return success - - class Momentum(Optimizer): """ Implements the Momentum algorithm. @@ -86,7 +72,7 @@ class Momentum(Optimizer): super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) - self.momentum = Parameter(momentum, name="momentum") + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 00d3fd3b7b..bab539461e 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -22,7 +22,8 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer -from mindspore._checkparam import ParamValidator as validator +import mindspore.common.dtype as mstype +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common.tensor import Tensor from mindspore import log as logger @@ -45,8 +46,10 @@ class Optimizer(Cell): learning_rate (float): A floating point value for the learning rate. Should be greater than 0. parameters (list): A list of parameter, which will be updated. The element in `parameters` should be class mindspore.Parameter. - weight_decay (float): A floating point value for the weight decay. Default: 0.0. - loss_scale (float): A floating point value for the loss scale. Default: 1.0. Should be greater than 0. + weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay` + input is int, it will be convertd to float. Default: 0.0. + loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the + type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda x: 'beta' not in x.name and 'gamma' not in x.name. @@ -63,7 +66,8 @@ class Optimizer(Cell): self.gather = None self.assignadd = None self.global_step = None - validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + learning_rate = Tensor(learning_rate, mstype.float32) else: self.dynamic_lr = True self.gather = P.GatherV2() @@ -84,14 +88,12 @@ class Optimizer(Cell): if isinstance(weight_decay, int): weight_decay = float(weight_decay) - if not isinstance(weight_decay, float): - raise TypeError("weight_decay should be a float number!") + validator.check_float_legal_value('weight_decay', weight_decay, None) if isinstance(loss_scale, int): loss_scale = float(loss_scale) - if not isinstance(loss_scale, float): - raise TypeError("loss_scale should be a float number!") + validator.check_float_legal_value('loss_scale', loss_scale, None) if loss_scale <= 0.0: raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) @@ -175,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: - return op_add((gradient, weight * weight_decay)) + return op_add((weight * weight_decay, gradient)) return gradient diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index a68dc6f7c4..a8f118b709 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -14,41 +14,24 @@ # ============================================================================ """rmsprop""" from mindspore.ops import functional as F, composite as C, operations as P -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .optimizer import Optimizer rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") -@rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): - """Apply rmsprop optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) - return success - - @rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): +def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): """Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" success = True success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) return success -@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor") -def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): - """Apply centered rmsprop optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) - return success - - @centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): +def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" success = True success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) @@ -144,8 +127,8 @@ class RMSProp(Optimizer): self.decay = decay self.epsilon = epsilon - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("centered", centered, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("centered", centered, [bool], self.cls_name) self.centered = centered if centered: self.opt = P.ApplyCenteredRMSProp(use_locking) diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 983be4bf80..cda5aa904a 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -15,20 +15,14 @@ """sgd""" from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.parameter import Parameter -from mindspore._checkparam import ParamValidator as validator +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +from mindspore._checkparam import Validator as validator from .optimizer import Optimizer sgd_opt = C.MultitypeFuncGraph("sgd_opt") -@sgd_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, accum, stat): - """Apply sgd optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat)) - return success - - @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat): """Apply sgd optimizer to the weight parameter using Tensor.""" @@ -37,14 +31,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, s return success -@sgd_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, accum, stat): - """Apply sgd optimizer to the weight parameter using dynamic learning rate.""" - success = True - success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat)) - return success - - class SGD(Optimizer): """ Implements stochastic gradient descent (optionally with momentum). @@ -100,12 +86,12 @@ class SGD(Optimizer): raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) self.dampening = dampening - validator.check_type("nesterov", nesterov, [bool]) + validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) self.nesterov = nesterov self.opt = P.SGD(dampening, weight_decay, nesterov) - self.momentum = Parameter(momentum, name="momentum") + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.accum = self.parameters.clone(prefix="accum", init='zeros') self.stat = self.parameters.clone(prefix="stat", init='ones') self.hyper_map = C.HyperMap() diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 64c382557a..641558921a 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -13,17 +13,10 @@ # limitations under the License. # ============================================================================ """Cell_wrapper.""" -import copy - -import numpy as np - from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, _get_parallel_mode) from mindspore.train.parallel_utils import ParallelMode - -from ...common import Tensor from ...common import dtype as mstype -from ...common.initializer import initializer from ...common.parameter import Parameter, ParameterTuple from ...ops import composite as C from ...ops import functional as F @@ -348,25 +341,8 @@ class ParameterUpdate(Cell): super(ParameterUpdate, self).__init__(auto_prefix=False) if not isinstance(param, Parameter): raise TypeError("`param` must be `Parameter`, but got {}".format(param)) - - default_input = param.default_input - if isinstance(default_input, Tensor): - shape = default_input.shape() - zero_dtype = default_input.dtype() - elif isinstance(default_input, float): - shape = [1] - zero_dtype = mstype.float32 - elif isinstance(default_input, int): - shape = [1] - zero_dtype = mstype.int32 - else: - raise TypeError("`default_input` in `param` must be Tensor, float or int, but got {}".format(default_input)) - - self._param = Parameter(initializer(copy.deepcopy(default_input), shape), param.name) - self._param.is_init = True - self._zero = Tensor(np.zeros(shape), zero_dtype) + self._param = param def construct(self, x): - zero = self._param + self._zero - F.control_depend(zero, F.assign(self._param, x)) - return zero + F.assign(self._param, x) + return x diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index ba8e6cbb7c..65d66f0150 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -209,6 +209,7 @@ class TrainOneStepWithLossScaleCell(Cell): self.gpu_target = True self.float_status = P.FloatStatus() self.addn = P.AddN() + self.reshape = P.Reshape() else: self.gpu_target = False self.alloc_status = NPUAllocFloatStatus() @@ -260,6 +261,8 @@ class TrainOneStepWithLossScaleCell(Cell): else: flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) flag_sum = self.addn(flag_sum) + # convert flag_sum to scalar + flag_sum = self.reshape(flag_sum, (())) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 2d819718c8..c334050218 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -17,6 +17,7 @@ from functools import reduce +import numpy as np from .. import functional as F from .. import operations as P from ..operations import _grad_ops as G @@ -333,6 +334,23 @@ def get_bprop_log(self): return bprop +@bprop_getters.register(P.Erf) +def get_bprop_erf(self): + """Grad definition for `Erf` operation.""" + exp = P.Exp() + square = P.Square() + sqrt = P.Sqrt() + cast = P.Cast() + dtype = P.DType() + + def bprop(x, out, dout): + half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x)) + x_square = square(x) + dx = dout * half_root_pi * exp(-x_square) + return (dx,) + return bprop + + @bprop_getters.register(P.Pow) def get_bprop_pow(self): """Grad definition for `Pow` operation.""" diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index ae730d78a7..6db059a7bb 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -227,6 +227,18 @@ def get_bprop_relu6(self): return bprop +@bprop_getters.register(P.ReLUV2) +def get_bprop_relu_v2(self): + """Grad definition for `ReLUV2` operation.""" + input_grad = G.ReluGradV2() + + def bprop(x, out, dout): + mask = out[1] + dx = input_grad(dout[0], mask) + return (dx,) + return bprop + + @bprop_getters.register(P.HSwish) def get_bprop_hswish(self): """Grad definition for `HSwish` operation.""" @@ -344,12 +356,10 @@ def get_bprop_batch_norm(self): if is_training: saved_reserve_1 = out[3] saved_reserve_2 = out[4] - saved_reserve_3 = out[5] else: saved_reserve_1 = mean saved_reserve_2 = variance - saved_reserve_3 = variance - out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3) + out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2) dx = out[0] dscale = out[1] dbias = out[2] @@ -456,6 +466,17 @@ def get_bprop_smooth_l1_loss(self): return bprop +@bprop_getters.register(P.L2Loss) +def get_bprop_l2_loss(self): + """Grad definition for `L2Loss` operation.""" + + def bprop(x, out, dout): + dx = x * dout + return (dx,) + + return bprop + + @bprop_getters.register(P.PReLU) def get_bprop_prelu(self): """Grad definition for `PReLU` operation.""" diff --git a/mindspore/ops/_op_impl/akg/gpu/__init__.py b/mindspore/ops/_op_impl/akg/gpu/__init__.py index 2135794b5f..08beb44340 100644 --- a/mindspore/ops/_op_impl/akg/gpu/__init__.py +++ b/mindspore/ops/_op_impl/akg/gpu/__init__.py @@ -23,3 +23,12 @@ from .relu6_grad import _relu6_grad_akg from .squeeze import _squeeze_akg from .squeeze_grad import _squeeze_grad_akg from .tile import _tile_akg +from .hsigmoid import _hsigmoid_akg +from .hsigmoid_grad import _hsigmoid_grad_akg +from .hswish import _hswish_akg +from .hswish_grad import _hswish_grad_akg +from .sub import _sub_akg +from .logical_and import _logical_and_akg +from .logical_not import _logical_not_akg +from .logical_or import _logical_or_akg +from .lessequal import _lessequal_akg diff --git a/mindspore/ops/_op_impl/akg/gpu/lessequal.py b/mindspore/ops/_op_impl/akg/gpu/lessequal.py new file mode 100644 index 0000000000..a3e4d4dc35 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/lessequal.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LessEqual op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +equal_op_info = AkgRegOp("LessEqual") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(equal_op_info) +def _lessequal_akg(): + """LessEqual register""" + return diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_and.py b/mindspore/ops/_op_impl/akg/gpu/logical_and.py new file mode 100644 index 0000000000..da5b696512 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/logical_and.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LogicalAnd op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +logicaland_op_info = AkgRegOp("LogicalAnd") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(logicaland_op_info) +def _logical_and_akg(): + """LogicalAnd register""" + return diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_not.py b/mindspore/ops/_op_impl/akg/gpu/logical_not.py new file mode 100644 index 0000000000..4b3c7bf647 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/logical_not.py @@ -0,0 +1,28 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LogicalNot op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +logical_not_op_info = AkgRegOp("LogicalNot") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(logical_not_op_info) +def _logical_not_akg(): + """LogicalNot AutoDiff register""" + return diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_or.py b/mindspore/ops/_op_impl/akg/gpu/logical_or.py new file mode 100644 index 0000000000..3a642511c6 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/logical_or.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LogicalOr op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +logicalor_op_info = AkgRegOp("LogicalOr") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(logicalor_op_info) +def _logical_or_akg(): + """LogicalOr register""" + return diff --git a/mindspore/ops/_op_impl/akg/gpu/sub.py b/mindspore/ops/_op_impl/akg/gpu/sub.py new file mode 100644 index 0000000000..06b92fb49e --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/sub.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sub op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +sub_op_info = AkgRegOp("Sub") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + +@op_info_register(sub_op_info) +def _sub_akg(): + """Sub AutoDiff register""" + return diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 2cffc37491..8030aac5c6 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -33,6 +33,7 @@ from .cast import _cast_tbe from .conv2d import _conv2d_tbe from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe from .conv2d_backprop_input import _conv2d_backprop_input_tbe +from .confusion_mul_grad import _confusion_mul_grad_tbe from .dropout_do_mask import _dropout_do_mask_tbe from .gelu import _gelu_tbe from .gelu_grad import _gelu_grad_tbe @@ -46,6 +47,8 @@ from .relu import _relu_tbe from .relu_grad import _relu_grad_tbe from .relu6 import _relu6_tbe from .relu6_grad import _relu6_grad_tbe +from .relu_v2 import _relu_v2_tbe +from .relu_grad_v2 import _relu_grad_v2_tbe from .softmax_cross_entropy_with_logits import _softmax_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe @@ -117,6 +120,7 @@ from .layer_norm_beta_gamma_backprop import _layer_norm_beta_gamma_backprop_tbe from .layer_norm import _layer_norm_tbe from .layer_norm_grad import _layer_norm_grad_tbe from .layer_norm_x_backprop import _layer_norm_x_backprop_tbe +from .l2_loss import _l2_loss_tbe from .square_sum_v1 import _square_sum_v1_tbe from .square_sum_v2 import _square_sum_v2_tbe from .confusion_transpose_d import _confusion_transpose_d_tbe @@ -138,6 +142,8 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe from .fused_mul_add import _fused_mul_add_tbe from .fused_mul_add_n import _fused_mul_add_n_tbe from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe +from .fill_d import _fill_d_op_tbe +from .erf import _erf_op_tbe from .depthwise_conv2d import _depthwise_conv2d_tbe from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe diff --git a/mindspore/ops/_op_impl/tbe/assign_add.py b/mindspore/ops/_op_impl/tbe/assign_add.py index fbbb9a997f..2b20a7781d 100644 --- a/mindspore/ops/_op_impl/tbe/assign_add.py +++ b/mindspore/ops/_op_impl/tbe/assign_add.py @@ -25,7 +25,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \ .partial_flag(True) \ .input(0, "ref", False, "required", "all") \ .input(1, "value", False, "required", "all") \ - .output(0, "output_ref", False, "required", "all") \ + .output(0, "ref", False, "required", "all") \ .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py b/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py new file mode 100644 index 0000000000..e49d5386f2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ConfusionMulGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +confusion_mul_grad_op_info = TBERegOp("ConfusionMulGrad") \ + .fusion_type("OPAQUE") \ + .attr("axis", "required", "listInt", "all") \ + .attr("keep_dims", "required", "bool", "all") \ + .input(0, "input0", False, "required", "all") \ + .input(1, "input1", False, "required", "all") \ + .input(2, "input2", False, "required", "all") \ + .output(0, "output0", False, "required", "all") \ + .output(1, "output1", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(confusion_mul_grad_op_info) +def _confusion_mul_grad_tbe(): + """ConfusionMulGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py b/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py index e32e99d888..04b55bb2a3 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +++ b/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py @@ -25,7 +25,7 @@ conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \ .partial_flag(True) \ .attr("filter_sizes", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ - .attr("pad_mode", "required", "str", "all") \ + .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ .input(0, "out_backprop", False, "required", "all") \ .input(1, "x", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py b/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py index 2c1dd6aea2..7756cb3ae4 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +++ b/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py @@ -25,7 +25,7 @@ conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \ .partial_flag(True) \ .attr("input_sizes", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ - .attr("pad_mode", "required", "str", "all") \ + .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ .input(0, "out_backprop", False, "required", "all") \ .input(1, "filter", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py index c19a311009..f4d8069b12 100644 --- a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +++ b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py @@ -26,7 +26,7 @@ depthwise_conv2d_backprop_filter_op_info = TBERegOp("DepthwiseConv2dNativeBackpr .attr("filter_size", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ - .attr("pads", "required", "str", "all") \ + .attr("pads", "required", "listInt", "all") \ .attr("data_format", "required", "str", "all") \ .input(0, "input", False, "required", "all") \ .input(1, "out_backprop", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py index 9e671f18e2..61c1406b32 100644 --- a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +++ b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py @@ -26,7 +26,7 @@ depthwise_conv2d_backprop_input_op_info = TBERegOp("DepthwiseConv2dNativeBackpro .attr("input_size", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ - .attr("pads", "required", "str", "all") \ + .attr("pads", "required", "listInt", "all") \ .attr("data_format", "required", "str", "all") \ .input(0, "filter", False, "required", "all") \ .input(1, "out_backprop", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/erf.py b/mindspore/ops/_op_impl/tbe/erf.py new file mode 100644 index 0000000000..2247197c4e --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/erf.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Erf op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +erf_op_info = TBERegOp("Erf") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("erf.so") \ + .compute_cost(10) \ + .kernel_name("erf") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(erf_op_info) +def _erf_op_tbe(): + """Erf TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/fill_d.py b/mindspore/ops/_op_impl/tbe/fill_d.py new file mode 100644 index 0000000000..97c6b73cf5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fill_d.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FillD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fill_d_op_info = TBERegOp("FillD") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fill_d.so") \ + .compute_cost(10) \ + .kernel_name("fill_d") \ + .partial_flag(True) \ + .attr("dims", "required", "listInt", "all") \ + .input(0, "value", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \ + .dtype_format(DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \ + .dtype_format(DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(fill_d_op_info) +def _fill_d_op_tbe(): + """FillD TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/l2_loss.py b/mindspore/ops/_op_impl/tbe/l2_loss.py new file mode 100644 index 0000000000..7d1394ad64 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/l2_loss.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""L2Loss op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +l2_loss_op_info = TBERegOp("L2Loss") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("l2_loss.so") \ + .compute_cost(10) \ + .kernel_name("l2_loss") \ + .partial_flag(True) \ + .input(0, "x", None, "required", None) \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_Default) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(l2_loss_op_info) +def _l2_loss_tbe(): + """L2Loss TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/relu_grad_v2.py b/mindspore/ops/_op_impl/tbe/relu_grad_v2.py new file mode 100644 index 0000000000..93d7dede62 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/relu_grad_v2.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReluGradV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +relu_grad_v2_op_info = TBERegOp("ReluGradV2") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("relu_grad_v2.so") \ + .compute_cost(10) \ + .kernel_name("relu_grad_v2") \ + .partial_flag(True) \ + .input(0, "gradients", False, "required", "all") \ + .input(1, "mask", False, "rerequired", "all") \ + .output(0, "backprops", True, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ + .dtype_format(DataType.I32_5HD, DataType.U8_Default, DataType.I32_5HD) \ + .dtype_format(DataType.I8_5HD, DataType.U8_Default, DataType.I8_5HD) \ + .dtype_format(DataType.U8_5HD, DataType.U8_Default, DataType.U8_5HD) \ + .get_op_info() + + +@op_info_register(relu_grad_v2_op_info) +def _relu_grad_v2_tbe(): + """ReluGradV2 TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/relu_v2.py b/mindspore/ops/_op_impl/tbe/relu_v2.py new file mode 100644 index 0000000000..c03858c1a7 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/relu_v2.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReluV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +relu_v2_op_info = TBERegOp("ReLUV2") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("relu_v2.so") \ + .compute_cost(10) \ + .kernel_name("relu_v2") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .output(1, "mask", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U8_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.U8_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.U8_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.U8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(relu_v2_op_info) +def _relu_v2_tbe(): + """ReluV2 TBE register""" + return diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index fbd81c4f0d..90496afc9b 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -15,7 +15,7 @@ """utils for operator""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype @@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): return broadcast_shape -def _get_concat_offset(x_shp, x_type, axis): +def _get_concat_offset(x_shp, x_type, axis, prim_name): """for concat and concatoffset check args and compute offset""" - validator.check_type("shape", x_shp, [tuple]) - validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) - validator.check_subclass("shape0", x_type[0], mstype.tensor) - validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) + validator.check_value_type("shape", x_shp, [tuple], prim_name) + validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) + validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) + validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name) rank_base = len(x_shp[0]) - validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) + validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) if axis < 0: axis = axis + rank_base all_shp = x_shp[0][axis] offset = [0,] for i in range(1, len(x_shp)): v = x_shp[i] - validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) - validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) + validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) + validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name) for j in range(rank_base): if j != axis and v[j] != x_shp[0][j]: - raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) + raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") offset.append(all_shp) all_shp += v[axis] return offset, all_shp, axis diff --git a/mindspore/ops/composite/multitype_ops/__init__.py b/mindspore/ops/composite/multitype_ops/__init__.py index 40bf71d49a..b7f4f671b8 100644 --- a/mindspore/ops/composite/multitype_ops/__init__.py +++ b/mindspore/ops/composite/multitype_ops/__init__.py @@ -23,6 +23,7 @@ from .pow_impl import pow_ from .floordiv_impl import floordiv from .mod_impl import mod from .getitem_impl import getitem +from .setitem_impl import setitem from .zeros_like_impl import zeros_like from .ones_like_impl import ones_like from .equal_impl import equal @@ -55,6 +56,7 @@ __all__ = [ 'greater_equal', 'negative', 'getitem', + 'setitem', 'logical_and', 'logical_or', 'logical_not' diff --git a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py new file mode 100644 index 0000000000..b3687c553c --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""constexpr util""" + +from ...primitive import constexpr + + +@constexpr +def is_same_type(inst, type_): + """ + Check whether an object is an instance of a target type. + + Inputs: + inst (mindspore.dtype): Inspected type. + type_ (mindspore.dtype): Target type. + + Outputs: + bool, the check result. + """ + return inst == type_ + + +@constexpr +def error_msg(msg="", format_values=""): + """ + Used to throw exception information. + + Inputs: + msg (str): information content. + """ + + raise ValueError(msg.format(*format_values)) diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index b2b46ebbb1..56617c06a8 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index): @getitem.register("Tensor", "Slice") def _tensor_getitem_by_slice(data, slice_index): """ - Getting item of tensor by slice index. + Getting item of tensor by slice. Inputs: data (Tensor): A tensor. @@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index): @getitem.register("Tensor", "Tuple") def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): """ - Getting item of tensor by slice tuple index. + Getting item of tensor by slice tuple. Inputs: data (Tensor): A tensor. @@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): Tensor, element type is same as the element type of data. """ return _tensor_slice(data, slice_tuple_index) + + +@getitem.register("Tensor", "Ellipsis") +def _tensor_getitem_by_ellipsis(data, ellipsis_index): + """ + Getting item of tensor by Ellipsis. + + Inputs: + data (Tensor): A tensor. + ellipsis (Ellipsis): A Ellipsis object. + + Outputs: + Tensor, same as data. + """ + return _tensor_slice(data, ellipsis_index) diff --git a/mindspore/ops/composite/multitype_ops/not_equal_impl.py b/mindspore/ops/composite/multitype_ops/not_equal_impl.py index de099a2b8f..7196f370cb 100644 --- a/mindspore/ops/composite/multitype_ops/not_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/not_equal_impl.py @@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y): @not_equal.register("Tuple", "Tuple") -def _euqal_tuple(x, y): +def _not_euqal_tuple(x, y): """ Determine if two tuples are not equal by element. @@ -147,7 +147,7 @@ def _euqal_tuple(x, y): @not_equal.register("List", "List") -def _euqal_list(x, y): +def _not_euqal_list(x, y): """ Determine if two lists are not equal by element. @@ -162,7 +162,7 @@ def _euqal_list(x, y): @not_equal.register("Tuple", "None") -def _tuple_euqal_none(x, y): +def _tuple_not_euqal_none(x, y): """ Determine if tuple element not equals none element. @@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y): """ return True + @not_equal.register("Tensor", "Number") @not_equal.register("Number", "Tensor") @not_equal.register("Tensor", "Tensor") @@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y): bool, return True. """ return True + + +@not_equal.register("List", "None") +def _list_not_equal_none(x, y): + """ + Determine if list not equal none. + + Args: + x (list): The first input which is a list. + y (none): The second input which is none. + + Returns: + bool, return true. + """ + return True + + +@not_equal.register("None", "List") +def _none_not_equal_list(x, y): + """ + Determine if none not equal list. + + Args: + x (none): The first input which is none. + y (list): The second input which is a list. + + Returns: + bool, return true. + """ + return True diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py new file mode 100644 index 0000000000..31c96932c5 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -0,0 +1,194 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Implementation for setitem.""" + +from ...composite import base +from ....common import dtype as mstype +from ... import functional as F +from . import _multitype_ops_util as mult_util + +setitem = base.MultitypeFuncGraph('setitem') + +@setitem.register("List", "Number", "String") +def _list_setitem_with_string(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (String): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "Number") +def _list_setitem_with_number(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (Number): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "Tensor") +def _list_setitem_with_Tensor(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (Tensor): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "List") +def _list_setitem_with_List(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (List): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("Dictionary", "String", "Tensor") +def _dict_setitem_with_tensor(data, key, value): + """ + Assign value to dictionary. + + Inputs: + data (Dictionary): Data of type dict. + key (str): Key of the data. + value (Tensor): Value given. + + Outputs: + Dict, type is as same as the element type of data. + """ + return F.dict_setitem(data, key, value) + + +@setitem.register("Dictionary", "String", "Number") +def _dict_setitem_with_number(data, key, value): + """ + Assign value to dictionary. + + Inputs: + data (Dictionary): Data of type dict. + key (str): Key of the data. + value (Number): Value given. + + Outputs: + Dict, type is as same as the element type of data. + """ + return F.dict_setitem(data, key, value) + + +@setitem.register("Tensor", "Tensor", "Tensor") +def _tensor_setitem_by_tensor_v1(data, index, value_tensor): + """ + Tensor assignment. + + Note: + Syntax support: A[B] = U and A[A>n] = U. + Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor. + 2) A.shape == B.shape + 3) U.size == 1 + 4) n is a number + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value_tensor (Tensor): Tensor with size 1. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_dtype = F.dtype(index) + index_shape = F.shape(index) + is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) + if not is_bool: + return mult_util.error_msg( + "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) + data_shape = F.shape(data) + if index_shape != data_shape: + return mult_util.error_msg( + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) + size = F.size(value_tensor) + if size != 1: + return mult_util.error_msg( + "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) + dtype = F.dtype(data) + u_cast = F.cast(value_tensor, dtype) + one_data = F.ones_like(data) + u = F.tensor_mul(one_data, u_cast) + return F.select(index, u, data) + + +@setitem.register("Tensor", "Tensor", "Number") +def _tensor_setitem_by_tensor_v2(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[B] = u and A[A>n] = u. + Restraint condition: 1) A is a Tensor, and B is a bool Tensor. + 2) A.shape == B.shape + 3) u is a scalar + 4) n is a number + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value_tensor (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_dtype = F.dtype(index) + index_shape = F.shape(index) + is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) + if not is_bool: + return mult_util.error_msg( + "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) + shape = F.shape(data) + if index_shape != shape: + return mult_util.error_msg( + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) + dtype = F.dtype(data) + u = F.fill(dtype, shape, value) + return F.select(index, u, data) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 611c569553..c5b8752ae2 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -31,6 +31,9 @@ dtype = P.DType() issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() fill = P.Fill() +select = P.Select() +size = P.Size() +ones_like = P.OnesLike() shape = P.Shape() rank = P.Rank() reshape = P.Reshape() @@ -63,12 +66,15 @@ scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() tuple_to_array = P.TupleToArray() scalar_cast = P.ScalarCast() - +print_ = P.Print() +expand_dims = P.ExpandDims() tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') list_getitem = Primitive('list_getitem') +list_setitem = Primitive('list_setitem') dict_getitem = Primitive('dict_getitem') +dict_setitem = Primitive('dict_setitem') tuple_div = Primitive("tuple_div") tuple_len = Primitive("tuple_len") tuple_reversed = Primitive("tuple_reversed") diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index e4b0bfdbfe..752b367023 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -19,7 +19,7 @@ import os import json import inspect from mindspore._c_expression import Oplib -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator # path of built-in op info register. BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" @@ -43,7 +43,7 @@ def op_info_register(op_info): op_info_real = json.dumps(op_info) else: op_info_real = op_info - validator.check_type("op_info", op_info_real, [str]) + validator.check_value_type("op_info", op_info_real, [str], None) op_lib = Oplib() file_path = os.path.realpath(inspect.getfile(func)) # keep the path custom ops implementation. diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1f0ee8a04d..c75c2031d7 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import ScalarCast from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, - Cos, Div, Equal, EqualCount, Exp, Floor, FloorDiv, FloorMod, Acosh, + Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh, Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd, LogicalNot, LogicalOr, MatMul, Maximum, Minimum, Mul, Neg, NMSWithMask, NotEqual, @@ -55,11 +55,11 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, DropoutDoMask, DropoutGenMask, Flatten, FusedBatchNorm, Gelu, Elu, - GetNext, L2Normalize, LayerNorm, + GetNext, L2Normalize, LayerNorm, L2Loss, LogSoftmax, MaxPool, ExtractImagePatches, - AvgPool, Conv2DBackpropInput, - MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, + AvgPool, Conv2DBackpropInput, ConfusionMulGrad, + MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, SmoothL1Loss, Softmax, @@ -101,6 +101,7 @@ __all__ = [ 'LogSoftmax', 'SoftmaxCrossEntropyWithLogits', 'ROIAlign', + 'ConfusionMulGrad', 'SparseSoftmaxCrossEntropyWithLogits', 'SGD', 'ApplyMomentum', @@ -138,7 +139,9 @@ __all__ = [ 'Split', 'ReLU', 'ReLU6', + 'ReLUV2', 'Elu', + 'Erf', 'Sigmoid', 'HSwish', 'HSigmoid', @@ -167,6 +170,7 @@ __all__ = [ 'FloatStatus', 'Reciprocal', 'SmoothL1Loss', + 'L2Loss', 'ReduceAll', 'ScalarToArray', 'ScalarToTensor', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 48d1a2a89c..782784ca00 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -18,8 +18,7 @@ from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_int_positive, check_bool +from ..._checkparam import Validator as validator, Rel from .._utils import _get_concat_offset from ...common import dtype as mstype @@ -51,12 +50,12 @@ class ACosGrad(PrimitiveWithInfer): """init ACosGrad""" def infer_shape(self, x, dout): - validator.check_param_equal("x", x, "dout", dout) + validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) return x def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x @@ -65,15 +64,15 @@ class BatchNormGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5): - self.is_training = validator.check_type('is_training', is_training, (bool,)) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) + self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) + self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") - def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) - def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): + def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) @@ -93,21 +92,22 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): """Computes gradients for `BinaryCrossEntropy` operation.""" @prim_attr_register def __init__(self, reduction='mean'): - self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum']) + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape): - validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape) + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) if weight_shape: - validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape) + validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, y_type, doutput_type, weight_type): args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) if weight_type: - validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) + validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) return x_type + class ConcatOffset(PrimitiveWithInfer): """primitive for computing Concat's gradient.""" @@ -119,7 +119,7 @@ class ConcatOffset(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - offset, _, axis = _get_concat_offset(x_shp, x_type, axis) + offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) offset_values = [] for i in range(len(x_shp)): @@ -183,16 +183,16 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): def __infer__(self, doutput, x, w_size): w_size_v = w_size['value'] - validator.check_type('w_size', w_size_v, [tuple]) + validator.check_value_type('w_size', w_size_v, [tuple], self.name) for i, dim_len in enumerate(w_size_v): - validator.check_type("w_size[%d]" % i, dim_len, [int]) - validator.check_typename('x_dtype', x['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'x_dtype', x['dtype']) + validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) + args = {"x": x['dtype'], "doutput": doutput['dtype']} + validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name) out = { 'value': None, 'shape': w_size_v, 'dtype': doutput['dtype'], - } + } return out @@ -249,8 +249,8 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): def __infer__(self, x, w_size, dout): w_size_v = w_size['value'] - args = {'x_dtype': x['dtype'], 'dout_type': dout['dtype']} - validator.check_type_same(args, mstype.number_type) + args = {'x': x['dtype'], 'dout': dout['dtype']} + validator.check_tensor_type_same(args, mstype.number_type, self.name) out = { 'value': None, 'shape': w_size_v, @@ -309,8 +309,8 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer): raise NotImplementedError def __infer__(self, x_size, w, dout): - args = {'w_dtype': w['dtype'], 'dout_type': dout['dtype']} - validator.check_type_same(args, mstype.number_type) + args = {'w': w['dtype'], 'dout': dout['dtype']} + validator.check_tensor_type_same(args, mstype.number_type, self.name) x_size_v = x_size['value'] out = { 'value': None, @@ -332,7 +332,7 @@ class FlattenGrad(PrimitiveWithInfer): 'value': None, 'shape': args[1]['value'], 'dtype': args[0]['dtype'], - } + } return out @@ -359,9 +359,9 @@ class GeluGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): - validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("y_dtype", y_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -372,56 +372,36 @@ class _PoolGrad(PrimitiveWithInfer): def __init__(self, ksize, strides, padding="VALID"): self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) - validator.check_type('ksize', ksize, [int, tuple]) - validator.check_type('strides', strides, [int, tuple]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + validator.check_value_type('ksize', ksize, [int, tuple], self.name) + validator.check_value_type('strides', strides, [int, tuple], self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax") if not self.is_maxpoolgradwithargmax: self.add_prim_attr('data_format', "NCHW") - if isinstance(ksize, int): - validator.check_integer("ksize", ksize, 1, Rel.GE) - if self.is_maxpoolgradwithargmax: - self.ksize = (1, ksize, ksize, 1) - else: - self.ksize = (1, 1, ksize, ksize) - else: - ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number" - f"or a tuple of two or four positive int numbers, but got {ksize}") - if len(ksize) != 2 and len(ksize) != 4: - raise ksize_error - for ksize_val in ksize: - if not isinstance(ksize_val, int) or (ksize_val <= 0): - raise ksize_error - if len(ksize) == 2 and self.is_maxpoolgradwithargmax: - self.ksize = (1, ksize[0], ksize[1], 1) - elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax: - self.ksize = (1, 1, ksize[0], ksize[1]) + def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax): + validator.check_value_type(arg_name, arg_val, (int, tuple), self.name) + error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number " + f"or a tuple of two or four positive int numbers, but got {arg_val}") + if isinstance(arg_val, int): + ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val) + elif len(arg_val) == 2: + ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1]) + elif len(arg_val) == 4: + ret = arg_val else: - self.ksize = ksize + raise error_msg + # whether all elements of tuple are positive integers + for item in ret: + if not isinstance(item, int) or item <= 0: + raise error_msg + return ret + + self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax) self.add_prim_attr("ksize", self.ksize) - if isinstance(strides, int): - validator.check_integer("strides", strides, 1, Rel.GE) - if self.is_maxpoolgradwithargmax: - self.strides = (1, strides, strides, 1) - else: - self.strides = (1, 1, strides, strides) - else: - strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number" - f"or a tuple of two or four positive int numbers, but got {strides}") - if len(strides) != 2 and len(strides) != 4: - raise strides_error - for strides_val in strides: - if not isinstance(strides_val, int) or (strides_val <= 0): - raise strides_error - if len(strides) == 2 and self.is_maxpoolgradwithargmax: - self.strides = (1, strides[0], strides[1], 1) - elif len(strides) == 2 and not self.is_maxpoolgradwithargmax: - self.strides = (1, 1, strides[0], strides[1]) - else: - self.strides = strides + self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax) self.add_prim_attr("strides", self.strides) @@ -528,17 +508,17 @@ class L2NormalizeGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0, epsilon=1e-4): - validator.check_type('axis', axis, [int]) - validator.check_type('epsilon', epsilon, [int, float]) + validator.check_value_type('axis', axis, [int], self.name) + validator.check_value_type('epsilon', epsilon, [int, float], self.name) def infer_shape(self, input_x, out, dout): - validator.check_param_equal('input_x', input_x, 'out', out) - validator.check_param_equal('input_x', input_x, 'dout', dout) + validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name) + validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name) return input_x def infer_dtype(self, input_x, out, dout): args = {'input_x': input_x, 'out': out, 'dout': dout} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return input_x @@ -559,8 +539,8 @@ class LayerNormGrad(Primitive): @prim_attr_register def __init__(self, begin_norm_axis=1, begin_params_axis=1): """init""" - self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int]) - self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int]) + self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) + self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) def __call__(self, x, dy, variance, mean, gamma): raise NotImplementedError @@ -572,15 +552,15 @@ class LogSoftmaxGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): """init LogSoftmaxGrad""" - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def infer_shape(self, dout, logits): rank = len(logits) - validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH) + validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name) return logits def infer_dtype(self, dout, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -589,13 +569,13 @@ class LSTMGradData(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -605,19 +585,19 @@ class LSTMGradData(PrimitiveWithInfer): def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape, hx_shape, cx_shape, reserve_shape, state_shape): # dhy and dcy should be same shape - validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ) - validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ) - validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ) - validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ) - validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ) + validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name) + validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name) + validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name) + validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name) + validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name) - validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ) - validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ) + validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) + validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name) # dy: (seq_len, batch_size, hidden_size * num_directions) - validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ) - validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ) - validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ) + validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name) + validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name) + validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name) # (seq_len, batch_size, input_size) dx_shape = (y_shape[0], y_shape[1], self.input_size) @@ -628,11 +608,8 @@ class LSTMGradData(PrimitiveWithInfer): def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, hx_dtype, cx_dtype, reserve_dtype, state_dtype): - validator.check_typename("dy_dtype", dy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("dhy_dtype", dhy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("dcy_dtype", dcy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("datatype", dy_dtype, (dhy_dtype.element_type(),)) - validator.check_typename("datatype", dy_dtype, (dcy_dtype.element_type(),)) + args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} + validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) return (dy_dtype, dy_dtype, dy_dtype) @@ -641,13 +618,13 @@ class LSTMGradWeight(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -692,9 +669,10 @@ class PReLUGrad(PrimitiveWithInfer): return y_backprop_shape, w_shape def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): - validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("A_dtype", A_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("w_dtype", w_dtype, (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name) return y_backprop_dtype, w_dtype @@ -724,11 +702,30 @@ class ReLU6Grad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad_dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype +class ReluGradV2(PrimitiveWithInfer): + """Performs grad of ReLUV2 operation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output']) + + def __call__(self, gradients, mask): + raise NotImplementedError + + def infer_shape(self, gradients_shape, mask_shape): + return gradients_shape + + def infer_dtype(self, gradients_dtype, mask_dtype): + validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name) + validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name) + return gradients_dtype + + class EluGrad(PrimitiveWithInfer): """Performs grad of Elu operation.""" @@ -740,10 +737,8 @@ class EluGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - args_type = {'y_grad': y_grad_dtype, 'x': x_dtype} - validator.check_args_tensor(args_type) - args_dtype = {'y_grad_dtype': y_grad_dtype, 'x_dtype': x_dtype} - validator.check_type_same(args_dtype, mstype.float_type) + args = {'y_grad': y_grad_dtype, 'x': x_dtype} + validator.check_tensor_type_same(args, mstype.float_type, self.name) return x_dtype @@ -799,11 +794,11 @@ class ROIAlignGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2): """init ROIAlignGrad""" - validator.check_type("pooled_height", pooled_height, [int]) - validator.check_type("pooled_width", pooled_width, [int]) - validator.check_type("spatial_scale", spatial_scale, [float]) - validator.check_type("sample_num", sample_num, [int]) - validator.check_type("xdiff_shape", xdiff_shape, [tuple]) + validator.check_value_type("pooled_height", pooled_height, [int], self.name) + validator.check_value_type("pooled_width", pooled_width, [int], self.name) + validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) + validator.check_value_type("sample_num", sample_num, [int], self.name) + validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name) self.xdiff_shape = xdiff_shape self.pooled_height = pooled_height self.pooled_width = pooled_width @@ -828,10 +823,8 @@ class SigmoidGrad(PrimitiveWithInfer): return out def infer_dtype(self, out, dout): - validator.check_typename("dout dtype", dout, (mstype.float16, mstype.float32)) - validator.check_typename("out dtype", out, (mstype.float16, mstype.float32)) - args = {"out type": out, "dout type": dout} - validator.check_type_same(args, mstype.number_type) + args = {'out': out, 'dout': dout} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return out @@ -846,8 +839,8 @@ class HSigmoidGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -862,8 +855,8 @@ class HSwishGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -876,13 +869,13 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad']) def infer_shape(self, x_shape, y_shape, dout_shape): - validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape) - validator.check_param_equal("x_shape", x_shape, "dout_shape", dout_shape) + validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) + validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_dtype, y_dtype, dout_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return dout_dtype @@ -898,8 +891,8 @@ class SliceGrad(PrimitiveWithInfer): dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value'] dy_shape_len = len(dy_shape) for i in range(dy_shape_len): - validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE) - validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ) + validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name) + validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name) return {'shape': x_shape, 'dtype': x['dtype'], 'value': None} @@ -913,13 +906,13 @@ class SmoothL1LossGrad(PrimitiveWithInfer): pass def infer_shape(self, prediction, target, dloss): - validator.check_param_equal('prediction', prediction, 'target', target) - validator.check_param_equal('prediction', prediction, 'dloss', dloss) + validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) + validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name) return prediction def infer_dtype(self, prediction, target, dloss): args = {"prediction": prediction, "target": target, 'dloss': dloss} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return dloss @@ -946,11 +939,11 @@ class StridedSliceGrad(PrimitiveWithInfer): new_axis_mask=0, shrink_axis_mask=0): """init StrideSliceGrad""" - validator.check_type('begin_mask', begin_mask, [int]) - validator.check_type('end_mask', end_mask, [int]) - validator.check_type('ellipsis_mask', ellipsis_mask, [int]) - validator.check_type('new_axis_mask', new_axis_mask, [int]) - validator.check_type('shrink_axis_mask', shrink_axis_mask, [int]) + validator.check_value_type('begin_mask', begin_mask, [int], self.name) + validator.check_value_type('end_mask', end_mask, [int], self.name) + validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) + validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) + validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) def __infer__(self, dy, shapex, begin, end, strides): @@ -970,10 +963,8 @@ class TanhGrad(PrimitiveWithInfer): return out def infer_dtype(self, out, dout): - validator.check_subclass("out", out, mstype.tensor) - validator.check_subclass("dout", dout, mstype.tensor) - args = {"out type": out, "dout type": dout} - validator.check_type_same(args, mstype.number_type) + args = {"out": out, "dout": dout} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return out @@ -983,13 +974,13 @@ class MirrorPadGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, mode="REFLECT"): """init MirrorPad""" - validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) self.mode = mode def __infer__(self, dout, paddings, x): - validator.check_subclass("dout", dout['dtype'], mstype.tensor) - validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) - validator.check_subclass("input_x", x['dtype'], mstype.tensor) + validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name) + validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) return {'shape': x['shape'], 'dtype': dout['dtype'], 'value': None} diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 14d1bc9234..4c7d64b581 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -15,8 +15,8 @@ """Operators for quantization.""" -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_bool, check_int_positive, check_int +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ..primitive import PrimitiveWithInfer, prim_attr_register from ...common import dtype as mstype @@ -69,36 +69,31 @@ class FakeQuantWithMinMax(PrimitiveWithInfer): training=True): """init FakeQuantWithMinMax OP""" if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") if ema and not ema_decay: - raise ValueError( - "Attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = check_bool(ema) - self.symmetric = check_bool(symmetric) - self.narrow_range = check_bool(narrow_range) - self.training = check_bool(training) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) - self.num_bits = check_int_positive(num_bits) - self.quant_delay = check_int(quant_delay) + raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type('training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x shape", len(x_shape), 1, Rel.GT) - validator.check("min shape", min_shape, "max shape", max_shape) - validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) - validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): - validator.check_typename( - "x type", x_type, (mstype.float16, mstype.float32)) - validator.check_typename("min type", min_type, - (mstype.float16, mstype.float32)) - validator.check_typename("max type", max_type, - (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return x_type @@ -109,29 +104,24 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, num_bits=8, quant_delay=0): if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - self.quant_delay = check_int(quant_delay) - self.num_bits = check_int_positive(num_bits) - self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], - outputs=['dx']) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): - validator.check("dout shape", dout_shape, "x shape", x_shape) - validator.check("min shape", min_shape, "max shape", max_shape) - validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) - validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) + validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - validator.check_typename( - "dout type", dout_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "x type", x_type, (mstype.float16, mstype.float32)) - validator.check_typename("min type", min_type, - (mstype.float16, mstype.float32)) - validator.check_typename("max type", max_type, - (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return dout_type @@ -172,37 +162,30 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): training=True): """init FakeQuantWithMinMaxPerChannel OP""" if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.") if ema and not ema_decay: - raise ValueError( - "Attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = check_bool(ema) - self.symmetric = check_bool(symmetric) - self.narrow_range = check_bool(narrow_range) - self.training = check_bool(training) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) - self.num_bits = check_int_positive(num_bits) - self.quant_delay = check_int(quant_delay) - self.init_prim_io_names(inputs=['x', 'min', 'max'], - outputs=['out']) + raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type('training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x shape", len(x_shape), 1, Rel.GT) - validator.check_integer( - "min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ) - validator.check_integer( - "max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ) + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) + validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): - validator.check_typename( - "x type", x_type, (mstype.float16, mstype.float32)) - validator.check_typename("min type", min_type, - (mstype.float16, mstype.float32)) - validator.check_typename("max type", max_type, - (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return x_type @@ -214,12 +197,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): def __init__(self, num_bits=8, quant_delay=0): """init FakeQuantWithMinMaxPerChannel Fill""" if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - self.quant_delay = check_int(quant_delay) - self.num_bits = check_int_positive(num_bits) - self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], - outputs=['dx']) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): validator.check("dout shape", dout_shape, "x shape", x_shape) @@ -227,13 +209,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - validator.check_typename( - "dout", dout_type, (mstype.float16, mstype.float32)) - validator.check_typename("x", x_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "min", min_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "max", max_type, (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return dout_type @@ -269,31 +249,26 @@ class BatchNormFold(PrimitiveWithInfer): @prim_attr_register def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0): """init batch norm fold layer""" - self.momentum = validator.check_number_range( - 'momentum', momentum, 0, 1, Rel.INC_BOTH) - self.epsilon = validator.check_float_positive('epsilon', epsilon) - self.is_training = check_bool(is_training) - self.freeze_bn = check_int(freeze_bn) + self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) + self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) + self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) + self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'], outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std']) def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): - validator.check("mean shape", mean_shape, - "gamma_shape", variance_shape) - validator.check("mean_shape size", - mean_shape[0], "input channel", x_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) + validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return mean_shape, mean_shape, mean_shape, mean_shape def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "variance type", variance_type) - validator.check_typename("input type", x_type, - (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"x": x_type, "mean": mean_type, "variance": variance_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return x_type, x_type, x_type, x_type @@ -304,39 +279,31 @@ class BatchNormFoldGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0): """init BatchNormGrad layer""" - self.is_training = check_bool(is_training) - self.freeze_bn = check_int(freeze_bn) - self.epsilon = validator.check_float_positive('epsilon', epsilon) + self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) + self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) + self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], outputs=['dx']) def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape, global_step_shape): validator.check("d_batch_mean shape", d_batch_mean_shape, - "d_batch_std shape", d_batch_std_shape) + "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name) validator.check("d_batch_mean shape", d_batch_mean_shape, - "batch_mean shape", batch_mean_shape) + "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) validator.check("d_batch_mean shape", d_batch_mean_shape, - "batch_std shape", batch_std_shape) - validator.check( - "x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + "batch_std shape", batch_std_shape, Rel.EQ, self.name) + validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, + self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, global_step_type): - validator.check("input type", x_type, - "d_batch_mean type", d_batch_mean_type) - validator.check("input type", x_type, - "d_batch_std type", d_batch_std_type) - validator.check("input type", x_type, - "batch_mean type", batch_mean_type) - validator.check("input type", x_type, "batch_std type", batch_std_type) - validator.check_typename("input type", x_type, - (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, + "batch_mean": batch_mean_type, "batch_std": batch_std_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return x_type @@ -364,18 +331,14 @@ class CorrectionMul(PrimitiveWithInfer): outputs=['out']) def infer_shape(self, x_shape, batch_std_shape, running_std_shape): - validator.check("batch_std shape", batch_std_shape, - "running_std shape", running_std_shape) - validator.check( - "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) + validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) + validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], + Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, batch_std_type, running_std_type): - validator.check("batch_std type", batch_std_type, - "running_std type", running_std_type) - validator.check("batch_std_type", batch_std_type, "x_type", x_type) - validator.check_typename( - "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) + args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return x_type @@ -390,20 +353,16 @@ class CorrectionMulGrad(PrimitiveWithInfer): outputs=['dx', 'd_gamma']) def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): - validator.check("dout shape", dout_shape, "x_shape x", x_shape) - validator.check( - "gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel]) - validator.check( - "running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel]) + validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) + validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel], + Rel.EQ, self.name) + validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel], + Rel.EQ, self.name) return x_shape, gamma_shape def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): - validator.check("x type", x_type, "dout type", dout_type) - validator.check("gamma type", gamma_type, "dout type", dout_type) - validator.check("running_std type", running_std_type, - "dout type", dout_type) - validator.check_typename( - "dout type", dout_type, (mstype.float16, mstype.float32)) + args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return x_type, x_type @@ -432,46 +391,29 @@ class BatchNormFold2(PrimitiveWithInfer): @prim_attr_register def __init__(self, freeze_bn=0): """init conv2d fold layer""" - self.freeze_bn = check_int(freeze_bn) + self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std', 'running_mean', 'global_step'], outputs=['y']) def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape, running_mean_shape, global_step_shape): - validator.check("batch_std shape", batch_std_shape, - "running_std shape", running_std_shape) - validator.check("batch_std shape", batch_std_shape, - "batch_mean shape", batch_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "beta shape", beta_shape) - validator.check("batch_std shape", batch_std_shape, - "running_mean shape", running_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "batch_mean shape", gamma_shape) - validator.check( - "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) + validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], + Rel.EQ, self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, running_mean_type, global_step_type): - validator.check("batch_std type", batch_std_type, - "running_std type", running_std_type) - validator.check("batch_std type", batch_std_type, - "batch_mean type", batch_mean_type) - validator.check("batch_std type", batch_std_type, - "beta type", beta_type) - validator.check("batch_std type", batch_std_type, - "running_mean type", running_mean_type) - validator.check("batch_std type", batch_std_type, - "gamma type", gamma_type) - validator.check("x_type", x_type, "batch_std type", batch_std_type) - validator.check_typename( - "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, + "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return x_type @@ -491,18 +433,13 @@ class BatchNormFold2Grad(PrimitiveWithInfer): def infer_shape(self, dout_shape, x_shape, gamma_shape, batch_std_shape, batch_mean_shape, running_std_shape, running_mean_shape, global_step_shape): - validator.check("batch_std shape", batch_std_shape, - "batch_mean shape", batch_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "running_std shape", running_std_shape) - validator.check("batch_std shape", batch_std_shape, - "running_mean shape", running_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "gamma shape", gamma_shape) - validator.check( - "batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) + validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel], + Rel.EQ, self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape def infer_dtype(self, dout_type, x_type, gamma_type, @@ -518,8 +455,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer): "running_mean type", running_mean_type) validator.check("batch_std_type", batch_std_type, "dout type", dout_type) - validator.check_typename( - "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, + "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2219e3bb50..21dbf81730 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -745,7 +745,7 @@ class Fill(PrimitiveWithInfer): out = { 'value': Tensor(ret), 'shape': dims['value'], - 'dtype': x_nptype, + 'dtype': x_dtype, } return out @@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis) + _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) self.add_prim_attr('inputNums', len(x_shp)) ret_shp = x_shp[0].copy() @@ -1329,7 +1329,7 @@ class Concat(PrimitiveWithInfer): def _get_pack_shape(x_shape, x_type, axis): """for pack output shape""" - validator.check_type("shape", x_shape, [tuple]) + validator.check_type("shape", x_shape, [tuple, list]) validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) validator.check_subclass("shape0", x_type[0], mstype.tensor) validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index a5a4c9f236..fbad5b49d3 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -15,7 +15,8 @@ """comm_ops""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ...communication.management import get_rank, get_group_size, GlobalComm, get_group from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register @@ -148,12 +149,10 @@ class AllGather(PrimitiveWithInfer): @prim_attr_register def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): - if not isinstance(get_group(group), str): - raise TypeError("The group of AllGather should be str.") + validator.check_value_type('group', get_group(group), (str,), self.name) self.rank = get_rank(get_group(group)) self.rank_size = get_group_size(get_group(group)) - if self.rank >= self.rank_size: - raise ValueError("The rank of AllGather should be less than the rank_size.") + validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) self.add_prim_attr('rank_size', self.rank_size) self.add_prim_attr('group', get_group(group)) @@ -163,7 +162,7 @@ class AllGather(PrimitiveWithInfer): def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("AllGather does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -205,10 +204,8 @@ class ReduceScatter(PrimitiveWithInfer): @prim_attr_register def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): - if not isinstance(op, type(ReduceOp.SUM)): - raise TypeError("The operation of ReduceScatter should be {}.".format(type(ReduceOp.SUM))) - if not isinstance(get_group(group), str): - raise TypeError("The group of ReduceScatter should be str.") + validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) + validator.check_value_type('group', get_group(group), (str,), self.name) self.op = op self.rank_size = get_group_size(get_group(group)) self.add_prim_attr('rank_size', self.rank_size) @@ -216,13 +213,13 @@ class ReduceScatter(PrimitiveWithInfer): def infer_shape(self, x_shape): if x_shape[0] % self.rank_size != 0: - raise ValueError("The first dimension of x should be divided by rank_size.") + raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.") x_shape[0] = int(x_shape[0]/self.rank_size) return x_shape def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -270,10 +267,8 @@ class Broadcast(PrimitiveWithInfer): @prim_attr_register def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): - if not isinstance(root_rank, int): - raise TypeError("The root_rank of Broadcast should be int.") - if not isinstance(get_group(group), str): - raise TypeError("The group of Broadcast should be str.") + validator.check_value_type('root_rank', root_rank, (int,), self.name) + validator.check_value_type('group', get_group(group), (str,), self.name) self.add_prim_attr('group', get_group(group)) def infer_shape(self, x_shape): @@ -281,7 +276,7 @@ class Broadcast(PrimitiveWithInfer): def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("Broadcast does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype @@ -311,8 +306,7 @@ class _AlltoAll(PrimitiveWithInfer): @prim_attr_register def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): """init AlltoAll""" - if not isinstance(get_group(group), str): - raise TypeError("The group of AllGather should be str.") + validator.check_value_type('group', get_group(group), (str,), self.name) self.split_count = split_count self.split_dim = split_dim self.concat_dim = concat_dim @@ -325,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer): def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -420,6 +414,6 @@ class _GetTensorSlice(PrimitiveWithInfer): def infer_value(self, x, dev_mat, tensor_map): from mindspore.parallel._tensor import _load_tensor - validator.check_type("dev_mat", dev_mat, [tuple]) - validator.check_type("tensor_map", tensor_map, [tuple]) + validator.check_value_type("dev_mat", dev_mat, [tuple], self.name) + validator.check_value_type("tensor_map", tensor_map, [tuple], self.name) return _load_tensor(x, dev_mat, tensor_map) diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index ca161cfad0..9743f9e3fd 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -16,7 +16,8 @@ """control_ops""" from ...common import dtype as mstype -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register @@ -123,11 +124,11 @@ class GeSwitch(PrimitiveWithInfer): raise NotImplementedError def infer_shape(self, data, pred): - validator.check_scalar_shape_input("pred", pred) + validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name) return (data, data) def infer_dtype(self, data_type, pred_type): - validator.check_type("pred", pred_type, [type(mstype.bool_)]) + validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) return (data_type, data_type) diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 1d8fdedc26..21c9c519b9 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -14,7 +14,7 @@ # ============================================================================ """debug_ops""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ...common import dtype as mstype from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer @@ -210,10 +210,14 @@ class Print(PrimitiveWithInfer): def __init__(self): pass + def __call__(self, *args): + for arg in args: + print(arg) + def infer_shape(self, *inputs): return [1] def infer_dtype(self, *inputs): for dtype in inputs: - validator.check_subclass("input", dtype, (mstype.tensor, mstype.string)) + validator.check_subclass("input", dtype, (mstype.tensor, mstype.string), self.name) return mstype.int32 diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 98665dd27a..8de4108435 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -24,7 +24,7 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor from .._utils import _get_broadcast_shape -from ..primitive import PrimitiveWithInfer, prim_attr_register +from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op def _infer_shape_reduce(x, axis, keep_dims, prim_name): @@ -225,6 +225,11 @@ class _Reduce(PrimitiveWithInfer): validator.check_value_type('keep_dims', keep_dims, [bool], self.name) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) + def __call__(self, x, axis=()): + args = [x, axis] + output = _run_op(self, self.name, args) + return output + def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): axis_v = axis['value'] input_shp = input_x['shape'] @@ -768,8 +773,8 @@ class Mul(_MathBinaryOp): Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'. Examples: - >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) - >>> input_y = Tensor(np.array([4, 5, 6]), mindspore.int32) + >>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) + >>> input_y = Tensor(np.array([4.0, 5.0, 6.0]), mindspore.float32) >>> mul = P.Mul() >>> mul(input_x, input_y) [4, 10, 18] @@ -1002,6 +1007,36 @@ class Log(PrimitiveWithInfer): return x +class Erf(PrimitiveWithInfer): + r""" + Computes the Gauss error function of `input_x` element-wise. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, has the same shape and dtype as the `input_x`. + + Examples: + >>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32) + >>> erf = P.Erf() + >>> erf(input_x) + [-0.8427168, 0., 0.8427168, 0.99530876, 0.99997765] + """ + + @prim_attr_register + def __init__(self): + """init Erf""" + self.init_prim_io_names(inputs=['x'], outputs=['y']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) + return x_type + + class Minimum(_MathBinaryOp): """ Computes the element-wise minimum of input tensors. @@ -1490,6 +1525,7 @@ class LogicalNot(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init LogicalNot""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) def infer_shape(self, x_shape): return x_shape diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9827975fd0..9750549dc5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -24,12 +24,40 @@ import numpy as np from ... import context from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_bool, check_int_positive +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register +from ..operations.math_ops import _infer_shape_reduce +def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): + """ + Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements. + """ + def _raise_message(): + raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two " + f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}") + def _get_return_value(): + if isinstance(arg_value, int): + ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value) + elif len(arg_value) == 2: + ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value + elif len(arg_value) == 4: + if not allow_four: + _raise_message() + ret = arg_value if ret_four else (arg_value[2], arg_value[3]) + else: + _raise_message() + return ret + validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) + ret_value = _get_return_value() + for item in ret_value: + if isinstance(item, int) and item > 0: + continue + _raise_message() + return ret_value + class Flatten(PrimitiveWithInfer): r""" Flattens a tensor without changing its batch size on the 0-th axis. @@ -53,12 +81,12 @@ class Flatten(PrimitiveWithInfer): pass def infer_shape(self, input_x): - validator.check('input_x rank', len(input_x), '', 1, Rel.GE) + validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:]) return input_x[0], prod def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) + validator.check_subclass("input_x", input_x, mstype.tensor, self.name) return input_x @@ -88,21 +116,21 @@ class Softmax(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_type("axis", axis, [int, tuple]) + validator.check_value_type("axis", axis, [int, tuple], self.name) if isinstance(axis, int): self.add_prim_attr('axis', (axis,)) for item in self.axis: - validator.check_type("item of axis", item, [int]) + validator.check_value_type("item of axis", item, [int], self.name) def infer_shape(self, logits): - validator.check_shape_length("axis shape", len(self.axis), 1, Rel.GE) + validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) rank = len(logits) for axis_v in self.axis: - validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT) + validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -131,15 +159,15 @@ class LogSoftmax(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def infer_shape(self, logits): rank = len(logits) - validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH) + validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name) return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -171,8 +199,7 @@ class ReLU(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, mstype.number_type) + validator.check_tensor_type_same({'input_x': input_x}, mstype.number_type, self.name) return input_x @@ -203,11 +230,66 @@ class ReLU6(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({'input_x': input_x}, (mstype.float16, mstype.float32), self.name) return input_x +class ReLUV2(PrimitiveWithInfer): + r""" + Computes ReLU(Rectified Linear Unit) of input tensor element-wise. + + It returns :math:`\max(x,\ 0)` element-wise. + + Inputs: + - **input_x** (Tensor) - The input tensor should be a 4-D tensor. + + Outputs: + - **output** (Tensor) - Has the same type and shape as the `input_x`. + - **mask** (Tensor) - A tensor whose data type must be uint8. + + Examples: + >>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32) + >>> relu_v2 = P.ReLUV2() + >>> output = relu_v2(input_x) + ([[[[1., 0.], [0., 4.]], [[0., 6.], [7., 0.]]]], + [[[[1, 0], [2, 0]], [[2, 0], [1, 0]]]]) + """ + @prim_attr_register + def __init__(self): + """init ReLUV2""" + self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask']) + + def __infer__(self, input_x): + input_shape = list(input_x['shape']) + input_dtype = input_x['dtype'] + mask_shape = [] + if len(input_shape) != 4: + raise ValueError("The `input_x` should be a 4-D tensor, " + f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}") + for i in enumerate(input_shape): + if i[0] == 1: + if input_dtype == mstype.uint8 and input_dtype == mstype.int8: + mask_shape.append((input_shape[1] + 31) // 32) + else: + mask_shape.append((input_shape[1] + 15) // 16) + else: + mask_shape.append(i[1]) + if input_dtype == mstype.uint8 and input_dtype == mstype.int8: + mask_shape.append(4) + else: + mask_shape.append(2) + + output_shape = (input_x['shape'], mask_shape) + validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name) + validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name) + mask_dtype = mstype.uint8 + output_dtype = (input_dtype, mask_dtype) + + return {'shape': output_shape, + 'dtype': output_dtype, + 'value': None} + + class Elu(PrimitiveWithInfer): r""" Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise. @@ -233,14 +315,13 @@ class Elu(PrimitiveWithInfer): @prim_attr_register def __init__(self, alpha=1.0): """Init Elu""" - validator.check_type("alpha", alpha, [float]) + validator.check_value_type("alpha", alpha, [float], self.name) def infer_shape(self, input_x): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x_dtype", input_x, mstype.float_type) + validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) return input_x @@ -258,7 +339,7 @@ class HSwish(PrimitiveWithInfer): where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. Inputs: - - **input_data** (Tensor) - The input of Hswish. + - **input_data** (Tensor) - The input of HSwish. Outputs: Tensor, with the same type and shape as the `input_data`. @@ -272,8 +353,7 @@ class HSwish(PrimitiveWithInfer): return xshape def infer_dtype(self, x_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -305,8 +385,7 @@ class Sigmoid(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -339,8 +418,7 @@ class HSigmoid(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -370,7 +448,7 @@ class Tanh(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) + validator.check_subclass("input_x", input_x, mstype.tensor, self.name) return input_x @@ -418,9 +496,9 @@ class FusedBatchNorm(Primitive): def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) - self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) - self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH) + self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) + self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) + self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) class BatchNorm(PrimitiveWithInfer): @@ -459,38 +537,38 @@ class BatchNorm(PrimitiveWithInfer): - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`. - - **reserve_space_3** (Tensor) - Tensor of shape :math:`(C,)`. """ @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5): - self.is_training = validator.check_type('is_training', is_training, (bool,)) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) + validator.check_value_type('is_training', is_training, (bool,), self.name) + validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], - outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2', - 'reserve_space_3']) + outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) def infer_shape(self, input_x, scale, bias, mean, variance): - validator.check("BatchNorm scale shape length", len(scale), "1", 1, Rel.EQ) - validator.check("BatchNorm scale shape", scale, "BatchNorm bias shape", bias) - validator.check("BatchNorm scale shape", scale[0], "BatchNorm input_x shape[1]", input_x[1]) + validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) + validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) if not self.is_training: - validator.check("BatchNorm mean shape length", len(mean), "1", 1, Rel.EQ) - validator.check("BatchNorm mean shape", mean, "BatchNorm variance shape", variance) - validator.check("BatchNorm mean shape", mean, "BatchNorm scale shape", scale) - return (input_x, scale, scale, scale, scale, scale) + validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) + validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) + validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) + return (input_x, scale, scale, scale, scale) def infer_dtype(self, input_x, scale, bias, mean, variance): - args = {"BatchNorm scale type": scale, "BatchNorm bias type": bias} - args_moving = {"BatchNorm mean type": mean, "BatchNorm variance type": variance} - validator.check_typename("input_x", input_x, [mstype.float32, mstype.float16]) - validator.check_type_same(args, [mstype.float32, mstype.float16]) + validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) + args = {"scale": scale, "bias": bias} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + args_moving = {"mean": mean, "variance": variance} if self.is_training: - validator.check_type_same(args_moving, [mstype.float32, mstype.float16, None]) + valid_types = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None] + validator.check_type_same(args_moving, valid_types, self.name) else: - validator.check_type_same(args_moving, [mstype.float32, mstype.float16]) - return (input_x, scale, bias, input_x, input_x, input_x) + args_moving = {"mean": mean, "variance": variance} + validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) + return (input_x, scale, bias, input_x, input_x) class Conv2D(PrimitiveWithInfer): @@ -559,53 +637,28 @@ class Conv2D(PrimitiveWithInfer): group=1): """init Conv2D""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ - (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {stride}") + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) if self.pad_mode == 'pad': - validator.check_integer('pad', self.pad, 0, Rel.GE) + validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.group = validator.check_integer('group', group, 0, Rel.GT) + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - validator.check_param_equal("x_shape[1]", x_shape[1] // self.group, "w_shape[1]", w_shape[1]) - validator.check_param_equal('out_channel', self.out_channel, 'w_shape[0]', w_shape[0]) - validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) + validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -647,10 +700,9 @@ class Conv2D(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} - validator.check_subclass('input', x_dtype, mstype.tensor) - validator.check_subclass('weight', w_dtype, mstype.tensor) - validator.check_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) + args = {'x': x_dtype, 'w': w_dtype} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) return x_dtype @@ -697,49 +749,25 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): group=1): """init DepthwiseConv2dNative""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ - (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {stride}") + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (dilation, dilation) - if len(self.dilation) != 2 or (not isinstance(self.dilation[0], int)) or \ - (not isinstance(self.dilation[1], int)) or \ - self.dilation[0] < 1 or self.dilation[1] < 1: - raise ValueError(f"The \'dilation\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name) self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1])) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - if pad_mode not in ("same", "valid", "pad"): - raise ValueError(f"Attr pad_mode of DepthwiseConv2dNative Op not passed" - f"{pad_mode} not in valid, same, pad.") - self.pad_mode = pad_mode - self.mode = validator.check_integer("mode", mode, 3, Rel.EQ) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT) - self.group = validator.check_integer("group", group, 0, Rel.GT) - self.pad = pad + self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT, + self.name) + self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - validator.check_param_equal("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1]) - validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) + validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -772,9 +800,6 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): / stride_w h_out = math.floor(h_out) w_out = math.floor(w_out) - else: - raise ValueError(f"Attr pad_mode of DepthwiseConv2dNative Op not passed" - "{pad_mode} not in valid, same, pad.") self.pad_list = (pad_top, pad_bottom, pad_left, pad_right) self.add_prim_attr('pads', self.pad_list) @@ -784,8 +809,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} - validator.check_type_same(args, mstype.number_type) + args = {'x': x_dtype, 'w': w_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -805,48 +830,26 @@ class _Pool(PrimitiveWithInfer): @prim_attr_register def __init__(self, ksize=1, strides=1, padding="valid"): self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_type('ksize', ksize, [int, tuple]) - validator.check_type('strides', strides, [int, tuple]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + validator.check_value_type('ksize', ksize, [int, tuple], self.name) + validator.check_value_type('strides', strides, [int, tuple], self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax") if not self.is_maxpoolwithargmax: self.add_prim_attr('data_format', "NCHW") - if isinstance(ksize, int): - validator.check_integer("ksize", ksize, 1, Rel.GE) - self.ksize = (1, 1, ksize, ksize) - else: - if (len(ksize) != 2 or - (not isinstance(ksize[0], int)) or - (not isinstance(ksize[1], int)) or - ksize[0] <= 0 or - ksize[1] <= 0): - raise ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number or " - f"a tuple of two positive int numbers, but got {ksize}") - self.ksize = (1, 1, ksize[0], ksize[1]) + self.ksize = _check_positive_int_or_tuple("ksize", ksize, self.name, allow_four=False, ret_four=True) if self.is_maxpoolwithargmax: self.ksize = (1, self.ksize[-2], self.ksize[-1], 1) self.add_prim_attr("ksize", self.ksize) - if isinstance(strides, int): - validator.check_integer("strides", strides, 1, Rel.GE) - self.strides = (1, 1, strides, strides) - else: - if (len(strides) != 2 or - (not isinstance(strides[0], int)) or - (not isinstance(strides[1], int)) or - strides[0] <= 0 or - strides[1] <= 0): - raise ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number or " - f"a tuple of two positive int numbers, but got {strides}") - self.strides = (1, 1, strides[0], strides[1]) + self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True) if self.is_maxpoolwithargmax: self.strides = (1, self.strides[-2], self.strides[-1], 1) self.add_prim_attr("strides", self.strides) def infer_shape(self, x_shape): - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) batch, channel, input_h, input_w = x_shape if self.is_maxpoolwithargmax: _, kernel_h, kernel_w, _ = self.ksize @@ -861,18 +864,16 @@ class _Pool(PrimitiveWithInfer): elif self.padding == "SAME": out_h = math.ceil(input_h / stride_h) out_w = math.ceil(input_w / stride_w) - else: - raise ValueError(f"The padding of operator {self.name} should be a str and must be 'SAME' or 'VALID', " - f"but got {self.padding}.") out_shape = [batch, channel, out_h, out_w] for shape_value in out_shape: if shape_value <= 0: - raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.") + raise ValueError(f"For '{self.name}' The kernel size is not valid, " + f"please check it if is larger than data's shape size.") return out_shape def infer_dtype(self, x_dtype): - validator.check_subclass("input", x_dtype, mstype.tensor) + validator.check_subclass("input", x_dtype, mstype.tensor, self.name) return x_dtype @@ -913,6 +914,11 @@ class MaxPool(_Pool): Outputs: Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) + >>> maxpool_op = P.MaxPool(padding="VALID", ksize=2, strides=1) + >>> output_tensor = maxpool_op(input_tensor) """ @prim_attr_register @@ -959,6 +965,11 @@ class MaxPoolWithArgmax(_Pool): - **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. - **mask** (Tensor) - Max values' index represented by the mask. + + Examples: + >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) + >>> maxpool_arg_op = P.MaxPoolWithArgmax(padding="VALID", ksize=2, strides=1) + >>> output_tensor, argmax = maxpool_arg_op(input_tensor) """ def __init__(self, ksize=1, strides=1, padding="valid"): super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding) @@ -987,7 +998,7 @@ class MaxPoolWithArgmax(_Pool): def infer_dtype(self, x_dtype): out_dtype = x_dtype - validator.check_typename("x_type", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) argmax_dtype = mstype.uint16 return out_dtype, argmax_dtype @@ -1071,56 +1082,33 @@ class Conv2DBackpropInput(PrimitiveWithInfer): group=1): """init Conv2DBackpropInput""" self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - elif isinstance(stride, tuple) and len(stride) == 4: - self.stride = (stride[2], stride[3]) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {stride}") + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.add_prim_attr('stride', self.stride) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) - self.group = validator.check_integer('group', group, 0, Rel.GT) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.add_prim_attr('data_format', "NCHW") if pad_list: - self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list) + for x in pad_list: + validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name) + self.pad_list = pad_list def __infer__(self, doutput, w, x_size): x_size_v = x_size['value'] - validator.check_type('x_size', x_size_v, [tuple]) + validator.check_value_type('x_size', x_size_v, [tuple], self.name) for i, dim_len in enumerate(x_size_v): - validator.check_type("x_size[%d]" % i, dim_len, [int]) - validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype']) + validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) + args = {'doutput': doutput['dtype'], 'w': w['dtype']} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) # infer shape dout_shape = doutput['shape'] @@ -1173,16 +1161,15 @@ class BiasAdd(PrimitiveWithInfer): self.add_prim_attr('data_format', 'NCHW') def infer_shape(self, x_shape, b_shape): - if len(b_shape) != 1 or len(x_shape) < 2 or b_shape[0] != x_shape[1]: - raise ValueError("Input_x and bias shapes do not match", - "(require: rank of input_x must be at least 2, rank of bias must be 1, " - "input_x.dim[1] must equal bias.dim[0])," - " but got input_x shape {}, bias shape {}.".format(x_shape, b_shape)) + validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) + validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name) + validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, b_type): - args = {"input_x type": x_type, "bias type": b_type} - validator.check_type_same(args, (mstype.float16, mstype.float32, mstype.int8, mstype.int32)) + args = {"input_x": x_type, "bias": b_type} + valid_types = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) return x_type @@ -1215,22 +1202,21 @@ class TopK(PrimitiveWithInfer): @prim_attr_register def __init__(self, sorted=False): - validator.check_type("sorted", sorted, [bool]) + validator.check_value_type("sorted", sorted, [bool], self.name) self.init_prim_io_names(inputs=['input', 'k'], outputs=['values', 'indices']) def __infer__(self, input_x, k): + x_dtype = input_x['dtype'] + valid_types = (mstype.int32, mstype.float16, mstype.float32) + validator.check_tensor_type_same({'x': x_dtype}, valid_types, self.name) + k_v = k['value'] + validator.check_value_type('k', k_v, (int,), self.name) x_shape = list(input_x['shape']) ndim = len(x_shape) - 1 - k_v = k['value'] x_shape[ndim] = k_v - input_dtype = input_x['dtype'] - validator.check_typename("TopK input_dtype", - input_dtype, (mstype.float16, mstype.float32, mstype.int32)) - if not isinstance(k_v, int): - raise ValueError('The k must int.', k) return {'shape': (x_shape, x_shape), - 'dtype': (input_dtype, mstype.int32), + 'dtype': (x_dtype, mstype.int32), 'value': None} @@ -1260,16 +1246,14 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): pass def infer_shape(self, logits_shape, labels_shape): - validator.check_param_equal("SoftmaxCrossEntropyWithLogits logits_shape", logits_shape, - "SoftmaxCrossEntropyWithLogits labels_shape", labels_shape) + validator.check("logits_shape", logits_shape, "labels_shape", labels_shape, Rel.EQ, self.name) loss_shape = [logits_shape[0]] dlogits_shape = logits_shape return (loss_shape, dlogits_shape) def infer_dtype(self, logits_type, labels_type): - args = {"SoftmaxCrossEntropyWithLogits logits_type": logits_type, - "SoftmaxCrossEntropyWithLogits labels_type": labels_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + args = {"logits": logits_type, "labels": labels_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return (logits_type, logits_type) @@ -1308,18 +1292,15 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): self.add_prim_attr('sens', 1.0) def infer_shape(self, logits_shape, labels_shape): - validator.check_param_equal("SparseSoftmaxCrossEntropyWithLogits logits_shape", logits_shape[0], - "SparseSoftmaxCrossEntropyWithLogits labels_shape", labels_shape[0]) + validator.check("logits_shape[0]", logits_shape[0], "labels_shape[0]", labels_shape[0], Rel.EQ, self.name) loss_shape = [] if self.is_grad: return logits_shape return loss_shape def infer_dtype(self, logits_type, labels_type): - validator.check_typename("SparseSoftmaxCrossEntropyWithLogits logits_type", - logits_type, (mstype.float16, mstype.float32)) - validator.check_typename("SparseSoftmaxCrossEntropyWithLogits labels_type", - labels_type, (mstype.int32, mstype.int64)) + validator.check_tensor_type_same({"logits": logits_type}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"labels": labels_type}, (mstype.int32, mstype.int64), self.name) return logits_type @@ -1364,14 +1345,13 @@ class ApplyMomentum(PrimitiveWithInfer): return v_shape def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): + valid_types = [mstype.float16, mstype.float32, mstype.float64] if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: - validator.check_subclass("v_dtype", v_dtype, mstype.tensor) - validator.check_subclass("a_dtype", a_dtype, mstype.tensor) - validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64]) + validator.check_tensor_type_same({"v": v_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"a": a_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) return g_dtype @@ -1403,20 +1383,55 @@ class SmoothL1Loss(PrimitiveWithInfer): @prim_attr_register def __init__(self, sigma=1.0): - validator.check_type('sigma', sigma, [float]) - validator.check('sigma', sigma, '', 0, Rel.GT) + validator.check_value_type('sigma', sigma, [float], self.name) + validator.check('sigma', sigma, '', 0, Rel.GT, self.name) self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) def infer_shape(self, prediction, target): - validator.check_param_equal('prediction shape', prediction, 'target shape', target) + validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) return prediction def infer_dtype(self, prediction, target): args = {"prediction": prediction, "target": target} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return prediction +class L2Loss(PrimitiveWithInfer): + """ + Calculates half of the L2 norm of a tensor without using the `sqrt`. + + Set `input_x` as x and output as loss. + + .. math:: + loss = sum(x ** 2) / 2 + + Inputs: + - **input_x** (Tensor) - A input Tensor. + + Outputs: + Tensor. Has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor. + + Examples + >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16) + >>> l2_loss = P.L2Loss() + >>> l2_loss(input_x) + 7.0 + """ + @prim_attr_register + def __init__(self): + """init L2Loss""" + + def infer_shape(self, input_x): + loss_shape = [] + return loss_shape + + def infer_dtype(self, x_type): + validator.check_subclass("x_type", x_type, mstype.tensor, self.name) + validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name) + return x_type + + class SGD(PrimitiveWithInfer): """ Computes stochastic gradient descent (optionally with momentum). @@ -1446,29 +1461,30 @@ class SGD(PrimitiveWithInfer): @prim_attr_register def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False): - validator.check_type("nesterov", nesterov, [bool]) + validator.check_value_type("nesterov", nesterov, [bool], self.name) self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], outputs=['output']) def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, accum_shape, momentum_shape, stat_shape): - validator.check(f'parameters shape {parameters_shape}', len(parameters_shape), '', 0, Rel.GT) - validator.check(f'gradient shape {gradient_shape}', len(gradient_shape), '', 0, Rel.GE) - validator.check(f'learning rate shape {learning_rate_shape}', len(learning_rate_shape), '', 0, Rel.GE) - validator.check(f'accumulation shape {accum_shape}', len(accum_shape), '', 0, Rel.GT) - validator.check(f'momentum shape {momentum_shape}', len(momentum_shape), '', 0, Rel.GE) - validator.check(f'stat shape {stat_shape}', len(stat_shape), '', 0, Rel.GE) - validator.check("gradient shape", gradient_shape, "stat shape", stat_shape) + validator.check_integer(f'parameters rank', len(parameters_shape), 0, Rel.GT, self.name) + validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name) + validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name) + validator.check_integer(f'accumulation rank', len(accum_shape), 0, Rel.GT, self.name) + validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name) + validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name) + validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) return parameters_shape def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype): - validator.check_typename("parameters_dtype", parameters_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("gradient_dtype", gradient_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("learning_rate_dtype", learning_rate_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("accum_dtype", accum_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("momentum_dtype", momentum_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32]) + valid_types = [mstype.float16, mstype.float32] + validator.check_tensor_type_same({"parameters": parameters_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"gradient": gradient_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"learning_rate": learning_rate_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"accum": accum_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"momentum": momentum_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name) return parameters_dtype class ApplyRMSProp(PrimitiveWithInfer): @@ -1514,28 +1530,23 @@ class ApplyRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): - validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) - validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype, momentum_dtype, epsilon_dtype): - validator.check_subclass("var_dtype", var_dtype, mstype.tensor) - validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) - validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) - validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) - args = {"var_dtype": var_dtype, "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, - "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) - - args = {"learning_rate_dtype": learning_rate_dtype, "decay_dtype": decay_dtype, - 'momentum_dtype': momentum_dtype, "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + + args = {"learning_rate": learning_rate_dtype, "decay": decay_dtype, + 'momentum': momentum_dtype, "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) return var_dtype @@ -1587,30 +1598,25 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): - validator.check_param_equal("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape) - validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) - validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype): - validator.check_subclass("var_dtype", var_dtype, mstype.tensor) - validator.check_subclass("mean_gradient_dtype", mean_gradient_dtype, mstype.tensor) - validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) - validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) - validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) - args = {"var_dtype": var_dtype, "mean_gradient_dtype": mean_gradient_dtype, - "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) - - args = {"learning_rate_dtype": learning_rate_dtype, "rho_dtype": rho_dtype, 'momentum_dtype': momentum_dtype, - "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype, + "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + + args = {"learning_rate": learning_rate_dtype, "rho": rho_dtype, 'momentum': momentum_dtype, + "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) return var_dtype @@ -1622,7 +1628,7 @@ class LayerNorm(Primitive): `Layer Normalization `_. .. math:: - y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta + y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. @@ -1651,8 +1657,8 @@ class LayerNorm(Primitive): @prim_attr_register def __init__(self, begin_norm_axis=1, begin_params_axis=1): - validator.check_type('begin_norm_axis', begin_norm_axis, [int]) - validator.check_type('begin_params_axis', begin_params_axis, [int]) + validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) + validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) class L2Normalize(PrimitiveWithInfer): @@ -1679,16 +1685,16 @@ class L2Normalize(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0, epsilon=1e-4): - validator.check_type('axis', axis, [int]) - validator.check_type('epsilon', epsilon, [int, float]) + validator.check_value_type('axis', axis, [int], self.name) + validator.check_value_type('epsilon', epsilon, [int, float], self.name) def infer_shape(self, input_x): dim = len(input_x) - validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) + validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) return input_x def infer_dtype(self, input_x): - validator.check_subclass("x", input_x, mstype.tensor) + validator.check_subclass("x", input_x, mstype.tensor, self.name) return input_x @@ -1718,8 +1724,8 @@ class DropoutGenMask(Primitive): @prim_attr_register def __init__(self, Seed0=0, Seed1=0): self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) - validator.check_type("Seed0", Seed0, [int]) - validator.check_type("Seed1", Seed1, [int]) + validator.check_value_type("Seed0", Seed0, [int], self.name) + validator.check_value_type("Seed1", Seed1, [int], self.name) class DropoutDoMask(PrimitiveWithInfer): @@ -1759,7 +1765,7 @@ class DropoutDoMask(PrimitiveWithInfer): input_x_shape = input_x['shape'] mask_shape = mask['shape'] keep_prob_shape = keep_prob['shape'] - validator.check("keep_prob's dim", len(keep_prob_shape), '0(scalar)', 0) + validator.check("keep_prob's dim", len(keep_prob_shape), '0(scalar)', 0, Rel.EQ, self.name) size_x = reduce(lambda x, y: x * y, input_x_shape) if len(mask_shape) != 1: raise ValueError("DropoutDoMask mask shape should be 1-dimension.") @@ -1768,13 +1774,13 @@ class DropoutDoMask(PrimitiveWithInfer): raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:" "{input_x_shape}, mask shape: {mask_shape}.") - validator.check_typename("input_x type", input_x['dtype'], [mstype.float32, mstype.float16, mstype.int32]) - validator.check_typename("input_mask type", mask['dtype'], [mstype.uint8]) + validator.check_tensor_type_same({"input_x": input_x['dtype']}, [mstype.float32, mstype.float16, mstype.int32], + self.name) + validator.check_tensor_type_same({"input_mask": mask['dtype']}, [mstype.uint8], self.name) keep_prob_v = keep_prob['value'] if keep_prob_v is not None: - validator.check_const_input('keep_prob', keep_prob_v) - validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH) + validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name) out = {'shape': input_x_shape, 'dtype': input_x['dtype'], @@ -1858,23 +1864,20 @@ class OneHot(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output']) - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def __infer__(self, indices, depth, on_value, off_value): # check type - validator.check_subclass("indices", indices['dtype'], mstype.tensor) - validator.check_typename("indices", indices['dtype'], (mstype.int32,)) - validator.check_typename("depth", depth['dtype'], mstype.int_type) - validator.check_subclass("on_value", on_value['dtype'], mstype.tensor) - validator.check_subclass("off_value", off_value['dtype'], mstype.tensor) - args = {"on_value dtype": on_value['dtype'], "off_value dtype": off_value['dtype']} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"indices": indices['dtype']}, (mstype.int32,), self.name) + validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name) + args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) # check shape indices_shp = indices['shape'] - validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH) + validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) depth_val = depth['value'] - validator.check_integer("depth", depth_val, 0, Rel.GE) + validator.check_integer("depth", depth_val, 0, Rel.GE, self.name) # create new dimension at end if self.axis is -1 indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val) @@ -1919,8 +1922,7 @@ class Gelu(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -1953,10 +1955,10 @@ class GetNext(PrimitiveWithInfer): @prim_attr_register def __init__(self, types, shapes, output_num, shared_name): - validator.check_type("types", types, [list, tuple]) - validator.check_type("shapes", shapes, [list, tuple]) - validator.check("types length", len(types), "shapes length", len(shapes)) - validator.check_type("output_num", output_num, [int]) + validator.check_value_type("types", types, [list, tuple], self.name) + validator.check_value_type("shapes", shapes, [list, tuple], self.name) + validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name) + validator.check_value_type("output_num", output_num, [int], self.name) def infer_shape(self): return tuple(self.shapes) @@ -1997,24 +1999,22 @@ class PReLU(PrimitiveWithInfer): weight_dim = len(weight_shape) if weight_dim != 1: - raise ValueError(f'weight_dim must be 1, while weight_dim is {weight_dim}.') + raise ValueError(f'For \'{self.name}\' weight_dim must be 1, while weight_dim is {weight_dim}.') if input_x_dim == 1 and weight_shape[0] != 1: - raise ValueError(f'when input_x_dim is 1, weight_shape[0] must be 1, ' + raise ValueError(f'For \'{self.name}\' when input_x_dim is 1, weight_shape[0] must be 1, ' f'while weight_shape[0] is {weight_shape[0]}.') if input_x_dim != 1 and weight_shape[0] != input_x_shape[1] and weight_shape[0] != 1: - raise ValueError(f'channel of input_x and weight must be matched,' + raise ValueError(f'For \'{self.name}\' channel of input_x and weight must be matched,' f' while channel of input_x is {input_x_shape[1]},' f' weight_shape[0] is {weight_shape[0]}.') return input_x_shape def infer_dtype(self, input_x_dtype, weight_dtype): - validator.check_subclass("input_x_dtype", input_x_dtype, mstype.tensor) - validator.check_subclass("weight_dtype", weight_dtype, mstype.tensor) - validator.check_typename("input_x_dtype", input_x_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("weight_dtype", weight_dtype, (mstype.float16, mstype.float32)) + args = {"input_x": input_x_dtype, "weight": weight_dtype} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return input_x_dtype @@ -2027,13 +2027,13 @@ class LSTM(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer("input_size", input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -2042,19 +2042,16 @@ class LSTM(PrimitiveWithInfer): def infer_shape(self, x_shape, h_shape, c_shape, w_shape): # (batch, seq, feature) - validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) # h and c should be same shape - validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ) - validator.check_integer("h_shape", len(h_shape), len(c_shape), Rel.EQ) - validator.check_integer("h_shape", h_shape[0], c_shape[0], Rel.EQ) - validator.check_integer("h_shape", h_shape[1], c_shape[1], Rel.EQ) - validator.check_integer("h_shape", h_shape[2], c_shape[2], Rel.EQ) + validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name) + validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) # (num_layers * num_directions, batch, hidden_size) - validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ) - validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ) - validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ) + validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) + validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name) + validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name) y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions) @@ -2064,13 +2061,8 @@ class LSTM(PrimitiveWithInfer): return (y_shape, h_shape, c_shape, reserved_shape, state_shape) def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): - validator.check_typename("x_dtype", x_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("h_dtype", h_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("c_dtype", c_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("w_dtype", w_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("datatype", x_dtype, (h_dtype.element_type(),)) - validator.check_typename("datatype", x_dtype, (c_dtype.element_type(),)) - validator.check_typename("datatype", x_dtype, (w_dtype.element_type(),)) + args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} + validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) @@ -2101,12 +2093,12 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer): self.init_prim_io_names(inputs=['predict', 'target'], outputs=['loss']) def infer_shape(self, x_shape, y_shape): - validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape) + validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_dtype, y_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -2150,7 +2142,7 @@ class Pad(PrimitiveWithInfer): def infer_shape(self, x): paddings = np.array(self.paddings) - validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ) + validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name) if not np.all(paddings >= 0): raise ValueError('All elements of paddings must be >= 0.') y_shape = () @@ -2159,7 +2151,7 @@ class Pad(PrimitiveWithInfer): return y_shape def infer_dtype(self, x): - validator.check_subclass("input_x", x, mstype.tensor) + validator.check_subclass("input_x", x, mstype.tensor, self.name) return x @@ -2210,16 +2202,16 @@ class MirrorPad(PrimitiveWithInfer): @prim_attr_register def __init__(self, mode='REFLECT'): """Init Pad""" - validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) self.mode = mode def __infer__(self, input_x, paddings): - validator.check_subclass("input_x", input_x['dtype'], mstype.tensor) - validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) + validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name) x_shape = list(input_x['shape']) paddings_value = paddings['value'].asnumpy() paddings_size = paddings_value.size - validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ) + validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name) if not np.all(paddings_size >= 0): raise ValueError('All elements of paddings must be >= 0.') y_shape = () @@ -2270,10 +2262,10 @@ class ROIAlign(PrimitiveWithInfer): @prim_attr_register def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2): """init ROIAlign""" - validator.check_type("pooled_height", pooled_height, [int]) - validator.check_type("pooled_width", pooled_width, [int]) - validator.check_type("spatial_scale", spatial_scale, [float]) - validator.check_type("sample_num", sample_num, [int]) + validator.check_value_type("pooled_height", pooled_height, [int], self.name) + validator.check_value_type("pooled_width", pooled_width, [int], self.name) + validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) + validator.check_value_type("sample_num", sample_num, [int], self.name) self.pooled_height = pooled_height self.pooled_width = pooled_width self.spatial_scale = spatial_scale @@ -2338,24 +2330,24 @@ class Adam(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False, use_nesterov=False): - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("use_nesterov", use_nesterov, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], self.name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name) def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape): - validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape) - validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape, m_shape, v_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): - args = {"var_dtype": var_dtype, "m_dtype": m_dtype, "v_dtype": v_dtype, "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) + args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) - args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype, - "beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, + "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) return var_dtype, m_dtype, v_dtype @@ -2397,12 +2389,12 @@ class BinaryCrossEntropy(PrimitiveWithInfer): @prim_attr_register def __init__(self, reduction='mean'): - self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum']) + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) def infer_shape(self, x_shape, y_shape, weight_shape): - validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape) + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) if weight_shape: - validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape) + validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name) if self.reduction in ('mean', 'sum'): shape = [] else: @@ -2410,10 +2402,11 @@ class BinaryCrossEntropy(PrimitiveWithInfer): return shape def infer_dtype(self, x_type, y_type, weight_type): - args = {'x_type': x_type, 'y_type': y_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + args = {'x': x_type, 'y': y_type} + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) if weight_type: - validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) + validator.check_tensor_type_same({'x': x_type, 'weight': weight_type}, valid_types, self.name) return x_type @@ -2445,27 +2438,22 @@ class SparseApplyAdagrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, lr, use_locking=False): - self.lr = validator.check_type("lr", lr, [float]) - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.lr = validator.check_value_type("lr", lr, [float], self.name) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): - validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape) - validator.check_param_equal('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape)) + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) if len(var_shape) > 1: - validator.check_param_equal('var_shape', var_shape[1:], 'grad_shape', grad_shape[1:]) - validator.check_integer("len of indices shape", len(indices_shape), 1, Rel.EQ) - validator.check('the first dimension of grad', grad_shape[0], - 'the shape of indices', indices_shape[0], Rel.EQ) + validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) return var_shape def infer_dtype(self, var_type, accum_type, grad_type, indices_type): - validator.check_subclass("var_type", var_type, mstype.tensor) - validator.check_subclass("accum_type", accum_type, mstype.tensor) - validator.check_subclass("grad_type", grad_type, mstype.tensor) - validator.check_subclass("indices_type", indices_type, mstype.tensor) - args = {'var_type': var_type, 'accum_type': accum_type, 'grad_type': grad_type} - validator.check_type_same(args, (mstype.float32,)) - validator.check_typename('indices_type', indices_type, [mstype.int32]) + args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} + validator.check_tensor_type_same(args, (mstype.float32,), self.name) + validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) return var_type @@ -2493,34 +2481,34 @@ class LARSUpdate(PrimitiveWithInfer): @prim_attr_register def __init__(self, epsilon=1e-05, hyperpara=0.001, use_clip=False): """init""" - validator.check_type("epsilon", epsilon, [float]) - validator.check_type("hyperpara", hyperpara, [float]) - validator.check_type("use_clip", use_clip, [bool]) + validator.check_value_type("epsilon", epsilon, [float], self.name) + validator.check_value_type("hyperpara", hyperpara, [float], self.name) + validator.check_value_type("use_clip", use_clip, [bool], self.name) def infer_shape(self, weight_shape, gradient_shape, norm_weight_shape, norm_gradient_shape, weight_decay_shape, learning_rate_shape): - validator.check_param_equal("Weight shape", weight_shape, "gradient shape", gradient_shape) - validator.check_param_equal("Norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape) + validator.check("weight shape", weight_shape, "gradient shape", gradient_shape, Rel.EQ, self.name) + validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ, + self.name) shp_len = len(weight_decay_shape) - validator.check_shape_length("Weight decay's shape", shp_len, 1, Rel.LE) + validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name) if shp_len == 1: - validator.check_integer("Weight decay's shape", weight_decay_shape[0], 1, Rel.EQ) + validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name) shp_len = len(learning_rate_shape) - validator.check_shape_length("Learning rate's shape", shp_len, 1, Rel.LE) + validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name) if shp_len == 1: - validator.check_integer("Learning rate's shape", learning_rate_shape[0], 1, Rel.EQ) + validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name) return weight_shape def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype, weight_decay_dtype, learning_rate_dtype): args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype, "norm gradient dtype": norm_gradient_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32]) - validator.check_args_tensor(args) - validator.check_typename("weight_decay_dtype", weight_decay_dtype, - [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("learning_rate_dtype", learning_rate_dtype, - [mstype.float16, mstype.float32, mstype.float64]) + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32], self.name) + validator.check_scalar_or_tensor_type_same({"weight_decay": weight_decay_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) + validator.check_scalar_or_tensor_type_same({"learning_rate": learning_rate_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) return weight_dtype @@ -2553,26 +2541,23 @@ class ApplyFtrl(PrimitiveWithInfer): def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], outputs=['output']) - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, lr_power_shape): - validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape) - validator.check_param_equal('var shape', var_shape, 'linear shape', linear_shape) + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): - validator.check_subclass("var_type", var_type, mstype.tensor) - validator.check_subclass("accum_type", accum_type, mstype.tensor) - validator.check_subclass("linear_type", linear_type, mstype.tensor) - validator.check_subclass("grad_type", grad_type, mstype.tensor) - args = {'var_type': var_type, 'accum_type': accum_type, 'linear_type': linear_type, 'grad_type': grad_type} - validator.check_type_same(args, (mstype.float32, mstype.float16)) - - validator.check_typename("lr", lr_type, [mstype.float16, mstype.float32]) - validator.check_typename("l1", l1_type, [mstype.float16, mstype.float32]) - validator.check_typename("l2", l2_type, [mstype.float16, mstype.float32]) - validator.check_typename("lr_power", lr_power_type, [mstype.float16, mstype.float32]) + valid_types = [mstype.float16, mstype.float32] + args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type} + validator.check_tensor_type_same(args, valid_types, self.name) + + validator.check_scalar_or_tensor_type_same({"lr": lr_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) return var_type @@ -2607,36 +2592,22 @@ class ExtractImagePatches(PrimitiveWithInfer): @prim_attr_register def __init__(self, ksizes, strides, rates, padding="valid"): """init""" - validator.check_type("ksizes", ksizes, [tuple, list]) - validator.check_type("strides", strides, [tuple, list]) - validator.check_type("rates", rates, [tuple, list]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + def _check_tuple_or_list(arg_name, arg_val, prim_name): + validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) + if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: + raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " + f"{arg_name}_col, 1], but got {arg_val}.") + if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: + raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " + f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " + f"is {arg_val[2]}") + + _check_tuple_or_list("ksize", ksizes, self.name) + _check_tuple_or_list("stride", strides, self.name) + _check_tuple_or_list("rate", rates, self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) - if len(ksizes) != 4 or ksizes[0] != 1 or ksizes[3] != 1: - raise ValueError("The format of ksizes should be [1, ksize_row, ksize_col, 1], " - f"but got {ksizes}.") - if not isinstance(ksizes[1], int) or not isinstance(ksizes[2], int) or \ - ksizes[1] < 1 or ksizes[2] < 1: - raise ValueError("The ksize_row and ksize_col in ksizes should be an positive integer number, " - f"but got ksize_row is {ksizes[1]}, ksize_col is {ksizes[2]}") - - if len(strides) != 4 or strides[0] != 1 or strides[3] != 1: - raise ValueError("The format of strides should be [1, stride_row, stride_col, 1], " - f"but got {strides}.") - if not isinstance(strides[1], int) or not isinstance(strides[2], int) or \ - strides[1] < 1 or strides[2] < 1: - raise ValueError("The stride_row and stride_col in strides should be an positive integer number, " - f"but got stride_row is {strides[1]}, stride_col is {strides[2]}") - - if len(rates) != 4 or rates[0] != 1 or rates[3] != 1: - raise ValueError("The format of rates should be [1, rate_row, rate_col, 1], " - f"but got {rates}.") - if not isinstance(rates[1], int) or not isinstance(rates[2], int) or \ - rates[1] < 1 or rates[2] < 1: - raise ValueError("The rate_row and rate_col in rates should be an positive integer number, " - f"but got rate_row is {rates[1]}, rate_col is {rates[2]}") - def infer_shape(self, input_x): in_batch, in_row, in_col, in_depth = input_x _, ksize_row, ksize_col, _ = self.ksizes @@ -2662,6 +2633,53 @@ class ExtractImagePatches(PrimitiveWithInfer): return out_shape def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x_dtype", input_x, (mstype.int8, mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) return input_x + + +class ConfusionMulGrad(PrimitiveWithInfer): + """ + `output0` is the result of which input0 dot multily input1. + + `output1` is the result of which input0 dot multily input1, then reducesum it. + + Args: + axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. + Default:(), reduce all dimensions. Only constant value is allowed. + keep_dims (bool): + - If true, keep these reduced dimensions and the length is 1. + - If false, don't keep these dimensions. Default:False. + + Inputs: + - **input_0** (Tensor) - The input Tensor. + - **input_1** (Tensor) - The input Tensor. + - **input_2** (Tensor) - The input Tensor. + + outputs: + - **output_0** (Tensor) - The same shape with `input0`. + - **output_1** (Tensor) + + - If axis is (), and keep_dims is false, the output is a 0-D array representing + the sum of all elements in the input array. + - If axis is int, set as 2, and keep_dims is false, + the shape of output is :math:`(x_1,x_3,...,x_R)`. + - If axis is tuple(int), set as (2,3), and keep_dims is false, + the shape of output is :math:`(x_1,x_4,...x_R)`. + """ + + @prim_attr_register + def __init__(self, axis = (), keep_dims = False): + self.init_prim_io_names(inputs = ["input0", "input1", "input2"], outputs = ["output0", "output1"]) + self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) + self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) + + def infer_shape(self, input0_shape, input1_shape, input2_shape): + outshape0 = input0_shape + outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) + return outshape0, outshape1 + + def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): + validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name) + validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) + validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) + return input0_dtype, input1_dtype diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 2ece6b7088..12a8a2cfde 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -16,7 +16,7 @@ """Other operators.""" from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind -from ..._checkparam import ParamValidator as validator, Rel +from ..._checkparam import Validator as validator, Rel from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register @@ -82,22 +82,21 @@ class BoundingBoxEncode(PrimitiveWithInfer): @prim_attr_register def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): - validator.check_type('means', means, [tuple]) - validator.check_type('stds', stds, [tuple]) - validator.check("means len", len(means), '', 4) - validator.check("stds len", len(stds), '', 4) + validator.check_value_type('means', means, [tuple], self.name) + validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) + validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) def infer_shape(self, anchor_box, groundtruth_box): - validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0]) - validator.check('anchor_box shape[1]', anchor_box[1], '', 4) - validator.check('groundtruth_box shape[1]', groundtruth_box[1], '', 4) + validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, + self.name) + validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) + validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name) return anchor_box def infer_dtype(self, anchor_box, groundtruth_box): - args = {"anchor_box": anchor_box, - "groundtruth_box": groundtruth_box - } - validator.check_type_same(args, mstype.number_type) + args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return anchor_box @@ -126,26 +125,24 @@ class BoundingBoxDecode(PrimitiveWithInfer): @prim_attr_register def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016): - validator.check_type('means', means, [tuple]) - validator.check_type('stds', stds, [tuple]) - validator.check_type('wh_ratio_clip', wh_ratio_clip, [float]) - validator.check("means", len(means), '', 4) - validator.check("stds", len(stds), '', 4) + validator.check_value_type('means', means, [tuple], self.name) + validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) + validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) + validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) if max_shape is not None: - validator.check_type('max_shape', max_shape, [tuple]) - validator.check("max_shape", len(max_shape), '', 2) + validator.check_value_type('max_shape', max_shape, [tuple], self.name) + validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name) def infer_shape(self, anchor_box, deltas): - validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0]) - validator.check('anchor_box shape[1]', anchor_box[1], '', 4) - validator.check('deltas shape[1]', deltas[1], '', 4) + validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) + validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) + validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name) return anchor_box def infer_dtype(self, anchor_box, deltas): - args = {"anchor_box": anchor_box, - "deltas": deltas - } - validator.check_type_same(args, mstype.number_type) + args = {"anchor_box": anchor_box, "deltas": deltas} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return anchor_box @@ -168,10 +165,10 @@ class CheckValid(PrimitiveWithInfer): self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output']) def infer_shape(self, bboxes_shape, metas_shape): - validator.check_shape_length("bboxes shape length", len(bboxes_shape), 2, Rel.EQ) - validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ) - validator.check_shape_length("img_metas shape length", len(metas_shape), 1, Rel.EQ) - validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ) + validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name) + validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name) + validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name) + validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name) return bboxes_shape[:-1] def infer_dtype(self, bboxes_type, metas_type): @@ -209,8 +206,8 @@ class IOU(PrimitiveWithInfer): Examples: >>> iou = P.IOU() - >>> anchor_boxes = Tensor(np.random.randint(1,5, [10, 4])) - >>> gt_boxes = Tensor(np.random.randint(1,5, [3, 4])) + >>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float32) + >>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float32) >>> iou(anchor_boxes, gt_boxes) """ @@ -221,18 +218,16 @@ class IOU(PrimitiveWithInfer): self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap']) def infer_shape(self, anchor_boxes, gt_boxes): - validator.check('gt_boxes shape[1]', gt_boxes[1], '', 4) - validator.check('anchor_boxes shape[1]', anchor_boxes[1], '', 4) - validator.check('anchor_boxes rank', len(anchor_boxes), '', 2) - validator.check('gt_boxes rank', len(gt_boxes), '', 2) + validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name) + validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name) + validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name) + validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name) iou = [gt_boxes[0], anchor_boxes[0]] return iou def infer_dtype(self, anchor_boxes, gt_boxes): - validator.check_subclass("anchor_boxes", anchor_boxes, mstype.tensor) - validator.check_subclass("gt_boxes", gt_boxes, mstype.tensor) args = {"anchor_boxes": anchor_boxes, "gt_boxes": gt_boxes} - validator.check_type_same(args, (mstype.float16,)) + validator.check_tensor_type_same(args, (mstype.float16,), self.name) return anchor_boxes @@ -270,7 +265,7 @@ class MakeRefKey(Primitive): @prim_attr_register def __init__(self, tag): - validator.check_type('tag', tag, (str,)) + validator.check_value_type('tag', tag, (str,), self.name) def __call__(self): pass diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 18c2212b3d..2692b43b46 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -15,7 +15,7 @@ """Operators for random.""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register @@ -52,16 +52,15 @@ class RandomChoiceWithMask(PrimitiveWithInfer): @prim_attr_register def __init__(self, count=256, seed=0, seed2=0): """Init RandomChoiceWithMask""" - validator.check_type("count", count, [int]) - validator.check_integer("count", count, 0, Rel.GT) - validator.check_type('seed', seed, [int]) - validator.check_type('seed2', seed2, [int]) + validator.check_value_type("count", count, [int], self.name) + validator.check_integer("count", count, 0, Rel.GT, self.name) + validator.check_value_type('seed', seed, [int], self.name) + validator.check_value_type('seed2', seed2, [int], self.name) def infer_shape(self, x_shape): - validator.check_shape_length("input_x shape", len(x_shape), 1, Rel.GE) + validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): - validator.check_subclass('x_dtype', x_dtype, mstype.tensor) - validator.check_typename('x_dtype', x_dtype, [mstype.bool_]) + validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) return (mstype.int32, mstype.bool_) diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index c99ac4a3c7..bf4b99085e 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -17,7 +17,7 @@ import threading import mindspore.context as context from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore._c_expression import AutoParallelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check class _AutoParallelContext: diff --git a/mindspore/parallel/_cost_model_context.py b/mindspore/parallel/_cost_model_context.py index 0920d66f41..54cca5516b 100644 --- a/mindspore/parallel/_cost_model_context.py +++ b/mindspore/parallel/_cost_model_context.py @@ -15,7 +15,7 @@ """Context of cost_model in auto_parallel""" import threading from mindspore._c_expression import CostModelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check class _CostModelContext: diff --git a/mindspore/parallel/algo_parameter_config.py b/mindspore/parallel/algo_parameter_config.py index d1e4aa87a9..244156da33 100644 --- a/mindspore/parallel/algo_parameter_config.py +++ b/mindspore/parallel/algo_parameter_config.py @@ -16,7 +16,7 @@ import threading from mindspore._c_expression import CostModelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 85b7629002..7bc07b126e 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -87,7 +87,7 @@ def _make_directory(path: str): # All exceptions need to be caught because create directory maybe have some limit(permissions) logger.debug("The directory(%s) doesn't exist, will create it", path) try: - os.makedirs(path) + os.makedirs(path, exist_ok=True) real_path = path except PermissionError as e: logger.error("No write permission on the directory(%r), error = %r", path, e) diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index c4c115ef27..917b4c3359 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -16,7 +16,7 @@ from easydict import EasyDict as edict from .. import nn -from .._checkparam import ParamValidator as validator +from .._checkparam import Validator as validator from .._checkparam import Rel from ..common import dtype as mstype from ..nn.wrap.cell_wrapper import _VirtualDatasetCell @@ -73,14 +73,14 @@ def _check_kwargs(key_words): raise ValueError(f"Unsupported arg '{arg}'") if 'cast_model_type' in key_words: - validator.check('cast_model_type', key_words['cast_model_type'], - [mstype.float16, mstype.float32], Rel.IN) + validator.check_type_name('cast_model_type', key_words['cast_model_type'], + [mstype.float16, mstype.float32], None) if 'keep_batchnorm_fp32' in key_words: - validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) + validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None) if 'loss_scale_manager' in key_words: loss_scale_manager = key_words['loss_scale_manager'] if loss_scale_manager: - validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) + validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None) def _add_loss_network(network, loss_fn, cast_model_type): @@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): label = _mp_cast_helper(mstype.float32, label) return self._loss_fn(F.cast(out, mstype.float32), label) - validator.check_isinstance('loss_fn', loss_fn, nn.Cell) + validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) if cast_model_type == mstype.float16: network = WithLossCell(network, loss_fn) else: @@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If set, overwrite the level setting. """ - validator.check_isinstance('network', network, nn.Cell) - validator.check_isinstance('optimizer', optimizer, nn.Optimizer) - validator.check('level', level, "", ['O0', 'O2'], Rel.IN) + validator.check_value_type('network', network, nn.Cell, None) + validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) + validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) _check_kwargs(kwargs) config = dict(_config_level[level], **kwargs) config = edict(config) @@ -151,7 +151,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): loss_scale = loss_scale_manager.get_loss_scale() update_cell = loss_scale_manager.get_update_cell() if update_cell is not None: - if not context.get_context("enable_ge"): + if not (context.get_context("enable_ge") or (context.get_context("device_target") == "GPU")): raise ValueError("Only `loss_scale_manager=None` and " "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" "are supported in current version. If you use `O2` option, please" diff --git a/mindspore/train/loss_scale_manager.py b/mindspore/train/loss_scale_manager.py index 5650c58f62..c8c28a72cb 100644 --- a/mindspore/train/loss_scale_manager.py +++ b/mindspore/train/loss_scale_manager.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """Loss scale manager abstract class.""" -from .._checkparam import ParamValidator as validator +from .._checkparam import Validator as validator from .._checkparam import Rel from .. import nn @@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager): if init_loss_scale < 1.0: raise ValueError("Loss scale value should be > 1") self.loss_scale = init_loss_scale - validator.check_integer("scale_window", scale_window, 0, Rel.GT) + validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__) self.scale_window = scale_window if scale_factor <= 0: raise ValueError("Scale factor should be > 1") diff --git a/mindspore/train/model.py b/mindspore/train/model.py index be3939d450..698105889a 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -62,6 +62,7 @@ class Model: loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. e.g. Use `loss_scale_manager=None` to set the value. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. Examples: >>> class Net(nn.Cell): @@ -96,7 +97,10 @@ class Model: self._optimizer = optimizer self._loss_scale_manager = None self._loss_scale_manager_set = False + self._keep_bn_fp32 = True self._check_kwargs(kwargs) + if 'keep_batchnorm_fp32' in kwargs: + self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] if 'loss_scale_manager' in kwargs: self._loss_scale_manager = kwargs['loss_scale_manager'] self._loss_scale_manager_set = True @@ -108,10 +112,11 @@ class Model: self._train_network = self._build_train_network() self._build_eval_network(metrics, eval_network, eval_indexes) + self._build_predict_network() def _check_kwargs(self, kwargs): for arg in kwargs: - if arg not in ['loss_scale_manager']: + if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: raise ValueError(f"Unsupport arg '{arg}'") def _build_train_network(self): @@ -123,12 +128,14 @@ class Model: self._optimizer, self._loss_fn, level=self._amp_level, - loss_scale_manager=self._loss_scale_manager) + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) else: network = amp.build_train_network(network, self._optimizer, self._loss_fn, - level=self._amp_level) + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None @@ -153,6 +160,12 @@ class Model: self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) self._eval_indexes = [0, 1, 2] + def _build_predict_network(self): + """Build the network for prediction.""" + self._predict_network = self._network + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + self._predict_network = _VirtualDatasetCell(self._network) + def _clear_metrics(self): """Clear metrics local values.""" for metric in self._metric_fns.values(): @@ -470,6 +483,7 @@ class Model: dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) for next_element in dataset_helper: + cb_params.cur_step_num += 1 list_callback.step_begin(run_context) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs @@ -549,12 +563,9 @@ class Model: >>> model = Model(Net()) >>> model.predict(input_data) """ - if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - self._network = _VirtualDatasetCell(self._network) - - self._network.set_train(False) + self._predict_network.set_train(False) check_input_data(*predict_data, data_class=Tensor) - result = self._network(*predict_data) + result = self._predict_network(*predict_data) check_output_data(result) return result diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index ae17bf8116..49cc5318fa 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -36,7 +36,6 @@ tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64, "Float16": np.float16, "Float32": np.float32, "Float64": np.float64} - def _special_process_par(par, new_par): """ Processes the special condition. @@ -182,8 +181,14 @@ def load_checkpoint(ckpoint_file_name, net=None): param_data = np.fromstring(data, np_type) dims = element.tensor.dims - if dims in [[0], [1]]: - parameter_dict[element.tag] = Parameter(param_data[0], name=element.tag) + if dims == [0]: + if 'Float' in data_type: + param_data = float(param_data[0]) + elif 'Int' in data_type: + param_data = int(param_data[0]) + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) + elif dims == [1]: + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: @@ -403,10 +408,11 @@ def _fill_param_into_net(net, parameter_list): for each_param in parameter_list: param_name = each_param["name"] np_val = each_param["data"].asnumpy() - if np_val.shape == (1,): # to scalar - parameter_dict[param_name] = Parameter(np_val[0], name=param_name) + if np_val.shape == (1,): + parameter_dict[param_name] = Parameter(np_val, name=param_name) elif np_val.shape == (): - parameter_dict[param_name] = Parameter(np_val.tolist(), name=param_name) + parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)), + name=param_name) else: parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name) diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index d96ac4773a..3dbe31f0e4 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -46,10 +46,14 @@ def _cache_summary_tensor_data(summary): class SummaryRecord: """ - Summary log record. - SummaryRecord is used to record the summary value. - The API will create an event file in a given directory and add summaries and events to it. + + Note: + The API will create an event file in a given directory and add summaries and events to it. + It writes the event log to a file by executing the record method. In addition, + if the SummaryRecord object is created and the summary operator is used in the network, + even if the record method is not called, the event in the cache will be written to the + file at the end of execution or when the summary is closed. Args: log_dir (str): The log_dir is a directory location to save the summary. diff --git a/predict/CMakeLists.txt b/predict/CMakeLists.txt index 2641932769..39ca6b27e8 100755 --- a/predict/CMakeLists.txt +++ b/predict/CMakeLists.txt @@ -6,6 +6,7 @@ set(CMAKE_BUILD_TYPE "Release") set(CMAKE_CXX_STANDARD 11) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=hidden") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s") option(ENABLE_ASAN "Enable Google Sanitizer to find memory bugs" OFF) option(ENABLE_PREDICT_ARM64 "predict arm64" OFF) diff --git a/predict/src/CMakeLists.txt b/predict/src/CMakeLists.txt index c32c047c82..92c45473d7 100644 --- a/predict/src/CMakeLists.txt +++ b/predict/src/CMakeLists.txt @@ -52,20 +52,6 @@ else() target_link_libraries(mspredict pthread tvm_kernel libsecurec.a) endif() -if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") - if(ENABLE_PREDICT_ARM64) - add_custom_command(TARGET mspredict POST_BUILD - COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip "${PREDICT_BUILD_DIR}/src/libmspredict.so" - COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip "${PREDICT_BUILD_DIR}/module/tvm_kernel/lite/libtvm_kernel.so" - ) - else() - add_custom_command(TARGET mspredict POST_BUILD - COMMAND strip "${PREDICT_BUILD_DIR}/src/libmspredict.so" - COMMAND strip "${PREDICT_BUILD_DIR}/module/tvm_kernel/lite/libtvm_kernel.so" - ) - endif() -endif() - add_dependencies(mspredict tvm_kernel) add_dependencies(mspredict securec) add_dependencies(mspredict gtest) diff --git a/tests/st/auto_parallel/onehot_model_parallel.py b/tests/st/auto_parallel/onehot_model_parallel.py index 14b351c0ee..1f35ac1f80 100644 --- a/tests/st/auto_parallel/onehot_model_parallel.py +++ b/tests/st/auto_parallel/onehot_model_parallel.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ import os import pytest @@ -26,6 +27,7 @@ device_num = 2 device_id = int(os.getenv('DEVICE_ID')) rank_id = 0 + def setup_module(): global device_num global rank_id @@ -42,9 +44,11 @@ def setup_module(): context.set_auto_parallel_context(device_num=device_num, global_rank=rank_id) + def teardown_module(): distributedTool.release() + class Onehot(Cell): def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, strategy=None): super(Onehot, self).__init__() @@ -56,25 +60,26 @@ class Onehot(Cell): self.on_value = Tensor(on_value, ms.float32) self.off_value = Tensor(off_value, ms.float32) self.transpose = P.Transpose().set_strategy(strategy=trans_stra) - self.sub = P.Sub().set_strategy(strategy=((1,1),(1,1))) + self.sub = P.Sub().set_strategy(strategy=((1, 1), (1, 1))) def construct(self, input, indices): x = self.onehot(indices, self.depth, self.on_value, self.off_value) - x = self.transpose(x, (1,0)) + x = self.transpose(x, (1, 0)) x = self.sub(input, x) return x + class DataGenerator(): def get_parallel_blocks(self, input_, strategy): blocks = [input_] i = 0 for stra in strategy: temp = [] - while len(blocks)>0: + while len(blocks) > 0: block = blocks.pop(0) temp.extend(np.split(block, stra, axis=i)) blocks.extend(temp) - i+=1 + i += 1 return blocks def generate_data(self, shape): @@ -93,32 +98,33 @@ class DataGenerator(): stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data),Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) + class OneHotFactory: def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None): dataGen = DataGenerator() self.input_full, self.input_part = dataGen.input_data((classes, batch_size)) - self.label_full, self.label_part = dataGen.label_data((batch_size,),classes) + self.label_full, self.label_part = dataGen.label_data((batch_size,), classes) self.depth = classes self.on_value = on_value self.off_value = off_value self.axis = axis self.strategy = strategy - + def forward_mindspore_single_impl(self): - net = Onehot(axis=self.axis, - depth=self.depth, - on_value=self.on_value, + net = Onehot(axis=self.axis, + depth=self.depth, + on_value=self.on_value, off_value=self.off_value) out = net(self.input_full, self.label_full) return out - + def forward_mindspore_parallel_impl(self): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - net = Onehot(axis=self.axis, - depth=self.depth, - on_value=self.on_value, + net = Onehot(axis=self.axis, + depth=self.depth, + on_value=self.on_value, off_value=self.off_value, strategy=self.strategy) out = net.compile_and_run(self.input_full, self.label_full) return out @@ -137,7 +143,7 @@ def test_reid_onehot_forward_int32_128_depth1024_model_parallel(): on_value=1.000000, off_value=0.000000, axis=-1, - strategy=((1,device_num),(),())) + strategy=((1, device_num), (), ())) fact.forward_cmp() @@ -147,5 +153,5 @@ def test_reid_onehot_forward_int32_1024_depth128_model_parallel(): on_value=1.000000, off_value=0.000000, axis=-1, - strategy=((1,device_num),(),())) + strategy=((1, device_num), (), ())) fact.forward_cmp() diff --git a/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py b/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py index 17dbe8f304..86a8b89521 100644 --- a/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py +++ b/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ import os import pytest @@ -31,7 +32,7 @@ from mindspore.nn.optim.momentum import Momentum from mindspore.train.callback import Callback np.set_printoptions(threshold=np.inf) -device_num=2 +device_num = 2 device_id = int(os.getenv('DEVICE_ID')) rank_id = 0 embed = 128 @@ -39,6 +40,7 @@ classes = 32 batch_size = 32*2 MatmulParamShape = (classes, embed) + def setup_module(): global device_num global rank_id @@ -55,26 +57,28 @@ def setup_module(): context.set_auto_parallel_context(device_num=device_num, global_rank=device_id) + def teardown_module(): distributedTool.release() + class DataGenerator(): def get_parallel_blocks(self, input_, strategy): blocks = [input_] i = 0 for stra in strategy: temp = [] - while len(blocks)>0: + while len(blocks) > 0: block = blocks.pop(0) temp.extend(np.split(block, stra, axis=i)) blocks.extend(temp) - i+=1 + i += 1 return blocks def generate_data(self, shape): size = np.cumprod(shape)[-1] num_range = min(size, 1000) - data = (np.arange(0, size)%num_range)/num_range + data = (np.arange(0, size) % num_range)/num_range data = np.reshape(data, shape) return data @@ -83,14 +87,15 @@ class DataGenerator(): stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data), Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) def label_data(self, shape, embed): data = (self.generate_data(shape)*(embed-1)).astype(np.int32) stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data),Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) + class Dataset(): def __init__(self, predict, label, length=1, input_num=2): @@ -121,15 +126,18 @@ class Dataset(): def get_repeat_count(self): return self.length + class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() self.loss_list = [] + def epoch_end(self, run_context, *args): cb_params = run_context.original_args() result = cb_params.net_outputs self.loss_list.append(result.asnumpy().mean()) + class SoftmaxCrossEntropyExpand(Cell): def __init__(self, sparse=False, stra_list=[]): super(SoftmaxCrossEntropyExpand, self).__init__() @@ -164,22 +172,25 @@ class SoftmaxCrossEntropyExpand(Cell): loss = self.reduce_mean(loss, -1) return loss + class MatmulNet(Cell): - def __init__(self, matmul_stra = None, loss_stra_list=[]): + def __init__(self, matmul_stra=None, loss_stra_list=[]): super(MatmulNet, self).__init__() self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy=matmul_stra) self.loss = SoftmaxCrossEntropyExpand(sparse=True, stra_list=loss_stra_list) - self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight") + self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight") + def construct(self, x, label): loss_input = self.matmul(x, self.weight) out = self.loss(loss_input, label) return out + class LossFactory(): def __init__(self): dataGen = DataGenerator() self.input_full, self.input_part = dataGen.input_data((batch_size, embed)) - self.label_full, self.label_part = dataGen.label_data((batch_size,),embed) + self.label_full, self.label_part = dataGen.label_data((batch_size,), embed) def single_matmul_trains(self): single_callback = ModelCallback() @@ -196,32 +207,33 @@ class LossFactory(): parallel_callback = ModelCallback() context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") net = MatmulNet() - optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, optimizer=optimizer) epoch_size = 6 dataset = Dataset(self.input_part, self.label_part) model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False) loss_value = np.array(parallel_callback.loss_list) return loss_value - + def model_parallel_matmul_trains(self): parallel_callback = ModelCallback() - matmul_stra = ((1,1),(device_num,1)) - reduce_max_stra = ((1,device_num),) - sub_stra = ((1,device_num),(1,1)) - exp_stra = ((1,device_num),) - reduce_sum_stra = ((1,device_num),) - div_stra = ((1,device_num),(1,1)) - log_stra = ((1,device_num),) - mul_stra = ((1,device_num),(1,device_num)) - sum_cross_entropy_stra = ((1,device_num),) - mul2_stra = ((),(device_num,)) + matmul_stra = ((1, 1), (device_num, 1)) + reduce_max_stra = ((1, device_num),) + sub_stra = ((1, device_num), (1, 1)) + exp_stra = ((1, device_num),) + reduce_sum_stra = ((1, device_num),) + div_stra = ((1, device_num), (1, 1)) + log_stra = ((1, device_num),) + mul_stra = ((1, device_num), (1, device_num)) + sum_cross_entropy_stra = ((1, device_num),) + mul2_stra = ((), (device_num,)) reduce_mean_stra = ((device_num,),) - onehot_stra = ((1,device_num),(),()) - loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] + onehot_stra = ((1, device_num), (), ()) + loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, + sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] context.set_auto_parallel_context(parallel_mode="auto_parallel") - net = MatmulNet(matmul_stra = matmul_stra, loss_stra_list = loss_stra_list) - optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, optimizer=optimizer) epoch_size = 6 dataset = Dataset(self.input_part, self.label_part) @@ -231,21 +243,22 @@ class LossFactory(): def mix_parallel_matmul_trains(self): parallel_callback = ModelCallback() - matmul_stra = ((device_num,1),(1,1)) - reduce_max_stra = ((1,device_num),) - sub_stra = ((device_num,1),(device_num,1)) - exp_stra = ((1,device_num),) - reduce_sum_stra = ((1,device_num),) - div_stra = ((1,device_num),(1,1)) - log_stra = ((1,device_num),) - mul_stra = ((1,device_num),(1,device_num)) - sum_cross_entropy_stra = ((1,device_num),) - mul2_stra = ((),(device_num,)) + matmul_stra = ((device_num, 1), (1, 1)) + reduce_max_stra = ((1, device_num),) + sub_stra = ((device_num, 1), (device_num, 1)) + exp_stra = ((1, device_num),) + reduce_sum_stra = ((1, device_num),) + div_stra = ((1, device_num), (1, 1)) + log_stra = ((1, device_num),) + mul_stra = ((1, device_num), (1, device_num)) + sum_cross_entropy_stra = ((1, device_num),) + mul2_stra = ((), (device_num,)) reduce_mean_stra = ((device_num,),) - onehot_stra = ((1,device_num),(),()) - loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] + onehot_stra = ((1, device_num), (), ()) + loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, + sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] context.set_auto_parallel_context(parallel_mode="auto_parallel") - net = MatmulNet(matmul_stra = matmul_stra, loss_stra_list = loss_stra_list) + net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, optimizer=optimizer) epoch_size = 6 @@ -254,6 +267,7 @@ class LossFactory(): loss_value = np.array(parallel_callback.loss_list) return loss_value + def test_all_trains(): loss_factory = LossFactory() context.reset_auto_parallel_context() diff --git a/tests/st/auto_parallel/test_expand_loss.py b/tests/st/auto_parallel/test_expand_loss.py index 786cbff980..e89ee5d3c8 100644 --- a/tests/st/auto_parallel/test_expand_loss.py +++ b/tests/st/auto_parallel/test_expand_loss.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - import os import pytest + @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training @@ -23,4 +23,4 @@ import pytest def test_expand_loss(): sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/run_auto_parallel_loss_expand.sh") - assert(ret==0) + assert(ret == 0) diff --git a/tests/st/auto_parallel/test_model_parallel_onehot.py b/tests/st/auto_parallel/test_model_parallel_onehot.py index 1df7ad1e99..55217421a4 100644 --- a/tests/st/auto_parallel/test_model_parallel_onehot.py +++ b/tests/st/auto_parallel/test_model_parallel_onehot.py @@ -16,6 +16,7 @@ import os import pytest + def test_expand_loss(): ret = os.system("sh run_onehot_model_parallel.sh") - assert(ret==0) + assert(ret == 0) diff --git a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py b/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py index 62711ccf6a..b28ad510e3 100644 --- a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py +++ b/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ +import os import numpy as np import pytest -from numpy import allclose +import mindspore.context as context import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore import Tensor @@ -22,21 +24,21 @@ from mindspore.ops import operations as P from mindspore.nn.optim.momentum import Momentum from mindspore.common.initializer import One from mindspore.train.model import Model, ParallelMode -from mindspore import context -import os from mindspore.communication.management import init import mindspore.ops.functional as F from mindspore.nn.loss.loss import _Loss from mindspore.train.callback import Callback from mindspore.parallel import set_algo_parameters + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_hccl=True) -context.set_context(enable_task_sink=True,device_id=int(os.getenv('DEVICE_ID'))) +context.set_context(enable_task_sink=True, device_id=int(os.getenv('DEVICE_ID'))) context.set_context(enable_ir_fusion=True) context.set_context(enable_loop_sink=False) init() context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) + def weight_variable(shape, factor=0.1): return One() @@ -52,6 +54,7 @@ def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) + def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): init_value = weight_variable((out_channels, in_channels, 7, 7)) return nn.Conv2d(in_channels, out_channels, @@ -63,6 +66,7 @@ def _fused_bn(channels, momentum=0.9): init_bias = weight_variable((channels,)) return nn.BatchNorm2d(channels, momentum=momentum) + class BasicBlock(nn.Cell): expansion = 1 @@ -172,7 +176,7 @@ class ResNet(nn.Cell): layer_nums, in_channels, out_channels, - strides=[1,2,2,2], + strides=[1, 2, 2, 2], num_classes=100): super(ResNet, self).__init__() @@ -292,17 +296,19 @@ class SoftmaxCrossEntropyExpand(_Loss): rank_id = int(os.environ["RANK_ID"]) device_num = int(os.environ["RANK_SIZE"]) + + class DataGenerator(): def get_parallel_blocks(self, input_, strategy): blocks = [input_] i = 0 for stra in strategy: temp = [] - while len(blocks)>0: + while len(blocks) > 0: block = blocks.pop(0) temp.extend(np.split(block, stra, axis=i)) blocks.extend(temp) - i+=1 + i += 1 return blocks def generate_data(self, shape): @@ -321,7 +327,7 @@ class DataGenerator(): stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data),Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) class Dataset(): @@ -359,6 +365,7 @@ class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() self.loss_list = [] + def epoch_end(self, run_context, *args): cb_params = run_context.original_args() result = cb_params.net_outputs @@ -382,7 +389,7 @@ def test_train_feed(num_classes=8192): model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) expect_out = [9.010913, 8.855984, 8.56246, 8.146317, 7.624489] - assert allclose(loss_value, expect_out, 0.0001, 0.0001) + assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) @pytest.mark.level0 @@ -402,4 +409,4 @@ def test_train_feed2(num_classes=1001): model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) expect_out = [6.908755, 6.8358116, 6.6986914, 6.506859, 6.2708097] - assert allclose(loss_value, expect_out, 0.0001, 0.0001) + assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) diff --git a/tests/st/control/test_while.py b/tests/st/control/test_while.py index 56b38f7f9a..6c659b6581 100644 --- a/tests/st/control/test_while.py +++ b/tests/st/control/test_while.py @@ -13,12 +13,12 @@ # limitations under the License. # ============================================================================ import numpy as np -from mindspore.common.tensor import Tensor -from mindspore.common import dtype as mstype import mindspore.context as context -from mindspore.ops import operations as P import mindspore.nn as nn -from mindspore.common import ms_function +from mindspore import Tensor, ms_function +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + @ms_function def t1_while(x, y, z): @@ -28,8 +28,9 @@ def t1_while(x, y, z): x = x + 3 return x + def test_net(): - context.set_context(mode=context.GRAPH_MODE,device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_task_sink=True) c1 = Tensor([2], mstype.int32) c2 = Tensor([14], mstype.int32) @@ -38,5 +39,6 @@ def test_net(): ret = t1_while(c1, c2, c3) assert (ret == expect) + if __name__ == "__main__": - test_net() \ No newline at end of file + test_net() diff --git a/tests/st/fusion/test_add_relu_buffer_fusion.py b/tests/st/fusion/test_add_relu_buffer_fusion.py index fbb0f73991..d63c8b355d 100644 --- a/tests/st/fusion/test_add_relu_buffer_fusion.py +++ b/tests/st/fusion/test_add_relu_buffer_fusion.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor, ms_function +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_id=5, device_target="Ascend") -#context.set_context(enable_task_sink=True) + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -35,17 +34,14 @@ class Net(nn.Cell): def construct(self, x, y): x = self.cast(x, mstype.float16) y = self.cast(y, mstype.float16) - #x = self.softmax(x) x = self.add(x, y) - #x = self.relu(x) x = self.relu(x) - #x = self.softmax(x) x = self.reduce_mean(x) return x + def test_net(): x = np.random.randn(32, 10).astype(np.float32) relu = Net() output = relu(Tensor(x), Tensor(x)) - print(x) print(output.asnumpy()) diff --git a/tests/st/fusion/test_conv_bn1_fusion.py b/tests/st/fusion/test_conv_bn1_fusion.py index 6149b9fd71..c3547ae1cf 100644 --- a/tests/st/fusion/test_conv_bn1_fusion.py +++ b/tests/st/fusion/test_conv_bn1_fusion.py @@ -13,15 +13,13 @@ # limitations under the License. # ============================================================================ import numpy as np +import mindspore.context as context import mindspore.nn as nn +from mindspore import Tensor, Parameter, Model, ms_function from mindspore.ops import operations as P from mindspore.common.initializer import initializer -from mindspore import Tensor, Parameter, Model from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.optim import Momentum -from mindspore.common.api import ms_function -import mindspore.nn as wrap -import mindspore.context as context context.set_context(device_target="Ascend", enable_task_sink=True) @@ -35,6 +33,7 @@ class MsWrapper(nn.Cell): def __init__(self, network): super(MsWrapper, self).__init__(auto_prefix=False) self._network = network + @ms_function def construct(self, *args): return self._network(*args) @@ -42,16 +41,16 @@ class MsWrapper(nn.Cell): def me_train_tensor(net, input_np, label_np, epoch_size=2): loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - opt = nn.Momentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), filter(lambda x: x.requires_grad, net.get_parameters())) + opt = nn.Momentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), + filter(lambda x: x.requires_grad, net.get_parameters())) context.set_context(mode=context.GRAPH_MODE) Model(net, loss, opt) - _network = wrap.WithLossCell(net, loss) - _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) + _network = nn.WithLossCell(net, loss) + _train_net = MsWrapper(nn.TrainOneStepCell(_network, opt)) _train_net.set_train() for epoch in range(0, epoch_size): - print(f"epoch %d"%(epoch)) + print(f"epoch %d" % (epoch)) output = _train_net(Tensor(input_np), Tensor(label_np)) - print("********output***********") print(output.asnumpy()) @@ -60,9 +59,9 @@ def test_conv_bn_add_relu_fusion(): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.conv1 = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.bn = nn.BatchNorm2d(output_channel, momentum=0.1, eps=0.0001) self.add = P.TensorAdd() self.relu = P.ReLU() @@ -91,7 +90,7 @@ def test_conv_bn_relu_fusion(): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.bn = nn.BatchNorm2d(output_channel, momentum=0.1, eps=0.0001) self.relu = P.ReLU() self.mean = P.ReduceMean(keep_dims=True) @@ -118,7 +117,7 @@ def test_conv_bn_fusion(): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.bn = nn.BatchNorm2d(output_channel, momentum=0.1, eps=0.0001) self.mean = P.ReduceMean(keep_dims=True) self.reshape = P.Reshape() diff --git a/tests/st/fusion/test_tbe_eltwise_fusion_1.py b/tests/st/fusion/test_tbe_eltwise_fusion_1.py index 0b9ae1fe63..5d4fd09d30 100644 --- a/tests/st/fusion/test_tbe_eltwise_fusion_1.py +++ b/tests/st/fusion/test_tbe_eltwise_fusion_1.py @@ -13,16 +13,15 @@ # limitations under the License. # ============================================================================ import pytest -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -35,6 +34,7 @@ class Net(nn.Cell): x = self.relu(x) return x + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -43,5 +43,4 @@ def test_net(): x = np.random.randn(32, 10).astype(np.float32) relu_relu = Net() output = relu_relu(Tensor(x)) - print(x) print(output.asnumpy()) diff --git a/tests/st/fusion/test_tbe_eltwise_fusion_2.py b/tests/st/fusion/test_tbe_eltwise_fusion_2.py index 8f6084ae5b..3ae754d385 100644 --- a/tests/st/fusion/test_tbe_eltwise_fusion_2.py +++ b/tests/st/fusion/test_tbe_eltwise_fusion_2.py @@ -13,16 +13,15 @@ # limitations under the License. # ============================================================================ import pytest -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -41,6 +40,7 @@ class Net(nn.Cell): x = self.relu(x) return x + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -50,5 +50,4 @@ def test_net(): y = np.random.randn(10).astype(np.float32) net = Net() output = net(Tensor(x), Tensor(y)) - print(x) - print(output.asnumpy()) \ No newline at end of file + print(output.asnumpy()) diff --git a/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py b/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py index 9a900a4739..76cf639da0 100644 --- a/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py +++ b/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_id=4, device_target="Ascend") -#context.set_context(enable_task_sink=True) + class Net(nn.Cell): def __init__(self): @@ -39,6 +38,7 @@ class Net(nn.Cell): z = self.add(z1, z2) return z + def test_net(): x = np.random.randn(32, 10).astype(np.float32) y = np.random.randn(32, 10).astype(np.float32) @@ -46,6 +46,4 @@ def test_net(): h = np.random.randn(10).astype(np.float32) relu_relu = Net() output = relu_relu(Tensor(x), Tensor(y), Tensor(k), Tensor(h)) - print(x) print(output.asnumpy()) - diff --git a/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py b/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py index 63b1cc542d..93c7174b52 100644 --- a/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py +++ b/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py @@ -13,17 +13,16 @@ # limitations under the License. # ============================================================================ import pytest +import numpy as np +import mindspore.context as context +import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -41,6 +40,7 @@ class Net(nn.Cell): x = self.relu(x) return x + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -49,5 +49,4 @@ def test_net(): x = np.random.randn(32, 10).astype(np.float32) net = Net() output = net(Tensor(x)) - print(x) - print(output.asnumpy()) \ No newline at end of file + print(output.asnumpy()) diff --git a/tests/st/mem_reuse/check_file.py b/tests/st/mem_reuse/check_file.py index 2f6fe82d2d..30b3b690a4 100644 --- a/tests/st/mem_reuse/check_file.py +++ b/tests/st/mem_reuse/check_file.py @@ -14,6 +14,7 @@ # ============================================================================ import os import filecmp + curr_path = os.path.abspath(os.curdir) file_memreuse = curr_path + "/mem_reuse_check/memreuse.ir" file_normal = curr_path + "/mem_reuse_check/normal_mem.ir" @@ -23,5 +24,3 @@ checker = os.path.exists(file_normal) assert (checker, True) checker = filecmp.cmp(file_memreuse, file_normal) assert (checker, True) - - diff --git a/tests/st/mem_reuse/resnet.py b/tests/st/mem_reuse/resnet.py index fb4075f0a4..1c1b880b57 100644 --- a/tests/st/mem_reuse/resnet.py +++ b/tests/st/mem_reuse/resnet.py @@ -19,6 +19,7 @@ from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common import dtype as mstype + def weight_variable(shape): return initializer('XavierUniform', shape=shape, dtype=mstype.float32) @@ -297,4 +298,3 @@ class ResNet(nn.Cell): def resnet50(batch_size, num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes, batch_size) - diff --git a/tests/st/mem_reuse/resnet_cifar_memreuse.py b/tests/st/mem_reuse/resnet_cifar_memreuse.py index 4edcdd8fb8..d6310612b6 100644 --- a/tests/st/mem_reuse/resnet_cifar_memreuse.py +++ b/tests/st/mem_reuse/resnet_cifar_memreuse.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import argparse +import os +import numpy as np +import mindspore.context as context import mindspore.nn as nn +import mindspore.common.dtype as mstype from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model, ParallelMode -from mindspore import context -import mindspore.common.dtype as mstype -import os -import numpy as np -import mindspore.ops.functional as F from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset as de @@ -30,11 +31,11 @@ import mindspore.dataset.transforms.vision.c_transforms as vision from mindspore.communication.management import init from resnet import resnet50 import random + random.seed(1) np.random.seed(1) de.config.set_seed(1) -import argparse parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--device_num', type=int, default=1, help='Device num.') @@ -47,9 +48,9 @@ parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoin parser.add_argument('--dataset_path', type=str, default="/var/log/npu/datasets/cifar", help='Dataset path') args_opt = parser.parse_args() -device_id=int(os.getenv('DEVICE_ID')) +device_id = int(os.getenv('DEVICE_ID')) -data_home=args_opt.dataset_path +data_home = args_opt.dataset_path context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_task_sink=True, device_id=device_id) @@ -64,8 +65,8 @@ def create_dataset(repeat_num=1, training=True): ds = de.Cifar10Dataset(data_dir) if args_opt.run_distribute: - rank_id=int(os.getenv('RANK_ID')) - rank_size=int(os.getenv('RANK_SIZE')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) ds = de.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) resize_height = 224 @@ -74,9 +75,9 @@ def create_dataset(repeat_num=1, training=True): shift = 0.0 # define map operations - random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() - resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR rescale_op = vision.Rescale(rescale, shift) normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) changeswap_op = vision.HWC2CHW() diff --git a/tests/st/mem_reuse/resnet_cifar_normal.py b/tests/st/mem_reuse/resnet_cifar_normal.py index 39f6e7fe59..2b6741e57a 100644 --- a/tests/st/mem_reuse/resnet_cifar_normal.py +++ b/tests/st/mem_reuse/resnet_cifar_normal.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import argparse +import os +import numpy as np +import mindspore.context as context import mindspore.nn as nn +import mindspore.common.dtype as mstype from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model, ParallelMode -from mindspore import context -import mindspore.common.dtype as mstype -import os -import numpy as np -import mindspore.ops.functional as F from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset as de @@ -35,7 +36,6 @@ random.seed(1) np.random.seed(1) de.config.set_seed(1) -import argparse parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') diff --git a/tests/st/nccl/test_nccl_all.py b/tests/st/nccl/test_nccl_all.py index 99494bb741..faa6394f9a 100644 --- a/tests/st/nccl/test_nccl_all.py +++ b/tests/st/nccl/test_nccl_all.py @@ -15,6 +15,7 @@ import os import pytest + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single @@ -22,6 +23,7 @@ def test_nccl_lenet(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_lenet.py") assert(return_code == 0) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single @@ -29,6 +31,7 @@ def test_nccl_all_reduce_op(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_all_reduce_op.py") assert(return_code == 0) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single @@ -36,6 +39,7 @@ def test_nccl_all_gather_op(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_all_gather_op.py") assert(return_code == 0) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single diff --git a/tests/st/nccl/test_nccl_all_gather_op.py b/tests/st/nccl/test_nccl_all_gather_op.py index f2a2c7133c..0a37a692da 100644 --- a/tests/st/nccl/test_nccl_all_gather_op.py +++ b/tests/st/nccl/test_nccl_all_gather_op.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn import numpy as np import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') init('nccl') rank = get_rank() size = get_group_size() -x = np.ones([1,1,3,3]).astype(np.float32) * 0.01 * (rank + 1) +x = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + class Net(nn.Cell): - def __init__( self): + def __init__(self): super(Net, self).__init__() self.all_gather = P.AllGather(group=NCCL_WORLD_COMM_GROUP) self.x = Parameter(initializer(Tensor(x), x.shape), name='x') @@ -36,6 +38,7 @@ class Net(nn.Cell): def construct(self): return self.all_gather(self.x) + def test_AllGather(): all_gather = Net() output = all_gather() diff --git a/tests/st/nccl/test_nccl_all_reduce_op.py b/tests/st/nccl/test_nccl_all_reduce_op.py index 7c2e579463..a1a732fd08 100644 --- a/tests/st/nccl/test_nccl_all_reduce_op.py +++ b/tests/st/nccl/test_nccl_all_reduce_op.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn import numpy as np import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size -context.set_context(mode=context.GRAPH_MODE, device_target='GPU', enable_dynamic_memory=False) + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') init('nccl') rank = get_rank() size = get_group_size() -x = np.ones([3,1,3,3]).astype(np.float32) * 0.01 * (rank + 1) +x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + class Net(nn.Cell): - def __init__( self): + def __init__(self): super(Net, self).__init__() self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') @@ -47,6 +49,7 @@ class Net(nn.Cell): self.all_reduce2(self.x2), self.all_reduce3(self.x3)) + def test_AllReduce(): all_reduce = Net() output = all_reduce() @@ -58,16 +61,16 @@ def test_AllReduce(): diff0 = output[0].asnumpy() - expect0 error0 = np.ones(shape=expect0.shape) * 1.0e-5 assert np.all(diff0 < error0) - assert (output[0].shape() == expect0.shape) + assert output[0].shape() == expect0.shape expect1 = expect0 diff1 = output[1].asnumpy() - expect1 error1 = np.ones(shape=expect1.shape) * 1.0e-5 assert np.all(diff1 < error1) - assert (output[1].shape() == expect1.shape) + assert output[1].shape() == expect1.shape expect2 = expect1 diff2 = output[2].asnumpy() - expect2 error2 = np.ones(shape=expect2.shape) * 1.0e-5 assert np.all(diff2 < error2) - assert (output[2].shape() == expect2.shape) + assert output[2].shape() == expect2.shape diff --git a/tests/st/nccl/test_nccl_lenet.py b/tests/st/nccl/test_nccl_lenet.py index 2aebc5da50..3880f1d473 100644 --- a/tests/st/nccl/test_nccl_lenet.py +++ b/tests/st/nccl/test_nccl_lenet.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import numpy as np -from mindspore.nn import Dense -import mindspore.nn as nn import datetime +import numpy as np import mindspore.context as context -from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size +import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops import operations as P -from mindspore.common.tensor import Tensor +from mindspore.communication.management import init, get_rank, get_group_size context.set_context(mode=context.GRAPH_MODE, device_target="GPU") init('nccl') @@ -31,6 +30,7 @@ total = 5000 batch_size = 32 mini_batch = total // batch_size + class LeNet(nn.Cell): def __init__(self): super(LeNet, self).__init__() @@ -43,15 +43,15 @@ class LeNet(nn.Cell): self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0) self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid") self.reshape = P.Reshape() - + weight1 = Tensor(np.ones([120, 400]).astype(np.float32) * 0.01) - self.fc1 = Dense(400, 120, weight_init=weight1) - + self.fc1 = nn.Dense(400, 120, weight_init=weight1) + weight2 = Tensor(np.ones([84, 120]).astype(np.float32) * 0.01) - self.fc2 = Dense(120, 84, weight_init=weight2) - + self.fc2 = nn.Dense(120, 84, weight_init=weight2) + weight3 = Tensor(np.ones([10, 84]).astype(np.float32) * 0.01) - self.fc3 = Dense(84, 10, weight_init=weight3) + self.fc3 = nn.Dense(84, 10, weight_init=weight3) def construct(self, input_x): output = self.conv1(input_x) @@ -66,6 +66,7 @@ class LeNet(nn.Cell): output = self.fc3(output) return output + def test_lenet_nccl(): net = LeNet() net.set_train() diff --git a/tests/st/nccl/test_nccl_reduce_scatter_op.py b/tests/st/nccl/test_nccl_reduce_scatter_op.py index 32c1f31788..f3322d07a3 100644 --- a/tests/st/nccl/test_nccl_reduce_scatter_op.py +++ b/tests/st/nccl/test_nccl_reduce_scatter_op.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn import numpy as np import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size @@ -27,8 +27,9 @@ rank = get_rank() size = get_group_size() x = np.ones([size, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + class Net(nn.Cell): - def __init__( self): + def __init__(self): super(Net, self).__init__() self.x = Parameter(initializer(Tensor(x), x.shape), name='x') @@ -46,6 +47,7 @@ class Net(nn.Cell): self.reduce_scatter2(self.x), self.reduce_scatter3(self.x)) + def test_ReduceScatter(): reduce_scatter = Net() output = reduce_scatter() @@ -53,7 +55,7 @@ def test_ReduceScatter(): sum = np.ones([size, 1, 3, 3]).astype(np.float32) * 0 for i in range(size): sum += np.ones([size, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1) - expect0 = sum[rank : rank + 1] + expect0 = sum[rank: rank + 1] diff0 = output[0].asnumpy() - expect0 error0 = np.ones(shape=expect0.shape) * 1.0e-5 assert np.all(diff0 < error0) diff --git a/tests/st/networks/models/alexnet.py b/tests/st/networks/models/alexnet.py index 4c8981f04a..f74d09353c 100644 --- a/tests/st/networks/models/alexnet.py +++ b/tests/st/networks/models/alexnet.py @@ -16,6 +16,7 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.nn import Dense + class AlexNet(nn.Cell): def __init__(self, num_classes=10): super(AlexNet, self).__init__() diff --git a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py index 5b6268505b..7d30592044 100644 --- a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py @@ -18,21 +18,22 @@ import os import pytest import numpy as np -from numpy import allclose +import mindspore.context as context import mindspore.common.dtype as mstype import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C -from mindspore import context -from mindspore.common.tensor import Tensor +from mindspore import Tensor from mindspore.train.model import Model from mindspore.train.callback import Callback from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell from mindspore.nn.optim import Momentum from mindspore import log as logger + _current_dir = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" + def get_config(version='base', batch_size=1): """get config""" if version == 'base': @@ -99,13 +100,14 @@ def get_config(version='base', batch_size=1): bert_config = BertConfig(batch_size=batch_size) return bert_config + def me_de_train_dataset(): """test me de train dataset""" # apply repeat operations repeat_count = 1 ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", - "next_sentence_labels", "masked_lm_positions", - "masked_lm_ids", "masked_lm_weights"], shuffle=False) + "next_sentence_labels", "masked_lm_positions", + "masked_lm_ids", "masked_lm_weights"], shuffle=False) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) @@ -137,6 +139,7 @@ class ModelCallback(Callback): self.loss_list.append(cb_params.net_outputs.asnumpy()[0]) logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -180,7 +183,8 @@ def test_bert_tdt(): expect_out = [12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289, 12.165594, 12.824818, 12.38842, 12.604046] logger.info("expected loss value output: {}".format(expect_out)) - assert allclose(loss_value, expect_out, 0.00001, 0.00001) + assert np.allclose(loss_value, expect_out, 0.00001, 0.00001) + if __name__ == '__main__': test_bert_tdt() diff --git a/tests/st/networks/models/lenet.py b/tests/st/networks/models/lenet.py index 9df91822f7..8f6b969cd7 100644 --- a/tests/st/networks/models/lenet.py +++ b/tests/st/networks/models/lenet.py @@ -14,9 +14,10 @@ # ============================================================================ import numpy as np import mindspore.nn as nn +from mindspore import Tensor from mindspore.ops import operations as P from mindspore.nn import Dense -from mindspore import Tensor + class LeNet(nn.Cell): def __init__(self): diff --git a/tests/st/networks/models/resnetv1_5.py b/tests/st/networks/models/resnetv1_5.py index 855aec7014..604389547e 100644 --- a/tests/st/networks/models/resnetv1_5.py +++ b/tests/st/networks/models/resnetv1_5.py @@ -13,9 +13,10 @@ # limitations under the License. # ============================================================================ import numpy as np -from mindspore.common.tensor import Tensor import mindspore.nn as nn -import mindspore.ops.operations as P +from mindspore import Tensor +from mindspore.ops import operations as P + def weight_variable(shape): ones = np.ones(shape).astype(np.float32) @@ -37,7 +38,7 @@ def conv3x3(in_channels, out_channels, stride=1, padding=0): weight_shape = (out_channels, in_channels, 3, 3) weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") def conv1x1(in_channels, out_channels, stride=1, padding=0): @@ -45,7 +46,7 @@ def conv1x1(in_channels, out_channels, stride=1, padding=0): weight_shape = (out_channels, in_channels, 1, 1) weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") def conv7x7(in_channels, out_channels, stride=1, padding=0): @@ -53,7 +54,7 @@ def conv7x7(in_channels, out_channels, stride=1, padding=0): weight_shape = (out_channels, in_channels, 7, 7) weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") def bn_with_initialize(out_channels): @@ -63,7 +64,7 @@ def bn_with_initialize(out_channels): beta = weight_variable_0(shape) gamma = weight_variable_1(shape) bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, - beta_init=beta, moving_mean_init=mean, moving_var_init=var) + beta_init=beta, moving_mean_init=mean, moving_var_init=var) return bn @@ -74,7 +75,7 @@ def bn_with_initialize_last(out_channels): beta = weight_variable_0(shape) gamma = weight_variable_0(shape) bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, - beta_init=beta, moving_mean_init=mean, moving_var_init=var) + beta_init=beta, moving_mean_init=mean, moving_var_init=var) return bn @@ -294,6 +295,6 @@ class ResNet(nn.Cell): x = self.fc(x) return x + def resnet50(batch_size, num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes, batch_size) - diff --git a/tests/st/networks/test_cpu_lenet.py b/tests/st/networks/test_cpu_lenet.py index 9fd50f5d9b..7101e29aa9 100644 --- a/tests/st/networks/test_cpu_lenet.py +++ b/tests/st/networks/test_cpu_lenet.py @@ -13,13 +13,15 @@ # limitations under the License. # ============================================================================ import pytest -from mindspore.nn import TrainOneStepCell, WithLossCell -import mindspore.context as context -from mindspore.nn.optim import Momentum import numpy as np +import mindspore.context as context import mindspore.nn as nn -from mindspore.ops import operations as P from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") class LeNet(nn.Cell): @@ -52,9 +54,6 @@ class LeNet(nn.Cell): return output -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - def train(net, data, label): learning_rate = 0.01 momentum = 0.9 diff --git a/tests/st/networks/test_gpu_alexnet.py b/tests/st/networks/test_gpu_alexnet.py index 9f92fc630e..699617b384 100644 --- a/tests/st/networks/test_gpu_alexnet.py +++ b/tests/st/networks/test_gpu_alexnet.py @@ -19,15 +19,17 @@ from __future__ import print_function import pytest import numpy as np +import mindspore.context as context import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore import Tensor from mindspore.common.initializer import initializer -import mindspore.context as context + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + class AlexNet(nn.Cell): def __init__(self, num_classes=10): super(AlexNet, self).__init__() @@ -66,6 +68,7 @@ class AlexNet(nn.Cell): x = self.fc3(x) return x + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -73,14 +76,14 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32): net = AlexNet(num_classes) lr = 0.1 momentum = 0.9 - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum, weight_decay = 0.0001) + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum, weight_decay=0.0001) criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) net_with_criterion = WithLossCell(net, criterion) train_network = TrainOneStepCell(net_with_criterion, optimizer) train_network.set_train() - losses=[] + losses = [] for i in range(0, epoch): - data = Tensor(np.ones([batch_size, 3 ,227, 227]).astype(np.float32) * 0.01) + data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) loss = train_network(data, label) losses.append(loss) diff --git a/tests/st/networks/test_gpu_lenet.py b/tests/st/networks/test_gpu_lenet.py index 4dac2247d0..b6b94cd23d 100644 --- a/tests/st/networks/test_gpu_lenet.py +++ b/tests/st/networks/test_gpu_lenet.py @@ -16,16 +16,19 @@ import pytest import numpy as np import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import Dense -from mindspore import Tensor from mindspore.common.initializer import initializer from mindspore.common import dtype as mstype -import mindspore.context as context + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + class LeNet(nn.Cell): def __init__(self): super(LeNet, self).__init__() @@ -65,6 +68,7 @@ def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32): lr.append(lr_) return Tensor(np.array(lr), dtype) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -81,7 +85,7 @@ def test_train_lenet(): train_network.set_train() losses = [] for i in range(epoch): - data = Tensor(np.ones([net.batch_size, 3 ,32, 32]).astype(np.float32) * 0.01) + data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([net.batch_size]).astype(np.int32)) loss = train_network(data, label) losses.append(loss) diff --git a/tests/st/networks/test_gpu_lstm.py b/tests/st/networks/test_gpu_lstm.py index e5208ff669..acf5ca9396 100644 --- a/tests/st/networks/test_gpu_lstm.py +++ b/tests/st/networks/test_gpu_lstm.py @@ -15,18 +15,20 @@ import pytest import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import Dense -from mindspore import Tensor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter -import mindspore.context as context -import mindspore.nn as nn + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + def InitialLstmWeight(input_size, hidden_size, num_layers, bidirectional, has_bias=False): num_directions = 1 if bidirectional: @@ -56,6 +58,7 @@ def InitialLstmWeight(input_size, hidden_size, num_layers, bidirectional, has_bi return h, c, w + class SentimentNet(nn.Cell): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, bidirectional, weight, labels, batch_size): @@ -99,6 +102,7 @@ class SentimentNet(nn.Cell): outputs = self.decoder(encoding) return outputs + batch_size = 64 @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -130,7 +134,7 @@ def test_LSTM(): train_network.set_train() train_features = Tensor(np.ones([64, max_len]).astype(np.int32)) - train_labels = Tensor(np.ones([64,]).astype(np.int32)[0:64]) + train_labels = Tensor(np.ones([64, ]).astype(np.int32)[0:64]) losses = [] for epoch in range(num_epochs): loss = train_network(train_features, train_labels) diff --git a/tests/st/networks/test_gpu_resnet.py b/tests/st/networks/test_gpu_resnet.py index 6d8337a6a9..a5f450d5e3 100644 --- a/tests/st/networks/test_gpu_resnet.py +++ b/tests/st/networks/test_gpu_resnet.py @@ -19,36 +19,34 @@ from __future__ import print_function import pytest import numpy as np - +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.cell import Cell from mindspore.nn.layer.conv import Conv2d from mindspore.nn.layer.basic import Flatten from mindspore.nn.layer.normalization import BatchNorm2d from mindspore.nn.layer.pooling import MaxPool2d from mindspore.ops.operations import TensorAdd -import mindspore.nn as nn - from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import Dense -from mindspore import Tensor from mindspore.common.initializer import initializer -import mindspore.context as context - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + def random_normal_init(shape, mean=0.0, stddev=0.01, seed=None): init_value = np.ones(shape).astype(np.float32) * 0.01 return Tensor(init_value) + def variance_scaling_raw(shape): variance_scaling_value = np.ones(shape).astype(np.float32) * 0.01 return Tensor(variance_scaling_value) - def weight_variable_0(shape): zeros = np.zeros(shape).astype(np.float32) return Tensor(zeros) @@ -323,6 +321,7 @@ class ResNet(Cell): def resnet50(num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -335,9 +334,9 @@ def test_trainTensor(num_classes=10, epoch=8, batch_size=1): net_with_criterion = WithLossCell(net, criterion) train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer train_network.set_train() - losses=[] + losses = [] for i in range(0, epoch): - data = Tensor(np.ones([batch_size, 3 ,224, 224]).astype(np.float32) * 0.01) + data = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) loss = train_network(data, label) losses.append(loss) diff --git a/tests/st/networks/test_network_main.py b/tests/st/networks/test_network_main.py index 4689adee54..79bd46d87a 100644 --- a/tests/st/networks/test_network_main.py +++ b/tests/st/networks/test_network_main.py @@ -22,16 +22,18 @@ import os import time import numpy as np import argparse +import mindspore.context as context import mindspore.nn as nn -from mindspore.common.tensor import Tensor +from mindspore import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell -import mindspore.context as context from mindspore.nn.optim import Momentum from models.lenet import LeNet from models.resnetv1_5 import resnet50 from models.alexnet import AlexNet + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + def train(net, data, label): learning_rate = 0.01 momentum = 0.9 @@ -42,29 +44,31 @@ def train(net, data, label): train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer train_network.set_train() res = train_network(data, label) - print("+++++++++Loss+++++++++++++") print(res) - print("+++++++++++++++++++++++++++") assert res + def test_resnet50(): - data = Tensor(np.ones([32, 3 ,224, 224]).astype(np.float32) * 0.01) + data = Tensor(np.ones([32, 3, 224, 224]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = resnet50(32, 10) train(net, data, label) + def test_lenet(): - data = Tensor(np.ones([32, 1 ,32, 32]).astype(np.float32) * 0.01) + data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = LeNet() train(net, data, label) + def test_alexnet(): - data = Tensor(np.ones([32, 3 ,227, 227]).astype(np.float32) * 0.01) + data = Tensor(np.ones([32, 3, 227, 227]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = AlexNet() train(net, data, label) + parser = argparse.ArgumentParser(description='MindSpore Testing Network') parser.add_argument('--net', default='resnet50', type=str, help='net name') parser.add_argument('--device', default='Ascend', type=str, help='device target') diff --git a/tests/st/ops/davinci/test_add.py b/tests/st/ops/ascend/test_add.py similarity index 100% rename from tests/st/ops/davinci/test_add.py rename to tests/st/ops/ascend/test_add.py diff --git a/tests/st/ops/davinci/test_addn.py b/tests/st/ops/ascend/test_addn.py similarity index 100% rename from tests/st/ops/davinci/test_addn.py rename to tests/st/ops/ascend/test_addn.py diff --git a/tests/st/ops/davinci/test_apply_momentum.py b/tests/st/ops/ascend/test_apply_momentum.py similarity index 97% rename from tests/st/ops/davinci/test_apply_momentum.py rename to tests/st/ops/ascend/test_apply_momentum.py index 885356ce48..e20c4f4746 100644 --- a/tests/st/ops/davinci/test_apply_momentum.py +++ b/tests/st/ops/ascend/test_apply_momentum.py @@ -1,44 +1,44 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.apply_momentum = P.ApplyMomentum(gradient_scale=1024.0) - self.variable = Parameter(initializer( - 'normal', [2, 3, 3, 4]), name='variable') - self.accumulation = Parameter(initializer( - 'normal', [2, 3, 3, 4]), name='accumulation') - self.learning_rate = Parameter(initializer( - 'normal', [1, ]), name='learning_rate') - self.gradient = Parameter(initializer( - 'normal', [2, 3, 3, 4]), name='gradient') - self.momentum = Parameter(initializer( - 'normal', [1, ]), name='momentum') - def construct(self): - return self.apply_momentum(self.variable, self.accumulation, self.learning_rate, self.gradient, self.momentum) - -def test_net(): - apply_momentum = Net() - output = apply_momentum() - print(output.asnumpy()) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.apply_momentum = P.ApplyMomentum(gradient_scale=1024.0) + self.variable = Parameter(initializer( + 'normal', [2, 3, 3, 4]), name='variable') + self.accumulation = Parameter(initializer( + 'normal', [2, 3, 3, 4]), name='accumulation') + self.learning_rate = Parameter(initializer( + 'normal', [1, ]), name='learning_rate') + self.gradient = Parameter(initializer( + 'normal', [2, 3, 3, 4]), name='gradient') + self.momentum = Parameter(initializer( + 'normal', [1, ]), name='momentum') + def construct(self): + return self.apply_momentum(self.variable, self.accumulation, self.learning_rate, self.gradient, self.momentum) + +def test_net(): + apply_momentum = Net() + output = apply_momentum() + print(output.asnumpy()) diff --git a/tests/st/ops/davinci/test_argmax.py b/tests/st/ops/ascend/test_argmax.py similarity index 100% rename from tests/st/ops/davinci/test_argmax.py rename to tests/st/ops/ascend/test_argmax.py diff --git a/tests/st/ops/davinci/test_biasAddGrad.py b/tests/st/ops/ascend/test_biasAddGrad.py similarity index 97% rename from tests/st/ops/davinci/test_biasAddGrad.py rename to tests/st/ops/ascend/test_biasAddGrad.py index 29b63ac336..f2e8f7a9bc 100644 --- a/tests/st/ops/davinci/test_biasAddGrad.py +++ b/tests/st/ops/ascend/test_biasAddGrad.py @@ -1,42 +1,42 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.bias_add_grad = G.BiasAddGrad() - #self.dout = Parameter(initializer( - #'normal', [2, 3, 3, 4]), name='dout') - - - @ms_function - def construct(self, dout): - return self.bias_add_grad(dout) - -dout = np.ones([2,3,4,4]).astype(np.float32) -bias_add_grad = Net() -output = bias_add_grad(Tensor(dout)) -expect_output = np.array([32.,32.,32.]).astype(np.float32) -assert np.all(output.asnumpy()==expect_output), "bias_add_grad execute failed, please check current code commit" -print(output.asnumpy()) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.bias_add_grad = G.BiasAddGrad() + #self.dout = Parameter(initializer( + #'normal', [2, 3, 3, 4]), name='dout') + + + @ms_function + def construct(self, dout): + return self.bias_add_grad(dout) + +dout = np.ones([2,3,4,4]).astype(np.float32) +bias_add_grad = Net() +output = bias_add_grad(Tensor(dout)) +expect_output = np.array([32.,32.,32.]).astype(np.float32) +assert np.all(output.asnumpy()==expect_output), "bias_add_grad execute failed, please check current code commit" +print(output.asnumpy()) diff --git a/tests/st/ops/davinci/test_bias_add_grad.py b/tests/st/ops/ascend/test_bias_add_grad.py similarity index 97% rename from tests/st/ops/davinci/test_bias_add_grad.py rename to tests/st/ops/ascend/test_bias_add_grad.py index e58d376e93..c6a51d8b3b 100644 --- a/tests/st/ops/davinci/test_bias_add_grad.py +++ b/tests/st/ops/ascend/test_bias_add_grad.py @@ -1,39 +1,39 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.bias_add_grad = G.BiasAddGrad() - - - @ms_function - def construct(self, dout): - return self.bias_add_grad(dout) - -def test_net(): - dout = np.random.rand(1, 1001).astype(np.float32) - bias_add_grad = Net() - output = bias_add_grad(dout) - print(output.asnumpy()) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.bias_add_grad = G.BiasAddGrad() + + + @ms_function + def construct(self, dout): + return self.bias_add_grad(dout) + +def test_net(): + dout = np.random.rand(1, 1001).astype(np.float32) + bias_add_grad = Net() + output = bias_add_grad(dout) + print(output.asnumpy()) diff --git a/tests/st/ops/davinci/test_conv.py b/tests/st/ops/ascend/test_conv.py similarity index 100% rename from tests/st/ops/davinci/test_conv.py rename to tests/st/ops/ascend/test_conv.py diff --git a/tests/st/ops/davinci/test_conv2dGradFilter.py b/tests/st/ops/ascend/test_conv2dGradFilter.py similarity index 100% rename from tests/st/ops/davinci/test_conv2dGradFilter.py rename to tests/st/ops/ascend/test_conv2dGradFilter.py diff --git a/tests/st/ops/davinci/test_conv_grad.py b/tests/st/ops/ascend/test_conv_grad.py similarity index 100% rename from tests/st/ops/davinci/test_conv_grad.py rename to tests/st/ops/ascend/test_conv_grad.py diff --git a/tests/st/ops/davinci/test_dense.py b/tests/st/ops/ascend/test_dense.py similarity index 100% rename from tests/st/ops/davinci/test_dense.py rename to tests/st/ops/ascend/test_dense.py diff --git a/tests/st/ops/davinci/test_dense_grad.py b/tests/st/ops/ascend/test_dense_grad.py similarity index 100% rename from tests/st/ops/davinci/test_dense_grad.py rename to tests/st/ops/ascend/test_dense_grad.py diff --git a/tests/st/ops/davinci/test_drop_out_gen_mask.py b/tests/st/ops/ascend/test_drop_out_gen_mask.py similarity index 97% rename from tests/st/ops/davinci/test_drop_out_gen_mask.py rename to tests/st/ops/ascend/test_drop_out_gen_mask.py index 4d7c555219..ce7ebbfbe0 100644 --- a/tests/st/ops/davinci/test_drop_out_gen_mask.py +++ b/tests/st/ops/ascend/test_drop_out_gen_mask.py @@ -1,44 +1,44 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -import numpy as np -import mindspore.context as context -context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend") - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.mask = P.DropoutGenMask(10, 28) - self.shape = P.Shape() - - def construct(self, x, y): - shape_x = self.shape(x) - return self.mask(shape_x, y) - - -x = np.ones([2, 4, 2, 2]).astype(np.int32) -y = np.array([1.0]).astype(np.float32) - - -def test_net(): - mask = Net() - tx, ty = Tensor(x), Tensor(y) - output = mask(tx, ty) - print(output.asnumpy()) - assert ([255, 255, 255, 255] == output.asnumpy()).all() +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context +context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mask = P.DropoutGenMask(10, 28) + self.shape = P.Shape() + + def construct(self, x, y): + shape_x = self.shape(x) + return self.mask(shape_x, y) + + +x = np.ones([2, 4, 2, 2]).astype(np.int32) +y = np.array([1.0]).astype(np.float32) + + +def test_net(): + mask = Net() + tx, ty = Tensor(x), Tensor(y) + output = mask(tx, ty) + print(output.asnumpy()) + assert ([255, 255, 255, 255] == output.asnumpy()).all() diff --git a/tests/st/ops/davinci/test_equal_count.py b/tests/st/ops/ascend/test_equal_count.py similarity index 100% rename from tests/st/ops/davinci/test_equal_count.py rename to tests/st/ops/ascend/test_equal_count.py diff --git a/tests/st/ops/davinci/test_full_connection.py b/tests/st/ops/ascend/test_full_connection.py similarity index 100% rename from tests/st/ops/davinci/test_full_connection.py rename to tests/st/ops/ascend/test_full_connection.py diff --git a/tests/st/ops/davinci/test_fused_batchnorm.py b/tests/st/ops/ascend/test_fused_batchnorm.py similarity index 100% rename from tests/st/ops/davinci/test_fused_batchnorm.py rename to tests/st/ops/ascend/test_fused_batchnorm.py diff --git a/tests/st/ops/davinci/test_fused_batchnorm_grad.py b/tests/st/ops/ascend/test_fused_batchnorm_grad.py similarity index 100% rename from tests/st/ops/davinci/test_fused_batchnorm_grad.py rename to tests/st/ops/ascend/test_fused_batchnorm_grad.py diff --git a/tests/st/ops/davinci/test_image_gradients.py b/tests/st/ops/ascend/test_image_gradients.py similarity index 100% rename from tests/st/ops/davinci/test_image_gradients.py rename to tests/st/ops/ascend/test_image_gradients.py diff --git a/tests/st/ops/davinci/test_matmul.py b/tests/st/ops/ascend/test_matmul.py similarity index 100% rename from tests/st/ops/davinci/test_matmul.py rename to tests/st/ops/ascend/test_matmul.py diff --git a/tests/st/ops/davinci/test_maxpool.py b/tests/st/ops/ascend/test_maxpool.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool.py rename to tests/st/ops/ascend/test_maxpool.py diff --git a/tests/st/ops/davinci/test_maxpool_grad.py b/tests/st/ops/ascend/test_maxpool_grad.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool_grad.py rename to tests/st/ops/ascend/test_maxpool_grad.py diff --git a/tests/st/ops/davinci/test_maxpool_with_argmax.py b/tests/st/ops/ascend/test_maxpool_with_argmax.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool_with_argmax.py rename to tests/st/ops/ascend/test_maxpool_with_argmax.py diff --git a/tests/st/ops/davinci/test_maxpool_with_argmax_grad.py b/tests/st/ops/ascend/test_maxpool_with_argmax_grad.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool_with_argmax_grad.py rename to tests/st/ops/ascend/test_maxpool_with_argmax_grad.py diff --git a/tests/st/ops/davinci/test_relu.py b/tests/st/ops/ascend/test_relu.py similarity index 100% rename from tests/st/ops/davinci/test_relu.py rename to tests/st/ops/ascend/test_relu.py diff --git a/tests/st/ops/davinci/test_relu_grad.py b/tests/st/ops/ascend/test_relu_grad.py similarity index 100% rename from tests/st/ops/davinci/test_relu_grad.py rename to tests/st/ops/ascend/test_relu_grad.py diff --git a/tests/st/ops/davinci/test_reshape.py b/tests/st/ops/ascend/test_reshape.py similarity index 100% rename from tests/st/ops/davinci/test_reshape.py rename to tests/st/ops/ascend/test_reshape.py diff --git a/tests/st/ops/davinci/test_simplemean.py b/tests/st/ops/ascend/test_simplemean.py similarity index 100% rename from tests/st/ops/davinci/test_simplemean.py rename to tests/st/ops/ascend/test_simplemean.py diff --git a/tests/st/ops/davinci/test_simplemean_grad.py b/tests/st/ops/ascend/test_simplemean_grad.py similarity index 100% rename from tests/st/ops/davinci/test_simplemean_grad.py rename to tests/st/ops/ascend/test_simplemean_grad.py diff --git a/tests/st/ops/davinci/test_softmax.py b/tests/st/ops/ascend/test_softmax.py similarity index 100% rename from tests/st/ops/davinci/test_softmax.py rename to tests/st/ops/ascend/test_softmax.py diff --git a/tests/st/ops/davinci/test_sparseSoftmaxCrossEntropyWithLogits.py b/tests/st/ops/ascend/test_sparseSoftmaxCrossEntropyWithLogits.py similarity index 100% rename from tests/st/ops/davinci/test_sparseSoftmaxCrossEntropyWithLogits.py rename to tests/st/ops/ascend/test_sparseSoftmaxCrossEntropyWithLogits.py diff --git a/tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits.py b/tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits.py similarity index 100% rename from tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits.py rename to tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits.py diff --git a/tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits_grad.py b/tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits_grad.py similarity index 100% rename from tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits_grad.py rename to tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_AssignAdd.py b/tests/st/ops/ascend/test_tbe_ops/test_AssignAdd.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_AssignAdd.py rename to tests/st/ops/ascend/test_tbe_ops/test_AssignAdd.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_AssignSub.py b/tests/st/ops/ascend/test_tbe_ops/test_AssignSub.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_AssignSub.py rename to tests/st/ops/ascend/test_tbe_ops/test_AssignSub.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_ReduceMean.py b/tests/st/ops/ascend/test_tbe_ops/test_ReduceMean.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_ReduceMean.py rename to tests/st/ops/ascend/test_tbe_ops/test_ReduceMean.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_add.py b/tests/st/ops/ascend/test_tbe_ops/test_add.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_add.py rename to tests/st/ops/ascend/test_tbe_ops/test_add.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_addn.py b/tests/st/ops/ascend/test_tbe_ops/test_addn.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_addn.py rename to tests/st/ops/ascend/test_tbe_ops/test_addn.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_apply_adam.py b/tests/st/ops/ascend/test_tbe_ops/test_apply_adam.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_apply_adam.py rename to tests/st/ops/ascend/test_tbe_ops/test_apply_adam.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_apply_momentum.py b/tests/st/ops/ascend/test_tbe_ops/test_apply_momentum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_apply_momentum.py rename to tests/st/ops/ascend/test_tbe_ops/test_apply_momentum.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_batchmatmul.py b/tests/st/ops/ascend/test_tbe_ops/test_batchmatmul.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_batchmatmul.py rename to tests/st/ops/ascend/test_tbe_ops/test_batchmatmul.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_batchnorm.py b/tests/st/ops/ascend/test_tbe_ops/test_batchnorm.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_batchnorm.py rename to tests/st/ops/ascend/test_tbe_ops/test_batchnorm.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_batchnorm_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_batchnorm_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_batchnorm_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_batchnorm_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_bias_add.py b/tests/st/ops/ascend/test_tbe_ops/test_bias_add.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_bias_add.py rename to tests/st/ops/ascend/test_tbe_ops/test_bias_add.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_bias_add_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_bias_add_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_bias_add_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_bias_add_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_concat.py b/tests/st/ops/ascend/test_tbe_ops/test_concat.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_concat.py rename to tests/st/ops/ascend/test_tbe_ops/test_concat.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_conv.py b/tests/st/ops/ascend/test_tbe_ops/test_conv.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_conv.py rename to tests/st/ops/ascend/test_tbe_ops/test_conv.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_filter.py b/tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_filter.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_filter.py rename to tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_filter.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_input.py b/tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_input.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_input.py rename to tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_input.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_dropout_do_mask.py b/tests/st/ops/ascend/test_tbe_ops/test_dropout_do_mask.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_dropout_do_mask.py rename to tests/st/ops/ascend/test_tbe_ops/test_dropout_do_mask.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_gelu.py b/tests/st/ops/ascend/test_tbe_ops/test_gelu.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_gelu.py rename to tests/st/ops/ascend/test_tbe_ops/test_gelu.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_gelu_grad_sens.py b/tests/st/ops/ascend/test_tbe_ops/test_gelu_grad_sens.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_gelu_grad_sens.py rename to tests/st/ops/ascend/test_tbe_ops/test_gelu_grad_sens.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_greater.py b/tests/st/ops/ascend/test_tbe_ops/test_greater.py similarity index 95% rename from tests/st/ops/davinci/test_tbe_ops/test_greater.py rename to tests/st/ops/ascend/test_tbe_ops/test_greater.py index 3976ad4301..b9dae700c2 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_greater.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_greater.py @@ -1,51 +1,51 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from mindspore.ops import operations as P -from mindspore.nn import Cell -from mindspore.common.tensor import Tensor -from mindspore.train.model import Model -from mindspore import log as logger -from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - -class Greater(Cell): - def __init__(self): - super(Greater, self).__init__() - self.greater = P.Greater() - - def construct(self, inputa, inputb): - return self.greater(inputa, inputb) - -def me_greater(inputa, inputb): - net = Greater() - net.set_train() - model = Model(net) - - out = model.predict(inputa, inputb) - logger.info("Check input a: ") - logger.info(inputa) - logger.info("Check input b: ") - logger.info(inputb) - return out.asnumpy() - -@pytest.mark.ssd_tbe -def test_greater_2d_scalar0(): - a = np.random.randint(-5, 5, [8, 32]).astype(np.int32) - b = np.random.randint(-5, 5, [8, 32]).astype(np.int32) - out_me = me_greater(Tensor(a), Tensor(b)) - logger.info("Check me result:") +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore import log as logger +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Greater(Cell): + def __init__(self): + super(Greater, self).__init__() + self.greater = P.Greater() + + def construct(self, inputa, inputb): + return self.greater(inputa, inputb) + +def me_greater(inputa, inputb): + net = Greater() + net.set_train() + model = Model(net) + + out = model.predict(inputa, inputb) + logger.info("Check input a: ") + logger.info(inputa) + logger.info("Check input b: ") + logger.info(inputb) + return out.asnumpy() + +@pytest.mark.ssd_tbe +def test_greater_2d_scalar0(): + a = np.random.randint(-5, 5, [8, 32]).astype(np.int32) + b = np.random.randint(-5, 5, [8, 32]).astype(np.int32) + out_me = me_greater(Tensor(a), Tensor(b)) + logger.info("Check me result:") logger.info(out_me) \ No newline at end of file diff --git a/tests/st/ops/davinci/test_tbe_ops/test_layernorm.py b/tests/st/ops/ascend/test_tbe_ops/test_layernorm.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_layernorm.py rename to tests/st/ops/ascend/test_tbe_ops/test_layernorm.py index 586c02cc1e..f3e4e43958 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_layernorm.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_layernorm.py @@ -1,55 +1,55 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -from mindspore.nn import LayerNorm -from mindspore.common.tensor import Tensor -from mindspore.nn import Cell -from mindspore.train.model import Model -from mindspore import log as logger -import pytest -from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - -class Net(Cell): - def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): - super(Net, self).__init__() - self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) - - def construct(self, input): - x = self.layernorm(input) - return x - -def pt_me_layernorm(input_data, normalized_shape, gamma, beta, axis): - net = Net(normalized_shape, begin_norm_axis=axis, - begin_params_axis=axis, - gamma=Tensor(gamma), - beta=Tensor(beta)) - net.set_train() - model = Model(net) - out_me = model.predict(Tensor(input_data)) - logger.info("Check me result:") - logger.info(out_me.asnumpy()) - -@pytest.mark.lower_bs -def test_normal_layernorm_1_128_1024_axis_2(): - """ - 2 input[1, 128, 1024],normalized_shape=[128, 1024] - """ - input_data = np.random.randn(1, 128, 1024).astype(np.float32) - gamma = np.random.randn(1024).astype(np.float32) - gamma.fill(1.1) - beta = np.random.randn(1024).astype(np.float32) - beta.fill(0.1) - pt_me_layernorm(input_data, (1024, ), gamma, beta, 2) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from mindspore.nn import LayerNorm +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell +from mindspore.train.model import Model +from mindspore import log as logger +import pytest +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(Cell): + def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): + super(Net, self).__init__() + self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) + + def construct(self, input): + x = self.layernorm(input) + return x + +def pt_me_layernorm(input_data, normalized_shape, gamma, beta, axis): + net = Net(normalized_shape, begin_norm_axis=axis, + begin_params_axis=axis, + gamma=Tensor(gamma), + beta=Tensor(beta)) + net.set_train() + model = Model(net) + out_me = model.predict(Tensor(input_data)) + logger.info("Check me result:") + logger.info(out_me.asnumpy()) + +@pytest.mark.lower_bs +def test_normal_layernorm_1_128_1024_axis_2(): + """ + 2 input[1, 128, 1024],normalized_shape=[128, 1024] + """ + input_data = np.random.randn(1, 128, 1024).astype(np.float32) + gamma = np.random.randn(1024).astype(np.float32) + gamma.fill(1.1) + beta = np.random.randn(1024).astype(np.float32) + beta.fill(0.1) + pt_me_layernorm(input_data, (1024, ), gamma, beta, 2) diff --git a/tests/st/ops/davinci/test_tbe_ops/test_layernorm_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_layernorm_grad.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_layernorm_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_layernorm_grad.py index 9f8eefdb3f..5ae09886ce 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_layernorm_grad.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_layernorm_grad.py @@ -1,65 +1,65 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -from mindspore.nn import LayerNorm -from mindspore.common.tensor import Tensor -from mindspore.nn import Cell -from mindspore.ops.composite import GradOperation -from mindspore import log as logger -from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - -class Grad(Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) - self.network = network - - def construct(self, input, output_grad,): - gout = self.grad(self.network)(input, output_grad) - return gout - -class Net(Cell): - def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): - super(Net, self).__init__() - self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) - - def construct(self, input): - x = self.layernorm(input) - return x - -def py_me_layernorm_grad(input_data, normalized_shape, gamma, beta, axis, gradients): - input_me = Tensor(input_data) - net_me = Grad(Net(normalized_shape, begin_norm_axis=axis, - begin_params_axis=axis, - gamma=Tensor(gamma), - beta=Tensor(beta))) - net_me.set_train() - out_pool_grad_me = Tensor(gradients) - out_grad = net_me(input_me, out_pool_grad_me) - logger.info("Check me result:") - logger.info(out_grad.asnumpy()) - -def test_normal_layernorm_grad_normalize_2d(): - """ - 1 input[1, 128, 1024],normalized_shape=[1024],element_affine=False - """ - input_data = np.ones([1, 128, 1024]).astype(np.float32) - gradients = np.ones((1, 128, 1024)).astype(np.float32) - gamma = np.random.randn(1024).astype(np.float32) - gamma.fill(1.1) - beta = np.random.randn(1024).astype(np.float32) - beta.fill(0.1) - py_me_layernorm_grad(input_data, (1024,), gamma, beta, 2, gradients) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from mindspore.nn import LayerNorm +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell +from mindspore.ops.composite import GradOperation +from mindspore import log as logger +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Grad(Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, input, output_grad,): + gout = self.grad(self.network)(input, output_grad) + return gout + +class Net(Cell): + def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): + super(Net, self).__init__() + self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) + + def construct(self, input): + x = self.layernorm(input) + return x + +def py_me_layernorm_grad(input_data, normalized_shape, gamma, beta, axis, gradients): + input_me = Tensor(input_data) + net_me = Grad(Net(normalized_shape, begin_norm_axis=axis, + begin_params_axis=axis, + gamma=Tensor(gamma), + beta=Tensor(beta))) + net_me.set_train() + out_pool_grad_me = Tensor(gradients) + out_grad = net_me(input_me, out_pool_grad_me) + logger.info("Check me result:") + logger.info(out_grad.asnumpy()) + +def test_normal_layernorm_grad_normalize_2d(): + """ + 1 input[1, 128, 1024],normalized_shape=[1024],element_affine=False + """ + input_data = np.ones([1, 128, 1024]).astype(np.float32) + gradients = np.ones((1, 128, 1024)).astype(np.float32) + gamma = np.random.randn(1024).astype(np.float32) + gamma.fill(1.1) + beta = np.random.randn(1024).astype(np.float32) + beta.fill(0.1) + py_me_layernorm_grad(input_data, (1024,), gamma, beta, 2, gradients) diff --git a/tests/st/ops/davinci/test_tbe_ops/test_less.py b/tests/st/ops/ascend/test_tbe_ops/test_less.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_less.py rename to tests/st/ops/ascend/test_tbe_ops/test_less.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_less_equal.py b/tests/st/ops/ascend/test_tbe_ops/test_less_equal.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_less_equal.py rename to tests/st/ops/ascend/test_tbe_ops/test_less_equal.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_logical_and.py b/tests/st/ops/ascend/test_tbe_ops/test_logical_and.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_logical_and.py rename to tests/st/ops/ascend/test_tbe_ops/test_logical_and.py index c9f180a56e..1df04b27d4 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_logical_and.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_logical_and.py @@ -1,39 +1,39 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.logical_and = P.LogicalAnd() - - @ms_function - def construct(self, x1, x2): - return self.logical_and(x1, x2) - -x1 = [True, True, False, False, True, True, False, False] -x2 = [True, False, False, True, True, False, False, True] -def test_net(): - logical_and = Net() - output = logical_and(Tensor(x1), Tensor(x2)) - print(x1) - print(x2) - print(output.asnumpy()) - +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.logical_and = P.LogicalAnd() + + @ms_function + def construct(self, x1, x2): + return self.logical_and(x1, x2) + +x1 = [True, True, False, False, True, True, False, False] +x2 = [True, False, False, True, True, False, False, True] +def test_net(): + logical_and = Net() + output = logical_and(Tensor(x1), Tensor(x2)) + print(x1) + print(x2) + print(output.asnumpy()) + diff --git a/tests/st/ops/davinci/test_tbe_ops/test_logical_not.py b/tests/st/ops/ascend/test_tbe_ops/test_logical_not.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_logical_not.py rename to tests/st/ops/ascend/test_tbe_ops/test_logical_not.py index 97e9caa5c9..5d13a48138 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_logical_not.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_logical_not.py @@ -1,38 +1,38 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.logical_not = P.LogicalNot() - - @ms_function - def construct(self, x1): - return self.logical_not(x1) - -x1 = [True, True, False, False, True, True, False, False] - -def test_net(): - logical_not = Net() - output = logical_not(Tensor(x1)) - print(x1) - print(output.asnumpy()) - +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.logical_not = P.LogicalNot() + + @ms_function + def construct(self, x1): + return self.logical_not(x1) + +x1 = [True, True, False, False, True, True, False, False] + +def test_net(): + logical_not = Net() + output = logical_not(Tensor(x1)) + print(x1) + print(output.asnumpy()) + diff --git a/tests/st/ops/davinci/test_tbe_ops/test_logical_or.py b/tests/st/ops/ascend/test_tbe_ops/test_logical_or.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_logical_or.py rename to tests/st/ops/ascend/test_tbe_ops/test_logical_or.py index e34d94c3e7..a2b7841c71 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_logical_or.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_logical_or.py @@ -1,39 +1,39 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.logical_or = P.LogicalOr() - - @ms_function - def construct(self, x1, x2): - return self.logical_or(x1, x2) - -x1 = [True, True, False, False, True, True, False, False] -x2 = [True, False, False, True, True, False, False, True] -def test_net(): - logical_or = Net() - output = logical_or(Tensor(x1), Tensor(x2)) - print(x1) - print(x2) - print(output.asnumpy()) - +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.logical_or = P.LogicalOr() + + @ms_function + def construct(self, x1, x2): + return self.logical_or(x1, x2) + +x1 = [True, True, False, False, True, True, False, False] +x2 = [True, False, False, True, True, False, False, True] +def test_net(): + logical_or = Net() + output = logical_or(Tensor(x1), Tensor(x2)) + print(x1) + print(x2) + print(output.asnumpy()) + diff --git a/tests/st/ops/davinci/test_tbe_ops/test_matmul.py b/tests/st/ops/ascend/test_tbe_ops/test_matmul.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_matmul.py rename to tests/st/ops/ascend/test_tbe_ops/test_matmul.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_matmul_failed.py b/tests/st/ops/ascend/test_tbe_ops/test_matmul_failed.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_matmul_failed.py rename to tests/st/ops/ascend/test_tbe_ops/test_matmul_failed.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maximum.py b/tests/st/ops/ascend/test_tbe_ops/test_maximum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maximum.py rename to tests/st/ops/ascend/test_tbe_ops/test_maximum.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maximum_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_maximum_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maximum_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_maximum_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maxpool.py b/tests/st/ops/ascend/test_tbe_ops/test_maxpool.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maxpool.py rename to tests/st/ops/ascend/test_tbe_ops/test_maxpool.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maxpool_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_maxpool_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maxpool_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_maxpool_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_minimum.py b/tests/st/ops/ascend/test_tbe_ops/test_minimum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_minimum.py rename to tests/st/ops/ascend/test_tbe_ops/test_minimum.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_minimum_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_minimum_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_minimum_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_minimum_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_mul.py b/tests/st/ops/ascend/test_tbe_ops/test_mul.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_mul.py rename to tests/st/ops/ascend/test_tbe_ops/test_mul.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_npu_alloc_float_status.py b/tests/st/ops/ascend/test_tbe_ops/test_npu_alloc_float_status.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_npu_alloc_float_status.py rename to tests/st/ops/ascend/test_tbe_ops/test_npu_alloc_float_status.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_npu_clear_float_status.py b/tests/st/ops/ascend/test_tbe_ops/test_npu_clear_float_status.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_npu_clear_float_status.py rename to tests/st/ops/ascend/test_tbe_ops/test_npu_clear_float_status.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_npu_get_float_status.py b/tests/st/ops/ascend/test_tbe_ops/test_npu_get_float_status.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_npu_get_float_status.py rename to tests/st/ops/ascend/test_tbe_ops/test_npu_get_float_status.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_pad.py b/tests/st/ops/ascend/test_tbe_ops/test_pad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_pad.py rename to tests/st/ops/ascend/test_tbe_ops/test_pad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_pow.py b/tests/st/ops/ascend/test_tbe_ops/test_pow.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_pow.py rename to tests/st/ops/ascend/test_tbe_ops/test_pow.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_realdiv.py b/tests/st/ops/ascend/test_tbe_ops/test_realdiv.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_realdiv.py rename to tests/st/ops/ascend/test_tbe_ops/test_realdiv.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_reciprocal.py b/tests/st/ops/ascend/test_tbe_ops/test_reciprocal.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_reciprocal.py rename to tests/st/ops/ascend/test_tbe_ops/test_reciprocal.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_relu.py b/tests/st/ops/ascend/test_tbe_ops/test_relu.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_relu.py rename to tests/st/ops/ascend/test_tbe_ops/test_relu.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_relu_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_relu_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_relu_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_relu_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor.py b/tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor.py rename to tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_scatter_nd.py b/tests/st/ops/ascend/test_tbe_ops/test_scatter_nd.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_scatter_nd.py rename to tests/st/ops/ascend/test_tbe_ops/test_scatter_nd.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_select.py b/tests/st/ops/ascend/test_tbe_ops/test_select.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_select.py rename to tests/st/ops/ascend/test_tbe_ops/test_select.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_slice.py b/tests/st/ops/ascend/test_tbe_ops/test_slice.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_slice.py rename to tests/st/ops/ascend/test_tbe_ops/test_slice.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py b/tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py rename to tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_softmax.py b/tests/st/ops/ascend/test_tbe_ops/test_softmax.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_softmax.py rename to tests/st/ops/ascend/test_tbe_ops/test_softmax.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_softmax_cross_entropy_with_logits.py b/tests/st/ops/ascend/test_tbe_ops/test_softmax_cross_entropy_with_logits.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_softmax_cross_entropy_with_logits.py rename to tests/st/ops/ascend/test_tbe_ops/test_softmax_cross_entropy_with_logits.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_split.py b/tests/st/ops/ascend/test_tbe_ops/test_split.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_split.py rename to tests/st/ops/ascend/test_tbe_ops/test_split.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sqrt.py b/tests/st/ops/ascend/test_tbe_ops/test_sqrt.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sqrt.py rename to tests/st/ops/ascend/test_tbe_ops/test_sqrt.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_square.py b/tests/st/ops/ascend/test_tbe_ops/test_square.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_square.py rename to tests/st/ops/ascend/test_tbe_ops/test_square.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_stridedslice.py b/tests/st/ops/ascend/test_tbe_ops/test_stridedslice.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_stridedslice.py rename to tests/st/ops/ascend/test_tbe_ops/test_stridedslice.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_stridedslice_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_stridedslice_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_stridedslice_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_stridedslice_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sub.py b/tests/st/ops/ascend/test_tbe_ops/test_sub.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sub.py rename to tests/st/ops/ascend/test_tbe_ops/test_sub.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_tanh.py b/tests/st/ops/ascend/test_tbe_ops/test_tanh.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_tanh.py rename to tests/st/ops/ascend/test_tbe_ops/test_tanh.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_tanh_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_tanh_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_tanh_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_tanh_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_tile.py b/tests/st/ops/ascend/test_tbe_ops/test_tile.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_tile.py rename to tests/st/ops/ascend/test_tbe_ops/test_tile.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_topk.py b/tests/st/ops/ascend/test_tbe_ops/test_topk.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_topk.py rename to tests/st/ops/ascend/test_tbe_ops/test_topk.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_transpose_d.py b/tests/st/ops/ascend/test_tbe_ops/test_transpose_d.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_transpose_d.py rename to tests/st/ops/ascend/test_tbe_ops/test_transpose_d.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_unsorted_segment_sum.py b/tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_unsorted_segment_sum.py rename to tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py diff --git a/tests/st/ops/davinci/test_tdt_data_ms.py b/tests/st/ops/ascend/test_tdt_data_ms.py similarity index 100% rename from tests/st/ops/davinci/test_tdt_data_ms.py rename to tests/st/ops/ascend/test_tdt_data_ms.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py b/tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py new file mode 100644 index 0000000000..28bf566c2d --- /dev/null +++ b/tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.ops.composite import GradOperation +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True) + self.network = network + + @ms_function + def construct(self, input): + return self.grad(self.network)(input) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.relu_v2 = P.ReLUV2() + + def construct(self, x): + return self.relu_v2(x) + +def test_net(): + x = Tensor(np.ones((2,3,3,4)).astype(np.float32)) + relu_net = Net() + relu_output = relu_net(x) + net = Grad(Net()) + output_grad = net(x) + print(relu_output[0].asnumpy()) + print(relu_output[1].asnumpy()) + print(len(output_grad)) + print(output_grad[0].asnumpy()) diff --git a/tests/st/ops/gpu/test_float_status_op.py b/tests/st/ops/gpu/test_float_status_op.py new file mode 100644 index 0000000000..09fc90feaa --- /dev/null +++ b/tests/st/ops/gpu/test_float_status_op.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.status = P.FloatStatus() + + def construct(self, x): + return self.status(x) + +class Netnan(nn.Cell): + def __init__(self): + super(Netnan, self).__init__() + self.isnan = P.IsNan() + + def construct(self, x): + return self.isnan(x) + +class Netinf(nn.Cell): + def __init__(self): + super(Netinf, self).__init__() + self.isinf = P.IsInf() + + def construct(self, x): + return self.isinf(x) + +class Netfinite(nn.Cell): + def __init__(self): + super(Netfinite, self).__init__() + self.isfinite = P.IsFinite() + + def construct(self, x): + return self.isfinite(x) + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") +x1 = np.array([[1.2, 2, np.nan, 88]]).astype(np.float32) +x2 = np.array([[np.inf, 1, 88.0, 0]]).astype(np.float32) +x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_status(): + ms_status = Net(); + output1 = ms_status(Tensor(x1)) + output2 = ms_status(Tensor(x2)) + output3 = ms_status(Tensor(x3)) + expect1 = 1 + expect2 = 1 + expect3 = 0 + assert output1.asnumpy()[0] == expect1 + assert output2.asnumpy()[0] == expect2 + assert output3.asnumpy()[0] == expect3 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nan(): + ms_isnan = Netnan(); + output1 = ms_isnan(Tensor(x1)) + output2 = ms_isnan(Tensor(x2)) + output3 = ms_isnan(Tensor(x3)) + expect1 = [[False, False, True, False]] + expect2 = [[False, False, False, False]] + expect3 = [[False, False], [False, False], [False, False]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_inf(): + ms_isinf = Netinf(); + output1 = ms_isinf(Tensor(x1)) + output2 = ms_isinf(Tensor(x2)) + output3 = ms_isinf(Tensor(x3)) + expect1 = [[False, False, False, False]] + expect2 = [[True, False, False, False]] + expect3 = [[False, False], [False, False], [False, False]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_finite(): + ms_isfinite = Netfinite(); + output1 = ms_isfinite(Tensor(x1)) + output2 = ms_isfinite(Tensor(x2)) + output3 = ms_isfinite(Tensor(x3)) + expect1 = [[True, True, False, True]] + expect2 = [[False, True, True, True]] + expect3 = [[True, True], [True, True], [True, True]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() diff --git a/tests/st/ops/gpu/test_lessequal_op.py b/tests/st/ops/gpu/test_lessequal_op.py new file mode 100644 index 0000000000..08bb28b0af --- /dev/null +++ b/tests/st/ops/gpu/test_lessequal_op.py @@ -0,0 +1,49 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +import mindspore.context as context +import numpy as np + + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.lessequal = P.LessEqual() + + def construct(self, x, y): + return self.lessequal(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_lessequal(): + x = Tensor(np.array([[1, 2, 3]]).astype(np.float32)) + y = Tensor(np.array([[2]]).astype(np.float32)) + expect = [[True, True, False]] + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + lessequal = Net() + output = lessequal(x, y) + assert np.all(output.asnumpy() == expect) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + lessequal = Net() + output = lessequal(x, y) + assert np.all(output.asnumpy() == expect) + diff --git a/tests/st/ops/gpu/test_logical_op.py b/tests/st/ops/gpu/test_logical_op.py new file mode 100644 index 0000000000..ab95aa8f3f --- /dev/null +++ b/tests/st/ops/gpu/test_logical_op.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +import mindspore.context as context +import numpy as np + + +class NetAnd(Cell): + def __init__(self): + super(NetAnd, self).__init__() + self.logicaland = P.LogicalAnd() + + def construct(self, x, y): + return self.logicaland(x, y) + +class NetOr(Cell): + def __init__(self): + super(NetOr, self).__init__() + self.logicalor = P.LogicalOr() + + def construct(self, x, y): + return self.logicalor(x, y) + +class NetNot(Cell): + def __init__(self): + super(NetNot, self).__init__() + self.logicalnot = P.LogicalNot() + + def construct(self, x): + return self.logicalnot(x) + +x = np.array([True, False, False]).astype(np.bool) +y = np.array([False]).astype(np.bool) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_logicaland(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + logicaland = NetAnd() + output = logicaland(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_and(x, y)) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + logicaland = NetAnd() + output = logicaland(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_and(x, y)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_logicalor(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + logicalor = NetOr() + output = logicalor(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_or(x, y)) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + logicalor = NetOr() + output = logicalor(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_or(x, y)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_logicalnot(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + logicalnot = NetNot() + output = logicalnot(Tensor(x)) + assert np.all(output.asnumpy() == np.logical_not(x)) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + logicalnot = NetNot() + output = logicalnot(Tensor(x)) + assert np.all(output.asnumpy() == np.logical_not(x)) + diff --git a/tests/st/ops/gpu/test_maximum_op.py b/tests/st/ops/gpu/test_maximum_op.py new file mode 100644 index 0000000000..3193dafa61 --- /dev/null +++ b/tests/st/ops/gpu/test_maximum_op.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +import mindspore.context as context +import numpy as np + + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.max = P.Maximum() + + def construct(self, x, y): + return self.max(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_max(): + x = Tensor(np.array([[1, 2, 3]]).astype(np.float32)) + y = Tensor(np.array([[2]]).astype(np.float32)) + expect = [[2, 2, 3]] + error = np.ones(shape=[1, 3]) * 1.0e-5 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + max = Net() + output = max(x, y) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + max = Net() + output = max(x, y) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + diff --git a/tests/st/pynative/test_ascend_lenet.py b/tests/st/pynative/test_ascend_lenet.py index 4009844791..5a84aaf930 100644 --- a/tests/st/pynative/test_ascend_lenet.py +++ b/tests/st/pynative/test_ascend_lenet.py @@ -14,7 +14,8 @@ # ============================================================================ import pytest import numpy as np -import time, math +import time +import math import mindspore.nn as nn from mindspore import context, Tensor, ParameterTuple from mindspore.ops import operations as P @@ -28,6 +29,7 @@ from mindspore.nn.optim import Momentum np.random.seed(1) + def weight_variable(): """weight initial""" return TruncatedNormal(0.02) @@ -58,6 +60,7 @@ class LeNet(nn.Cell): Examples: >>> LeNet(num_class=10) """ + def __init__(self, num_class=10): super(LeNet, self).__init__() self.num_class = num_class @@ -91,6 +94,7 @@ class CrossEntropyLoss(nn.Cell): """ Define loss for network """ + def __init__(self): super(CrossEntropyLoss, self).__init__() self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() @@ -111,6 +115,7 @@ class GradWrap(nn.Cell): """ GradWrap definition """ + def __init__(self, network): super(GradWrap, self).__init__() self.network = network @@ -154,4 +159,3 @@ def test_ascend_pynative_lenet(): print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) assert(loss_output.asnumpy() < 0.1) - \ No newline at end of file diff --git a/tests/st/summary/test_davinci_summary.py b/tests/st/summary/test_davinci_summary.py index 1611ca8ec7..a2ed840515 100644 --- a/tests/st/summary/test_davinci_summary.py +++ b/tests/st/summary/test_davinci_summary.py @@ -33,10 +33,12 @@ SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" context.set_context(device_target="Ascend") + class MsWrapper(nn.Cell): def __init__(self, network): super(MsWrapper, self).__init__(auto_prefix=False) self._network = network + @ms_function def construct(self, *args): return self._network(*args) @@ -45,14 +47,15 @@ class MsWrapper(nn.Cell): def me_train_tensor(net, input_np, label_np, epoch_size=2): context.set_context(mode=context.GRAPH_MODE) loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - opt = ApplyMomentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), filter(lambda x: x.requires_grad, net.get_parameters())) + opt = ApplyMomentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), + filter(lambda x: x.requires_grad, net.get_parameters())) Model(net, loss, opt) _network = wrap.WithLossCell(net, loss) _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) _train_net.set_train() summary_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) for epoch in range(0, epoch_size): - print(f"epoch %d"%(epoch)) + print(f"epoch %d" % (epoch)) output = _train_net(Tensor(input_np), Tensor(label_np)) summary_writer.record(i) print("********output***********") diff --git a/tests/st/summary/test_gpu_summary.py b/tests/st/summary/test_gpu_summary.py index c97c08c4e1..e8eadc66ab 100644 --- a/tests/st/summary/test_gpu_summary.py +++ b/tests/st/summary/test_gpu_summary.py @@ -108,6 +108,6 @@ def me_scalar_summary(steps, tag=None, value=None): def test_scalarsummary_scalar1_step10_summaryrecord1(): clean_environment_file(SUMMARY_DIR_ME_TEMP) output_dict = me_scalar_summary(10) - print("test_scalarsummary_scalar1_step10_summaryrecord1 \n",output_dict) + print("test_scalarsummary_scalar1_step10_summaryrecord1 \n", output_dict) save_summary_events_file(SUMMARY_DIR_ME_TEMP, SUMMARY_DIR_ME) clean_environment_file(SUMMARY_DIR_ME) diff --git a/tests/st/tbe_networks/export_geir.py b/tests/st/tbe_networks/export_geir.py index 467388c5e8..a4368e6320 100644 --- a/tests/st/tbe_networks/export_geir.py +++ b/tests/st/tbe_networks/export_geir.py @@ -24,12 +24,13 @@ import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import save, load, save_checkpoint, load_checkpoint,\ - load_param_into_net, _exec_save_checkpoint,\ - _check_filedir_or_create, _chg_model_file_name_if_same_exist, \ - _read_file_last_line, context, export + load_param_into_net, _exec_save_checkpoint,\ + _check_filedir_or_create, _chg_model_file_name_if_same_exist, \ + _read_file_last_line, context, export + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", + enable_task_sink=True, enable_loop_sink=True, enable_ir_fusion=True) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", -enable_task_sink=True,enable_loop_sink=True,enable_ir_fusion=True) def test_resnet50_export(batch_size=1, num_classes=5): context.set_context(enable_ir_fusion=False) diff --git a/tests/st/tbe_networks/resnet.py b/tests/st/tbe_networks/resnet.py index 2024286b8f..4f2ff79a86 100644 --- a/tests/st/tbe_networks/resnet.py +++ b/tests/st/tbe_networks/resnet.py @@ -19,6 +19,7 @@ from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common import dtype as mstype + def weight_variable(shape): return initializer('XavierUniform', shape=shape, dtype=mstype.float32) @@ -297,4 +298,3 @@ class ResNet(nn.Cell): def resnet50(batch_size, num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes, batch_size) - diff --git a/tests/st/tbe_networks/resnet_cifar.py b/tests/st/tbe_networks/resnet_cifar.py index f1ab02afa3..7bd03f5d81 100644 --- a/tests/st/tbe_networks/resnet_cifar.py +++ b/tests/st/tbe_networks/resnet_cifar.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import argparse import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P @@ -35,7 +36,6 @@ random.seed(1) np.random.seed(1) ds.config.set_seed(1) -import argparse parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--device_num', type=int, default=1, help='Device num.') @@ -48,15 +48,16 @@ parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoin parser.add_argument('--dataset_path', type=str, default="/var/log/npu/datasets/cifar", help='Dataset path') args_opt = parser.parse_args() -device_id=int(os.getenv('DEVICE_ID')) +device_id = int(os.getenv('DEVICE_ID')) -data_home=args_opt.dataset_path +data_home = args_opt.dataset_path context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) + def create_dataset(repeat_num=1, training=True): data_dir = data_home + "/cifar-10-batches-bin" if not training: @@ -64,8 +65,8 @@ def create_dataset(repeat_num=1, training=True): data_set = ds.Cifar10Dataset(data_dir) if args_opt.run_distribute: - rank_id=int(os.getenv('RANK_ID')) - rank_size=int(os.getenv('RANK_SIZE')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) resize_height = 224 @@ -74,9 +75,9 @@ def create_dataset(repeat_num=1, training=True): shift = 0.0 # define map operations - random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() - resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR rescale_op = vision.Rescale(rescale, shift) normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) changeswap_op = vision.HWC2CHW() @@ -103,6 +104,7 @@ def create_dataset(repeat_num=1, training=True): return data_set + class CrossEntropyLoss(nn.Cell): def __init__(self): super(CrossEntropyLoss, self).__init__() diff --git a/tests/st/tbe_networks/test_resnet_cifar_8p.py b/tests/st/tbe_networks/test_resnet_cifar_8p.py index 6e83f4180e..69f0a80d12 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_8p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_8p.py @@ -112,6 +112,7 @@ class CrossEntropyLoss(nn.Cell): loss = self.mean(loss, (-1,)) return loss + class LossGet(Callback): def __init__(self, per_print_times=1): super(LossGet, self).__init__() @@ -143,6 +144,7 @@ class LossGet(Callback): def get_loss(self): return self._loss + def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, enable_hccl): os.system("mkdir " + str(device_id)) os.chdir(str(device_id)) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index ae9c46e62c..2224565c30 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -65,6 +65,8 @@ SET(DE_UT_SRCS cifar_op_test.cc celeba_op_test.cc take_op_test.cc + text_file_op_test.cc) + filter_op_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/filter_op_test.cc b/tests/ut/cpp/dataset/filter_op_test.cc new file mode 100644 index 0000000000..45ee714337 --- /dev/null +++ b/tests/ut/cpp/dataset/filter_op_test.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/circular_pool.h" +#include "dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +namespace de = mindspore::dataset; + +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestfilter_op : public UT::DatasetOpTesting { + +}; + + +std::shared_ptr Filter() { + Status rc; + std::shared_ptr op; + rc = de::FilterOp::Builder().Build(&op); + EXPECT_TRUE(rc.IsOk()); + return op; +} + +TEST_F(MindDataTestfilter_op, Testfilter_opFuntions) { + MS_LOG(INFO) << "Doing MindDataTest filter_op."; + auto my_tree = std::make_shared(); + + std::shared_ptr parent_op = Filter(); + + std::shared_ptr leaf_op = Filter(); + my_tree->AssociateNode(parent_op); + my_tree->AssociateNode(leaf_op); + ASSERT_NE(parent_op, nullptr); + ASSERT_NE(leaf_op, nullptr); +} diff --git a/tests/ut/cpp/dataset/take_op_test.cc b/tests/ut/cpp/dataset/take_op_test.cc index 7f8508de20..b7be066d6c 100644 --- a/tests/ut/cpp/dataset/take_op_test.cc +++ b/tests/ut/cpp/dataset/take_op_test.cc @@ -69,7 +69,7 @@ TEST_F(MindDataTestTakeOp, TestTakeProject) { rc = my_tree->AssignRoot(my_take_op); ASSERT_TRUE(rc.IsOk()); - MS_LOG(INFO) << "Launching tree and begin iteration."; + MS_LOG(DEBUG) << "Launching tree and begin iteration."; rc = my_tree->Prepare(); ASSERT_TRUE(rc.IsOk()); @@ -85,13 +85,13 @@ TEST_F(MindDataTestTakeOp, TestTakeProject) { int row_count = 0; while (!tensor_list.empty()) { - MS_LOG(INFO) << "Row display for row #: " << row_count << "."; + MS_LOG(DEBUG) << "Row display for row #: " << row_count << "."; // Display the tensor by calling the printer on it for (int i = 0; i < tensor_list.size(); i++) { std::ostringstream ss; ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; - MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + MS_LOG(DEBUG) << "Tensor print: " << ss.str() << "."; } rc = di.FetchNextTensorRow(&tensor_list); diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index 7437b3d942..494d4b2329 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { ASSERT_EQ(*t == *t6, true); } +// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values +TEST_F(MindDataTestTensorDE, BoolTensor) { + std::shared_ptr t = std::make_shared(TensorShape({2}), + DataType(DataType::DE_BOOL)); + t->SetItemAt({0}, true); + t->SetItemAt({1}, true); + std::string out = t->ToString(); + ASSERT_TRUE(out.find("Template type and Tensor type are not compatible") == std::string::npos); +} + TEST_F(MindDataTestTensorDE, GetItemAt) { std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); t->Fill(254); diff --git a/tests/ut/cpp/dataset/text_file_op_test.cc b/tests/ut/cpp/dataset/text_file_op_test.cc new file mode 100644 index 0000000000..7887eda955 --- /dev/null +++ b/tests/ut/cpp/dataset/text_file_op_test.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "dataset/core/client.h" +#include "common/common.h" +#include "common/utils.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/util/status.h" + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestTextFileOp : public UT::DatasetOpTesting { + +}; + +TEST_F(MindDataTestTextFileOp, TestTextFileBasic) { + // Start with an empty execution tree + auto tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testTextFileDataset/1.txt"; + + std::shared_ptr op; + TextFileOp::Builder builder; + builder.SetTextFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16) + .SetOpConnectorSize(2); + + Status rc = builder.Build(&op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(op); + ASSERT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration."; + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + } + + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + row_count++; + } + + ASSERT_EQ(row_count, 3); +} + +TEST_F(MindDataTestTextFileOp, TestTotalRows) { + std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; + std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; + std::vector files; + files.push_back(tf_file1); + int64_t total_rows = 0; + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 3); + files.clear(); + + files.push_back(tf_file2); + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 2); + files.clear(); + + files.push_back(tf_file1); + files.push_back(tf_file2); + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 5); + files.clear(); +} diff --git a/tests/ut/cpp/dataset/tfReader_op_test.cc b/tests/ut/cpp/dataset/tfReader_op_test.cc index 5fb1f4e909..9b312296d8 100644 --- a/tests/ut/cpp/dataset/tfReader_op_test.cc +++ b/tests/ut/cpp/dataset/tfReader_op_test.cc @@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) { TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true); ASSERT_EQ(total_rows, 60); } + +TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) { + // Start with an empty execution tree + auto my_tree = std::make_shared(); + + std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt"; + std::string nonexistent_file = "this/file/doesnt/exist"; + + std::shared_ptr my_tfreader_op; + TFReaderOp::Builder builder; + builder.SetDatasetFilesList({invalid_file, valid_file, schema_file}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16); + + std::unique_ptr schema = std::make_unique(); + schema->LoadSchemaFile(schema_file, {}); + builder.SetDataSchema(std::move(schema)); + + Status rc = builder.Build(&my_tfreader_op); + ASSERT_TRUE(!rc.IsOk()); + + builder.SetDatasetFilesList({invalid_file, valid_file, schema_file, nonexistent_file}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16); + + schema = std::make_unique(); + schema->LoadSchemaFile(schema_file, {}); + builder.SetDataSchema(std::move(schema)); + + rc = builder.Build(&my_tfreader_op); + ASSERT_TRUE(!rc.IsOk()); +} diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 549e2140f4..bfd49069b2 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -25,6 +25,7 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "mindrecord/include/shard_category.h" +#include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_reader.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" @@ -146,6 +147,57 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { ASSERT_TRUE(i <= 10); } +TEST_F(TestShardOperator, TestShardPkSamplerBasic) { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); + + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name", "label"}; + + std::vector> ops; + ops.push_back(std::make_shared("label", 2)); + + ShardReader dataset; + dataset.Open(file_name, 4, column_list, ops); + dataset.Launch(); + + int i = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + i++; + } + dataset.Finish(); + ASSERT_TRUE(i == 20); +} // namespace mindrecord + +TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); + + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name", "label"}; + + std::vector> ops; + ops.push_back(std::make_shared("label", 2, 3, 0)); + + ShardReader dataset; + dataset.Open(file_name, 4, column_list, ops); + dataset.Launch(); + + int i = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + + std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + i++; + } + dataset.Finish(); + ASSERT_TRUE(i == 6); +} // namespace mindrecord + TEST_F(TestShardOperator, TestShardCategory) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index d9dd9e5e99..2c4b9b7146 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -128,8 +128,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { trace::ClearTraceStack(); engine_->Run(tupleSliceGraphPtr, args_spec_list); FAIL() << "Excepted exception :Args type is wrong"; - } catch (std::runtime_error const &err) { - ASSERT_TRUE(std::string(err.what()).find("TypeError") != std::string::npos); + } catch (pybind11::type_error const &err) { + ASSERT_TRUE(true); } catch (...) { FAIL() << "Excepted exception :Args type is wrong"; } diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py index c120ac3e68..76d7e73a80 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py @@ -45,13 +45,10 @@ def test_addn_fission(tag): b = addn((input2, input3)) c = addn((input4, input5)) d = addn((input6, input7)) - e = addn((input8,)) f = addn((a, b)) g = addn((c, d)) - h = addn((e,)) i = addn((f, g)) - j = addn((h,)) - return addn((i, j)) + return addn((i, input8)) @fns def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): @@ -64,14 +61,12 @@ def test_addn_fission(tag): def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3)) b = addn((input4, input5, input6, input7)) - c = addn((input8,)) - return addn((a, b, c)) + return addn((a, b, input8)) @fns def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3, input4, input5, input6, input7)) - b = addn((input8,)) - return addn((a, b)) + return addn((a, input8)) @fns def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py index e5b0a15387..8ce64109c6 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py @@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag): @fns def before(a, b): res = mul(scalar, a) - res = addn((b, res)) + res = addn((res, b)) return res @fns diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index e0b5ab0d61..9c4fe2539d 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -24,9 +24,7 @@ void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } -std::vector AscendStreamAssign::GetWaitStreams() { return vector(); } - -std::vector AscendStreamAssign::GetHcomStreams() { return vector(); } +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { return; } namespace tasksink { bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *const task_info_list, diff --git a/tests/ut/cpp/transform/convert_test.cc b/tests/ut/cpp/transform/convert_test.cc index 4388312592..277aaa15c3 100644 --- a/tests/ut/cpp/transform/convert_test.cc +++ b/tests/ut/cpp/transform/convert_test.cc @@ -189,7 +189,8 @@ TEST_F(TestConvert, TestConvertBatchNorm) { TEST_F(TestConvert, TestConvertConvBackpropInput) { auto prim = prim::kPrimConv2DBackpropInput; - prim->AddAttr("stride", MakeValue(1)); + const std::vector list{1,1}; + prim->AddAttr("stride", MakeValue(list)); prim->AddAttr("pad", MakeValue(0)); prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); prim->AddAttr("dilation", MakeValue(1)); @@ -218,7 +219,8 @@ TEST_F(TestConvert, TestConvertConvBackpropInput) { TEST_F(TestConvert, TestConvertConvBackpropFilter) { auto prim = prim::kPrimConv2DBackpropFilter; - prim->AddAttr("stride", MakeValue(1)); + const std::vector list{1,1}; + prim->AddAttr("stride", MakeValue(list)); prim->AddAttr("pad", MakeValue(0)); prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); prim->AddAttr("dilation", MakeValue(1)); diff --git a/tests/ut/data/dataset/declient_filter.cfg b/tests/ut/data/dataset/declient_filter.cfg new file mode 100644 index 0000000000..89e1199f5a --- /dev/null +++ b/tests/ut/data/dataset/declient_filter.cfg @@ -0,0 +1,3 @@ +{ + "rowsPerBuffer": 10, +} diff --git a/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data index c5b5440cff..f3bb23af51 100644 Binary files a/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data and b/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data differ diff --git a/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data index c5b5440cff..f3bb23af51 100644 Binary files a/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data and b/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data differ diff --git a/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data index c5b5440cff..f3bb23af51 100644 Binary files a/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data and b/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data differ diff --git a/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json b/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json new file mode 100644 index 0000000000..92abf66ef8 --- /dev/null +++ b/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json @@ -0,0 +1,45 @@ +{ + "datasetType": "TF", + "columns": { + "col_sint16": { + "type": "int16", + "rank": 1, + "shape": [1] + }, + "col_sint32": { + "type": "int32", + "rank": 1, + "shape": [1] + }, + "col_sint64": { + "type": "int64", + "rank": 1, + "shape": [1] + }, + "col_float": { + "type": "float32", + "rank": 1, + "shape": [1] + }, + "col_1d": { + "type": "int64", + "rank": 1, + "shape": [2] + }, + "col_2d": { + "type": "int64", + "rank": 2, + "shape": [2, 2] + }, + "col_3d": { + "type": "int64", + "rank": 3, + "shape": [2, 2, 2] + }, + "col_binary": { + "type": "uint8", + "rank": 1, + "shape": [1] + } + } +} diff --git a/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt b/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt new file mode 100644 index 0000000000..3307b71672 --- /dev/null +++ b/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt @@ -0,0 +1 @@ +this is just a text file, not a valid tfrecord file. diff --git a/tests/ut/data/dataset/testTextFileDataset/1.txt b/tests/ut/data/dataset/testTextFileDataset/1.txt new file mode 100644 index 0000000000..9d911eacc0 --- /dev/null +++ b/tests/ut/data/dataset/testTextFileDataset/1.txt @@ -0,0 +1,3 @@ +This is a text file. +Be happy every day. +Good luck to everyone. diff --git a/tests/ut/data/dataset/testTextFileDataset/2.txt b/tests/ut/data/dataset/testTextFileDataset/2.txt new file mode 100644 index 0000000000..7382722eb8 --- /dev/null +++ b/tests/ut/data/dataset/testTextFileDataset/2.txt @@ -0,0 +1,2 @@ +Another file. +End of file. diff --git a/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json b/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json new file mode 100644 index 0000000000..e00fd39c10 --- /dev/null +++ b/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json @@ -0,0 +1,15 @@ +{ + "datasetType": "TF", + "columns": { + "image": { + "type": "uint8", + "rank": 1, + "t_impl": "cvmat" + }, + "label" : { + "type": "uint64", + "rank": 1, + "t_impl": "flex" + } + } +} diff --git a/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt b/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt new file mode 100644 index 0000000000..fbfbba025f --- /dev/null +++ b/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt @@ -0,0 +1,10 @@ +image_00001.jpg,164 +image_00002.jpg,164 +image_00003.jpg,164 +image_00004.jpg,599 +image_00005.jpg,599 +image_00006.jpg,599 +image_00007.jpg,13 +image_00008.jpg,13 +image_00009.jpg,13 +image_00010.jpg,13 diff --git a/tests/ut/python/communication/test_management_api.py b/tests/ut/python/communication/test_management_api.py index c455c5491b..d624c5ab59 100644 --- a/tests/ut/python/communication/test_management_api.py +++ b/tests/ut/python/communication/test_management_api.py @@ -99,7 +99,7 @@ def test_raise_error_funcs(): assert has_raise_error(create_backend, 'nccl') is False assert has_raise_error(get_group_size_int, 123) is True assert has_raise_error(create_group0, (0,1)) is True - assert has_raise_error(create_group1, [0]) is True + assert has_raise_error(create_group1, [0]) is False assert has_raise_error(create_group2, [0,0,1]) is True assert has_raise_error(create_group3, [0,1]) is True assert has_raise_error(create_group4, [0,1]) is False diff --git a/tests/ut/python/dataset/test_autocontrast.py b/tests/ut/python/dataset/test_autocontrast.py new file mode 100644 index 0000000000..7dba2f21f6 --- /dev/null +++ b/tests/ut/python/dataset/test_autocontrast.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_auto_contrast): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_auto_contrast) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_auto_contrast[i]) + plt.title("DE AutoContrast image") + + plt.show() + + +def test_auto_contrast(plot=False): + """ + Test AutoContrast + """ + logger.info("Test AutoContrast") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # AutoContrast Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_auto_contrast = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.AutoContrast(), + F.ToTensor()]) + + ds_auto_contrast = ds.map(input_columns="image", + operations=transforms_auto_contrast()) + + ds_auto_contrast = ds_auto_contrast.batch(512) + + for idx, (image,label) in enumerate(ds_auto_contrast): + if idx == 0: + images_auto_contrast = np.transpose(image, (0, 2,3,1)) + else: + images_auto_contrast = np.append(images_auto_contrast, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_auto_contrast[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_auto_contrast) + + +if __name__ == "__main__": + test_auto_contrast(plot=True) + diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index 8cabe81aaa..0c1e0073af 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -12,8 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +""" +Testing configuration manager +""" +import filecmp +import glob +import os + import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" def test_basic(): ds.config.load('../data/dataset/declient.cfg') @@ -36,6 +46,34 @@ def test_basic(): assert ds.config.get_prefetch_size() == 4 assert ds.config.get_seed() == 5 +def test_pipeline(): + """ + Test that our configuration pipeline works when we set parameters at dataset interval + """ + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(2) + data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data1, "testpipeline.json") + + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(4) + data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data2, "testpipeline2.json") + + # check that the generated output is different + assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) + + # this test passes currently because our num_parallel_workers don't get updated. + + # remove generated jason files + file_list = glob.glob('*.json') + for f in file_list: + try: + os.remove(f) + except IOError: + logger.info("Error while deleting: {}".format(f)) + if __name__ == '__main__': test_basic() + test_pipeline() diff --git a/tests/ut/python/dataset/test_datasets_textfileop.py b/tests/ut/python/dataset/test_datasets_textfileop.py new file mode 100644 index 0000000000..720fcdcce0 --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_textfileop.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.transforms.nlp.utils as nlp + +DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" +DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" + +def test_textline_dataset_one_file(): + data = ds.TextFileDataset(DATA_FILE) + count = 0 + for i in data.create_dict_iterator(): + logger.info("{}".format(i["text"])) + count += 1 + assert(count == 3) + +def test_textline_dataset_all_file(): + data = ds.TextFileDataset(DATA_ALL_FILE) + count = 0 + for i in data.create_dict_iterator(): + logger.info("{}".format(i["text"])) + count += 1 + assert(count == 5) + +def test_textline_dataset_totext(): + data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) + count = 0 + line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] + for i in data.create_dict_iterator(): + str = nlp.as_text(i["text"]) + assert(str == line[count]) + count += 1 + assert(count == 5) + +def test_textline_dataset_num_samples(): + data = ds.TextFileDataset(DATA_FILE, num_samples=2) + count = 0 + for i in data.create_dict_iterator(): + count += 1 + assert(count == 2) + +def test_textline_dataset_distribution(): + data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1) + count = 0 + for i in data.create_dict_iterator(): + count += 1 + assert(count == 3) + +def test_textline_dataset_repeat(): + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.repeat(3) + count = 0 + line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", + "This is a text file.", "Be happy every day.", "Good luck to everyone.", + "This is a text file.", "Be happy every day.", "Good luck to everyone."] + for i in data.create_dict_iterator(): + str = nlp.as_text(i["text"]) + assert(str == line[count]) + count += 1 + assert(count == 9) + +def test_textline_dataset_get_datasetsize(): + data = ds.TextFileDataset(DATA_FILE) + size = data.get_dataset_size() + assert(size == 3) + +if __name__ == "__main__": + test_textline_dataset_one_file() + test_textline_dataset_all_file() + test_textline_dataset_totext() + test_textline_dataset_num_samples() + test_textline_dataset_distribution() + test_textline_dataset_repeat() + test_textline_dataset_get_datasetsize() diff --git a/tests/ut/python/dataset/test_equalize.py b/tests/ut/python/dataset/test_equalize.py new file mode 100644 index 0000000000..077c316d67 --- /dev/null +++ b/tests/ut/python/dataset/test_equalize.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_equalize): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_equalize) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_equalize[i]) + plt.title("DE Color Equalized image") + + plt.show() + + +def test_equalize(plot=False): + """ + Test Equalize + """ + logger.info("Test Equalize") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Color Equalized Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_equalize = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.Equalize(), + F.ToTensor()]) + + ds_equalize = ds.map(input_columns="image", + operations=transforms_equalize()) + + ds_equalize = ds_equalize.batch(512) + + for idx, (image,label) in enumerate(ds_equalize): + if idx == 0: + images_equalize = np.transpose(image, (0, 2,3,1)) + else: + images_equalize = np.append(images_equalize, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_equalize[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_equalize) + + +if __name__ == "__main__": + test_equalize(plot=True) + diff --git a/tests/ut/python/dataset/test_filterop.py b/tests/ut/python/dataset/test_filterop.py new file mode 100644 index 0000000000..90f512caa4 --- /dev/null +++ b/tests/ut/python/dataset/test_filterop.py @@ -0,0 +1,504 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as cde +import mindspore.dataset.transforms.c_transforms as C +import mindspore.common.dtype as mstype +from mindspore import log as logger + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" +# test for predicate +def test_diff_predicate_func(): + def test_filter(predicate_func): + transforms = [ + cde.Decode(), + cde.Resize([64, 64]) + ] + type_cast_op = C.TypeCast(mstype.int32) + dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False) + dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1) + dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4) + + num_iter = 0 + label_list = [] + for data in dataset.create_dict_iterator(): + num_iter += 1 + ori_img = data["image"] + label = data["label"] + label_list.append(label) + assert num_iter == 1 + assert label_list[0] == 3 + + test_filter(lambda image, label: label == 3) + test_filter(lambda image, label: label[0] == 3) + test_filter(lambda image, label: label == [3]) + test_filter(lambda image, label: label == np.array([3])) + test_filter(lambda image, label: label == np.array(3)) + +def filter_func_ge(data): + if data > 10: + return False + return True + + +def generator_1d(): + for i in range(64): + yield (np.array(i),) + +# test with GeneratorDataset +def test_filter_by_generator_with_no(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) + num_iter = 0 + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_f.create_dict_iterator(): + assert item["data"] == expected_rs[num_iter] + num_iter += 1 + +# test with repeatOp before +def test_filter_by_generator_with_repeat(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_r = dataset.repeat(4) + dataset_f = dataset_r.filter(predicate=filter_func_ge, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 44 + for i in range(4): + for ii in range(len(expected_rs)): + index = i * len(expected_rs) + ii + assert ret_data[index] == expected_rs[ii] + +# test with repeatOp after +def test_filter_by_generator_with_repeat_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_ge, num_parallel_workers=4) + dataset_r = dataset_f.repeat(4) + num_iter = 0 + ret_data = [] + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_r.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 44 + for i in range(4): + for ii in range(len(expected_rs)): + index = i * len(expected_rs) + ii + assert ret_data[index] == expected_rs[ii] + +def filter_func_batch(data): + if data[0] > 8: + return False + return True + +def filter_func_batch_after(data): + if data > 20: + return False + return True + +# test with batchOp before +def test_filter_by_generator_with_batch(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_b = dataset.batch(4) + dataset_f = dataset_b.filter(predicate=filter_func_batch, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 3 + assert ret_data[0][0] == 0 + assert ret_data[1][0] == 4 + assert ret_data[2][0] == 8 + +# test with batchOp after +def test_filter_by_generator_with_batch_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_batch_after, num_parallel_workers=4) + dataset_b = dataset_f.batch(4) + num_iter = 0 + ret_data = [] + for item in dataset_b.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 6 + assert ret_data[0][0] == 0 + assert ret_data[1][0] == 4 + assert ret_data[5][0] == 20 + + +def filter_func_shuffle(data): + if data > 20: + return False + return True + +# test with batchOp before +def test_filter_by_generator_with_shuffle(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_s = dataset.shuffle(4) + dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4) + num_iter = 0 + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + assert num_iter == 21 + + +def filter_func_shuffle_after(data): + if data > 20: + return False + return True + +# test with batchOp after +def test_filter_by_generator_with_shuffle_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4) + dataset_s = dataset_f.shuffle(4) + num_iter = 0 + for item in dataset_s.create_dict_iterator(): + num_iter += 1 + assert num_iter == 21 + + +def generator_1d_zip1(): + for i in range(64): + yield (np.array(i),) + + +def generator_1d_zip2(): + for i in range(64): + yield (np.array(i+100),) + + +def filter_func_zip(data1, data2): + if data1 > 20: + return False + return True + +def filter_func_zip_after(data1): + if data1 > 20: + return False + return True + +# test with zipOp before +def test_filter_by_generator_with_zip(): + dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) + dataset2 = ds.GeneratorDataset(generator_1d_zip2, ["data2"]) + dataz = ds.zip((dataset1, dataset2)) + dataset_f = dataz.filter(predicate=filter_func_zip, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append({"data1": item["data1"], "data2":item["data2"]}) + assert num_iter == 21 + assert ret_data[0]["data1"] == 0 + assert ret_data[0]["data2"] == 100 + assert ret_data[5]["data1"] == 5 + assert ret_data[5]["data2"] == 105 + + +# test with zipOp after +def test_filter_by_generator_with_zip_after(): + dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) + dataset2 = ds.GeneratorDataset(generator_1d_zip1, ["data2"]) + dt1 = dataset1.filter(predicate=filter_func_zip_after, num_parallel_workers=4) + dt2 = dataset2.filter(predicate=filter_func_zip_after, num_parallel_workers=4) + dataz = ds.zip((dt1, dt2)) + num_iter = 0 + ret_data = [] + for item in dataz.create_dict_iterator(): + num_iter += 1 + ret_data.append({"data1": item["data1"], "data2":item["data2"]}) + assert num_iter == 21 + assert ret_data[0]["data1"] == 0 + assert ret_data[0]["data2"] == 0 + assert ret_data[5]["data1"] == 5 + assert ret_data[5]["data2"] == 5 + + +def filter_func_map(col1, col2): + if col1[0] > 8: + return True + return False + + +def filter_func_map_part(col1): + if col1 < 3: + return True + else: + return False + + +def filter_func_map_all(col1, col2): + return True + +def generator_mc(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) + + +def func_map(data_col1, data_col2): + return (data_col1, data_col2) + + +def func_map_part(data_col1): + return (data_col1) + +# test with map +def test_filter_by_generator_with_map_all_col(): + dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["col1"] , operations=func_map_part) + # dataset_map = dataset.map( operations=func_map_part) + dataset_f = dataset_map.filter(input_columns=["col1"], predicate=filter_func_map_part, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["col1"]) + assert num_iter == 3 + assert ret_data[0] == 0 + assert ret_data[1] == 1 + +# test with map +def test_filter_by_generator_with_map_part_col(): + dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) + + dataset_f = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_map, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + print(item) + ret_data.append(item["out1"]) + assert num_iter == 3 + assert ret_data[0] == 9 + assert ret_data[2] == 11 + + +def filter_func_rename(data): + if data> 8: + return True + return False + +# test with rename before +def test_filter_by_generator_with_rename(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_b = dataset.rename(input_columns=["data"], output_columns=["col1"]) + dataset_f = dataset_b.filter(predicate=filter_func_rename, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["col1"]) + assert num_iter == 55 + assert ret_data[0] == 9 + assert ret_data[54] == 63 + + +#test input_column +def filter_func_input_column1(col1, col2): + if col1[0] < 8: + return True + return False + +def filter_func_input_column2(col1): + if col1[0] < 8: + return True + return False + +def filter_func_input_column3(col1): + return True + +# test with input_columns +def test_filter_by_generator_with_input_column(): + dataset = ds.GeneratorDataset(generator_mc(64), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) + dataset_f1 = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_input_column1, num_parallel_workers=4) + dataset_f2 = dataset_f1.filter(input_columns=["out1"], predicate=filter_func_input_column2, num_parallel_workers=4) + dataset_f3 = dataset_f2.filter(input_columns=["col2"], predicate=filter_func_input_column3, num_parallel_workers=4) + dataset_f4 = dataset_f3.filter(predicate=filter_func_input_column1, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f4.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["out1"]) + assert num_iter == 8 + assert ret_data[0] == 0 + assert ret_data[7] == 7 + + +#test kFilterPartial +def generator_mc_p0(maxid=20): + for i in range(maxid): + yield (np.array([i ]), np.array([i + 100])) + +def generator_mc_p1(maxid=20): + for i in range(maxid): + yield (np.array([i + 200 ]), np.array([i + 300])) + + +def filter_func_Partial_0(col1, col2, col3, col4): + filter_data = [0,1,2,3,4, 11] + if col1[0] in filter_data: + return False + return True + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial0(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + dataset_zip = ds.zip((dataset1, dataset2)) + dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) + ret = [] + for item in dataset_f1.create_dict_iterator(): + ret.append(item["col1"]) + assert ret[0] == 5 + assert ret[6] == 12 + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial1(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + dataset_zip = ds.zip((dataset1, dataset2)) + dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) + dataset_map = dataset_f1.map( input_columns=["col1"], output_columns=["out1"] , operations=lambda x1: x1 + 400) + ret = [] + for item in dataset_map.create_dict_iterator(): + ret.append(item["out1"]) + assert ret[0] == 405 + assert ret[6] == 412 + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial2(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + + dataset1f = dataset1.filter( input_columns= ["col1"], predicate=lambda x: x not in [3,7,9], num_parallel_workers=2) + dataset2f = dataset2.filter( input_columns= ["col3"], predicate=lambda x: x not in [203,207,209], num_parallel_workers=2) + dataset_zip = ds.zip((dataset1f, dataset2f)) + dataset_map = dataset_zip.map( input_columns=["col1", "col3"], output_columns=["out1", "out3"] , operations=lambda x1,x3: (x1 + 400, x3+500)) + ret1 = [] + ret3 = [] + for item in dataset_map.create_dict_iterator(): + ret1.append(item["out1"]) + ret3.append(item["out3"]) + assert ret1[0] == 400 + assert ret1[6] == 408 + assert ret3[0] == 700 + assert ret3[6] == 708 + + +def filter_func_Partial(col1, col2): + if col1[0] % 3 == 0: + return True + return False + +def generator_big(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset = ds.GeneratorDataset(source = generator_mc(99), column_names = ["col1", "col2"]) + dataset_s = dataset.shuffle(4) + dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) + + for item in dataset_f1.create_dict_iterator(): + assert item["col1"] % 3 == 0 + + +def filter_func_cifar(col1, col2): + if col2 % 3 == 0: + return True + return False + +# test with cifar10 +def test_filte_case_dataset_cifar10(): + DATA_DIR_10 = "../data/dataset/testCifar10Data" + ds.config.load('../data/dataset/declient_filter.cfg') + dataset_c = ds.Cifar10Dataset(dataset_dir = DATA_DIR_10, num_samples = 100000, shuffle=False) + dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1) + num_iter = 0 + for item in dataset_f1.create_dict_iterator(): + # in this example, each dictionary has keys "image" and "label" + assert item["label"] % 3 == 0 + +# column id sort + +def generator_sort1(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([i + 100]), np.array([i + 200])) + +def generator_sort2(maxid=20): + for i in range(maxid): + yield (np.array([i + 300]), np.array([i + 400]), np.array([i + 500])) + + +def filter_func_part_sort(col1, col2, col3, col4, col5, col6): + return True + +def filter_func_map_sort(col1, col2, col3): + return (col1, col2, col3) + +def test_filter_by_generator_with_map_all_sort(): + dataset1 = ds.GeneratorDataset(generator_sort1(10), ["col1", "col2", "col3"]) + dataset2 = ds.GeneratorDataset(generator_sort2(10), ["col4 ", "col5", "col6"]) + + dataz = ds.zip((dataset1, dataset2)) + dataset_f = dataz.filter(predicate=filter_func_part_sort, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item) + + assert num_iter == 10 + assert ret_data[0]["col1"] == 0 + assert ret_data[9]["col6"] == 509 + + + +if __name__ == '__main__': + test_diff_predicate_func() + test_filte_case_dataset_cifar10() + test_filter_by_generator_Partial0() + test_filter_by_generator_Partial1() + test_filter_by_generator_Partial2() + test_filter_by_generator_with_batch() + test_filter_by_generator_with_batch_after() + test_filter_by_generator_with_input_column() + test_filter_by_generator_with_map_all_col() + test_filter_by_generator_with_map_all_sort() + test_filter_by_generator_with_map_part_col() + test_filter_by_generator_with_no() + test_filter_by_generator_with_rename() + test_filter_by_generator_with_repeat() + test_filter_by_generator_with_repeat_after() + test_filter_by_generator_with_shuffle() + test_filter_by_generator_with_shuffle_after() + test_filter_by_generator_with_zip() + test_filter_by_generator_with_zip_after() + test_filter_by_generator_Partial() diff --git a/tests/ut/python/dataset/test_generator.py b/tests/ut/python/dataset/test_generator.py index c224c5a2ea..4daf952eba 100644 --- a/tests/ut/python/dataset/test_generator.py +++ b/tests/ut/python/dataset/test_generator.py @@ -391,6 +391,80 @@ def test_case_13(): i = i + 1 +def test_case_14(): + """ + Test 1D Generator MP + CPP sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_case_15(): + """ + Test 1D Generator MP + Python sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + sampler = [x for x in range(256)] + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_case_16(): + """ + Test multi column generator Mp + CPP sampler + """ + logger.info("Test multi column generator") + + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler()) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(item["col0"], golden) + golden = np.array([i + 1]) + assert np.array_equal(item["col1"], golden) + i = i + 1 + + +def test_case_17(): + """ + Test multi column generator Mp + Python sampler + """ + logger.info("Test multi column generator") + + sampler = [x for x in range(256)] + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(item["col0"], golden) + golden = np.array([i + 1]) + assert np.array_equal(item["col1"], golden) + i = i + 1 + + def test_case_error_1(): def generator_np(): for i in range(64): @@ -506,6 +580,25 @@ def test_num_samples_underflow(): count = count + 1 assert count == 64 +def manual_test_keyborad_interrupt(): + """ + Test keyborad_interrupt + """ + logger.info("Test 1D Generator MP : 0 - 63") + + class MyDS(): + def __getitem__(self, item): + while True: + pass + + def __len__(self): + return 1024 + + ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + pass + if __name__ == "__main__": test_case_0() @@ -522,6 +615,10 @@ if __name__ == "__main__": test_case_11() test_case_12() test_case_13() + test_case_14() + test_case_15() + test_case_16() + test_case_17() test_case_error_1() test_case_error_2() test_case_error_3() @@ -529,3 +626,5 @@ if __name__ == "__main__": test_sequential_sampler() test_distributed_sampler() test_random_sampler() + + diff --git a/tests/ut/python/dataset/test_invert.py b/tests/ut/python/dataset/test_invert.py new file mode 100644 index 0000000000..a1bfd63431 --- /dev/null +++ b/tests/ut/python/dataset/test_invert.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +def visualize(image_original, image_invert): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_invert) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_invert[i]) + plt.title("DE Color Inverted image") + + plt.show() + + +def test_invert(plot=False): + """ + Test Invert + """ + logger.info("Test Invert") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Color Inverted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_invert = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.Invert(), + F.ToTensor()]) + + ds_invert = ds.map(input_columns="image", + operations=transforms_invert()) + + ds_invert = ds_invert.batch(512) + + for idx, (image,label) in enumerate(ds_invert): + if idx == 0: + images_invert = np.transpose(image, (0, 2,3,1)) + else: + images_invert = np.append(images_invert, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_invert[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_invert) + + +if __name__ == "__main__": + test_invert(plot=True) + diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index 102fd0eea1..7c69adf561 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", def check(project_columns): - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) - data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns) + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False) + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False) for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns), data2.create_tuple_iterator()): assert len(data_actual) == len(data_expected) diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 3cad3877ef..584bb88041 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -46,7 +46,7 @@ def add_and_remove_cv_file(): if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) writer = FileWriter(CV_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME) + data = get_data(CV_DIR_NAME, True) cv_schema_json = {"id": {"type": "int32"}, "file_name": {"type": "string"}, "label": {"type": "int32"}, @@ -61,6 +61,59 @@ def add_and_remove_cv_file(): os.remove("{}.db".format(x)) +def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(2) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + +def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(3, None, True) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 9 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + +def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(5, None, True) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 15 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] @@ -69,8 +122,7 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -93,8 +145,7 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 6 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -117,8 +168,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 0 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -133,7 +183,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): assert num_iter == 0 -def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): +def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] num_readers = 4 @@ -141,8 +191,7 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -165,8 +214,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -181,7 +229,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): assert num_iter == 5 -def get_data(dir_name): +def get_data(dir_name, sampler=False): """ usage: get data from imagenet dataset params: @@ -191,7 +239,10 @@ def get_data(dir_name): if not os.path.isdir(dir_name): raise IOError("Directory {} not exists".format(dir_name)) img_dir = os.path.join(dir_name, "images") - ann_file = os.path.join(dir_name, "annotation.txt") + if sampler: + ann_file = os.path.join(dir_name, "annotation_sampler.txt") + else: + ann_file = os.path.join(dir_name, "annotation.txt") with open(ann_file, "r") as file_reader: lines = file_reader.readlines() diff --git a/tests/ut/python/dataset/test_normalizeOp.py b/tests/ut/python/dataset/test_normalizeOp.py index 1abee96173..c080b00105 100644 --- a/tests/ut/python/dataset/test_normalizeOp.py +++ b/tests/ut/python/dataset/test_normalizeOp.py @@ -15,7 +15,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision import numpy as np - +import matplotlib.pyplot as plt import mindspore.dataset as ds from mindspore import log as logger @@ -114,6 +114,7 @@ def test_decode_op(): # plt.subplot(131) # plt.imshow(image) # plt.title("DE image") + # plt.show() num_iter += 1 @@ -138,8 +139,8 @@ def test_decode_normalize_op(): # plt.subplot(131) # plt.imshow(image) # plt.title("DE image") + # plt.show() num_iter += 1 - break if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_pyfunc.py b/tests/ut/python/dataset/test_pyfunc.py index 4b0672a1f2..e7bdc48639 100644 --- a/tests/ut/python/dataset/test_pyfunc.py +++ b/tests/ut/python/dataset/test_pyfunc.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -181,6 +182,106 @@ def test_case_6(): i = i + 4 +def test_case_7(): + """ + Test PyFunc + """ + logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x") + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x), + num_parallel_workers=4, python_multiprocessing = True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) + assert np.array_equal(item["out"], golden) + i = i + 4 + + +def test_case_8(): + """ + Test PyFunc + """ + logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)") + + col = ["col0", "col1"] + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4, + operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"], + python_multiprocessing=True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + assert np.array_equal(item["out0"], golden) + golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) + assert np.array_equal(item["out1"], golden) + golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]]) + assert np.array_equal(item["out2"], golden) + i = i + 4 + + +def test_case_9(): + """ + Test PyFunc + """ + logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x") + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1), + (lambda x: x + 2)], + num_parallel_workers=4, python_multiprocessing=True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]]) + assert np.array_equal(item["out"], golden) + i = i + 4 + + +def test_pyfunc_execption(): + logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()") + + def pyfunc(x): + raise Exception("Pyfunc Throw") + + with pytest.raises(RuntimeError) as info: + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc, + num_parallel_workers=4) + for _ in data1: + pass + assert "Pyfunc Throw" in str(info.value) + + +def test_pyfunc_execption_multiprocess(): + logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()") + + def pyfunc(x): + raise Exception("MP Pyfunc Throw") + + with pytest.raises(RuntimeError) as info: + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc, + num_parallel_workers=4, python_multiprocessing = True) + for _ in data1: + pass + assert "MP Pyfunc Throw" in str(info.value) + + if __name__ == "__main__": test_case_0() test_case_1() @@ -189,3 +290,8 @@ if __name__ == "__main__": test_case_4() test_case_5() test_case_6() + test_case_7() + test_case_8() + test_case_9() + test_pyfunc_execption() + test_pyfunc_execption_multiprocess() diff --git a/tests/ut/python/dataset/test_random_color.py b/tests/ut/python/dataset/test_random_color.py new file mode 100644 index 0000000000..9472b7e35a --- /dev/null +++ b/tests/ut/python/dataset/test_random_color.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_random_color): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_random_color) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_random_color[i]) + plt.title("DE Random Color image") + + plt.show() + + +def test_random_color(degrees=(0.1,1.9), plot=False): + """ + Test RandomColor + """ + logger.info("Test RandomColor") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Random Color Adjusted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_random_color = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.RandomColor(degrees=degrees), + F.ToTensor()]) + + ds_random_color = ds.map(input_columns="image", + operations=transforms_random_color()) + + ds_random_color = ds_random_color.batch(512) + + for idx, (image,label) in enumerate(ds_random_color): + if idx == 0: + images_random_color = np.transpose(image, (0, 2,3,1)) + else: + images_random_color = np.append(images_random_color, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_random_color[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_random_color) + + +if __name__ == "__main__": + test_random_color() + test_random_color(plot=True) + test_random_color(degrees=(0.5,1.5), plot=True) diff --git a/tests/ut/python/dataset/test_random_color_adjust.py b/tests/ut/python/dataset/test_random_color_adjust.py index 57c77caf81..dcb7cd48ac 100644 --- a/tests/ut/python/dataset/test_random_color_adjust.py +++ b/tests/ut/python/dataset/test_random_color_adjust.py @@ -182,8 +182,6 @@ def test_random_color_jitter_op_saturation(): ] transform = py_vision.ComposeOp(transforms) data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) - # data2 = data2.map(input_columns=["image"], operations=decode_op) - # data2 = data2.map(input_columns=["image"], operations=c_vision.Decode()) data2 = data2.map(input_columns=["image"], operations=transform()) num_iter = 0 @@ -220,8 +218,6 @@ def test_random_color_jitter_op_hue(): # First dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) decode_op = c_vision.Decode() - # channel_swap_op = c_vision.ChannelSwap() - # change_mode_op = c_vision.ChangeMode() random_jitter_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2)) diff --git a/tests/ut/python/dataset/test_random_sharpness.py b/tests/ut/python/dataset/test_random_sharpness.py new file mode 100644 index 0000000000..949a658597 --- /dev/null +++ b/tests/ut/python/dataset/test_random_sharpness.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_random_sharpness): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_random_sharpness) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_random_sharpness[i]) + plt.title("DE Random Sharpness image") + + plt.show() + + +def test_random_sharpness(degrees=(0.1,1.9), plot=False): + """ + Test RandomSharpness + """ + logger.info("Test RandomSharpness") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Random Sharpness Adjusted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_random_sharpness = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.RandomSharpness(degrees=degrees), + F.ToTensor()]) + + ds_random_sharpness = ds.map(input_columns="image", + operations=transforms_random_sharpness()) + + ds_random_sharpness = ds_random_sharpness.batch(512) + + for idx, (image,label) in enumerate(ds_random_sharpness): + if idx == 0: + images_random_sharpness = np.transpose(image, (0, 2,3,1)) + else: + images_random_sharpness = np.append(images_random_sharpness, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_random_sharpness[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_random_sharpness) + + +if __name__ == "__main__": + test_random_sharpness() + test_random_sharpness(plot=True) + test_random_sharpness(degrees=(0.5,1.5), plot=True) diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 7a58249f9c..4efca6f818 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -14,6 +14,7 @@ # ============================================================================== import mindspore.dataset as ds from mindspore import log as logger +import numpy as np # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] @@ -107,8 +108,64 @@ def test_sampler_py_api(): sampler.get_indices() +def test_python_sampler(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + + class Sp1(ds.Sampler): + def __iter__(self): + return iter([i for i in range(self.dataset_size)]) + + class Sp2(ds.Sampler): + def __init__(self): + super(Sp2, self).__init__() + # at this stage, self.dataset_size and self.num_samples are not yet known + self.cnt = 0 + + def __iter__(self): # first epoch, all 0, second epoch all 1, third all 2 etc.. ... + return iter([self.cnt for i in range(self.num_samples)]) + + def reset(self): + self.cnt = (self.cnt + 1) % self.dataset_size + + def test_config(num_samples, num_repeats, sampler): + data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) + if num_repeats is not None: + data1 = data1.repeat(num_repeats) + res = [] + for item in data1.create_dict_iterator(): + logger.info("item[image].shape[0]: {}, item[label].item(): {}" + .format(item["image"].shape[0], item["label"].item())) + res.append(map[(item["image"].shape[0], item["label"].item())]) + # print(res) + return res + + def test_generator(): + class MySampler(ds.Sampler): + def __iter__(self): + for i in range(99, -1, -1): + yield i + + data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler = MySampler()) + i = 99 + for data in data1: + assert data[0] == (np.array(i),) + i = i - 1 + + assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] + test_generator() + + sp1 = Sp1().create() + sp1.set_num_rows(5) + sp1.set_num_samples(5) + sp1.initialize() + assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) test_random_sampler_multi_iter(True) test_sampler_py_api() + test_python_sampler() \ No newline at end of file diff --git a/tests/ut/python/dataset/test_serdes_dataset.py b/tests/ut/python/dataset/test_serdes_dataset.py index 7fdb0f1dde..0a6f86974b 100644 --- a/tests/ut/python/dataset/test_serdes_dataset.py +++ b/tests/ut/python/dataset/test_serdes_dataset.py @@ -243,7 +243,7 @@ def test_minddataset(add_and_remove_cv_file): assert ds1_json == ds2_json data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): num_iter += 1 diff --git a/tests/ut/python/dataset/test_skip.py b/tests/ut/python/dataset/test_skip.py index bea7db4e05..59893f6ded 100644 --- a/tests/ut/python/dataset/test_skip.py +++ b/tests/ut/python/dataset/test_skip.py @@ -22,7 +22,11 @@ from mindspore import log as logger DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + def test_tf_skip(): + """ + a simple skip operation. + """ data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) resize_height, resize_width = 32, 32 @@ -37,11 +41,15 @@ def test_tf_skip(): num_iter += 1 assert num_iter == 1 + def generator_md(): - # Create a dataset with [0, 1, 2, 3, 4] + """ + create a dataset with [0, 1, 2, 3, 4] + """ for i in range(5): yield (np.array([i]), ) + def test_generator_skip(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -53,6 +61,7 @@ def test_generator_skip(): buf.append(data[0][0]) assert len(buf) == 2 + def test_skip_1(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -64,6 +73,7 @@ def test_skip_1(): buf.append(data[0][0]) assert len(buf) == 0 + def test_skip_2(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -75,6 +85,7 @@ def test_skip_2(): buf.append(data[0][0]) assert len(buf) == 5 + def test_skip_repeat_1(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -89,6 +100,7 @@ def test_skip_repeat_1(): buf.append(data[0][0]) assert len(buf) == 7 + def test_skip_repeat_2(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -103,6 +115,7 @@ def test_skip_repeat_2(): buf.append(data[0][0]) assert len(buf) == 4 + def test_skip_repeat_3(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -120,6 +133,7 @@ def test_skip_repeat_3(): buf.append(data[0][0]) assert len(buf) == 6 + if __name__ == "__main__": test_tf_skip() test_generator_skip() @@ -127,4 +141,4 @@ if __name__ == "__main__": test_skip_2() test_skip_repeat_1() test_skip_repeat_2() - test_skip_repeat_3() \ No newline at end of file + test_skip_repeat_3() diff --git a/tests/ut/python/dataset/test_storage.py b/tests/ut/python/dataset/test_storage.py index b37a52f37d..92a689a689 100644 --- a/tests/ut/python/dataset/test_storage.py +++ b/tests/ut/python/dataset/test_storage.py @@ -37,3 +37,15 @@ def test_case_storage(): filename = "storage_result.npz" save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + + +def test_case_no_rows(): + DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] + SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json" + + dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + assert dataset.get_dataset_size() == 3 + count = 0 + for data in dataset.create_tuple_iterator(): + count += 1 + assert count == 3 diff --git a/tests/ut/python/dataset/test_sync_wait.py b/tests/ut/python/dataset/test_sync_wait.py new file mode 100644 index 0000000000..277499d9ae --- /dev/null +++ b/tests/ut/python/dataset/test_sync_wait.py @@ -0,0 +1,182 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.dataset as ds +from mindspore import log as logger +import time +import numpy as np + + +def gen(): + for i in range(100): + yield np.array(i), + + +class Augment: + def __init__(self, loss): + self.loss = loss + + def preprocess(self, input): + return input + + def update(self, data): + self.loss = data["loss"] + + +def test_simple_sync_wait(): + """ + Test simple sync wait: test sync in dataset pipeline + """ + logger.info("test_simple_sync_wait") + batch_size = 4 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_simple_shuffle_sync(): + """ + Test simple shuffle sync: test shuffle before sync + """ + logger.info("test_simple_shuffle_sync") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.shuffle(shuffle_size) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + #time.sleep(0.5) + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_two_sync(): + """ + Test two sync: dataset pipeline with with two sync_operators + """ + logger.info("test_two_sync") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches") + + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + data = {"loss": count} + dataset.sync_update(condition_name="every batch", data=data) + if count % 2 == 0: + dataset.sync_update(condition_name="every 2 batches") + +def test_sync_epoch(): + """ + Test sync wait with epochs: test sync with epochs in dataset pipeline + """ + logger.info("test_sync_epoch") + batch_size = 30 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size, drop_remainder=True) + + for epochs in range(3): + aug.update({"loss": 0}) + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_sync_exception_01(): + """ + Test sync: with shuffle in sync mode + """ + logger.info("test_sync_exception_01") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.shuffle(shuffle_size) + except BaseException as e: + assert "shuffle" in str(e) + dataset = dataset.batch(batch_size) + + +def test_sync_exception_02(): + """ + Test sync: with duplicated condition name + """ + logger.info("test_sync_exception_02") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") + except BaseException as e: + assert "name" in str(e) + dataset = dataset.batch(batch_size) + + +if __name__ == "__main__": + test_simple_sync_wait() + test_simple_shuffle_sync() + test_two_sync() + test_sync_exception_01() + test_sync_exception_02() + test_sync_epoch() \ No newline at end of file diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py index 6de14df34e..c5d9471f8b 100644 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ b/tests/ut/python/dataset/test_tfreader_op.py @@ -32,11 +32,41 @@ def test_case_tf_shape(): ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds1.batch(2) for data in ds1.create_dict_iterator(): - print(data) + logger.info(data) output_shape = ds1.output_shapes() assert (len(output_shape[-1]) == 1) +def test_case_tf_read_all_dataset(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 12 + count = 0 + for data in ds1.create_tuple_iterator(): + count += 1 + assert count == 12 + + +def test_case_num_samples(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) + assert ds1.get_dataset_size() == 8 + count = 0 + for data in ds1.create_dict_iterator(): + count += 1 + assert count == 8 + + +def test_case_num_samples2(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 7 + count = 0 + for data in ds1.create_dict_iterator(): + count += 1 + assert count == 7 + + def test_case_tf_shape_2(): ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) ds1 = ds1.batch(2) @@ -203,6 +233,32 @@ def test_tf_record_schema_columns_list(): a = row["col_sint32"] assert "col_sint32" in str(info.value) +def test_case_invalid_files(): + valid_file = "../data/dataset/testTFTestAllTypes/test.data" + invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" + files = [invalid_file, valid_file, SCHEMA_FILE] + + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + + with pytest.raises(RuntimeError) as info: + row = data.create_dict_iterator().get_next() + assert "cannot be opened" in str(info.value) + assert "not valid tfrecord files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file in str(info.value) + assert SCHEMA_FILE in str(info.value) + + nonexistent_file = "this/file/does/not/exist" + files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file] + + with pytest.raises(ValueError) as info: + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + assert "did not match any files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file not in str(info.value) + assert SCHEMA_FILE not in str(info.value) + assert nonexistent_file in str(info.value) + if __name__ == '__main__': test_case_tf_shape() test_case_tf_file() @@ -212,3 +268,4 @@ if __name__ == '__main__': test_tf_record_schema() test_tf_record_shuffle() test_tf_shard_equal_rows() + test_case_invalid_files() diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py new file mode 100644 index 0000000000..ce0490336e --- /dev/null +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +def visualize(image_original, image_ua): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_ua) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_ua[i]) + plt.title("DE UniformAugment image") + + plt.show() + + +def test_uniform_augment(plot=False, num_ops=2): + """ + Test UniformAugment + """ + logger.info("Test UniformAugment") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # UniformAugment Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transform_list = [F.RandomRotation(45), + F.RandomColor(), + F.RandomSharpness(), + F.Invert(), + F.AutoContrast(), + F.Equalize()] + + transforms_ua = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.UniformAugment(transforms=transform_list, num_ops=num_ops), + F.ToTensor()]) + + ds_ua = ds.map(input_columns="image", + operations=transforms_ua()) + + ds_ua = ds_ua.batch(512) + + for idx, (image,label) in enumerate(ds_ua): + if idx == 0: + images_ua = np.transpose(image, (0, 2,3,1)) + else: + images_ua = np.append(images_ua, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_ua[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_ua) + + +if __name__ == "__main__": + test_uniform_augment(num_ops=1) + diff --git a/tests/ut/python/nn/test_cell_wrapper.py b/tests/ut/python/nn/test_cell_wrapper.py index 3e163c9e4f..148d42ab64 100755 --- a/tests/ut/python/nn/test_cell_wrapper.py +++ b/tests/ut/python/nn/test_cell_wrapper.py @@ -94,10 +94,6 @@ def test_parameter_update_float32(): def test_parameter_update_error(): """ test_parameter_update """ input_np = np.array([1]) - input_parameter = Parameter(np.array([1]), 'input_parameter') with pytest.raises(TypeError): ParameterUpdate(input_np) - - with pytest.raises(TypeError): - ParameterUpdate(input_parameter) diff --git a/tests/ut/python/nn/test_dynamic_lr.py b/tests/ut/python/nn/test_dynamic_lr.py index 96f9d5afde..8d03be1766 100644 --- a/tests/ut/python/nn/test_dynamic_lr.py +++ b/tests/ut/python/nn/test_dynamic_lr.py @@ -41,7 +41,7 @@ class TestInputs: dr.piecewise_constant_lr(milestone1, learning_rates) milestone2 = [1.0, 2.0, True] - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone2, learning_rates) def test_learning_rates1(self): @@ -92,13 +92,13 @@ class TestInputs: def test_total_step1(self): total_step1 = 2.0 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power) def test_total_step2(self): @@ -114,13 +114,13 @@ class TestInputs: def test_step_per_epoch1(self): step_per_epoch1 = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power) def test_step_per_epoch2(self): @@ -136,13 +136,13 @@ class TestInputs: def test_decay_epoch1(self): decay_epoch1 = 'm' - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power) def test_decay_epoch2(self): diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 49e89e124e..d6bc40ba02 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -61,32 +61,15 @@ def test_parameter_init_illegal(): data_int = 3 data_list = [1, "2", True] data_tuple = (1, 2, 3) - np_arr_int16 = np.ones([1,1], dtype=np.int16) - np_arr_int32 = np.ones([1,1], dtype=np.int32) - np_arr_float16 = np.ones([1,1], dtype=np.float16) - np_arr_float32 = np.ones([1,1], dtype=np.float32) - -# with pytest.raises(ValueError): -# Parameter(np_arr_int16[0][0], name=data_str) - Parameter(np_arr_int32[0], name=data_str) - Parameter(np_arr_float16[0], name=data_str) - Parameter(np_arr_float32[0], name=data_str) - Parameter(np_arr_float32, name=data_str) + # test data Parameter(tensor, name=data_str) Parameter(data_int, name=data_str) Parameter(dat, name=data_str) - with pytest.raises(ValueError): - Parameter(data_none, name=data_str) with pytest.raises(ValueError): Parameter(data_bool, name=data_str) - with pytest.raises(ValueError): - Parameter(data_str, name=data_str) - Parameter(data_list, name=data_str) - with pytest.raises(ValueError): - Parameter(data_tuple, name=data_str) - Parameter(tensor, name=data_str) + # test name Parameter(tensor, name=data_none) with pytest.raises(ValueError): Parameter(tensor, name=dat) diff --git a/tests/ut/python/nn/test_pooling.py b/tests/ut/python/nn/test_pooling.py index 10bb7632b2..428e050ea2 100644 --- a/tests/ut/python/nn/test_pooling.py +++ b/tests/ut/python/nn/test_pooling.py @@ -56,3 +56,19 @@ def test_compile_max(): net = MaxNet(3, stride=1, padding=0) x = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) _executor.compile(net, x) + + +class Avg1dNet(nn.Cell): + def __init__(self, + kernel_size, + stride=None): + super(Avg1dNet, self).__init__() + self.avg1d = nn.AvgPool1d(kernel_size, stride) + + def construct(self, x): + return self.avg1d(x) + +def test_avg1d(): + net = Avg1dNet(3, 1) + input = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) + _executor.compile(net, input) \ No newline at end of file diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py index c1a652de57..77d065b100 100644 --- a/tests/ut/python/nn/test_ssim.py +++ b/tests/ut/python/nn/test_ssim.py @@ -60,7 +60,7 @@ def test_ssim_max_val_zero(): net = SSIMNet(max_val) def test_ssim_filter_size_float(): - with pytest.raises(ValueError): + with pytest.raises(TypeError): net = SSIMNet(filter_size=1.1) def test_ssim_filter_size_zero(): @@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value(): with pytest.raises(ValueError): net = SSIMNet(k2=0.0) with pytest.raises(ValueError): - net = SSIMNet(k2=-1.0) \ No newline at end of file + net = SSIMNet(k2=-1.0) diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index 595bf35e2c..b866c7c556 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -30,6 +30,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config from ....mindspore_test_framework.pipeline.forward.verify_exception \ import pipeline_for_verify_exception_for_case_by_case_config +import pytest # pylint: disable=W0613 @@ -81,14 +82,29 @@ def test_sqrt(): assert np.all(result.asnumpy() == expect) +class PowNet(nn.Cell): + def __init__(self): + super(PowNet, self).__init__() + self.pow = P.Pow() + + def construct(self, x, y): + return self.pow(x, y) + + def test_pow(): """ test_pow """ input_tensor = Tensor(np.array([[2, 2], [3, 3]])) power = Tensor(np.array(3.0, np.int64)) + power2 = Tensor(np.array(True, np.bool)) testpow = P.Pow() expect = np.array([[8, 8], [27, 27]]) result = testpow(input_tensor, power) assert np.all(result.asnumpy() == expect) + net = PowNet() + with pytest.raises(TypeError): + net(input_tensor, True) + with pytest.raises(TypeError): + net(input_tensor, power2) def test_exp(): diff --git a/tests/ut/python/ops/test_momentum.py b/tests/ut/python/ops/test_momentum.py index 64b5a9af12..3334f1670a 100644 --- a/tests/ut/python/ops/test_momentum.py +++ b/tests/ut/python/ops/test_momentum.py @@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ run_opt = C.MultitypeFuncGraph("run_opt") -@run_opt.register("Function", "Int", "Number", "Number", +@run_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") def tensor_run_opt(opt, iters, learning_rate, momentum, diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index ec5d25cccb..ab6f31095d 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -516,7 +516,7 @@ test_cases = [ test_cases_for_verify_exception = [ ('Conv2d_ValueError_1', { - 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}), + 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}), 'desc_inputs': [0], }), ('Conv2d_ValueError_2', { @@ -540,7 +540,7 @@ test_cases_for_verify_exception = [ 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_1', { - 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_2', { diff --git a/tests/ut/python/ops/test_nn_ops_check.py b/tests/ut/python/ops/test_nn_ops_check.py new file mode 100755 index 0000000000..c2a751aa0c --- /dev/null +++ b/tests/ut/python/ops/test_nn_ops_check.py @@ -0,0 +1,463 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test ops """ +import functools +import numpy as np +from mindspore import ops +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter +from ..ut_filter import non_graph_engine +from mindspore.common.api import _executor + +from ....mindspore_test_framework.mindspore_test import mindspore_test +from ....mindspore_test_framework.pipeline.forward.compile_forward\ + import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config, + pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +from ....mindspore_test_framework.pipeline.gradient.compile_gradient\ + import pipeline_for_compile_grad_ge_graph_for_case_by_case_config + + +class Conv2DBackpropInputNet(nn.Cell): + def __init__(self, net, x_shape): + super(Conv2DBackpropInputNet, self).__init__() + self.net = net + self.x_shape = x_shape + + def construct(self, dout, w): + return self.net(dout, w, self.x_shape) + + +class TopKNet(nn.Cell): + def __init__(self, net, k): + super(TopKNet, self).__init__() + self.net = net + self.k = k + + def construct(self, x): + return self.net(x, self.k) + + +raise_set = [ + # input is scalar + ('Flatten0', { + 'block': (P.Flatten(), {'exception': TypeError, 'error_keywords': ['Flatten']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # dim of input is zero + ('Flatten1', { + 'block': (P.Flatten(), {'exception': ValueError, 'error_keywords': ['Flatten']}), + 'desc_inputs': [F.scalar_to_tensor(5.0)], + 'skip': ['backward']}), + + # input is scalar + ('Softmax0', { + 'block': (P.Softmax(), {'exception': TypeError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # axis is empty tuple + ('Softmax1', { + 'block': (P.Softmax(axis=()), {'exception': ValueError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + # axis value is not in range + ('Softmax2', { + 'block': (P.Softmax(axis=2), {'exception': ValueError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('LogSoftmax0', { + 'block': (P.LogSoftmax(), {'exception': TypeError, 'error_keywords': ['LogSoftmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # axis value is not in range + ('LogSoftmax1', { + 'block': (P.LogSoftmax(axis=2), {'exception': ValueError, 'error_keywords': ['LogSoftmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('ReLU0', { + 'block': (P.ReLU(), {'exception': TypeError, 'error_keywords': ['ReLU']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(Bool) + ('ReLU1', { + 'block': (P.ReLU(), {'exception': TypeError, 'error_keywords': ['ReLU']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], + 'skip': ['backward']}), + + # input is scalar + ('ReLU60', { + 'block': (P.ReLU6(), {'exception': TypeError, 'error_keywords': ['ReLU6']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('ReLU61', { + 'block': (P.ReLU6(), {'exception': TypeError, 'error_keywords': ['ReLU6']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Elu0', { + 'block': (P.Elu(), {'exception': TypeError, 'error_keywords': ['Elu']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('Elu1', { + 'block': (P.Elu(alpha=0.9), {'exception': TypeError, 'error_keywords': ['Elu']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Sigmoid0', { + 'block': (P.Sigmoid(), {'exception': TypeError, 'error_keywords': ['Sigmoid']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('Sigmoid1', { + 'block': (P.Sigmoid(), {'exception': TypeError, 'error_keywords': ['Sigmoid']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Tanh0', { + 'block': (P.Tanh(), {'exception': TypeError, 'error_keywords': ['Tanh']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + + # input is scalar + ('BatchNorm0', { + 'block': (P.BatchNorm(is_training=False), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [5.0, 5.0, 5.0, 5.0, 5.0], + 'skip': ['backward']}), + # is_training=False and mean=None + ('BatchNorm1', { + 'block': (P.BatchNorm(is_training=False), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32)), + Tensor(np.ones([5, 3]).astype(np.float32)), None, None], + 'skip': ['backward']}), + # is_training=True and mean=None + ('BatchNorm2', { + 'block': (P.BatchNorm(is_training=True), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float16)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # scale and bias rank > 1 + ('BatchNorm3', { + 'block': (P.BatchNorm(is_training=True), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32)), + Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # scale and bias shape not match + ('BatchNorm4', { + 'block': (P.BatchNorm(is_training=True), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([7]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # is_training=False, mean and variance shape not match + ('BatchNorm5', { + 'block': (P.BatchNorm(is_training=False), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # is_training=False, mean and scale shape not match + ('BatchNorm6', { + 'block': (P.BatchNorm(is_training=False), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32)), + Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('Conv2D0', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('Conv2D1', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # input x and w type mismatch + ('Conv2D2', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float16))], + 'skip': ['backward']}), + # rank of x is not 4 + ('Conv2D3', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1, 1]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # rank of 2 is not 4 + ('Conv2D4', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9]).astype(np.float32))], + 'skip': ['backward']}), + # x_shape[1] / group != w_shape[1] + ('Conv2D5', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,2,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # out_channel != w_shape[0] + ('Conv2D6', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # kernel_size != w_shape[2:4] + ('Conv2D7', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([2,1,5,6]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('DepthwiseConv2dNative0', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('DepthwiseConv2dNative1', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # input x and w type mismatch + ('DepthwiseConv2dNative2', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float16))], + 'skip': ['backward']}), + # rank of x is not 4 + ('DepthwiseConv2dNative3', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1, 1]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # rank of 2 is not 4 + ('DepthwiseConv2dNative4', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9]).astype(np.float32))], + 'skip': ['backward']}), + # x_shape[1] != w_shape[1] + ('DepthwiseConv2dNative5', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,2,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # kernel_size != w_shape[2:4] + ('DepthwiseConv2dNative6', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([2,1,5,6]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('MaxPoolWithArgmax0', { + 'block': (P.MaxPoolWithArgmax(), {'exception': TypeError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('MaxPoolWithArgmax1', { + 'block': (P.MaxPoolWithArgmax(), {'exception': TypeError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPoolWithArgmax2', { + 'block': (P.MaxPoolWithArgmax(), {'exception': ValueError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # kernel size is invalid(very large) + ('MaxPoolWithArgmax3', { + 'block': (P.MaxPoolWithArgmax(ksize=50), {'exception': ValueError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('MaxPool0', { + 'block': (P.MaxPool(), {'exception': TypeError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPool1', { + 'block': (P.MaxPool(), {'exception': ValueError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPool2', { + 'block': (P.MaxPool(ksize=50, strides=1), {'exception': ValueError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('AvgPool0', { + 'block': (P.AvgPool(), {'exception': TypeError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # rank of x is not 4 + ('AvgPool1', { + 'block': (P.AvgPool(), {'exception': ValueError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x is not 4 + ('AvgPool2', { + 'block': (P.AvgPool(ksize=50, strides=1), {'exception': ValueError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('Conv2DBackpropInput0', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('Conv2DBackpropInput1', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of doutput and w mismatch + ('Conv2DBackpropInput2', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # types x_size is not tuple + ('Conv2DBackpropInput3', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), 2), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # types x_size is not tuple(int,...) + ('Conv2DBackpropInput4', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2, 3.0)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('BiasAdd0', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('BiasAdd1', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of x and bias mismatch + ('BiasAdd2', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x less than 2 + ('BiasAdd3', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # rank of bias is not equal to 1 + ('BiasAdd4', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32))], + 'skip': ['backward']}), + # b_shape[0] != x_shape[1] + ('BiasAdd5', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input x is scalar + ('TopK0', { + 'block': (TopKNet(P.TopK(), 5), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input x is Tensor(bool) + ('TopK1', { + 'block': (TopKNet(P.TopK(), 5), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [Tensor(np.ones([10]).astype(np.bool_))], + 'skip': ['backward']}), + # k is not integer + ('TopK2', { + 'block': (TopKNet(P.TopK(), 5.0), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('SoftmaxCrossEntropyWithLogits0', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('SoftmaxCrossEntropyWithLogits1', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of logits and labels mismatch + ('SoftmaxCrossEntropyWithLogits2', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float16)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # shapes of logits and labels mismatch + ('SoftmaxCrossEntropyWithLogits3', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': ValueError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('SparseSoftmaxCrossEntropyWithLogits0', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # logits is Tensor(bool) + ('SparseSoftmaxCrossEntropyWithLogits1', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # labels is Tensor(bool) + ('SparseSoftmaxCrossEntropyWithLogits2', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # logits_shape[0] != labels_shape[0] + ('SparseSoftmaxCrossEntropyWithLogits3', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': ValueError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([3]).astype(np.int32))], + 'skip': ['backward']}), +] + + +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +def test_check_exception(): + return raise_set diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 3345f77862..1dea7b6502 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -250,6 +250,10 @@ test_case_math_ops = [ 'block': P.Exp(), 'desc_inputs': [[2, 3]], 'desc_bprop': [[2, 3]]}), + ('Erf', { + 'block': P.Erf(), + 'desc_inputs': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))], + 'desc_bprop': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))]}), ('Floor', { 'block': P.Floor(), 'desc_inputs': [[2, 512, 56, 56]], @@ -414,6 +418,11 @@ test_case_math_ops = [ 'block': P.NotEqual(), 'desc_inputs': [[4, 1], [2, 3, 4, 5]], 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))]}), + ('NotEqual_0', { + 'block': P.NotEqual(), + 'desc_inputs': [ 1, [2, 3, 4, 5]], + 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))], + 'skip': ['backward']}), ('Greater', { 'block': P.Greater(), 'desc_inputs': [[2, 3, 4, 1], [4, 5]], @@ -573,6 +582,10 @@ test_case_nn_ops = [ 'block': P.ReLU6(), 'desc_inputs': [[1, 3, 4, 4]], 'desc_bprop': [[1, 3, 4, 4]]}), + ('ReLUV2', { + 'block': P.ReLUV2(), + 'desc_inputs': [[1, 3, 4, 4]], + 'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}), ('ReLUGrad', { 'block': G.ReluGrad(), 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], @@ -658,7 +671,7 @@ test_case_nn_ops = [ 'skip': []}), ('BatchNormGrad', { 'block': G.BatchNormGrad(), - 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], + 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), ('ApplyMomentum', { @@ -866,6 +879,14 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], 'desc_bprop': [3, 3], 'skip': ['backward']}), + ('L2Loss_1', { + 'block': P.L2Loss(), + 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)], + 'desc_bprop': []}), + ('L2Loss_2', { + 'block': P.L2Loss(), + 'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)], + 'desc_bprop': []}), ] test_case_array_ops = [ @@ -1117,6 +1138,21 @@ test_case_other_ops = [ 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), Tensor(np.array([1.2]).astype(np.float32))], 'skip': ['backward']}), + ('ConfusionMulGrad_1', { + 'block': P.ConfusionMulGrad(axis = [0], keep_dims = False), + 'desc_inputs': [[3, 2], [3, 2], [3, 2]], + 'desc_bprop': [[3, 2], [2]], + 'skip': ['backward']}), + ('ConfusionMulGrad_2', { + 'block': P.ConfusionMulGrad(axis = [0], keep_dims = True), + 'desc_inputs': [[3, 2], [3, 2], [3, 2]], + 'desc_bprop': [[3, 2], [1, 2]], + 'skip': ['backward']}), + ('ConfusionMulGrad_3', { + 'block': P.ConfusionMulGrad(axis = (), keep_dims = True), + 'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]], + 'desc_bprop': [[2, 3, 4], [1, 1, 1]], + 'skip': ['backward']}), ('HistogramSummary', { 'block': HistogramSummaryNet(), 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 6200d4e163..ddd1fb46a1 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -18,6 +18,7 @@ import pytest from mindspore import Tensor from mindspore import context +from mindspore import dtype as mstype from mindspore.nn import Cell from ....mindspore_test_framework.mindspore_test import mindspore_test @@ -41,6 +42,20 @@ class NetWorkSlicePositive(Cell): return ret0, ret1, ret2, ret3 +class NetWorkSliceEllipsis(Cell): + def __init__(self): + super(NetWorkSliceEllipsis, self).__init__() + self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32)) + self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32)) + self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32)) + + def construct(self, tensor): + ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0 + ret1 = tensor[...] + self.tensor_ret1 + ret2 = tensor[True] + self.tensor_ret2 + return ret0, ret1, ret2 + + class NetWorkReduceDimension(Cell): def __init__(self): super(NetWorkReduceDimension, self).__init__() @@ -79,7 +94,102 @@ class NetWorkReduceToScalar(Cell): return ret +class TensorAssignWithBoolTensorIndex(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex, self).__init__() + self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) + + def construct(self, a, b, c, u_tensor, _scalar): + a[c] = u_scalar + a[b] = u_tensor + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndexError(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndexError, self).__init__() + + def construct(self, a, b, c, u_tensor): + a[b][c] = u_tensor + return a + + +class TensorAssignWithBoolTensorIndex2(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2, self).__init__() + self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) + + def construct(self, a, u_tensor, _scalar): + a[a > 8] = u_tensor + a[a >= 6] = u_scalar + a[a < 3] = u_scalar + a[a <= 5] = u_tensor + a[a == 5] = u_scalar + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndex2Error(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2Error, self).__init__() + + def construct(self, a, u_tensor): + a[a > 8][a > 5] = u_tensor + return a + + +a = np.random.uniform(1, 10, [2, 3]) +b = a > 5 +c = a < 3 +Ta = Tensor(a) +Tb = Tensor(b) +Tc = Tensor(c) +Td = Tensor([True, True]) +u_tensor = Tensor([1]) +u_tensor_error = Tensor([1, 2]) +u_scalar = 5 + + +def test_tensor_assign_bool_index(): + net1 = TensorAssignWithBoolTensorIndex() + net2 = TensorAssignWithBoolTensorIndex2() + + net1(Ta, Tb, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Td, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, u_tensor, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Td, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Ta, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Tc, u_tensor_error, u_scalar) + # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) + with pytest.raises(ValueError): + net2(Ta, u_tensor_error, u_scalar) + net3 = TensorAssignWithBoolTensorIndexError() + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_tensor) + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_scalar) + net4 = TensorAssignWithBoolTensorIndex2Error() + with pytest.raises(AttributeError): + net4(Ta, u_tensor) + with pytest.raises(AttributeError): + net4(Ta, u_scalar) + + test_cases = [ + ('TensorAssignWithBoolTensorIndex', { + 'block': TensorAssignWithBoolTensorIndex(), + 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], + }), + ('TensorAssignWithBoolTensorIndex2', { + 'block': TensorAssignWithBoolTensorIndex2(), + 'desc_inputs': [Ta, u_tensor, u_scalar], + }), ('SlicePositive', { 'block': NetWorkSlicePositive(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], @@ -96,7 +206,10 @@ test_cases = [ 'block': NetWorkReduceToScalar(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], }), - + ('NetWorkSliceEllipsis', { + 'block': NetWorkSliceEllipsis(), + 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], + }), ] diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py index 9b4e1fda23..ae7bd952d9 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py @@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 - cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) + cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) @@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 dev_num = 8 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) - cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) + cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) diff --git a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py index db6190ab89..848c8025cb 100644 --- a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py +++ b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py @@ -86,7 +86,7 @@ def test_two_matmul(): costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha") assert costmodel_alpha == 1.0 costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta") - assert costmodel_beta == 260.0 + assert costmodel_beta == 400.0 costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma") assert costmodel_gamma == 0.001 costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold") diff --git a/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py b/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py new file mode 100644 index 0000000000..b572968a4f --- /dev/null +++ b/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mindspore as ms +from mindspore import Tensor, Parameter, ParameterTuple, context +from mindspore import nn +from mindspore.common.api import _executor +from mindspore.nn.optim import Adam, FTRL +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +import numpy as np + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mul = P.Mul() + self.relu = P.ReLU() + self.param1 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide") + self.param2 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="deep") + + def construct(self, x): + out = self.mul(x, self.param1) + out = self.mul(out, self.param2) + out = self.relu(out) + return out + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.sum = P.ReduceSum(keep_dims=False).set_strategy(strategy=((4, 1, 1, 1),)) + self.mean = P.ReduceMean(keep_dims=False).set_strategy(strategy=((8, 1, 1, 1),)) + self.net = network + + def construct(self, x): + net_out = self.net(x) + loss1 = self.sum(net_out, -1) + loss2 = self.mean(net_out, -1) + return loss1, loss2 + + +class IthOutputCell(nn.Cell): + def __init__(self, network, output_index): + super(IthOutputCell, self).__init__() + self.network = network + self.output_index = output_index + + def construct(self, x1): + predict = self.network(x1)[self.output_index] + return predict + + +class TrainStepWrap(nn.Cell): + def __init__(self, network, sens=1000.0): + super(TrainStepWrap, self).__init__() + self.network = network + self.network.set_train() + self.trainable_params = network.trainable_params() + weights_w = [] + weights_d = [] + for params in self.trainable_params: + weights_w.append(params) + weights_d.append(params) + + self.weights_w = ParameterTuple(weights_w) + self.weights_d = ParameterTuple(weights_d) + self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, + l1=1e-8, l2=1e-8, initial_accum=1.0) + self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8, + loss_scale=sens) + self.hyper_map = C.HyperMap() + self.grad_w = C.GradOperation('grad_w', get_by_list=True, + sens_param=True) + self.grad_d = C.GradOperation('grad_d', get_by_list=True, + sens_param=True) + self.sens = sens + self.loss_net_w = IthOutputCell(network, output_index=0) + self.loss_net_d = IthOutputCell(network, output_index=1) + + def construct(self, x): + weights_w = self.weights_w + weights_d = self.weights_d + loss_w, loss_d = self.network(x) + sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) + sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) + grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w) + grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d) + return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d)) + + +def test_two_subgraphs(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + net = TrainStepWrap(NetWithLoss(Net())) + input_x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32) + _executor.compile(net, input_x) diff --git a/tests/ut/python/pipeline/parse/test_operator.py b/tests/ut/python/pipeline/parse/test_operator.py index a3c5f7e422..19d70b20a1 100644 --- a/tests/ut/python/pipeline/parse/test_operator.py +++ b/tests/ut/python/pipeline/parse/test_operator.py @@ -160,14 +160,17 @@ def test_ops(): ret_floor = p // q + q // p ret = ret_pow + ret_mod + ret_floor if self.int > self.float: - if self.str_a + self.str_b == "helloworld": - return ret + if [1, 2, 3] != None: + if self.str_a + self.str_b == "helloworld": + if q == 86: + print("hello world") + return ret return x net = OpsNet(9, 2) x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE) net(x, y) diff --git a/tests/ut/python/pynative_mode/test_backend.py b/tests/ut/python/pynative_mode/test_backend.py index 937f7b24ff..fae1974854 100644 --- a/tests/ut/python/pynative_mode/test_backend.py +++ b/tests/ut/python/pynative_mode/test_backend.py @@ -13,16 +13,14 @@ # limitations under the License. # ============================================================================ """ test_backend """ -import numpy as np +import os import pytest from mindspore.ops import operations as P import mindspore.nn as nn -from mindspore import context +from mindspore import context, ms_function from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter -from mindspore._extends.pynative_helper import args_type_check -from mindspore.common.tensor import Tensor -from mindspore.common.api import ms_function +from mindspore._checkparam import args_type_check def setup_module(module): @@ -31,6 +29,7 @@ def setup_module(module): class Net(nn.Cell): """ Net definition """ + def __init__(self): super(Net, self).__init__() self.add = P.TensorAdd() @@ -49,14 +48,17 @@ def test_vm_backend(): output = add() assert output.asnumpy().shape == (1, 3, 3, 4) + def test_vm_set_context(): """ test_vm_set_context """ - context.set_context(save_graphs=True, save_graphs_path="/home/mindspore", mode=context.GRAPH_MODE) + context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE) assert context.get_context("save_graphs") assert context.get_context("mode") == context.GRAPH_MODE - assert context.get_context("save_graphs_path") == "/home/mindspore" + assert os.path.exists("mindspore_ir_path") + assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 context.set_context(mode=context.PYNATIVE_MODE) + @args_type_check(v_str=str, v_int=int, v_tuple=tuple) def check_input(v_str, v_int, v_tuple): """ check_input """ @@ -74,3 +76,15 @@ def test_args_type_check(): with pytest.raises(TypeError): check_input("name", 100, "age") check_input("name", 100, (10, 10)) + + +def teardown_module(): + dirs = ['mindspore_ir_path'] + for item in dirs: + item_name = './' + item + if not os.path.exists(item_name): + continue + if os.path.isdir(item_name): + os.rmdir(item_name) + elif os.path.isfile(item_name): + os.remove(item_name) diff --git a/tests/ut/python/pynative_mode/test_cell_bprop.py b/tests/ut/python/pynative_mode/test_cell_bprop.py index da1e14974f..c69b80412e 100644 --- a/tests/ut/python/pynative_mode/test_cell_bprop.py +++ b/tests/ut/python/pynative_mode/test_cell_bprop.py @@ -51,7 +51,7 @@ class InlineMulADD(nn.Cell): def __init__(self): super(InlineMulADD, self).__init__() self.mul_add = MulAdd() - self.param = Parameter(2, 'param') + self.param = 2 def construct(self, x, y): return self.mul_add(x, y) + x + self.param * y diff --git a/tests/ut/python/pynative_mode/test_context.py b/tests/ut/python/pynative_mode/test_context.py index 450bf60b90..2425b53f42 100644 --- a/tests/ut/python/pynative_mode/test_context.py +++ b/tests/ut/python/pynative_mode/test_context.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test_context """ +import os import pytest from mindspore import context # pylint: disable=W0212 @@ -74,11 +75,12 @@ def test_dump_target(): def test_set_context(): """ test_set_context """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", - device_id=0, save_graphs=True, save_graphs_path="/mindspore") + device_id=0, save_graphs=True, save_graphs_path="mindspore_ir_path") assert context.get_context("device_id") == 0 assert context.get_context("device_target") == "Ascend" assert context.get_context("save_graphs") - assert context.get_context("save_graphs_path") == "/mindspore" + assert os.path.exists("mindspore_ir_path") + assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 assert context.get_context("mode") == context.GRAPH_MODE context.set_context(mode=context.PYNATIVE_MODE) @@ -87,3 +89,16 @@ def test_set_context(): with pytest.raises(ValueError): context.set_context(modex="ge") + + +def teardown_module(): + dirs = ['mindspore_ir_path'] + for item in dirs: + item_name = './' + item + if not os.path.exists(item_name): + continue + if os.path.isdir(item_name): + os.rmdir(item_name) + elif os.path.isfile(item_name): + os.remove(item_name) + diff --git a/tests/vm_impl/nn_ops_vm_impl.py b/tests/vm_impl/nn_ops_vm_impl.py index fc1fa95024..0df4b5fbaa 100644 --- a/tests/vm_impl/nn_ops_vm_impl.py +++ b/tests/vm_impl/nn_ops_vm_impl.py @@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self): """Generate vm_impl function for MaxPoolGradWithArgmax""" def vm_impl(x, dout, argmax): - print("buxue") - print(argmax) x = x.asnumpy() dout = dout.asnumpy() arg_max = argmax.asnumpy() @@ -379,8 +377,8 @@ def vm_impl_momentum(self): accumulation = accumulation.asnumpy() variable = variable.asnumpy() shape = accumulation.shape - learning_rate = np.full(shape, learning_rate) - momentum = np.full(shape, momentum) + learning_rate = np.full(shape, learning_rate.asnumpy()) + momentum = np.full(shape, momentum.asnumpy()) accumulation = accumulation * momentum + gradient if use_nesterov is True: variable -= gradient * learning_rate + accumulation * momentum * learning_rate diff --git a/third_party/patch/incubator-tvm/CMakeLists.txt b/third_party/patch/incubator-tvm/CMakeLists.txt new file mode 100644 index 0000000000..d8964579cd --- /dev/null +++ b/third_party/patch/incubator-tvm/CMakeLists.txt @@ -0,0 +1,100 @@ +cmake_minimum_required(VERSION 3.2) +project(tvm C CXX) +set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +# Utility functions +include(${TVM_DIR}/cmake/util/Util.cmake) +include(${TVM_DIR}/cmake/util/FindCUDA.cmake) + +# include directories +include_directories(AFTER "${TVM_DIR}/include") +include_directories(AFTER "${TVM_DIR}/src") +include_directories(AFTER "${TVM_DIR}") +include_directories(AFTER "${TVM_DIR}/src/schedule") + +include_directories(AFTER "${TVM_DIR}/3rdparty/dmlc-core/include") +include_directories(AFTER "${TVM_DIR}/3rdparty/dlpack/include") +include_directories(AFTER "${TVM_DIR}/3rdparty/compiler-rt") +include_directories(AFTER "${TVM_DIR}/3rdparty/rang/include") + +# lib contain dlopen and dlclose +set(TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) + +# add source group +file(GLOB_RECURSE GROUP_SOURCE "${TVM_DIR}/src/*.cc" "src/*.cc") +file(GLOB_RECURSE GROUP_INCLUDE "${TVM_DIR}/src/*.h" + "${TVM_DIR}/include/*.h" "src/*.h" "include/*.h") +assign_source_group("Source" ${GROUP_SOURCE}) +assign_source_group("Include" ${GROUP_INCLUDE}) + +file(GLOB COMPILER_SRCS + "pre_activate/gpu/*.cc" + ${TVM_DIR}/src/api/*.cc + ${TVM_DIR}/src/arithmetic/*.cc + ${TVM_DIR}/src/autotvm/*.cc + ${TVM_DIR}/src/codegen/*.cc + ${TVM_DIR}/src/lang/*.cc + ${TVM_DIR}/src/pass/*.cc + ${TVM_DIR}/src/op/*.cc + ${TVM_DIR}/src/node/*.cc + ${TVM_DIR}/src/schedule/*.cc + ${TVM_DIR}/src/runtime/*.cc + ${TVM_DIR}/src/runtime/vm/*.cc + ${TVM_DIR}/src/runtime/vm/profiler/*.cc + ${TVM_DIR}/src/codegen/stackvm/*.cc) + +file(GLOB_RECURSE RELAY_SRCS ${TVM_DIR}/src/relay/*.cc) +list(APPEND COMPILER_SRCS ${RELAY_SRCS}) + +file(GLOB DATATYPE_SRCS ${TVM_DIR}/src/codegen/datatype/*.cc) +list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) + +file(GLOB COMPILER_VERILOG_SRCS ${TVM_DIR}/src/codegen/verilog/*.cc) +list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) + +file(GLOB TOPI_SRCS ${TVM_DIR}/topi/src/*.cc) + +file(GLOB RUNTIME_SRCS + ${TVM_DIR}/src/runtime/*.cc + ${TVM_DIR}/src/runtime/vm/*.cc + ${TVM_DIR}/src/runtime/stub/*.cc + ${TVM_DIR}/src/runtime/stackvm/*.cc) + + +file(GLOB COMPILER_OFF_SRCS + ${TVM_DIR}/src/codegen/opt/build_*_off.cc) + +list(REMOVE_ITEM COMPILER_OFF_SRCS + ${TVM_DIR}/src/codegen/opt/build_cuda_off.cc) +set(USE_CUDA "ON") +list(APPEND COMPILER_SRCS ${COMPILER_OFF_SRCS}) +# Module rules +include(${TVM_DIR}/cmake/modules/CUDA.cmake) + +set(CMAKE_C_FLAGS_AKG -pipe -Wall -fPIC -fstack-protector-all) +set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) + +set(CMAKE_CXX_FLAGS_AKG -std=c++11 -pipe -Wall -fPIC -fstack-protector-all) +set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) + +if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + message("-- Build in Debug mode") + set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O0 -g -rdynamic) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O0 -g -rdynamic) +else() + message("-- Build in Release mode") + set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O2 -Werror) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O2 -Werror) +endif() +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION + VERSION_GREATER 7.0) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -faligned-new) +endif() + +add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS} ${TOPI_SRCS}) + +target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) +target_compile_options(tvm PRIVATE + $<$:${CMAKE_C_FLAGS_AKG}> + $<$:${CMAKE_CXX_FLAGS_AKG}>) +target_include_directories(tvm PRIVATE "${TVM_DIR}/topi/include") +install(TARGETS tvm) \ No newline at end of file diff --git a/third_party/patch/incubator-tvm/find_library.patch b/third_party/patch/incubator-tvm/find_library.patch index e54df2c7cf..f7b2f9af0a 100644 --- a/third_party/patch/incubator-tvm/find_library.patch +++ b/third_party/patch/incubator-tvm/find_library.patch @@ -18,11 +18,11 @@ - lib_path = libinfo.find_lib_path() + """Load library by searching possible path.""" + pwd = os.path.dirname(os.path.realpath(__file__)) -+ path = os.path.realpath(pwd+"/../../../mindspore") ++ path = os.path.realpath(pwd+"/../../../mindspore/lib") + lib_path = [] + files = os.listdir(path) + for f in files: -+ if f.startswith("_c_expression.") and f.endswith(".so"): ++ if f.startswith("libtvm.") and f.endswith(".so"): + lib_path.append(path+"/"+f) + break + if not lib_path: @@ -56,11 +56,11 @@ diff -Npur tvm/topi/python/topi/cpp/impl.py tvm_new/topi/python/topi/cpp/impl.py - return None, None + """Load library by searching possible path.""" + pwd = os.path.dirname(os.path.realpath(__file__)) -+ path = os.path.realpath(pwd+"/../../../mindspore") ++ path = os.path.realpath(pwd+"/../../../mindspore/lib") + lib_path = [] + files = os.listdir(path) + for f in files: -+ if f.startswith("_c_expression.") and f.endswith(".so"): ++ if f.startswith("libtvm.") and f.endswith(".so"): + lib_path.append(path+"/"+f) + break + if not lib_path: