Merge pull request !56 from changzherui/syn-code423tags/v0.3.0-alpha
| @@ -52,7 +52,7 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true | |||
| ConstructorInitializerIndentWidth: 4 | |||
| ContinuationIndentWidth: 2 | |||
| Cpp11BracedListStyle: true | |||
| DerivePointerAlignment: true | |||
| DerivePointerAlignment: false | |||
| DisableFormat: false | |||
| ExperimentalAutoDetectBinPacking: false | |||
| FixNamespaceComments: true | |||
| @@ -94,7 +94,7 @@ PenaltyBreakString: 1000 | |||
| PenaltyBreakTemplateDeclaration: 10 | |||
| PenaltyExcessCharacter: 1000000 | |||
| PenaltyReturnTypeOnItsOwnLine: 200 | |||
| PointerAlignment: Left | |||
| PointerAlignment: Right | |||
| RawStringFormats: | |||
| - Language: Cpp | |||
| Delimiters: | |||
| @@ -1,8 +1,6 @@ | |||
| cmake_minimum_required(VERSION 3.14) | |||
| project (MindSpore) | |||
| include(${CMAKE_SOURCE_DIR}/cmake/options.cmake) | |||
| set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/") | |||
| if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") | |||
| @@ -179,7 +179,7 @@ Check out how MindSpore Open Governance [works](https://gitee.com/mindspore/comm | |||
| - [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers. | |||
| - IRC channel at `#mindspore` (only for meeting minutes logging purpose) | |||
| - Video Conferencing: meet.jit.si | |||
| - Video Conferencing: https://meet.jit.si | |||
| - Mailing-list: https://mailweb.mindspore.cn/postorius/lists | |||
| ## Contributing | |||
| @@ -31,6 +31,7 @@ cd %CD%/mindspore | |||
| cmake -DCMAKE_BUILD_TYPE=Release -DENABLE_CPU=ON -DENABLE_MINDDATA=ON -DUSE_GLOG=ON -G "CodeBlocks - MinGW Makefiles" ../.. | |||
| IF NOT %errorlevel% == 0 ( | |||
| echo "cmake fail." | |||
| goto run_fail | |||
| ) | |||
| @@ -40,6 +41,7 @@ IF "%1%" == "" ( | |||
| cmake --build . --target package -- -j%1% | |||
| ) | |||
| IF NOT %errorlevel% == 0 ( | |||
| echo "build fail." | |||
| goto run_fail | |||
| ) | |||
| @@ -49,6 +51,6 @@ goto run_eof | |||
| :run_fail | |||
| cd %BASEPATH% | |||
| echo "build fail." | |||
| set errorlevel=1 | |||
| :run_eof | |||
| @@ -23,30 +23,30 @@ export BUILD_PATH="${BASEPATH}/build/" | |||
| usage() | |||
| { | |||
| echo "Usage:" | |||
| echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge|cpu] [-m infer|train] \\" | |||
| echo " [-a on|off] [-g on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" | |||
| echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" | |||
| echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" | |||
| echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K]" | |||
| echo "" | |||
| echo "Options:" | |||
| echo " -d Debug mode" | |||
| echo " -r Release mode, default mode" | |||
| echo " -v Display build command" | |||
| echo " -c Enable code coverage switch, default off" | |||
| echo " -t Run testcases switch, default on" | |||
| echo " -c Enable code coverage, default off" | |||
| echo " -t Run testcases, default on" | |||
| echo " -g Use glog to output log, default on" | |||
| echo " -h Print usage" | |||
| echo " -b Select other backend, available: \\" | |||
| echo " ge:graph engine, cpu" | |||
| echo " -m Select mode, available: infer, train, default is infer " | |||
| echo " ge:graph engine" | |||
| echo " -m Select graph engine backend mode, available: infer, train, default is infer" | |||
| echo " -a Enable ASAN, default off" | |||
| echo " -p Enable pipeline profile, default off" | |||
| echo " -p Enable pipeline profile, print to stdout, default off" | |||
| echo " -R Enable pipeline profile, record to json, default off" | |||
| echo " -i Enable increment building, default off" | |||
| echo " -L Enable load ANF-IR as input of 'infer', default off" | |||
| echo " -R Enable the time_line record, default off" | |||
| echo " -j[n] Set the threads when building (Default: -j8)" | |||
| echo " -e Use gpu, d or cpu" | |||
| echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" | |||
| echo " -Q Enable dump end to end, default off" | |||
| echo " -Q Enable dump memory, default off" | |||
| echo " -D Enable dumping of function graph ir, default on" | |||
| echo " -z Compile dataset & mindrecord, default on" | |||
| echo " -M Enable MPI and NCCL for GPU training, default on" | |||
| @@ -64,7 +64,7 @@ set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) | |||
| string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | |||
| string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | |||
| # force __FILE__ to show relative path of file, from source directory | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") | |||
| add_subdirectory(${GE_SOURCE_DIR}/src/common/graph) | |||
| if(ENABLE_D) | |||
| add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) | |||
| @@ -1,16 +1,15 @@ | |||
| set(incubator_tvm_gpu_CFLAGS "-pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") | |||
| set(incubator_tvm_gpu_CXXFLAGS "-std=c++11 -pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") | |||
| set(USE_CUDA "ON") | |||
| set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||
| set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") | |||
| mindspore_add_pkg(incubator_tvm_gpu | |||
| VER 0.6.0 | |||
| LIBS tvm | |||
| URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz | |||
| MD5 9cbbd32545a776023acabbba270449fe | |||
| CUSTOM_CMAKE ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/ | |||
| SUBMODULES ${dlpack_DIRPATH} ${dmlc-core_DIRPATH} ${rang_DIRPATH} | |||
| SOURCEMODULES topi/python/topi python/tvm | |||
| PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/find_library.patch | |||
| ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch | |||
| ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch | |||
| CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON) | |||
| include_directories(${incubator_tvm_gpu_INC}) | |||
| add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm) | |||
| ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch | |||
| ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch | |||
| CMAKE_OPTION " ") | |||
| add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm) | |||
| @@ -205,7 +205,7 @@ set(MS_FIND_NO_DEFAULT_PATH ${MS_FIND_NO_DEFAULT_PATH} PARENT_SCOPE) | |||
| function(mindspore_add_pkg pkg_name ) | |||
| set(options ) | |||
| set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH) | |||
| set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH CUSTOM_CMAKE) | |||
| set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES SUBMODULES SOURCEMODULES) | |||
| cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) | |||
| @@ -281,10 +281,6 @@ function(mindspore_add_pkg pkg_name ) | |||
| file(GLOB ${pkg_name}_INSTALL_SUBMODULE ${_SUBMODULE_FILE}/*) | |||
| file(COPY ${${pkg_name}_INSTALL_SUBMODULE} DESTINATION ${${pkg_name}_SOURCE_DIR}/3rdparty/${_SUBMODENAME}) | |||
| endforeach (_SUBMODULE_FILE) | |||
| foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) | |||
| file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) | |||
| file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) | |||
| endforeach (_SUBMODULE_FILE) | |||
| else() | |||
| set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) | |||
| endif () | |||
| @@ -304,12 +300,20 @@ function(mindspore_add_pkg pkg_name ) | |||
| message(FATAL_ERROR "Failed patch: ${_LF_PATCH_FILE}") | |||
| endif() | |||
| endforeach(_PATCH_FILE) | |||
| foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) | |||
| file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) | |||
| file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) | |||
| endforeach (_SUBMODULE_FILE) | |||
| file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) | |||
| if(NOT ${pkg_name}_LOCK_RET EQUAL "0") | |||
| message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") | |||
| endif() | |||
| if (PKG_CUSTOM_CMAKE) | |||
| file(GLOB ${pkg_name}_cmake ${PKG_CUSTOM_CMAKE}/CMakeLists.txt) | |||
| file(COPY ${${pkg_name}_cmake} DESTINATION ${${pkg_name}_SOURCE_DIR}) | |||
| endif () | |||
| if(${pkg_name}_SOURCE_DIR) | |||
| if (PKG_HEAD_ONLY) | |||
| file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) | |||
| @@ -0,0 +1,58 @@ | |||
| # AlexNet Example | |||
| ## Description | |||
| Training AlexNet with CIFAR-10 dataset in MindSpore. | |||
| This is the simple tutorial for training AlexNet in MindSpore. | |||
| ## Requirements | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| - Download the CIFAR-10 dataset at <http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz>. The directory structure is as follows: | |||
| ``` | |||
| ├─cifar-10-batches-bin | |||
| │ | |||
| └─cifar-10-verify-bin | |||
| ``` | |||
| ## Running the example | |||
| ```python | |||
| # train AlexNet, hyperparameter setting in config.py | |||
| python train.py --data_path cifar-10-batches-bin | |||
| ``` | |||
| You can get loss with each step similar to this: | |||
| ```bash | |||
| epoch: 1 step: 1, loss is 2.2791853 | |||
| ... | |||
| epoch: 1 step: 1536, loss is 1.9366643 | |||
| epoch: 1 step: 1537, loss is 1.6983616 | |||
| epoch: 1 step: 1538, loss is 1.0221305 | |||
| ... | |||
| ``` | |||
| Then, test AlexNet according to network model | |||
| ```python | |||
| # test AlexNet, 1 epoch training accuracy is up to 51.1%; 10 epoch training accuracy is up to 81.2% | |||
| python eval.py --data_path cifar-10-verify-bin --mode test --ckpt_path checkpoint_alexnet-1_1562.ckpt | |||
| ``` | |||
| ## Note | |||
| There are some optional arguments: | |||
| ```bash | |||
| -h, --help show this help message and exit | |||
| --device_target {Ascend,GPU} | |||
| device where the code will be implemented (default: Ascend) | |||
| --data_path DATA_PATH | |||
| path where the dataset is saved | |||
| --dataset_sink_mode DATASET_SINK_MODE | |||
| dataset_sink_mode is False or True | |||
| ``` | |||
| You can run ```python train.py -h``` or ```python eval.py -h``` to get more information. | |||
| @@ -0,0 +1,46 @@ | |||
| # MindRecord generating guidelines | |||
| <!-- TOC --> | |||
| - [MindRecord generating guidelines](#mindrecord-generating-guidelines) | |||
| - [Create work space](#create-work-space) | |||
| - [Implement data generator](#implement-data-generator) | |||
| - [Run data generator](#run-data-generator) | |||
| <!-- /TOC --> | |||
| ## Create work space | |||
| Assume the dataset name is 'xyz' | |||
| * Create work space from template | |||
| ```shell | |||
| cd ${your_mindspore_home}/example/convert_to_mindrecord | |||
| cp -r template xyz | |||
| ``` | |||
| ## Implement data generator | |||
| Edit dictionary data generator | |||
| * Edit file | |||
| ```shell | |||
| cd ${your_mindspore_home}/example/convert_to_mindrecord | |||
| vi xyz/mr_api.py | |||
| ``` | |||
| Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented | |||
| - 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. | |||
| - 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() | |||
| Tricky for parallel run | |||
| - For imagenet, one directory can be a task. | |||
| - For TFRecord with multiple files, each file can be a task. | |||
| - For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) | |||
| ## Run data generator | |||
| * run python script | |||
| ```shell | |||
| cd ${your_mindspore_home}/example/convert_to_mindrecord | |||
| python writer.py --mindrecord_script imagenet [...] | |||
| ``` | |||
| @@ -0,0 +1,122 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| User-defined API for MindRecord writer. | |||
| Two API must be implemented, | |||
| 1. mindrecord_task_number() | |||
| # Return number of parallel tasks. return 1 if no parallel | |||
| 2. mindrecord_dict_data(task_id) | |||
| # Yield data for one task | |||
| # task_id is 0..N-1, if N is return value of mindrecord_task_number() | |||
| """ | |||
| import argparse | |||
| import os | |||
| import pickle | |||
| ######## mindrecord_schema begin ########## | |||
| mindrecord_schema = {"label": {"type": "int64"}, | |||
| "data": {"type": "bytes"}, | |||
| "file_name": {"type": "string"}} | |||
| ######## mindrecord_schema end ########## | |||
| ######## Frozen code begin ########## | |||
| with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: | |||
| ARG_LIST = pickle.load(mindrecord_argument_file_handle) | |||
| ######## Frozen code end ########## | |||
| parser = argparse.ArgumentParser(description='Mind record imagenet example') | |||
| parser.add_argument('--label_file', type=str, default="", help='label file') | |||
| parser.add_argument('--image_dir', type=str, default="", help='images directory') | |||
| ######## Frozen code begin ########## | |||
| args = parser.parse_args(ARG_LIST) | |||
| print(args) | |||
| ######## Frozen code end ########## | |||
| def _user_defined_private_func(): | |||
| """ | |||
| Internal function for tasks list | |||
| Return: | |||
| tasks list | |||
| """ | |||
| if not os.path.exists(args.label_file): | |||
| raise IOError("map file {} not exists".format(args.label_file)) | |||
| label_dict = {} | |||
| with open(args.label_file) as file_handle: | |||
| line = file_handle.readline() | |||
| while line: | |||
| labels = line.split(" ") | |||
| label_dict[labels[1]] = labels[0] | |||
| line = file_handle.readline() | |||
| # get all the dir which are n02087046, n02094114, n02109525 | |||
| dir_paths = {} | |||
| for item in label_dict: | |||
| real_path = os.path.join(args.image_dir, label_dict[item]) | |||
| if not os.path.isdir(real_path): | |||
| print("{} dir is not exist".format(real_path)) | |||
| continue | |||
| dir_paths[item] = real_path | |||
| if not dir_paths: | |||
| print("not valid image dir in {}".format(args.image_dir)) | |||
| return {}, {} | |||
| dir_list = [] | |||
| for label in dir_paths: | |||
| dir_list.append(label) | |||
| return dir_list, dir_paths | |||
| dir_list_global, dir_paths_global = _user_defined_private_func() | |||
| def mindrecord_task_number(): | |||
| """ | |||
| Get task size. | |||
| Return: | |||
| number of tasks | |||
| """ | |||
| return len(dir_list_global) | |||
| def mindrecord_dict_data(task_id): | |||
| """ | |||
| Get data dict. | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| # get the filename, label and image binary as a dict | |||
| label = dir_list_global[task_id] | |||
| for item in os.listdir(dir_paths_global[label]): | |||
| file_name = os.path.join(dir_paths_global[label], item) | |||
| if not item.endswith("JPEG") and not item.endswith( | |||
| "jpg") and not item.endswith("jpeg"): | |||
| print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) | |||
| continue | |||
| data = {} | |||
| data["file_name"] = str(file_name) | |||
| data["label"] = int(label) | |||
| # get the image data | |||
| image_file = open(file_name, "rb") | |||
| image_bytes = image_file.read() | |||
| image_file.close() | |||
| data["data"] = image_bytes | |||
| yield data | |||
| @@ -0,0 +1,8 @@ | |||
| #!/bin/bash | |||
| rm /tmp/imagenet/mr/* | |||
| python writer.py --mindrecord_script imagenet \ | |||
| --mindrecord_file "/tmp/imagenet/mr/m" \ | |||
| --mindrecord_partitions 16 \ | |||
| --label_file "/tmp/imagenet/label.txt" \ | |||
| --image_dir "/tmp/imagenet/jpeg" | |||
| @@ -0,0 +1,6 @@ | |||
| #!/bin/bash | |||
| rm /tmp/template/* | |||
| python writer.py --mindrecord_script template \ | |||
| --mindrecord_file "/tmp/template/m" \ | |||
| --mindrecord_partitions 4 | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| User-defined API for MindRecord writer. | |||
| Two API must be implemented, | |||
| 1. mindrecord_task_number() | |||
| # Return number of parallel tasks. return 1 if no parallel | |||
| 2. mindrecord_dict_data(task_id) | |||
| # Yield data for one task | |||
| # task_id is 0..N-1, if N is return value of mindrecord_task_number() | |||
| """ | |||
| import argparse | |||
| import pickle | |||
| # ## Parse argument | |||
| with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line | |||
| ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line | |||
| parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line | |||
| # ## Your arguments below | |||
| # parser.add_argument(...) | |||
| args = parser.parse_args(ARG_LIST) # Do NOT change this line | |||
| print(args) # Do NOT change this line | |||
| # ## Default mindrecord vars. Comment them unless default value has to be changed. | |||
| # mindrecord_index_fields = ['label'] | |||
| # mindrecord_header_size = 1 << 24 | |||
| # mindrecord_page_size = 1 << 25 | |||
| # define global vars here if necessary | |||
| # ####### Your code below ########## | |||
| mindrecord_schema = {"label": {"type": "int32"}} | |||
| def mindrecord_task_number(): | |||
| """ | |||
| Get task size. | |||
| Return: | |||
| number of tasks | |||
| """ | |||
| return 1 | |||
| def mindrecord_dict_data(task_id): | |||
| """ | |||
| Get data dict. | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("task is {}".format(task_id)) | |||
| for i in range(256): | |||
| data = {} | |||
| data['label'] = i | |||
| yield data | |||
| @@ -0,0 +1,152 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| ######################## write mindrecord example ######################## | |||
| Write mindrecord by data dictionary: | |||
| python writer.py --mindrecord_script /YourScriptPath ... | |||
| """ | |||
| import argparse | |||
| import os | |||
| import pickle | |||
| import time | |||
| from importlib import import_module | |||
| from multiprocessing import Pool | |||
| from mindspore.mindrecord import FileWriter | |||
| def _exec_task(task_id, parallel_writer=True): | |||
| """ | |||
| Execute task with specified task id | |||
| """ | |||
| print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) | |||
| imagenet_iter = mindrecord_dict_data(task_id) | |||
| batch_size = 2048 | |||
| transform_count = 0 | |||
| while True: | |||
| data_list = [] | |||
| try: | |||
| for _ in range(batch_size): | |||
| data_list.append(imagenet_iter.__next__()) | |||
| transform_count += 1 | |||
| writer.write_raw_data(data_list, parallel_writer=parallel_writer) | |||
| print("transformed {} record...".format(transform_count)) | |||
| except StopIteration: | |||
| if data_list: | |||
| writer.write_raw_data(data_list, parallel_writer=parallel_writer) | |||
| print("transformed {} record...".format(transform_count)) | |||
| break | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='Mind record writer') | |||
| parser.add_argument('--mindrecord_script', type=str, default="template", | |||
| help='path where script is saved') | |||
| parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", | |||
| help='written file name prefix') | |||
| parser.add_argument('--mindrecord_partitions', type=int, default=1, | |||
| help='number of written files') | |||
| parser.add_argument('--mindrecord_workers', type=int, default=8, | |||
| help='number of parallel workers') | |||
| args = parser.parse_known_args() | |||
| args, other_args = parser.parse_known_args() | |||
| print(args) | |||
| print(other_args) | |||
| with open('mr_argument.pickle', 'wb') as file_handle: | |||
| pickle.dump(other_args, file_handle) | |||
| try: | |||
| mr_api = import_module(args.mindrecord_script + '.mr_api') | |||
| except ModuleNotFoundError: | |||
| raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) | |||
| num_tasks = mr_api.mindrecord_task_number() | |||
| print("Write mindrecord ...") | |||
| mindrecord_dict_data = mr_api.mindrecord_dict_data | |||
| # get number of files | |||
| writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) | |||
| start_time = time.time() | |||
| # set the header size | |||
| try: | |||
| header_size = mr_api.mindrecord_header_size | |||
| writer.set_header_size(header_size) | |||
| except AttributeError: | |||
| print("Default header size: {}".format(1 << 24)) | |||
| # set the page size | |||
| try: | |||
| page_size = mr_api.mindrecord_page_size | |||
| writer.set_page_size(page_size) | |||
| except AttributeError: | |||
| print("Default page size: {}".format(1 << 25)) | |||
| # get schema | |||
| try: | |||
| mindrecord_schema = mr_api.mindrecord_schema | |||
| except AttributeError: | |||
| raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") | |||
| # create the schema | |||
| writer.add_schema(mindrecord_schema, "mindrecord_schema") | |||
| # add the index | |||
| try: | |||
| index_fields = mr_api.mindrecord_index_fields | |||
| writer.add_index(index_fields) | |||
| except AttributeError: | |||
| print("Default index fields: all simple fields are indexes.") | |||
| writer.open_and_set_header() | |||
| task_list = list(range(num_tasks)) | |||
| # set number of workers | |||
| num_workers = args.mindrecord_workers | |||
| if num_tasks < 1: | |||
| num_tasks = 1 | |||
| if num_workers > num_tasks: | |||
| num_workers = num_tasks | |||
| if os.name == 'nt': | |||
| for window_task_id in task_list: | |||
| _exec_task(window_task_id, False) | |||
| elif num_tasks > 1: | |||
| with Pool(num_workers) as p: | |||
| p.map(_exec_task, task_list) | |||
| else: | |||
| _exec_task(0, False) | |||
| ret = writer.commit() | |||
| os.remove("{}".format("mr_argument.pickle")) | |||
| end_time = time.time() | |||
| print("--------------------------------------------") | |||
| print("END. Total time: {}".format(end_time - start_time)) | |||
| print("--------------------------------------------") | |||
| @@ -0,0 +1,63 @@ | |||
| # LeNet Example | |||
| ## Description | |||
| Training LeNet with MNIST dataset in MindSpore. | |||
| This is the simple and basic tutorial for constructing a network in MindSpore. | |||
| ## Requirements | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| - Download the MNIST dataset at <http://yann.lecun.com/exdb/mnist/>. The directory structure is as follows: | |||
| ``` | |||
| └─MNIST_Data | |||
| ├─test | |||
| │ t10k-images.idx3-ubyte | |||
| │ t10k-labels.idx1-ubyte | |||
| │ | |||
| └─train | |||
| train-images.idx3-ubyte | |||
| train-labels.idx1-ubyte | |||
| ``` | |||
| ## Running the example | |||
| ```python | |||
| # train LeNet, hyperparameter setting in config.py | |||
| python train.py --data_path MNIST_Data | |||
| ``` | |||
| You can get loss with each step similar to this: | |||
| ```bash | |||
| epoch: 1 step: 1, loss is 2.3040335 | |||
| ... | |||
| epoch: 1 step: 1739, loss is 0.06952668 | |||
| epoch: 1 step: 1740, loss is 0.05038793 | |||
| epoch: 1 step: 1741, loss is 0.05018193 | |||
| ... | |||
| ``` | |||
| Then, test LeNet according to network model | |||
| ```python | |||
| # test LeNet, after 1 epoch training, the accuracy is up to 96.5% | |||
| python eval.py --data_path MNIST_Data --mode test --ckpt_path checkpoint_lenet-1_1875.ckpt | |||
| ``` | |||
| ## Note | |||
| There are some optional arguments: | |||
| ```bash | |||
| -h, --help show this help message and exit | |||
| --device_target {Ascend,GPU,CPU} | |||
| device where the code will be implemented (default: Ascend) | |||
| --data_path DATA_PATH | |||
| path where the dataset is saved | |||
| --dataset_sink_mode DATASET_SINK_MODE | |||
| dataset_sink_mode is False or True | |||
| ``` | |||
| You can run ```python train.py -h``` or ```python eval.py -h``` to get more information. | |||
| @@ -1 +1 @@ | |||
| Subproject commit 70bb745b459ff9a0e7fc1008d15fe4b510f03da7 | |||
| Subproject commit 43a715bc461fd70b7837051a2f47f0a1b19c5859 | |||
| @@ -26,7 +26,12 @@ from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad | |||
| from .mean import SimpleMean, gpu_schedule_SimpleMean | |||
| from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad | |||
| from .mul import Mul, gpu_schedule_Mul | |||
| from .hsigmoid import Hsigmoid, gpu_schedule_Hsigmoid | |||
| from .hsigmoid_grad import HsigmoidGrad, gpu_schedule_HsigmoidGrad | |||
| from .hswish import Hswish, gpu_schedule_Hswish | |||
| from .hswish_grad import HswishGrad, gpu_schedule_HswishGrad | |||
| from .hsigmoid import HSigmoid, gpu_schedule_HSigmoid | |||
| from .hsigmoid_grad import HSigmoidGrad, gpu_schedule_HSigmoidGrad | |||
| from .hswish import HSwish, gpu_schedule_HSwish | |||
| from .hswish_grad import HSwishGrad, gpu_schedule_HSwishGrad | |||
| from .logical_or import LogicalOr, gpu_schedule_LogicalOr | |||
| from .logical_not import LogicalNot, gpu_schedule_LogicalNot | |||
| from .logical_and import LogicalAnd, gpu_schedule_LogicalAnd | |||
| from .sub import Sub, gpu_schedule_Sub | |||
| from .less_equal import LessEqual, gpu_schedule_LessEqual | |||
| @@ -33,9 +33,9 @@ def topi_nn_hsigmoid(x): | |||
| (x(*i) + 3) / 6))) | |||
| def Hsigmoid(x): | |||
| def HSigmoid(x): | |||
| """ | |||
| Hsigmoid | |||
| HSigmoid | |||
| Args: | |||
| x: | |||
| @@ -45,9 +45,9 @@ def Hsigmoid(x): | |||
| return topi_nn_hsigmoid(x) | |||
| def gpu_schedule_Hsigmoid(outs): | |||
| def gpu_schedule_HSigmoid(outs): | |||
| """ | |||
| gpu schedule Hsigmoid | |||
| gpu schedule HSigmoid | |||
| Args: | |||
| outs: | |||
| @@ -12,14 +12,14 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Hsigmoid grad""" | |||
| """HSigmoid grad""" | |||
| import _akg.topi as topi | |||
| import _akg.tvm as tvm | |||
| def HsigmoidGrad(y_grad, x): | |||
| def HSigmoidGrad(y_grad, x): | |||
| """ | |||
| HsigmoidGrad | |||
| HSigmoidGrad | |||
| Args: | |||
| y_grad: | |||
| x: | |||
| @@ -32,7 +32,7 @@ def HsigmoidGrad(y_grad, x): | |||
| y_grad(*i) / 6))) | |||
| def gpu_schedule_HsigmoidGrad(outs): | |||
| def gpu_schedule_HSigmoidGrad(outs): | |||
| """ | |||
| gpu schedule ReLU6Grad | |||
| Args: | |||
| @@ -33,9 +33,9 @@ def topi_nn_hswish(x): | |||
| x(*i) * (x(*i) + 3) / 6))) | |||
| def Hswish(x): | |||
| def HSwish(x): | |||
| """ | |||
| Hswish | |||
| HSwish | |||
| Args: | |||
| x: | |||
| @@ -45,9 +45,9 @@ def Hswish(x): | |||
| return topi_nn_hswish(x) | |||
| def gpu_schedule_Hswish(outs): | |||
| def gpu_schedule_HSwish(outs): | |||
| """ | |||
| gpu schedule Hswish | |||
| gpu schedule HSwish | |||
| Args: | |||
| outs: | |||
| @@ -12,14 +12,14 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """HswishGrad""" | |||
| """HSwishGrad""" | |||
| import _akg.topi as topi | |||
| import _akg.tvm as tvm | |||
| def HswishGrad(y_grad, x): | |||
| def HSwishGrad(y_grad, x): | |||
| """ | |||
| HswishGrad | |||
| HSwishGrad | |||
| Args: | |||
| y_grad: | |||
| x: | |||
| @@ -34,9 +34,9 @@ def HswishGrad(y_grad, x): | |||
| return res6 | |||
| def gpu_schedule_HswishGrad(outs): | |||
| def gpu_schedule_HSwishGrad(outs): | |||
| """ | |||
| gpu schedule HswishGrad | |||
| gpu schedule HSwishGrad | |||
| Args: | |||
| outs: | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """less_equal""" | |||
| import _akg.tvm | |||
| from _akg.ops.math import less_equal | |||
| from _akg.topi.generic import schedule_elemwise | |||
| def LessEqual(x, y): | |||
| """LessEqual.""" | |||
| return less_equal.less_equal(x, y) | |||
| def gpu_schedule_LessEqual(outs): | |||
| """ | |||
| GPU schedule for LessEqual. | |||
| Args: | |||
| outs (tvm.tensor.Tensor): Outputs of compute. | |||
| Returns: | |||
| sch (schedule.Schedule): The created schedule. | |||
| """ | |||
| device = 'cuda' | |||
| ctx = _akg.tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with _akg.tvm.target.create(device): | |||
| sch = schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """logical_and""" | |||
| import _akg.tvm | |||
| from _akg.ops.math import logical_and | |||
| from _akg.topi.generic import schedule_elemwise | |||
| def LogicalAnd(x, y): | |||
| """LogicalAnd.""" | |||
| return logical_and.logical_and(x, y) | |||
| def gpu_schedule_LogicalAnd(outs): | |||
| """ | |||
| GPU schedule for LogicalAnd. | |||
| Args: | |||
| outs (tvm.tensor.Tensor): outputs of compute. | |||
| Returns: | |||
| sch (schedule.Schedule): The created schedule. | |||
| """ | |||
| device = 'cuda' | |||
| ctx = _akg.tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with _akg.tvm.target.create(device): | |||
| sch = schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """logical_not""" | |||
| import _akg.tvm | |||
| from _akg.ops.math import logical_not | |||
| from _akg.topi.generic import schedule_elemwise | |||
| def LogicalNot(x): | |||
| """LogicalNot.""" | |||
| return logical_not.logical_not(x) | |||
| def gpu_schedule_LogicalNot(outs): | |||
| """ | |||
| GPU schedule for LogicalNot. | |||
| Args: | |||
| outs (tvm.tensor.Tensor): outputs of compute. | |||
| Returns: | |||
| sch (schedule.Schedule): The created schedule. | |||
| """ | |||
| device = 'cuda' | |||
| ctx = _akg.tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with _akg.tvm.target.create(device): | |||
| sch = schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """logical_or""" | |||
| import _akg.tvm | |||
| from _akg.ops.math import logical_or | |||
| from _akg.topi.generic import schedule_elemwise | |||
| def LogicalOr(x, y): | |||
| """LogicalOr.""" | |||
| return logical_or.logical_or(x, y) | |||
| def gpu_schedule_LogicalOr(outs): | |||
| """ | |||
| GPU schedule for LogicalOr. | |||
| Args: | |||
| outs (tvm.tensor.Tensor): outputs of compute. | |||
| Returns: | |||
| sch (schedule.Schedule): The created schedule. | |||
| """ | |||
| device = 'cuda' | |||
| ctx = _akg.tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with _akg.tvm.target.create(device): | |||
| sch = schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """sub""" | |||
| import _akg.tvm | |||
| from _akg.ops.math import sub | |||
| from _akg.topi.generic import schedule_elemwise | |||
| def Sub(x, y): | |||
| """Sub.""" | |||
| return sub.sub(x, y) | |||
| def gpu_schedule_Sub(outs): | |||
| """ | |||
| GPU schedule for Sub. | |||
| Args: | |||
| outs (tvm.tensor.Tensor): outputs of compute. | |||
| Returns: | |||
| sch (schedule.Schedule): The created schedule. | |||
| """ | |||
| device = 'cuda' | |||
| ctx = _akg.tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with _akg.tvm.target.create(device): | |||
| sch = schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """operator dsl function: lessequal""" | |||
| import _akg.tvm | |||
| import _akg.topi | |||
| from _akg.utils.dsl_create import produce_shapes | |||
| from _akg.utils import validation_check as vc_util | |||
| @vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) | |||
| def less_equal(input1, input2): | |||
| """ | |||
| Check whether input1 lessequals to input2. | |||
| Args: | |||
| input1 (tvm.tensor.Tensor): Tensor. | |||
| input2 (tvm.tensor.Tensor): Tensor. | |||
| Returns: | |||
| tvm.tensor.Tensor. If input1 lessequal to input2 return True, else return False. | |||
| """ | |||
| shape1 = [x.value for x in input1.shape] | |||
| shape2 = [x.value for x in input2.shape] | |||
| vc_util.check_shape(shape1) | |||
| vc_util.check_shape(shape2) | |||
| shape1, shape2, shape = produce_shapes(shape1, shape2) | |||
| vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) | |||
| dtype = input1.dtype | |||
| # get lessequal compute | |||
| t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") | |||
| f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") | |||
| input1_bro = _akg.topi.broadcast_to(input1, shape) | |||
| input2_bro = _akg.topi.broadcast_to(input2, shape) | |||
| c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] <= input2_bro[indice], | |||
| t_value[indice], f_value[indice]), name="C") | |||
| res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") | |||
| return res | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """operator dsl function: logical_and""" | |||
| import _akg.tvm | |||
| import _akg.topi | |||
| from _akg.utils import validation_check as vc_util | |||
| @vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) | |||
| def logical_and(input1, input2): | |||
| """ | |||
| Compute logical_and of input1 and input2. | |||
| Args: | |||
| input1 (tvm.tensor.Tensor): Tensor. | |||
| input2 (tvm.tensor.Tensor): Tensor. | |||
| Returns: | |||
| tvm.tensor.Tensor. LogicalAnd of input1 and input2. | |||
| """ | |||
| vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) | |||
| shape1 = [x.value for x in input1.shape] | |||
| shape2 = [x.value for x in input2.shape] | |||
| vc_util.check_shape(shape1) | |||
| vc_util.check_shape(shape2) | |||
| res = _akg.topi.logical_and(input1, input2) | |||
| return res | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """operator dsl function: logical_not""" | |||
| import _akg.tvm | |||
| import _akg.topi | |||
| from _akg.utils import validation_check as vc_util | |||
| @vc_util.check_input_type(_akg.tvm.tensor.Tensor) | |||
| def logical_not(input1): | |||
| """ | |||
| Compute logical_not of input1. | |||
| Args: | |||
| input1 (tvm.tensor.Tensor): Tensor. | |||
| Returns: | |||
| tvm.tensor.Tensor. | |||
| """ | |||
| res = _akg.topi.logical_not(input1) | |||
| return res | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """operator dsl function: logical_or""" | |||
| import _akg.tvm | |||
| import _akg.topi | |||
| from _akg.utils import validation_check as vc_util | |||
| @vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) | |||
| def logical_or(input1, input2): | |||
| """ | |||
| Compute logical_or of input1 and input2. | |||
| Args: | |||
| input1 (tvm.tensor.Tensor): Tensor. | |||
| input2 (tvm.tensor.Tensor): Tensor. | |||
| Returns: | |||
| tvm.tensor.Tensor. LogicalOr of input1 and input2. | |||
| """ | |||
| vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) | |||
| shape1 = [x.value for x in input1.shape] | |||
| shape2 = [x.value for x in input2.shape] | |||
| vc_util.check_shape(shape1) | |||
| vc_util.check_shape(shape2) | |||
| res = _akg.topi.logical_or(input1, input2) | |||
| return res | |||
| @@ -14,10 +14,12 @@ | |||
| # ============================================================================ | |||
| """Check parameters.""" | |||
| import re | |||
| import inspect | |||
| import math | |||
| from enum import Enum | |||
| from functools import reduce | |||
| from functools import reduce, wraps | |||
| from itertools import repeat | |||
| from collections import Iterable | |||
| from collections.abc import Iterable | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| @@ -98,7 +100,7 @@ class Validator: | |||
| """validator for checking input parameters""" | |||
| @staticmethod | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): | |||
| """ | |||
| Method for judging relation between two int values or list/tuple made up of ints. | |||
| @@ -108,18 +110,29 @@ class Validator: | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||
| msg_prefix = f'For {prim_name} the' if prim_name else "The" | |||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||
| @staticmethod | |||
| def check_integer(arg_name, arg_value, value, rel, prim_name): | |||
| """Integer value judgment.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| excp_cls = TypeError if type_mismatch else ValueError | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' | |||
| f' but got {arg_value}.') | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' | |||
| f' with type `{type(arg_value).__name__}`.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_number(arg_name, arg_value, value, rel, prim_name): | |||
| """Integer value judgment.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| @@ -127,15 +140,53 @@ class Validator: | |||
| """Method for checking whether an int value is in some range.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) | |||
| excp_cls = TypeError if type_mismatch else ValueError | |||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' | |||
| f' but got {arg_value}.') | |||
| raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' | |||
| f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||
| """Method for checking whether a numeric value is in some range.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_string(arg_name, arg_value, valid_values, prim_name): | |||
| """Checks whether a string is in some value list""" | |||
| if isinstance(arg_value, str) and arg_value in valid_values: | |||
| return arg_value | |||
| if len(valid_values) == 1: | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' | |||
| f' but got {arg_value}.') | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' | |||
| f' but got {arg_value}.') | |||
| @staticmethod | |||
| def check_pad_value_by_mode(pad_mode, padding, prim_name): | |||
| """Validates value of padding according to pad_mode""" | |||
| if pad_mode != 'pad' and padding != 0: | |||
| raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||
| return padding | |||
| @staticmethod | |||
| def check_float_positive(arg_name, arg_value, prim_name): | |||
| """Float type judgment.""" | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| if isinstance(arg_value, float): | |||
| if arg_value > 0: | |||
| return arg_value | |||
| raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") | |||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_type, prim_name): | |||
| """Check whether some type is sublcass of another type""" | |||
| """Checks whether some type is subclass of another type""" | |||
| if not isinstance(template_type, Iterable): | |||
| template_type = (template_type,) | |||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||
| @@ -144,32 +195,51 @@ class Validator: | |||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||
| @staticmethod | |||
| def check_tensor_type_same(args, valid_values, prim_name): | |||
| """check whether the element types of input tensors are the same.""" | |||
| def check_const_input(arg_name, arg_value, prim_name): | |||
| """Checks valid value.""" | |||
| if arg_value is None: | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | |||
| @staticmethod | |||
| def check_type_same(args, valid_values, prim_name): | |||
| """Checks whether the types of inputs are the same.""" | |||
| def _check_tensor_type(arg): | |||
| arg_key, arg_val = arg | |||
| Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) | |||
| elem_type = arg_val.element_type() | |||
| elem_type = arg_val | |||
| if not elem_type in valid_values: | |||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' | |||
| f' but `{arg_key}` is {elem_type}.') | |||
| type_names = [] | |||
| for t in valid_values: | |||
| type_names.append(str(t)) | |||
| types_info = '[' + ", ".join(type_names) + ']' | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' | |||
| f' but got {elem_type}.') | |||
| return (arg_key, elem_type) | |||
| def _check_types_same(arg1, arg2): | |||
| arg1_name, arg1_type = arg1 | |||
| arg2_name, arg2_type = arg2 | |||
| if arg1_type != arg2_type: | |||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') | |||
| return arg1 | |||
| elem_types = map(_check_tensor_type, args.items()) | |||
| reduce(_check_types_same, elem_types) | |||
| @staticmethod | |||
| def check_tensor_type_same(args, valid_values, prim_name): | |||
| """Checks whether the element types of input tensors are the same.""" | |||
| tensor_types = [mstype.tensor_type(t) for t in valid_values] | |||
| Validator.check_type_same(args, tensor_types, prim_name) | |||
| @staticmethod | |||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name): | |||
| """check whether the types of inputs are the same. if the input args are tensors, check their element types""" | |||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | |||
| """ | |||
| Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | |||
| If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | |||
| """ | |||
| def _check_argument_type(arg): | |||
| arg_key, arg_val = arg | |||
| if isinstance(arg_val, type(mstype.tensor)): | |||
| @@ -182,16 +252,19 @@ class Validator: | |||
| def _check_types_same(arg1, arg2): | |||
| arg1_name, arg1_type = arg1 | |||
| arg2_name, arg2_type = arg2 | |||
| excp_flag = False | |||
| except_flag = False | |||
| if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): | |||
| arg1_type = arg1_type.element_type() | |||
| arg2_type = arg2_type.element_type() | |||
| elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): | |||
| pass | |||
| elif allow_mix: | |||
| arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type | |||
| arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type | |||
| else: | |||
| excp_flag = True | |||
| except_flag = True | |||
| if excp_flag or arg1_type != arg2_type: | |||
| if except_flag or arg1_type != arg2_type: | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||
| return arg1 | |||
| @@ -199,13 +272,15 @@ class Validator: | |||
| @staticmethod | |||
| def check_value_type(arg_name, arg_value, valid_types, prim_name): | |||
| """Check whether a values is instance of some types.""" | |||
| """Checks whether a value is instance of some types.""" | |||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| type_names = [t.__name__ for t in valid_types] | |||
| num_types = len(valid_types) | |||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' | |||
| f'{"one of " if num_types > 1 else ""}' | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||
| # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | |||
| @@ -216,6 +291,34 @@ class Validator: | |||
| return arg_value | |||
| raise_error_msg() | |||
| @staticmethod | |||
| def check_type_name(arg_name, arg_type, valid_types, prim_name): | |||
| """Checks whether a type in some specified types""" | |||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | |||
| def get_typename(t): | |||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||
| if arg_type in valid_types: | |||
| return arg_type | |||
| type_names = [get_typename(t) for t in valid_types] | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||
| if len(valid_types) == 1: | |||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| @staticmethod | |||
| def check_float_legal_value(arg_name, arg_value, prim_name): | |||
| """Checks whether a legal value of float type""" | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| if isinstance(arg_value, float): | |||
| if math.isinf(arg_value) or math.isnan(arg_value): | |||
| raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.") | |||
| return arg_value | |||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | |||
| class ParamValidator: | |||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||
| @@ -268,9 +371,9 @@ class ParamValidator: | |||
| @staticmethod | |||
| def check_isinstance(arg_name, arg_value, classes): | |||
| """Check arg isintance of classes""" | |||
| """Check arg isinstance of classes""" | |||
| if not isinstance(arg_value, classes): | |||
| raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.') | |||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| @@ -284,7 +387,7 @@ class ParamValidator: | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_type, with_type_of=True): | |||
| """Check whether some type is sublcass of another type""" | |||
| """Check whether some type is subclass of another type""" | |||
| if not isinstance(template_type, Iterable): | |||
| template_type = (template_type,) | |||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||
| @@ -302,9 +405,9 @@ class ParamValidator: | |||
| @staticmethod | |||
| def check_bool(arg_name, arg_value): | |||
| """Check arg isintance of bool""" | |||
| """Check arg isinstance of bool""" | |||
| if not isinstance(arg_value, bool): | |||
| raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.') | |||
| raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| @@ -671,3 +774,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII): | |||
| if re.match(reg, target, flag) is None: | |||
| raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) | |||
| return True | |||
| def args_type_check(*type_args, **type_kwargs): | |||
| """Check whether input data type is correct.""" | |||
| def type_check(func): | |||
| sig = inspect.signature(func) | |||
| bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| nonlocal bound_types | |||
| bound_values = sig.bind(*args, **kwargs) | |||
| argument_dict = bound_values.arguments | |||
| if "kwargs" in bound_types: | |||
| bound_types = bound_types["kwargs"] | |||
| if "kwargs" in argument_dict: | |||
| argument_dict = argument_dict["kwargs"] | |||
| for name, value in argument_dict.items(): | |||
| if name in bound_types: | |||
| if value is not None and not isinstance(value, bound_types[name]): | |||
| raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||
| return type_check | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Extension functions. | |||
| Extension functions. | |||
| Python functions that will be called in the c++ parts of MindSpore. | |||
| """ | |||
| @@ -83,6 +83,7 @@ convert_object_map = { | |||
| T.mul: multitype_ops.mul, | |||
| T.truediv: multitype_ops.div, | |||
| T.getitem: multitype_ops.getitem, | |||
| T.setitem: multitype_ops.setitem, | |||
| T.floordiv: multitype_ops.floordiv, | |||
| T.mod: multitype_ops.mod, | |||
| T.pow: multitype_ops.pow_, | |||
| @@ -113,12 +114,12 @@ convert_object_map = { | |||
| T.map: C.HyperMap(), | |||
| T.partial: F.partial, | |||
| T.zip: C.zip_operation, | |||
| T.print: F.print_, | |||
| # custom define operation | |||
| T.iter: M.ms_iter, | |||
| T.next: M.ms_next, | |||
| T.hasnext: M.hasnext, | |||
| T.setitem: M.setitem, | |||
| T.make_tuple: F.make_tuple, | |||
| T.make_dict: F.make_dict, | |||
| @@ -27,7 +27,7 @@ from operator import ( # noqa | |||
| # support system function call | |||
| from builtins import ( # noqa | |||
| bool, getattr, setattr, len, iter, next, pow, range, map, zip | |||
| bool, getattr, setattr, len, iter, next, pow, range, map, zip, print | |||
| ) | |||
| # support functools | |||
| @@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', | |||
| 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', | |||
| 'matmul', 'getitem', 'setitem', | |||
| 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', | |||
| 'partial', | |||
| 'partial', 'print', | |||
| 'exp', 'log', 'sin', 'cos', 'tan'] | |||
| @@ -1,44 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Pynative mode help module.""" | |||
| from inspect import signature | |||
| from functools import wraps | |||
| def args_type_check(*type_args, **type_kwargs): | |||
| """Check whether input data type is correct.""" | |||
| def type_check(func): | |||
| sig = signature(func) | |||
| bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| nonlocal bound_types | |||
| bound_values = sig.bind(*args, **kwargs) | |||
| argument_dict = bound_values.arguments | |||
| if "kwargs" in bound_types: | |||
| bound_types = bound_types["kwargs"] | |||
| if "kwargs" in argument_dict: | |||
| argument_dict = argument_dict["kwargs"] | |||
| for name, value in argument_dict.items(): | |||
| if name in bound_types: | |||
| if value is not None and not isinstance(value, bound_types[name]): | |||
| raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||
| return type_check | |||
| @@ -394,10 +394,6 @@ if(USE_GLOG) | |||
| target_link_libraries(_c_expression PRIVATE mindspore::glog) | |||
| endif() | |||
| if(ENABLE_GPU) | |||
| target_link_libraries(_c_expression PRIVATE mindspore::tvm) | |||
| endif() | |||
| if(ENABLE_DUMP_PROTO) | |||
| target_link_libraries(_c_expression PRIVATE mindspore::protobuf) | |||
| endif() | |||
| @@ -103,17 +103,39 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{ | |||
| template <typename SrcT, typename DstT> | |||
| void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { | |||
| auto src_id = TypeIdSize(args.src_type); | |||
| auto dst_id = TypeIdSize(args.dst_type); | |||
| if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { | |||
| MS_LOG(EXCEPTION) << "Invalid src or dst data size."; | |||
| } | |||
| for (size_t idx = 0; idx != data_size; idx++) { | |||
| SrcT src_data = static_cast<const SrcT *>(args.data)[idx]; | |||
| static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data); | |||
| } | |||
| } | |||
| template <typename SrcT> | |||
| void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) { | |||
| auto src_id = TypeIdSize(args.src_type); | |||
| auto dst_id = TypeIdSize(args.dst_type); | |||
| if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { | |||
| MS_LOG(EXCEPTION) << "Invalid src or dst data size."; | |||
| } | |||
| auto src_data = static_cast<const SrcT *>(args.data); | |||
| auto half_data = static_cast<Eigen::half *>(dst); | |||
| for (size_t i = 0; i < data_size; i++) { | |||
| half_data[i] = Eigen::half(src_data[i]); | |||
| } | |||
| } | |||
| bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) { | |||
| switch (mode) { | |||
| case FROM_FLOAT_TO_FLOAT16: | |||
| device::FloatToHalf(dst, args.data, data_size); | |||
| break; | |||
| case FROM_INT32_TO_FLOAT16: | |||
| TransDataSrc2Fp16<int32_t>(args, dst, data_size); | |||
| break; | |||
| case FROM_FLOAT16_TO_FLOAT: | |||
| device::HalfToFloat(dst, args.data, data_size); | |||
| break; | |||
| @@ -372,27 +394,27 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { | |||
| } | |||
| bool TransDataType(const TypeIdArgs &args, void *result) { | |||
| MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " | |||
| << TypeIdLabel(args.device_data_type); | |||
| MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type); | |||
| MS_EXCEPTION_IF_NULL(result); | |||
| std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type); | |||
| std::pair<TypeId, TypeId> type_info(args.src_type, args.dst_type); | |||
| auto iter = mode_map.find(type_info); | |||
| if (iter == mode_map.end()) { | |||
| MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type) | |||
| << ", dst_type:" << TypeIdLabel(args.device_data_type); | |||
| MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type) | |||
| << ", dst_type:" << TypeIdLabel(args.dst_type); | |||
| return false; | |||
| } | |||
| auto trans_mode = iter->second; | |||
| auto type_size = TypeIdSize(args.device_data_type); | |||
| if (type_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid host data type."; | |||
| auto src_id = TypeIdSize(args.src_type); | |||
| auto dst_id = TypeIdSize(args.dst_type); | |||
| if (src_id < 1 || dst_id < 1) { | |||
| MS_LOG(ERROR) << "Invalid src or dst data type."; | |||
| return false; | |||
| } | |||
| if (args.host_shape_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid host data size."; | |||
| if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { | |||
| MS_LOG(ERROR) << "Invalid src or dst data size."; | |||
| return false; | |||
| } | |||
| if (!CastKernel(args, result, args.host_shape_size, trans_mode)) { | |||
| if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) { | |||
| MS_LOG(ERROR) << "Failed to trans datatype.."; | |||
| return false; | |||
| } | |||
| @@ -31,9 +31,12 @@ namespace mindspore { | |||
| namespace trans { | |||
| struct TypeIdArgs { | |||
| const void *data; | |||
| size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d | |||
| TypeId host_data_type; | |||
| TypeId device_data_type; | |||
| size_t src_size; | |||
| size_t dst_size; | |||
| TypeId src_type; | |||
| TypeId dst_type; | |||
| size_t src_shape_size; | |||
| size_t dst_shape_size; | |||
| }; | |||
| struct FormatArgs { | |||
| @@ -23,7 +23,7 @@ namespace common { | |||
| const int CACHED_STR_NUM = 1 << 8; | |||
| const int CACHED_STR_MASK = CACHED_STR_NUM - 1; | |||
| std::vector<std::string> STR_HOLDER(CACHED_STR_NUM); | |||
| const char* SafeCStr(const std::string&& str) { | |||
| const char *SafeCStr(const std::string &&str) { | |||
| static std::atomic<uint32_t> index{0}; | |||
| uint32_t cur_index = index++; | |||
| cur_index = cur_index & CACHED_STR_MASK; | |||
| @@ -21,16 +21,16 @@ | |||
| #include <string> | |||
| #define DISABLE_COPY_AND_ASSIGN(ClassType) \ | |||
| ClassType(const ClassType&) = delete; \ | |||
| ClassType& operator=(const ClassType&) = delete; | |||
| ClassType(const ClassType &) = delete; \ | |||
| ClassType &operator=(const ClassType &) = delete; | |||
| namespace mindspore { | |||
| namespace common { | |||
| inline const char* SafeCStr(const std::string& str) { return str.c_str(); } | |||
| const char* SafeCStr(const std::string&& str); | |||
| inline const char *SafeCStr(const std::string &str) { return str.c_str(); } | |||
| const char *SafeCStr(const std::string &&str); | |||
| static inline std::string GetEnv(const std::string& envvar) { | |||
| const char* value = ::getenv(envvar.c_str()); | |||
| static inline std::string GetEnv(const std::string &envvar) { | |||
| const char *value = ::getenv(envvar.c_str()); | |||
| if (value == nullptr) { | |||
| return std::string(); | |||
| @@ -12,10 +12,10 @@ endif() | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") | |||
| if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--image-base -Wl,0x10000000") | |||
| endif() | |||
| ############################# Options ################################ | |||
| if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") | |||
| add_definitions(-D _CRT_RAND_S) | |||
| endif () | |||
| if (ENABLE_GPUQUE) | |||
| add_definitions(-D ENABLE_GPUQUE) | |||
| message(STATUS "GPU queue is enabled") | |||
| @@ -28,10 +28,11 @@ | |||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "mindrecord/include/shard_category.h" | |||
| #include "mindrecord/include/shard_sample.h" | |||
| #include "mindrecord/include/shard_shuffle.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -45,7 +46,9 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kShuffle, &DEPipeline::ParseShuffleOp}, | |||
| {kMindrecord, &DEPipeline::ParseMindRecordOp}, | |||
| {kMap, &DEPipeline::ParseMapOp}, | |||
| {kFilter, &DEPipeline::ParseFilterOp}, | |||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||
| {kBarrier, &DEPipeline::ParseBarrierOp}, | |||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||
| {kZip, &DEPipeline::ParseZipOp}, | |||
| @@ -61,7 +64,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kVoc, &DEPipeline::ParseVOCOp}, | |||
| {kCifar10, &DEPipeline::ParseCifar10Op}, | |||
| {kCifar100, &DEPipeline::ParseCifar100Op}, | |||
| {kCelebA, &DEPipeline::ParseCelebAOp}}; | |||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | |||
| {kTextFile, &DEPipeline::ParseTextFileOp}}; | |||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | |||
| try { | |||
| @@ -501,6 +505,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<FilterOp::Builder> builder = std::make_shared<FilterOp::Builder>(); | |||
| if (args["predicate"].is_none()) { | |||
| RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); | |||
| } | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "predicate") { | |||
| py::handle op = args["predicate"]; | |||
| if (!py::isinstance<py::function>(op)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); | |||
| } | |||
| py::function predicate_func = op.cast<py::function>(); | |||
| (void)builder->SetPredicateFunc(std::move(predicate_func)); | |||
| } else if (key == "input_columns") { | |||
| std::vector<std::string> in_col_names = ToStringVector(args["input_columns"]); | |||
| (void)builder->SetInColNames(in_col_names); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<FilterOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| if (args["count"].is_none()) { | |||
| std::string err_msg = "Error: count is invalid or not set."; | |||
| @@ -589,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>(); | |||
| // Right now barrier should only take num_rows_per_buffer = 1 | |||
| // The reason for this is because having it otherwise can lead to blocking issues | |||
| // See barrier_op.h for more details | |||
| (void)builder->SetRowsPerBuffer(1); | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "condition_name") { | |||
| (void)builder->SetConditionName(ToString(value)); | |||
| } else if (key == "condition_func") { | |||
| (void)builder->SetConditionFunc(value.cast<py::function>()); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<BarrierOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| int32_t prefetch_size = 0; | |||
| if (args.contains("prefetch_size")) { | |||
| @@ -670,8 +733,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| return Status::OK(); | |||
| } | |||
| DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); } | |||
| Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| // Required arguments | |||
| std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | |||
| @@ -985,5 +1046,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| // Required arguments | |||
| std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| (void)builder->SetTextFilesList(ToStringVector(args["dataset_files"])); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); | |||
| } | |||
| // Optional arguments | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_shards") { | |||
| (void)builder->SetNumDevices(ToInt(value)); | |||
| } else if (key == "shard_id") { | |||
| (void)builder->SetDeviceId(ToInt(value)); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<TextFileOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -40,6 +40,7 @@ enum OpName { | |||
| kShuffle, | |||
| kMindrecord, | |||
| kBatch, | |||
| kBarrier, | |||
| kCache, | |||
| kRepeat, | |||
| kSkip, | |||
| @@ -58,7 +59,8 @@ enum OpName { | |||
| kVoc, | |||
| kCifar10, | |||
| kCifar100, | |||
| kCelebA | |||
| kCelebA, | |||
| kTextFile | |||
| }; | |||
| // The C++ binder class that we expose to the python script. | |||
| @@ -106,12 +108,16 @@ class DEPipeline { | |||
| Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -120,8 +126,6 @@ class DEPipeline { | |||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| DsOpPtr ParseFilterOp(const py::dict &args) const; | |||
| Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -148,6 +152,8 @@ class DEPipeline { | |||
| Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| private: | |||
| // Execution tree that links the dataset operators. | |||
| std::shared_ptr<ExecutionTree> tree_; | |||
| @@ -24,7 +24,6 @@ | |||
| #endif | |||
| #include "dataset/kernels/image/cut_out_op.h" | |||
| #include "dataset/kernels/image/decode_op.h" | |||
| #include "dataset/kernels/image/distort_bounding_box_crop_op.h" | |||
| #include "dataset/kernels/image/hwc_to_chw_op.h" | |||
| #include "dataset/kernels/image/image_utils.h" | |||
| #include "dataset/kernels/image/normalize_op.h" | |||
| @@ -40,6 +39,7 @@ | |||
| #include "dataset/kernels/image/rescale_op.h" | |||
| #include "dataset/kernels/image/resize_bilinear_op.h" | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| #include "dataset/kernels/image/uniform_aug_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| @@ -53,11 +53,14 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/python_sampler.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/kernels/data/to_float16_op.h" | |||
| #include "dataset/util/random.h" | |||
| #include "mindrecord/include/shard_operator.h" | |||
| #include "mindrecord/include/shard_pk_sample.h" | |||
| #include "mindrecord/include/shard_sample.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| @@ -150,9 +153,14 @@ void bindDatasetOps(py::module *m) { | |||
| }); | |||
| (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp") | |||
| .def_static("get_num_rows", [](const std::string &path) { | |||
| .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { | |||
| int64_t count = 0; | |||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count)); | |||
| std::shared_ptr<mindrecord::ShardOperator> op; | |||
| if (py::hasattr(sampler, "_create_for_minddataset")) { | |||
| auto create = sampler.attr("_create_for_minddataset"); | |||
| op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||
| } | |||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); | |||
| return count; | |||
| }); | |||
| @@ -176,6 +184,17 @@ void bindDatasetOps(py::module *m) { | |||
| THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count)); | |||
| return count; | |||
| }); | |||
| (void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp") | |||
| .def_static("get_num_rows", [](const py::list &files) { | |||
| int64_t count = 0; | |||
| std::vector<std::string> filenames; | |||
| for (auto file : files) { | |||
| !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); | |||
| } | |||
| THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); | |||
| return count; | |||
| }); | |||
| } | |||
| void bindTensor(py::module *m) { | |||
| (void)py::class_<GlobalContext>(*m, "GlobalContext") | |||
| @@ -251,6 +270,10 @@ void bindTensorOps1(py::module *m) { | |||
| .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"), | |||
| py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); | |||
| (void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>( | |||
| *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") | |||
| .def(py::init<py::list, int32_t>(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps); | |||
| (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>( | |||
| *m, "ResizeBilinearOp", | |||
| "Tensor operation to resize an image using " | |||
| @@ -345,18 +368,6 @@ void bindTensorOps3(py::module *m) { | |||
| } | |||
| void bindTensorOps4(py::module *m) { | |||
| (void)py::class_<DistortBoundingBoxCropOp, TensorOp, std::shared_ptr<DistortBoundingBoxCropOp>>( | |||
| *m, "DistortBoundingBoxCropOp", | |||
| "Tensor operator to crop an image randomly as long as the cropped image has sufficient " | |||
| "overlap with any one bounding box associated with original image" | |||
| "Takes aspect ratio of the generated crop box, the intersection ratio of crop box and bounding box," | |||
| "crop ratio lower and upper bounds" | |||
| "Optional parameters: number of attempts for crop, number of attempts of crop box generation") | |||
| .def(py::init<float, float, float, float, int32_t, int32_t>(), py::arg("aspect_ratio"), py::arg("intersect_ratio"), | |||
| py::arg("crop_ratio_lower_bound"), py::arg("crop_ratio_upper_bound"), | |||
| py::arg("max_attempts") = DistortBoundingBoxCropOp::kDefMaxAttempts, | |||
| py::arg("box_gen_attempts") = DistortBoundingBoxCropOp::kDefBoxGenAttempts); | |||
| (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>( | |||
| *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") | |||
| .def(py::init<DataType>(), py::arg("data_type")) | |||
| @@ -415,16 +426,30 @@ void bindSamplerOps(py::module *m) { | |||
| (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler") | |||
| .def(py::init<>()); | |||
| (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler") | |||
| .def(py::init<std::vector<int64_t>>(), py::arg("indices")); | |||
| (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>( | |||
| *m, "MindrecordSubsetRandomSampler") | |||
| .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); | |||
| (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( | |||
| *m, "MindrecordPkSampler") | |||
| .def(py::init([](int64_t kVal, bool shuffle) { | |||
| if (shuffle == true) { | |||
| return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(), | |||
| GetSeed()); | |||
| } else { | |||
| return std::make_shared<mindrecord::ShardPkSample>("label", kVal); | |||
| } | |||
| })); | |||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | |||
| .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), | |||
| py::arg("replacement")); | |||
| (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler") | |||
| .def(py::init<py::object>(), py::arg("pySampler")); | |||
| } | |||
| void bindInfoObjects(py::module *m) { | |||
| @@ -443,6 +468,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("STORAGE", OpName::kStorage) | |||
| .value("SHUFFLE", OpName::kShuffle) | |||
| .value("BATCH", OpName::kBatch) | |||
| .value("BARRIER", OpName::kBarrier) | |||
| .value("MINDRECORD", OpName::kMindrecord) | |||
| .value("CACHE", OpName::kCache) | |||
| .value("REPEAT", OpName::kRepeat) | |||
| @@ -463,7 +489,8 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("VOC", OpName::kVoc) | |||
| .value("CIFAR10", OpName::kCifar10) | |||
| .value("CIFAR100", OpName::kCifar100) | |||
| .value("CELEBA", OpName::kCelebA); | |||
| .value("CELEBA", OpName::kCelebA) | |||
| .value("TEXTFILE", OpName::kTextFile); | |||
| (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic()) | |||
| .value("DE_INTER_LINEAR", InterpolationMode::kLinear) | |||
| @@ -25,12 +25,14 @@ | |||
| #include "dataset/core/tensor_shape.h" | |||
| #include "dataset/engine/data_schema.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/barrier_op.h" | |||
| #include "dataset/engine/datasetops/batch_op.h" | |||
| #include "dataset/engine/datasetops/dataset_op.h" | |||
| #include "dataset/engine/datasetops/device_queue_op.h" | |||
| #include "dataset/engine/datasetops/map_op.h" | |||
| #include "dataset/engine/datasetops/project_op.h" | |||
| #include "dataset/engine/datasetops/rename_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| @@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c | |||
| DS_ASSERT(data_); | |||
| switch (type_.value()) { | |||
| CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t); | |||
| CASE_PRINT_HEX(DataType::DE_BOOL, bool); | |||
| CASE_PRINT_HEX(DataType::DE_INT8, int8_t); | |||
| @@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT | |||
| dataset_op.cc | |||
| parallel_op.cc | |||
| pipeline_op.cc | |||
| barrier_op.cc | |||
| batch_op.cc | |||
| device_queue_op.cc | |||
| map_op.cc | |||
| @@ -14,5 +15,6 @@ add_library(engine-datasetops OBJECT | |||
| take_op.cc | |||
| shuffle_op.cc | |||
| zip_op.cc | |||
| filter_op.cc | |||
| ) | |||
| @@ -0,0 +1,235 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/barrier_op.h" | |||
| #include <utility> | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| BarrierOp::Builder::Builder() { | |||
| // Some arguments to the BarrierOp constructor have a default argument that is taken | |||
| // from the client config. | |||
| // The user may choose to change these values for the construction of the BarrierOp by | |||
| // using the various builder set methods. | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } | |||
| Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, | |||
| builder_condition_func_); | |||
| return Status::OK(); | |||
| } | |||
| // Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions | |||
| BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, | |||
| py::function condition_func) | |||
| : PipelineOp(op_connector_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| buffer_id_(0), | |||
| clean_up_(false), | |||
| eof_(false), | |||
| condition_name_(condition_name), | |||
| condition_function_(condition_func) {} | |||
| // destructor | |||
| BarrierOp::~BarrierOp() {} | |||
| // Entry point for Barrier, called by launch() | |||
| Status BarrierOp::operator()() { | |||
| // The children_num_ parameter needs to be put here | |||
| // Synchronize with TaskManager once the thread is created. | |||
| TaskManager::FindMe()->Post(); | |||
| // create child iterator, right now this barrier is a pipeline operator | |||
| int32_t worker_id = 0; | |||
| int32_t child_idx = 0; | |||
| child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx); | |||
| // Loop until eof is true | |||
| while (!eof_) { | |||
| // Create new table to put the new tensor rows | |||
| std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>(); | |||
| RETURN_IF_NOT_OK(prepare(curr_table.get())); | |||
| // If an eof got picked up during the above prepare, then we're done | |||
| if (eof_) { | |||
| break; | |||
| } | |||
| // we have to output new buffer with possibly different buffer size, possibly one row | |||
| while (!clean_up_) { | |||
| // 1. If a previous loop iteration sent the current table out, then create a new one. | |||
| if (curr_table == nullptr) { | |||
| curr_table = std::make_unique<TensorQTable>(); | |||
| } | |||
| // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished | |||
| RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); | |||
| // 3 create and update buffer and send it to the out connector | |||
| if (!curr_table->empty()) { | |||
| std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone); | |||
| curr_buffer->set_tensor_table(std::move(curr_table)); | |||
| curr_buffer->set_column_name_map(col_name_id_map_); | |||
| MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " | |||
| << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << "."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); | |||
| buffer_id_++; | |||
| } | |||
| } | |||
| // 4 handle drain state. | |||
| if (clean_up_) { | |||
| MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; | |||
| // Send the eoe up. | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||
| } | |||
| } | |||
| // 5 handle eof | |||
| // propagate eof here. | |||
| MS_LOG(INFO) << "Barrier operator got EOF, propagating."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||
| return Status::OK(); | |||
| } | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| Status BarrierOp::prepare(TensorQTable *const table) { | |||
| MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; | |||
| clean_up_ = false; | |||
| buffer_id_ = 0; | |||
| if (table == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); | |||
| } | |||
| // fill initial row | |||
| TensorRow new_row = {}; | |||
| // use iterator to get next row and invoke pyfunc wait | |||
| RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); | |||
| // If the first row fetching resulted in eof, then we are done. | |||
| if (eof_) { | |||
| return Status::OK(); | |||
| } | |||
| if (new_row.empty()) { | |||
| // This epoch is empty | |||
| return Status::OK(); | |||
| } | |||
| // Pack this first row into our tensor table | |||
| // first row we also have to check if we should block | |||
| RETURN_IF_NOT_OK(blockCond()); | |||
| table->push_back(std::move(new_row)); | |||
| // At this point we have 1 row produced, we take the old column map id and use it in the new table | |||
| // Initializing col_name_id_map_ from the first data buffer. | |||
| col_name_id_map_ = child_iterator_->col_name_id_map(); | |||
| // the update code below shouldn't do anything bad if the column name already exists. | |||
| return Status::OK(); | |||
| } | |||
| // fillBuffer always expects a new table to fill | |||
| Status BarrierOp::fillBuffer(TensorQTable *const table) { | |||
| if (table == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); | |||
| } | |||
| TensorRow new_row = {}; | |||
| while (table->size() < static_cast<size_t>(rows_per_buffer_)) { | |||
| RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); | |||
| // Early exit the loop if we got empty row from any of our child iterations | |||
| if (new_row.empty()) { | |||
| return Status::OK(); | |||
| } | |||
| // else we got a row so pack it into the tensor table. | |||
| RETURN_IF_NOT_OK(blockCond()); | |||
| table->push_back(std::move(new_row)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // function executes a py_func and blocks until condition becomes true. | |||
| Status BarrierOp::blockCond() { | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| // we have condition name, however the flexibility is in python today | |||
| try { | |||
| // Invoke python function | |||
| py::object ret_py_obj = condition_function_(); | |||
| // Process the return value | |||
| if (!py::isinstance<py::bool_>(ret_py_obj)) { | |||
| return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); | |||
| } | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // fetches next Barrier buffer row | |||
| Status BarrierOp::getNextTensorRow(TensorRow *new_row) { | |||
| // iterate over all iterators and generate a row | |||
| RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); | |||
| // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row | |||
| if (new_row->empty()) { | |||
| // If we did not get a row from any of the children, then it's the end of an epoch and we can move | |||
| // to drain state. | |||
| MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; | |||
| clean_up_ = true; | |||
| // If we picked up an eof here, then we are completely done. | |||
| if ((child_iterator_)->eof_handled()) { | |||
| MS_LOG(INFO) << "Barrier operator iterator got EOF."; | |||
| eof_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // A function that prints info about the Operator | |||
| void BarrierOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call base class printer first | |||
| PipelineOp::Print(out, show_all); | |||
| out << "\nBarrierOp:\n" | |||
| << "\nCondition " << condition_name_ << "\n\n"; | |||
| } | |||
| // overwrite function and handle eof | |||
| Status BarrierOp::EofReceived(int32_t) { | |||
| MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; | |||
| return Status::OK(); | |||
| } | |||
| // overwrite function and handle eoe | |||
| Status BarrierOp::EoeReceived(int32_t) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,172 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Forward declare | |||
| class DataBuffer; | |||
| class ExecutionTree; | |||
| // BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has | |||
| // been received. This signal is given from python layer. The current barrier design respects the | |||
| // rows per buffer design and will only output a buffer with rows once it has received rows per buffer | |||
| // signals from python. | |||
| class BarrierOp : public PipelineOp { | |||
| public: | |||
| // The nested builder class inside of the BarrierOp is used to help manage all of | |||
| // the arguments for constructing it. Use the builder by setting each argument | |||
| // with the provided set methods, and then finally call the build method to execute | |||
| // the actual construction. | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { | |||
| builder_rows_per_buffer_ = rows_per_buffer; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param int32_t op_connector_size | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||
| builder_op_connector_size_ = op_connector_size; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param const std::string & condition_name | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetConditionName(const std::string &condition_name) { | |||
| builder_condition_name_ = condition_name; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param py::function condition_func - blocking condition function | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetConditionFunc(py::function condition_func) { | |||
| builder_condition_func_ = condition_func; | |||
| return *this; | |||
| } | |||
| // The builder "build" method creates the BarrierOp dataset Operator. | |||
| // @return shared_ptr to the new BarrierOp object | |||
| Status Build(std::shared_ptr<BarrierOp> *); | |||
| private: | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::string builder_condition_name_; | |||
| py::function builder_condition_func_; | |||
| Status SanityCheck() const; | |||
| }; | |||
| // Constructor for BarrierOp | |||
| // @param rows_per_buffer - number of rows in output buffer | |||
| // @param op_connector_size - connector size | |||
| // @param condition_name - the condition name associated with this operator | |||
| // @param condition_func - the blocking condition check per row | |||
| // @note - currently rows_per_buffer should = 1 for barrier. | |||
| // The reason for this is having other values would complicate how the pipeline behaves with other operators | |||
| // One example of such case is having batch after barrier. Batch would be waiting for data and having | |||
| // rows per buffer in this case can result in hanging | |||
| BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, | |||
| py::function condition_func); | |||
| // Destructor | |||
| ~BarrierOp(); | |||
| Status EofReceived(int32_t) override; | |||
| Status EoeReceived(int32_t) override; | |||
| // Print function for Barrier | |||
| // @param out - output stream to print to | |||
| // @param show_all - if it should print everything | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Provide stream operator for displaying it | |||
| friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { | |||
| bo.Print(out, false); | |||
| return out; | |||
| } | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| // @param table - a table of tensors to be moved into a buffer | |||
| Status prepare(TensorQTable *const table); | |||
| // This function calls takes a table repeatedly adds rows to it. | |||
| // @param table - a table of tensors to be moved into a buffer | |||
| Status fillBuffer(TensorQTable *const table); | |||
| // Gets next tensor row and sets control signals | |||
| Status getNextTensorRow(TensorRow *new_row); | |||
| // This function runs the wait function on condition | |||
| Status blockCond(); | |||
| private: | |||
| // clean up variable to return imcomplete buffer | |||
| bool clean_up_; | |||
| // end of file state, we stop reading data and shut down | |||
| bool eof_; | |||
| // rows per buffer | |||
| int32_t rows_per_buffer_; | |||
| // buffer_id | |||
| int32_t buffer_id_; | |||
| // local variable to keep track of the buffer information | |||
| std::unordered_map<std::string, int32_t> col_name_id_map_; | |||
| // iterator to pull new rows, we only have one child | |||
| std::unique_ptr<ChildIterator> child_iterator_; | |||
| // condition name, to support multiple barriers | |||
| std::string condition_name_; | |||
| // Function pointer of blocking function | |||
| py::function condition_function_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ | |||
| @@ -0,0 +1,253 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include <algorithm> | |||
| #include <cstring> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status FilterOp::Builder::SanityCheck() { | |||
| std::string err; | |||
| err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; | |||
| err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; | |||
| return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); | |||
| } | |||
| FilterOp::Builder::Builder() { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<FilterOp>(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, | |||
| builder_predicate_func_); | |||
| return Status::OK(); | |||
| } | |||
| FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size, | |||
| py::function predicate_func) | |||
| : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} | |||
| Status FilterOp::operator()() { | |||
| // The operator class just starts off threads by calling the tree_ function. | |||
| RETURN_UNEXPECTED_IF_NULL(tree_); | |||
| // Synchronize with TaskManager. | |||
| TaskManager::FindMe()->Post(); | |||
| filter_queues_.Init(num_workers_, oc_queue_size_); | |||
| RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK(Collector()); | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::EofReceived(int32_t) { return Status::OK(); } | |||
| Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } | |||
| // Validating if each of the input_columns exists in the DataBuffer. | |||
| Status FilterOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map, | |||
| const std::vector<std::string> *input_columns) { | |||
| for (const auto &inCol : *input_columns) { | |||
| bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; | |||
| if (!found) { | |||
| std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // A print method typically used for debugging. | |||
| void FilterOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call base class printer first. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then display our own stuff. | |||
| out << "\nFilterOp:"; | |||
| out << "\n Input column names:"; | |||
| for (size_t i = 0; i < in_columns_.size(); i++) { | |||
| out << " " << in_columns_[i]; | |||
| } | |||
| } | |||
| Status FilterOp::WorkerEntry(int32_t worker_id) { | |||
| // Handshake with TaskManager that thread creation is successful. | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<DataBuffer> in_buffer; | |||
| bool worker_stop = false; | |||
| while (worker_stop == false) { | |||
| // Getting a databuffer to work on. | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); | |||
| if (in_buffer->eoe()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); | |||
| continue; | |||
| } else if (in_buffer->eof()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); | |||
| worker_stop = true; | |||
| continue; | |||
| } | |||
| RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_)); | |||
| // if the databuffer was all filtered, it is marked as kFilterEmpty. | |||
| // if the databuffer was partially filtered, it is marked as kFilterPartial. | |||
| // if the databuffer was not filtered, it is marked as kFilterFull. | |||
| int32_t num_rows = in_buffer->NumRows(); | |||
| std::unique_ptr<TensorQTable> new_tensor_table; | |||
| RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table)); | |||
| if (new_tensor_table->empty()) { | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); | |||
| } else if (new_tensor_table->size() == num_rows) { | |||
| in_buffer->set_tensor_table(std::move(new_tensor_table)); | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); | |||
| } else { // kFilterPartial | |||
| in_buffer->set_tensor_table(std::move(new_tensor_table)); | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out) { | |||
| *out = std::make_unique<TensorQTable>(); | |||
| int32_t num_rows = in_buffer->NumRows(); | |||
| for (int32_t i = 0; i < num_rows; i++) { | |||
| TensorRow to_process; | |||
| TensorRow cur_row; | |||
| RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); | |||
| if (in_columns_.empty() == true) { | |||
| MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; | |||
| to_process = cur_row; | |||
| } else { | |||
| std::unordered_map<std::string, int32_t> col_map = in_buffer->column_name_map(); | |||
| (void)std::transform( | |||
| in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process), | |||
| [&cur_row, &col_map](const auto &it) -> std::shared_ptr<Tensor> { return cur_row[col_map[it]]; }); | |||
| } | |||
| bool predicate = true; | |||
| RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); | |||
| if (predicate) { | |||
| (*out)->push_back(std::move(cur_row)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // if the filtered DataBuffer is written directly to out_connector_, | |||
| // the thread fetching data will block in a queue. | |||
| // Collector function will reorder the DataBuffer in order. | |||
| // for example in two work queues: | |||
| // int filter_queues_: | |||
| // queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) | |||
| // queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) | |||
| // after reorder in out_connector_: | |||
| // queue1: DB(data2) DB(data4) DB(eof) | |||
| // queue2: DB(eoe) DB(eoe) | |||
| Status FilterOp::Collector() { | |||
| bool collector_stop = false; | |||
| uint64_t task_id_cnt = 0; | |||
| uint64_t out_id_cnt = 0; | |||
| std::pair<std::unique_ptr<DataBuffer>, filterCtrl> in_pair; | |||
| while (collector_stop == false) { | |||
| uint32_t w_id = task_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); | |||
| if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || | |||
| in_pair.second == filterCtrl::kFilterEoe) { | |||
| uint32_t out_task_id = out_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first))); | |||
| out_id_cnt++; | |||
| task_id_cnt++; | |||
| } else if (in_pair.second == filterCtrl::kFilterEof) { | |||
| uint32_t out_task_id = out_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first))); | |||
| collector_stop = true; | |||
| } else { // kFilterEmpty | |||
| task_id_cnt++; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Private function for checking the column legality. | |||
| Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector<std::string> *input_columns) { | |||
| int32_t num_rows = in_buf->NumRows(); | |||
| int32_t num_cols = in_buf->NumCols(); | |||
| if (num_rows == 0 || num_cols == 0) { | |||
| RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); | |||
| } | |||
| std::unordered_map<std::string, int32_t> col_name_id_map = in_buf->column_name_map(); | |||
| // Check if there is invalid column name in the inColumns. | |||
| RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns)); | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::CheckInput(const TensorRow &input) const { | |||
| for (auto &item : input) { | |||
| if (item == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("input is null."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { | |||
| RETURN_IF_NOT_OK(CheckInput(input)); | |||
| // Acquire Python GIL. | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| // Transform input tensor vector into numpy array vector. | |||
| py::tuple input_args(input.size()); | |||
| for (size_t i = 0; i < input.size(); i++) { | |||
| py::array new_data; | |||
| RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); | |||
| input_args[i] = new_data; | |||
| } | |||
| // Invoke python function. | |||
| py::object ret_py_obj = predicate_func_(*input_args); | |||
| *out_predicate = ret_py_obj.cast<py::bool_>(); | |||
| } catch (const py::error_already_set &e) { | |||
| std::stringstream ss; | |||
| ss << e.what() << std::endl; | |||
| ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; | |||
| return Status(StatusCode::kPyFuncException, ss.str()); | |||
| } | |||
| return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,181 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ | |||
| #include <memory> | |||
| #include <queue> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/queue.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class FilterOp : public ParallelOp { | |||
| public: | |||
| // The nested builder class inside of the FilterOp is used to help manage all of | |||
| // the arguments for constructing it. Use the builder by setting each argument | |||
| // with the provided set methods, and then finally call the build method to execute | |||
| // the actual construction. | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args. | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetPredicateFunc(py::function func) { | |||
| builder_predicate_func_ = std::move(func); | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetInColNames(const std::vector<std::string> &in_col_names) { | |||
| build_in_col_names_ = in_col_names; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| builder_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t connector_size) { | |||
| builder_op_connector_size_ = connector_size; | |||
| return *this; | |||
| } | |||
| // The builder "build" method creates the final object. | |||
| // @param ptr The shared_ptr to the new FilterOp object. | |||
| // @return Status. | |||
| Status Build(std::shared_ptr<FilterOp> *ptr); | |||
| private: | |||
| // Sanity check for builder class args. | |||
| // @return Status - The error code return. | |||
| Status SanityCheck(); | |||
| std::vector<std::string> build_in_col_names_; | |||
| py::function builder_predicate_func_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| }; | |||
| enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; | |||
| // Constructor of FilterOp | |||
| // @note The builder class should be used to call it. | |||
| // @param in_col_names A list of input column names,when it is empty the predicate will be | |||
| // applied all columns in the dataset. | |||
| // @param num_workers The number of worker threads. | |||
| // @param op_connector_size The size of each queue in the connector. | |||
| // @param predicate_func python callable which returns a boolean value. | |||
| FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size, | |||
| py::function predicate_func); | |||
| // Destructor | |||
| ~FilterOp() = default; | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will | |||
| // provide the master loop that drives the logic for performing the work. | |||
| // @return Status The error code return | |||
| Status operator()() override; | |||
| // @param int32_t workerId. | |||
| // @return Status - The error code return. | |||
| Status EofReceived(int32_t) override; | |||
| // @param int32_t workerId. | |||
| // @return Status - The error code return. | |||
| Status EoeReceived(int32_t) override; | |||
| // A print method typically used for debugging. | |||
| // @param out The output stream to write output to. | |||
| // @param show_all A bool to control if you want to show all info or just a summary. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| private: | |||
| // predicate_func python callable which returns a boolean value. | |||
| py::function predicate_func_; | |||
| // Variable to store the column name that will feed to predicate function. | |||
| std::vector<std::string> in_columns_; | |||
| // Internal queue for filter. | |||
| QueueList<std::pair<std::unique_ptr<DataBuffer>, filterCtrl>> filter_queues_; | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| // logic of FilterOp, getting the data from previous Op, validating user specified column names, | |||
| // applying predicate to each of the data, filter the data when predicate result is false. | |||
| // @param worker_id The id assigned to this thread/worker upon creation. | |||
| // @return Status The error code return. | |||
| Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | |||
| // Filter the data by predicate function . | |||
| // @param in_buffer input data buffer. | |||
| // @param to_proess_indices Indices of columns to be processed. | |||
| // @param out data buffer that are filtered by predicate. | |||
| // @return Status The error code return. | |||
| Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out); | |||
| // Collector databuffer. | |||
| // @return Status The error code return. | |||
| Status Collector(); | |||
| // @param input tensor vector. | |||
| // @return Status - The error code return. | |||
| Status CheckInput(const TensorRow &input) const; | |||
| // Invoke python func. | |||
| // @param input tensor vector. | |||
| // @param the result of predicate. | |||
| // @return Status - The error code return. | |||
| Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); | |||
| // Private function for validating if each of the user specified input column names | |||
| // exist in the DataBuffer. | |||
| // @param col_name_id_map The column name to index mapping obtained from DataBuffer. | |||
| // @param input_columns The vector of input column names used in the current thread. | |||
| // @return Status The error code return. | |||
| Status ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map, | |||
| const std::vector<std::string> *input_columns); | |||
| // Private function for checking the column legality | |||
| // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory | |||
| // and is not shared with other threads. | |||
| // @param[out] to_process_indices Indices of columns that will feed to predicate. | |||
| // @param input_columns The vector of input column names used in the current thread. | |||
| Status CheckColumns(const DataBuffer *in_buf, const std::vector<std::string> *input_columns); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -65,9 +65,6 @@ MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std | |||
| tfuncs_(std::move(tensor_funcs)), | |||
| in_columns_(in_col_names), | |||
| out_columns_(out_col_names), | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| eof_worker_id_(0), | |||
| #endif | |||
| perf_mode_(perf_mode) { | |||
| // If caller didn't specify the out_col_names, assume they are same as the in_columns. | |||
| if (out_columns_.empty() || out_columns_[0].empty()) { | |||
| @@ -123,17 +120,6 @@ Status MapOp::operator()() { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| is_eof = buff->eof(); | |||
| RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| if (is_eof) { | |||
| eof_worker_id_ = que_id; | |||
| for (int32_t id = 0; id < num_workers_; id++) { | |||
| if (id != eof_worker_id_) { | |||
| auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| RETURN_IF_NOT_OK(local_queues_[id]->Add(std::move(eof_buffer))); | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| que_id = (que_id + 1) % num_workers_; | |||
| } | |||
| } | |||
| @@ -173,14 +159,6 @@ Status MapOp::WorkerEntry(int32_t worker_id) { | |||
| continue; | |||
| } else if (in_buffer->eof()) { | |||
| // Calling base class EofReceived to forward eof buffer. | |||
| #if defined(_WIN32) || defined(_Win64) | |||
| if (perf_mode_) { | |||
| if (eof_worker_id_ == worker_id) { | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| } | |||
| break; | |||
| } | |||
| #endif | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| break; | |||
| } | |||
| @@ -193,10 +193,6 @@ class MapOp : public ParallelOp { | |||
| // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. | |||
| bool perf_mode_; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| // EOF worker id is only work on Performance mode, to record the worker id of queue which gets EOF | |||
| int32_t eof_worker_id_; | |||
| #endif | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| // logic of MapOp: getting the data from previous Op, validating user specified column names, | |||
| // applying a list of TensorOps to each of the data, process the results and then | |||
| @@ -13,6 +13,9 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| #include <stdlib.h> | |||
| #endif | |||
| #include <securec.h> | |||
| #include <algorithm> | |||
| #include <chrono> | |||
| @@ -86,7 +89,9 @@ Status ShuffleOp::SelfReset() { | |||
| rng_ = std::mt19937_64(shuffle_seed_); | |||
| } else { | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| std::random_device random_device; | |||
| unsigned int number; | |||
| rand_s(&number); | |||
| std::mt19937 random_device{static_cast<uint32_t>(number)}; | |||
| #else | |||
| std::random_device random_device("/dev/urandom"); | |||
| #endif | |||
| @@ -67,9 +67,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work | |||
| } | |||
| std::unique_ptr<DataBuffer> buf; | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| // Drop first max_skips_ rows | |||
| while (skip_count_ < max_skips_) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| if (buf->eoe() || buf->eof()) { | |||
| break; | |||
| } | |||
| @@ -77,31 +78,24 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work | |||
| // Consider the rows of buffer more than 1 | |||
| TensorRow drop_row; | |||
| int row_num = buf->NumRows(); | |||
| for (int i = 0; i < row_num; i++) { | |||
| int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; | |||
| skip_count_ += drop_num; | |||
| for (int i = 0; i < drop_num; i++) { | |||
| RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); | |||
| if (++skip_count_ == max_skips_) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| // If buffer is none or the rows of buffer is 0, | |||
| // then get a buffer from child. | |||
| if (!buf || buf->NumRows() == 0) { | |||
| if (buf && buf->eof()) { | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| if (buf->NumRows() == 0) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| // Handling eoe and eof | |||
| if (buf->eoe() || buf->eof()) { | |||
| // Handling eoe | |||
| if (buf->eoe()) { | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| if (state_ == OpState::kDeOpIdle) { | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| // Handling eof | |||
| if (buf->eof()) { | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| } | |||
| *p_buffer = std::move(buf); | |||
| @@ -125,7 +119,7 @@ Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is a | |||
| // Base-class override for handling cases when an eof is received. | |||
| Status SkipOp::EofReceived(int32_t worker_id) { | |||
| MS_LOG(INFO) << "Skip operator EOF received, do nothing now."; | |||
| MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT | |||
| manifest_op.cc | |||
| cifar_op.cc | |||
| celeba_op.cc | |||
| text_file_op.cc | |||
| ) | |||
| add_dependencies(engine-datasetops-source mindspore::protobuf) | |||
| @@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordOp::CountTotalRows(const std::string dataset_path, int64_t *count) { | |||
| Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, | |||
| int64_t *count) { | |||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | |||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, count); | |||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); | |||
| if (rc == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); | |||
| } | |||
| @@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp { | |||
| int32_t num_rows() const { return num_rows_; } | |||
| // Getter method | |||
| static Status CountTotalRows(const std::string dataset_path, int64_t *count); | |||
| static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, | |||
| int64_t *count); | |||
| // Getter method | |||
| int32_t rows_per_buffer() const { return rows_per_buffer_; } | |||
| @@ -1,6 +1,7 @@ | |||
| add_library(engine-datasetops-source-sampler OBJECT | |||
| distributed_sampler.cc | |||
| pk_sampler.cc | |||
| python_sampler.cc | |||
| random_sampler.cc | |||
| sampler.cc | |||
| sequential_sampler.cc | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/source/sampler/python_sampler.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||
| Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (need_to_reset_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone); | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| py::object py_ret = py_sampler_instance.attr("_get_indices")(); | |||
| py::array np_sample_ids = py_ret.cast<py::array>(); | |||
| Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } catch (const py::cast_error &e) { | |||
| return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index"); | |||
| } | |||
| } | |||
| TensorRow row(1, sample_ids); | |||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||
| need_to_reset_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PythonSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PythonSampler::Reset() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); | |||
| need_to_reset_ = false; | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| py_sampler_instance.attr("reset")(); | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ | |||
| #include <limits> | |||
| #include <memory> | |||
| #include "dataset/engine/datasetops/source/sampler/sampler.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class PythonSampler : public Sampler { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit PythonSampler(py::object py_sampler_instance, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~PythonSampler() = default; | |||
| // Initialize the sampler. | |||
| // @return Status | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| Status Reset() override; | |||
| // Op calls this to get next Buffer that contains all the sampleIds | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| private: | |||
| bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | |||
| py::object py_sampler_instance; // The handle to the py_sampler python object | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ | |||
| @@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| // check samples_per_buffer is properly set and doesn't overflow | |||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); | |||
| // A call to derived class to get sample ids wrapped inside a buffer | |||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||
| // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | |||
| @@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) | |||
| } | |||
| Status SequentialSampler::InitSampler() { | |||
| num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| return Status::OK(); | |||
| @@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const { | |||
| std::ifstream in(schemaFile); | |||
| nlohmann::json js; | |||
| in >> js; | |||
| num_rows = js.value("numRows", 0); | |||
| if (js.find("numRows") == js.end()) { | |||
| num_rows = MAX_INTEGER_INT32; | |||
| } else { | |||
| num_rows = js.value("numRows", 0); | |||
| } | |||
| if (num_rows == 0) { | |||
| std::string err_msg = | |||
| "Storage client has not properly done dataset " | |||
| @@ -0,0 +1,459 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "common/utils.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/util/wait_post.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TextFileOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| builder_rows_per_buffer_ = config_manager->rows_per_buffer(); | |||
| builder_worker_connector_size_ = config_manager->worker_connector_size(); | |||
| } | |||
| Status TextFileOp::Builder::ValidateInputs() const { | |||
| std::string err_msg; | |||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : ""; | |||
| err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; | |||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| RETURN_IF_NOT_OK(ValidateInputs()); | |||
| // Throttle the number of workers if we have more workers than files! | |||
| if (static_cast<size_t>(builder_num_workers_) > builder_text_files_list_.size()) { | |||
| builder_num_workers_ = builder_text_files_list_.size(); | |||
| MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK( | |||
| builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | |||
| std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, | |||
| builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(text_file_op->Init()); | |||
| *op = std::move(text_file_op); | |||
| return Status::OK(); | |||
| } | |||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_samples_(num_samples), | |||
| text_files_list_(std::move(text_files_list)), | |||
| shuffle_files_(shuffle_files), | |||
| data_schema_(std::move(schema)), | |||
| all_num_rows_(0), | |||
| num_rows_per_shard_(0), | |||
| filename_index_(std::make_unique<StringIndex>()), | |||
| finished_reading_dataset_(false), | |||
| load_io_block_queue_(true), | |||
| load_jagged_connector_(true) { | |||
| worker_connector_size_ = worker_connector_size; | |||
| } | |||
| Status TextFileOp::Init() { | |||
| RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); | |||
| int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1); | |||
| io_block_queues_.Init(num_workers_, safe_queue_size); | |||
| for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { | |||
| col_name_map_[data_schema_->column(i).name()] = i; | |||
| } | |||
| RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); | |||
| jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::Reset() { | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||
| NotifyToFillIOBlockQueue(); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) { | |||
| TensorRow tRow(1, nullptr); | |||
| (*tensor_table)->push_back(std::move(tRow)); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(), | |||
| TensorShape(std::vector<dsize_t>(1, line.size())), data_schema_->column(0).type(), | |||
| const_cast<unsigned char *>(reinterpret_cast<const unsigned char *>(common::SafeCStr(line))))); | |||
| (**tensor_table)[row][0] = std::move(tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id) { | |||
| std::ifstream handle(file); | |||
| if (!handle.is_open()) { | |||
| RETURN_STATUS_UNEXPECTED("Failed to open file " + file); | |||
| } | |||
| int64_t rows_each_buffer = 0; | |||
| int64_t rows_total = 0; | |||
| std::string line; | |||
| std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| cur_buffer->set_column_name_map(col_name_map_); | |||
| std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>(); | |||
| while (getline(handle, line)) { | |||
| // If read to the end offset of this file, break. | |||
| if (rows_total >= end_offset) { | |||
| break; | |||
| } | |||
| // Skip line before start offset. | |||
| if (rows_total < start_offset) { | |||
| rows_total++; | |||
| continue; | |||
| } | |||
| RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); | |||
| rows_each_buffer++; | |||
| rows_total++; | |||
| if (rows_each_buffer == rows_per_buffer_) { | |||
| cur_buffer->set_tensor_table(std::move(tensor_table)); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); | |||
| cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| cur_buffer->set_column_name_map(col_name_map_); | |||
| tensor_table = std::make_unique<TensorQTable>(); | |||
| rows_each_buffer = 0; | |||
| } | |||
| } | |||
| if (rows_each_buffer > 0) { | |||
| cur_buffer->set_tensor_table(std::move(tensor_table)); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<FilenameBlock> io_block; | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| while (!io_block->eof()) { | |||
| if (!io_block->eoe()) { | |||
| if (load_jagged_connector_) { | |||
| std::string filename; | |||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||
| int64_t start_offset = io_block->GetStartOffset(); | |||
| int64_t end_offset = io_block->GetEndOffset(); | |||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||
| } | |||
| } else { | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||
| } | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Pops an element from a queue in io_block_queues | |||
| Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||
| return Status::OK(); | |||
| } | |||
| // Pushes an element to a queue in io_block_queues | |||
| Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||
| return Status::OK(); | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| Status TextFileOp::PostEndOfData() { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||
| std::mt19937 rng(seed); | |||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||
| } | |||
| bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count) { | |||
| *start_offset = 0; | |||
| *end_offset = 0; | |||
| bool push = false; | |||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||
| if (device_id_ + 1 < 0) { | |||
| MS_LOG(ERROR) << "Device id is invalid"; | |||
| return false; | |||
| } | |||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||
| *start_offset = start_index - pre_count; | |||
| push = true; | |||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| if (pre_count >= start_index && pre_count < end_index) { | |||
| *start_offset = 0; | |||
| push = true; | |||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| return push; | |||
| } | |||
| Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||
| int32_t queue_index = 0; | |||
| int64_t pre_count = 0; | |||
| int64_t start_offset = 0; | |||
| int64_t end_offset = 0; | |||
| bool finish = false; | |||
| while (!finish) { | |||
| std::vector<std::pair<std::string, int64_t>> file_index; | |||
| if (!i_keys.empty()) { | |||
| for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| auto file_it = filename_index_->Search(*it); | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it)); | |||
| } | |||
| } else { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key())); | |||
| } | |||
| } | |||
| for (auto file_info : file_index) { | |||
| if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { | |||
| auto ioBlock = | |||
| std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); | |||
| queue_index = (queue_index + 1) % num_workers_; | |||
| } | |||
| pre_count += filename_numrows_[file_info.first]; | |||
| } | |||
| if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||
| finish = false; | |||
| } else { | |||
| finish = true; | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::WaitToFillIOBlockQueue() { | |||
| // must be called first if called by worker spanwed by taskgroup | |||
| TaskManager::FindMe()->Post(); | |||
| std::vector<int64_t> i_keys; | |||
| if (shuffle_files_) { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| i_keys.push_back(it.key()); | |||
| } | |||
| } | |||
| uint32_t seed = 0; | |||
| while (true) { | |||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||
| io_block_queue_wait_post_.Clear(); | |||
| if (finished_reading_dataset_) { | |||
| break; | |||
| } | |||
| if (shuffle_files_) { | |||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||
| } | |||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||
| Status TextFileOp::operator()() { | |||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| // launch one thread, responsible for filling IoBlockQueue | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); | |||
| // Read data from disk into buffers | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); | |||
| // must be called after launching workers. | |||
| TaskManager::FindMe()->Post(); | |||
| io_block_queue_wait_post_.Register(tree_->AllTasks()); | |||
| NotifyToFillIOBlockQueue(); | |||
| while (!finished_reading_dataset_) { | |||
| int64_t buffer_id = 0; | |||
| int32_t workers_done = 0; | |||
| int64_t rows_read = 0; | |||
| load_io_block_queue_ = true; | |||
| while (workers_done < num_workers_) { | |||
| std::unique_ptr<DataBuffer> buffer; | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||
| if (buffer->eoe()) { | |||
| workers_done++; | |||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||
| } | |||
| rows_read += buffer->NumRows(); | |||
| buffer->set_id(buffer_id++); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||
| } else { | |||
| // end of epoch | |||
| load_jagged_connector_ = false; | |||
| load_io_block_queue_ = false; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||
| finished_reading_dataset_ = true; | |||
| NotifyToFillIOBlockQueue(); | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||
| return Status::OK(); | |||
| } | |||
| int64_t TextFileOp::CountTotalRows(const std::string &file) { | |||
| std::ifstream handle(file); | |||
| if (!handle.is_open()) { | |||
| MS_LOG(ERROR) << "Failed to open file: " << file; | |||
| return 0; | |||
| } | |||
| std::string line; | |||
| int64_t count = 0; | |||
| while (getline(handle, line)) { | |||
| count++; | |||
| } | |||
| return count; | |||
| } | |||
| Status TextFileOp::CalculateNumRowsPerShard() { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| int64_t count = CountTotalRows(it.value()); | |||
| filename_numrows_[it.value()] = count; | |||
| all_num_rows_ += count; | |||
| } | |||
| if (all_num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Number of rows can not be zero"); | |||
| } | |||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) { | |||
| std::shared_ptr<TextFileOp> op; | |||
| *count = 0; | |||
| RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); | |||
| for (auto file : files) { | |||
| *count += op->CountTotalRows(file); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,263 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ | |||
| #include <memory> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/util/auto_index.h" | |||
| #include "dataset/engine/data_schema.h" | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| #include "dataset/util/queue.h" | |||
| #include "dataset/util/wait_post.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using StringIndex = AutoIndexObj<std::string>; | |||
| class TextFileOp : public ParallelOp { | |||
| public: | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Checks if the inputs of the builder is valid. | |||
| // @return Status - the error code returned. | |||
| Status ValidateInputs() const; | |||
| // Create the final object. | |||
| // @param op - dataset op. | |||
| // @return - the error code return. | |||
| Status Build(std::shared_ptr<TextFileOp> *op); | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| builder_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||
| builder_op_connector_size_ = op_connector_size; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { | |||
| builder_rows_per_buffer_ = rows_per_buffer; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumDevices(int64_t num_dev) { | |||
| builder_num_devices_ = num_dev; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetDeviceId(int64_t dev_id) { | |||
| builder_device_id_ = dev_id; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetTextFilesList(const std::vector<std::string> &files_list) { | |||
| builder_text_files_list_ = files_list; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShuffleFiles(bool shuffle_files) { | |||
| builder_shuffle_files_ = shuffle_files; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int64_t builder_rows_per_buffer_; | |||
| int64_t builder_num_samples_; | |||
| int32_t builder_worker_connector_size_; | |||
| std::vector<std::string> builder_text_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| }; | |||
| // Constructor of TextFileOp | |||
| // @note The builder class should be used to call this constructor. | |||
| // @param num_workers - number of worker threads reading data from tf_file files. | |||
| // @param rows_per_buffer - number of rows that a full buffer will contain. | |||
| // @param total_num_rows - number of rows to read | |||
| // @param dataset_files_list - list of filepaths for the dataset files. | |||
| // @param data_schema - the data schema object. | |||
| // @param op_connector_size - size of each queue in the connector that the child operator pulls from. | |||
| // @param columns_to_load - the names of the columns to load data from. | |||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~TextFileOp() = default; | |||
| // Instantiates the internal queues and connectors | |||
| // @return Status - the error code returned | |||
| Status Init(); | |||
| // Class functor operator () override. | |||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - the error code returned. | |||
| Status operator()() override; | |||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||
| // @return Status - the error code returned. | |||
| Status Reset() override; | |||
| // Get total rows in files. | |||
| // @param files - all text files. | |||
| // @param count - number of rows. | |||
| // @return Status - the error coed returned. | |||
| static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count); | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Parses a single row and puts the data into a tensor table. | |||
| // @param line - the content of the row. | |||
| // @param tensor_table - the tensor table to put the parsed data in. | |||
| // @param row - the id of the row filled in the tensor table. | |||
| // @return Status - the error code returned. | |||
| Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row); | |||
| // Reads a text file and loads the data into multiple buffers. | |||
| // @param file - the file to read. | |||
| // @param start_offset - the start offset of file. | |||
| // @param end_offset - the end offset of file. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id); | |||
| // Calculate number of rows in each shard. | |||
| // @return Status - the error code returned. | |||
| Status CalculateNumRowsPerShard(); | |||
| // Count number of rows in each file. | |||
| // @param filename - text file name. | |||
| // @return int64_t - the total number of rows in file. | |||
| int64_t CountTotalRows(const std::string &file); | |||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||
| void NotifyToFillIOBlockQueue(); | |||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||
| // @return Status - the error code returned. | |||
| Status WaitToFillIOBlockQueue(); | |||
| // Fill the IOBlockQueue. | |||
| // @para i_keys - keys of file to fill to the IOBlockQueue | |||
| // @return Status - the error code returned. | |||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||
| // Select file and push it to the block queue. | |||
| // @param file_name - File name. | |||
| // @param start_file - If file contains the first sample of data. | |||
| // @param end_file - If file contains the end sample of data. | |||
| // @param pre_count - Total rows of previous files. | |||
| // @return Status - the error code returned. | |||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count); | |||
| // Pops an element from a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to pop from. | |||
| // @param out_block - the popped element. | |||
| // @return Status - the error code returned. | |||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||
| // Pushes an element to a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to push to. | |||
| // @param io_block - the element to push onto the queue. | |||
| // @return Status - the error code returned. | |||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfData(); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfEpoch(int32_t queue_index); | |||
| int32_t device_id_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| int64_t num_samples_; | |||
| std::vector<std::string> text_files_list_; | |||
| bool shuffle_files_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| int64_t all_num_rows_; | |||
| int64_t num_rows_per_shard_; | |||
| std::map<std::string, int64_t> filename_numrows_; | |||
| std::unique_ptr<StringIndex> filename_index_; | |||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||
| WaitPost io_block_queue_wait_post_; | |||
| bool finished_reading_dataset_; | |||
| bool load_io_block_queue_; | |||
| bool load_jagged_connector_; | |||
| std::unordered_map<std::string, int32_t> col_name_map_; | |||
| std::unique_ptr<JaggedConnector> jagged_buffer_connector_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ | |||
| @@ -42,6 +42,7 @@ | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/util/wait_post.h" | |||
| #include "utils/system/crc32c.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder() | |||
| builder_data_schema_ = std::make_unique<DataSchema>(); | |||
| } | |||
| bool ValidateFirstRowCrc(const std::string &filename) { | |||
| std::ifstream reader; | |||
| reader.open(filename); | |||
| if (!reader) { | |||
| return false; | |||
| } | |||
| // read data | |||
| int64_t record_length = 0; | |||
| (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t))); | |||
| // read crc from file | |||
| uint32_t masked_crc = 0; | |||
| (void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t))); | |||
| // generate crc from data | |||
| uint32_t generated_crc = | |||
| system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t)); | |||
| return masked_crc == generated_crc; | |||
| } | |||
| Status TFReaderOp::Builder::ValidateInputs() const { | |||
| std::string err_msg; | |||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : ""; | |||
| if (!builder_equal_rows_per_shard_) { | |||
| err_msg += builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_) | |||
| ? "No enough tf_file files provided\n" | |||
| : ""; | |||
| if (builder_num_workers_ <= 0) { | |||
| err_msg += "Number of parallel workers is smaller or equal to 0\n"; | |||
| } | |||
| if (!builder_equal_rows_per_shard_ && | |||
| builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) { | |||
| err_msg += "Not enough tfrecord files provided\n"; | |||
| } | |||
| if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { | |||
| err_msg += "Wrong sharding configs\n"; | |||
| } | |||
| std::vector<std::string> invalid_files(builder_dataset_files_list_.size()); | |||
| auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), | |||
| [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); | |||
| invalid_files.resize(std::distance(invalid_files.begin(), it)); | |||
| if (!invalid_files.empty()) { | |||
| err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; | |||
| std::string accumulated_filenames = std::accumulate( | |||
| invalid_files.begin(), invalid_files.end(), std::string(""), | |||
| [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); | |||
| err_msg += accumulated_filenames; | |||
| } | |||
| err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; | |||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| @@ -119,6 +163,9 @@ Status TFReaderOp::Init() { | |||
| if (total_rows_ == 0) { | |||
| total_rows_ = data_schema_->num_rows(); | |||
| } | |||
| if (total_rows_ < 0) { | |||
| RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); | |||
| } | |||
| // Build the index with our files such that each file corresponds to a key id. | |||
| RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); | |||
| @@ -523,6 +570,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off | |||
| RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); | |||
| rows_read++; | |||
| } | |||
| // ignore crc footer | |||
| (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t))); | |||
| rows_total++; | |||
| @@ -67,7 +67,7 @@ Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work | |||
| bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); | |||
| if (take_count_ == max_takes_) { | |||
| if (state_ == OpState::kDeOpRunning) { | |||
| MS_LOG(INFO) << "meet max count and push-back eoe buffer."; | |||
| MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer."; | |||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| *p_buffer = std::move(eoe_buffer); | |||
| state_ = OpState::kDeOpIdle; | |||
| @@ -80,11 +80,13 @@ Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(INFO) << "meet max count and push-back eof buffer."; | |||
| } else if (state_ == OpState::kDeOpIdle) { | |||
| MS_LOG(DEBUG) << "Meet max count and push-back eof buffer."; | |||
| auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| *p_buffer = std::move(eof_buffer); | |||
| take_count_ = 0; | |||
| } else { | |||
| MS_LOG(WARNING) << "Invalid OpState: " << state_; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -116,7 +118,7 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D | |||
| *data_buffer = std::move(*buffer); | |||
| take_count_ = take_count_ + buffer_size; | |||
| } else { | |||
| MS_LOG(INFO) << "In last buffer: Push one buffer."; | |||
| MS_LOG(DEBUG) << "In last buffer: Push one buffer."; | |||
| std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>(); | |||
| while (take_count_ < max_takes_) { | |||
| TensorRow new_row; | |||
| @@ -34,7 +34,7 @@ class DataBuffer; | |||
| class ZipOp : public PipelineOp { | |||
| public: | |||
| // The nested builder class inside of the BatchOp is used to help manage all of | |||
| // The nested builder class inside of the ZipOp is used to help manage all of | |||
| // the arguments for constructing it. Use the builder by setting each argument | |||
| // with the provided set methods, and then finally call the build method to execute | |||
| // the actual construction. | |||
| @@ -76,8 +76,8 @@ class ZipOp : public PipelineOp { | |||
| }; | |||
| // Constructor for ZipOp | |||
| // @param rows_per_buffer number of rows in output buffer | |||
| // @param op_connector_size connector | |||
| // @param rows_per_buffer - number of rows in output buffer | |||
| // @param op_connector_size - connector size | |||
| ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); | |||
| // Destructor | |||
| @@ -88,8 +88,8 @@ class ZipOp : public PipelineOp { | |||
| Status EoeReceived(int32_t) override; | |||
| // Print function for Zip | |||
| // @param out output stream to print to | |||
| // @param show_all if it should print everything | |||
| // @param out - output stream to print to | |||
| // @param show_all - if it should print everything | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Provide stream operator for displaying it | |||
| @@ -113,14 +113,14 @@ class ZipOp : public PipelineOp { | |||
| Status fillBuffer(TensorQTable *const table); | |||
| // Special handle case where an empty row has been received from child iterator | |||
| // @note we need to drain eoe signals from all children connectors. | |||
| // @details when this function is called, then we encountered eoe at child iterator | |||
| // @note - we need to drain eoe signals from all children connectors. | |||
| // @details - when this function is called, then we encountered eoe at child iterator | |||
| // we have to drain rows from other child iterators until we hit eoe from all other child iterators | |||
| Status drainPipeline(); | |||
| // Merges 1 row from each childIterator together | |||
| // @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty | |||
| // @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true | |||
| // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty | |||
| // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true | |||
| // @details merge rows from iterator together. This is the main functionality for ZipOp | |||
| // this function takes one row and fills it with tensors from rows fetched | |||
| // from childIterators. | |||
| @@ -3,7 +3,6 @@ if (WIN32) | |||
| center_crop_op.cc | |||
| cut_out_op.cc | |||
| decode_op.cc | |||
| distort_bounding_box_crop_op.cc | |||
| hwc_to_chw_op.cc | |||
| image_utils.cc | |||
| normalize_op.cc | |||
| @@ -19,6 +18,7 @@ if (WIN32) | |||
| rescale_op.cc | |||
| resize_bilinear_op.cc | |||
| resize_op.cc | |||
| uniform_aug_op.cc | |||
| ) | |||
| else() | |||
| add_library(kernels-image OBJECT | |||
| @@ -26,7 +26,6 @@ else() | |||
| change_mode_op.cc | |||
| cut_out_op.cc | |||
| decode_op.cc | |||
| distort_bounding_box_crop_op.cc | |||
| hwc_to_chw_op.cc | |||
| image_utils.cc | |||
| normalize_op.cc | |||
| @@ -42,5 +41,6 @@ else() | |||
| rescale_op.cc | |||
| resize_bilinear_op.cc | |||
| resize_op.cc | |||
| uniform_aug_op.cc | |||
| ) | |||
| endif() | |||
| endif() | |||
| @@ -33,7 +33,8 @@ const uint8_t CutOutOp::kDefFillB = 0; | |||
| // constructor | |||
| CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, | |||
| uint8_t fill_g, uint8_t fill_b) | |||
| : box_height_(box_height), | |||
| : rnd_(GetSeed()), | |||
| box_height_(box_height), | |||
| box_width_(box_width), | |||
| num_patches_(num_patches), | |||
| random_color_(random_color), | |||
| @@ -46,8 +47,8 @@ Status CutOutOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T | |||
| IO_CHECK(input, output); | |||
| std::shared_ptr<CVTensor> inputCV = CVTensor::AsCVTensor(input); | |||
| // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black | |||
| RETURN_IF_NOT_OK( | |||
| Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, fill_r_, fill_g_, fill_b_)); | |||
| RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, | |||
| fill_g_, fill_b_)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -62,6 +62,7 @@ class CutOutOp : public TensorOp { | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| private: | |||
| std::mt19937 rnd_; | |||
| int32_t box_height_; | |||
| int32_t box_width_; | |||
| int32_t num_patches_; | |||
| @@ -34,11 +34,11 @@ class DecodeOp : public TensorOp { | |||
| ~DecodeOp() = default; | |||
| Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override; | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| void Print(std::ostream& out) const override { out << "DecodeOp"; } | |||
| Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; | |||
| Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; | |||
| void Print(std::ostream &out) const override { out << "DecodeOp"; } | |||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| private: | |||
| bool is_rgb_format_ = true; | |||
| @@ -1,117 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/image/distort_bounding_box_crop_op.h" | |||
| #include <random> | |||
| #include "dataset/core/cv_tensor.h" | |||
| #include "dataset/kernels/image/image_utils.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const int32_t DistortBoundingBoxCropOp::kDefMaxAttempts = 100; | |||
| const int32_t DistortBoundingBoxCropOp::kDefBoxGenAttempts = 10; | |||
| DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, | |||
| float crop_ratio_ub, int32_t max_attempts, int32_t box_gen_attempts) | |||
| : max_attempts_(max_attempts), | |||
| box_gen_attempts_(box_gen_attempts), | |||
| aspect_ratio_(aspect_ratio), | |||
| intersect_ratio_(intersect_ratio), | |||
| crop_ratio_lb_(crop_ratio_lb), | |||
| crop_ratio_ub_(crop_ratio_ub) { | |||
| seed_ = GetSeed(); | |||
| rnd_.seed(seed_); | |||
| } | |||
| Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input, | |||
| std::vector<std::shared_ptr<Tensor>>* output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| if (input.size() != NumInput()) | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input[1]->shape().Size() >= 1, "The shape of the second tensor is abnormal"); | |||
| int64_t num_boxes = 0; | |||
| for (uint64_t i = 1; i < input.size(); i++) { | |||
| if (i == 1) num_boxes = input[i]->shape()[0]; | |||
| if (num_boxes != input[i]->shape()[0]) | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Numbers of boxes do not match"); | |||
| if (input[i]->type() != DataType::DE_FLOAT32) | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "boxes' type is not DE_FLOAT21"); | |||
| } | |||
| // assume input Tensor vector in the order of [img, bbox_y_min, bbox_y_max, bbox_x_min, bbox_x_max] | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of the first tensor is abnormal"); | |||
| int h_in = input[0]->shape()[0]; | |||
| int w_in = input[0]->shape()[1]; | |||
| std::vector<cv::Rect> bounding_boxes; | |||
| for (int64_t i = 0; i < num_boxes; ++i) { | |||
| // bbox coordinates are floats relative to the image width and height | |||
| float y_min, y_max, x_min, x_max; | |||
| RETURN_IF_NOT_OK(input[1]->GetItemAt<float>(&y_min, {i})); | |||
| RETURN_IF_NOT_OK(input[2]->GetItemAt<float>(&y_max, {i})); | |||
| RETURN_IF_NOT_OK(input[3]->GetItemAt<float>(&x_min, {i})); | |||
| RETURN_IF_NOT_OK(input[4]->GetItemAt<float>(&x_max, {i})); | |||
| bounding_boxes.emplace_back(static_cast<int>(x_min * w_in), static_cast<int>(y_min * h_in), | |||
| static_cast<int>((x_max - x_min) * w_in), static_cast<int>((y_max - y_min) * h_in)); | |||
| } | |||
| cv::Rect output_box; | |||
| bool should_crop = false; | |||
| // go over iterations, if no satisfying box found we return the original image | |||
| for (int32_t t = 0; t < max_attempts_; ++t) { | |||
| // try to generate random box | |||
| RETURN_IF_NOT_OK(GenerateRandomCropBox(h_in, w_in, aspect_ratio_, crop_ratio_lb_, crop_ratio_ub_, | |||
| box_gen_attempts_, // int maxIter, should not be needed here | |||
| &output_box, seed_)); | |||
| RETURN_IF_NOT_OK(CheckOverlapConstraint(output_box, | |||
| bounding_boxes, // have to change, should take tensor or add bbox logic | |||
| intersect_ratio_, &should_crop)); | |||
| if (should_crop) { | |||
| // found a box to crop | |||
| break; | |||
| } | |||
| } | |||
| // essentially we have to check this again at the end to return original tensor | |||
| if (should_crop) { | |||
| std::shared_ptr<Tensor> out; | |||
| RETURN_IF_NOT_OK(Crop(input[0], &out, output_box.x, output_box.y, output_box.width, output_box.height)); | |||
| output->push_back(out); | |||
| } else { | |||
| output->push_back(input[0]); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs, | |||
| std::vector<TensorShape>& outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | |||
| outputs.clear(); | |||
| TensorShape out = TensorShape{-1, -1}; | |||
| if (inputs[0].Rank() == 2) outputs.emplace_back(out); | |||
| if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); | |||
| if (!outputs.empty()) return Status::OK(); | |||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | |||
| } | |||
| Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); | |||
| outputs[0] = inputs[0]; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,72 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ | |||
| #define DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ | |||
| #include <memory> | |||
| #include <random> | |||
| #include <vector> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DistortBoundingBoxCropOp : public TensorOp { | |||
| public: | |||
| // Default values, also used by python_bindings.cc | |||
| static const int32_t kDefMaxAttempts; | |||
| static const int32_t kDefBoxGenAttempts; | |||
| // Constructor for DistortBoundingBoxCropOp | |||
| // @param max_attempts tries before the crop happens | |||
| // @param box_gen_attempts crop box generation attempts | |||
| // @param aspect_ratio aspect ratio of the generated crop box | |||
| // @param intersect_ratio area overlap ratio, condition for crop only if area over lap between the generated | |||
| // crop box has sufficient overlap with any 1 bounding box | |||
| // @param crop_ratio_lb the crop ratio lower bound | |||
| // @param crop_ratio_ub the crop ratio upper bound | |||
| // @param seed | |||
| DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, float crop_ratio_ub, | |||
| int32_t max_attempts = kDefMaxAttempts, int32_t box_gen_attempts = kDefBoxGenAttempts); | |||
| ~DistortBoundingBoxCropOp() override = default; | |||
| void Print(std::ostream& out) const override { | |||
| out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; | |||
| } | |||
| Status Compute(const std::vector<std::shared_ptr<Tensor>>& input, | |||
| std::vector<std::shared_ptr<Tensor>>* output) override; | |||
| uint32_t NumInput() override { return 5; } | |||
| Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; | |||
| Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; | |||
| private: | |||
| int32_t max_attempts_; | |||
| int32_t box_gen_attempts_; | |||
| float aspect_ratio_; | |||
| float intersect_ratio_; | |||
| float crop_ratio_lb_; | |||
| float crop_ratio_ub_; | |||
| std::mt19937 rnd_; | |||
| uint32_t seed_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ | |||
| @@ -636,76 +636,10 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||
| return Status::OK(); | |||
| } | |||
| Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, | |||
| cv::Rect *crop_box, uint32_t seed) { | |||
| try { | |||
| std::mt19937 rnd; | |||
| rnd.seed(GetSeed()); | |||
| if (input_height <= 0 || input_width <= 0 || ratio <= 0.0 || lb <= 0.0 || lb > ub) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid inputs GenerateRandomCropBox"); | |||
| } | |||
| std::uniform_real_distribution<float> rd_crop_ratio(lb, ub); | |||
| float crop_ratio; | |||
| int crop_width, crop_height; | |||
| bool crop_success = false; | |||
| int64_t input_area = input_height * input_width; | |||
| for (auto i = 0; i < max_itr; i++) { | |||
| crop_ratio = rd_crop_ratio(rnd); | |||
| crop_width = static_cast<int32_t>(std::round(std::sqrt(input_area * static_cast<double>(crop_ratio) / ratio))); | |||
| crop_height = static_cast<int32_t>(std::round(crop_width * ratio)); | |||
| if (crop_width <= input_width && crop_height <= input_height) { | |||
| crop_success = true; | |||
| break; | |||
| } | |||
| } | |||
| if (crop_success == false) { | |||
| ratio = static_cast<float>(input_height) / input_width; | |||
| crop_ratio = rd_crop_ratio(rnd); | |||
| crop_width = static_cast<int>(std::lround(std::sqrt(input_area * static_cast<double>(crop_ratio) / ratio))); | |||
| crop_height = static_cast<int>(std::lround(crop_width * ratio)); | |||
| crop_height = (crop_height > input_height) ? input_height : crop_height; | |||
| crop_width = (crop_width > input_width) ? input_width : crop_width; | |||
| } | |||
| std::uniform_int_distribution<> rd_x(0, input_width - crop_width); | |||
| std::uniform_int_distribution<> rd_y(0, input_height - crop_height); | |||
| *crop_box = cv::Rect(rd_x(rnd), rd_y(rnd), crop_width, crop_height); | |||
| return Status::OK(); | |||
| } catch (const cv::Exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("error in GenerateRandomCropBox."); | |||
| } | |||
| } | |||
| Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Rect> &bounding_boxes, | |||
| float min_intersect_ratio, bool *is_satisfied) { | |||
| try { | |||
| // not satisfied if the crop box contains no pixel | |||
| if (crop_box.area() < 1.0) { | |||
| *is_satisfied = false; | |||
| } | |||
| for (const auto &b_box : bounding_boxes) { | |||
| const float b_box_area = b_box.area(); | |||
| // not satisfied if the bounding box contains no pixel | |||
| if (b_box_area < 1.0) { | |||
| continue; | |||
| } | |||
| const float intersect_ratio = (crop_box & b_box).area() / b_box_area; | |||
| if (intersect_ratio >= min_intersect_ratio) { | |||
| *is_satisfied = true; | |||
| break; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } catch (const cv::Exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("error in CheckOverlapConstraint."); | |||
| } | |||
| } | |||
| Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height, | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r, uint8_t fill_g, | |||
| uint8_t fill_b) { | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, | |||
| uint8_t fill_g, uint8_t fill_b) { | |||
| try { | |||
| std::mt19937 rnd; | |||
| rnd.seed(GetSeed()); | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||
| if (input_cv->mat().data == nullptr || (input_cv->Rank() != 3 && input_cv->shape()[2] != 3)) { | |||
| RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); | |||
| @@ -731,8 +665,8 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp | |||
| // rows in cv mat refers to the height of the cropped box | |||
| // we determine h_start and w_start using two different distributions as erasing is used by two different | |||
| // image augmentations. The bounds are also different in each case. | |||
| int32_t h_start = (bounded) ? height_distribution_bound(rnd) : (height_distribution_unbound(rnd) - box_height); | |||
| int32_t w_start = (bounded) ? width_distribution_bound(rnd) : (width_distribution_unbound(rnd) - box_width); | |||
| int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); | |||
| int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); | |||
| int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; | |||
| int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; | |||
| @@ -744,9 +678,9 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp | |||
| for (int x = h_start; x < max_height; x++) { | |||
| if (random_color) { | |||
| // fill each box with a random value | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = static_cast<int32_t>(normal_distribution(rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = static_cast<int32_t>(normal_distribution(rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[2] = static_cast<int32_t>(normal_distribution(rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = static_cast<int32_t>(normal_distribution(*rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = static_cast<int32_t>(normal_distribution(*rnd)); | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[2] = static_cast<int32_t>(normal_distribution(*rnd)); | |||
| } else { | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = fill_r; | |||
| input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = fill_g; | |||
| @@ -196,12 +196,6 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||
| // @param output: Adjusted image of same shape and type. | |||
| Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue); | |||
| Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, | |||
| cv::Rect *crop_box, uint32_t seed = std::mt19937::default_seed); | |||
| Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Rect> &bounding_boxes, | |||
| float min_intersect_ratio, bool *is_satisfied); | |||
| // Masks out a random section from the image with set dimension | |||
| // @param input: input Tensor | |||
| // @param output: cutOut Tensor | |||
| @@ -214,8 +208,8 @@ Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Re | |||
| // @param fill_g: green fill value for erase | |||
| // @param fill_b: blue fill value for erase. | |||
| Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height, | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r = 0, | |||
| uint8_t fill_g = 0, uint8_t fill_b = 0); | |||
| int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, | |||
| uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); | |||
| // Pads the input image and puts the padded image in the output | |||
| // @param input: input Tensor | |||
| @@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ | |||
| rnd_.seed(GetSeed()); | |||
| } | |||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) { | |||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); | |||
| @@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std: | |||
| (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); | |||
| return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); | |||
| } | |||
| Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) { | |||
| Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | |||
| outputs.clear(); | |||
| TensorShape out = TensorShape{target_height_, target_width_}; | |||
| @@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs | |||
| if (!outputs.empty()) return Status::OK(); | |||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | |||
| } | |||
| Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { | |||
| Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { | |||
| double scale, aspect; | |||
| *crop_width = w_in; | |||
| *crop_height = h_in; | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/image/uniform_aug_op.h" | |||
| #include "dataset/kernels/py_func_op.h" | |||
| #include "dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const int UniformAugOp::kDefNumOps = 2; | |||
| UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) { | |||
| std::shared_ptr<TensorOp> tensor_op; | |||
| // iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ | |||
| for (auto op : op_list) { | |||
| if (py::isinstance<py::function>(op)) { | |||
| // python op | |||
| tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>()); | |||
| } else if (py::isinstance<TensorOp>(op)) { | |||
| // C++ op | |||
| tensor_op = op.cast<std::shared_ptr<TensorOp>>(); | |||
| } | |||
| tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); | |||
| } | |||
| rnd_.seed(GetSeed()); | |||
| } | |||
| // compute method to apply uniformly random selected augmentations from a list | |||
| Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||
| std::vector<std::shared_ptr<Tensor>> *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| // variables to generate random number to select ops from the list | |||
| std::vector<int> random_indexes; | |||
| // variables to copy the result to output if it is not already | |||
| std::vector<std::shared_ptr<Tensor>> even_out; | |||
| std::vector<std::shared_ptr<Tensor>> *even_out_ptr = &even_out; | |||
| int count = 1; | |||
| // select random indexes for candidates to be applied | |||
| for (int i = 0; i < num_ops_; ++i) { | |||
| random_indexes.insert(random_indexes.end(), | |||
| std::uniform_int_distribution<int>(0, tensor_op_list_.size() - 1)(rnd_)); | |||
| } | |||
| for (auto it = random_indexes.begin(); it != random_indexes.end(); ++it) { | |||
| // Do NOT apply the op, if second random generator returned zero | |||
| if (std::uniform_int_distribution<int>(0, 1)(rnd_)) { | |||
| continue; | |||
| } | |||
| std::shared_ptr<TensorOp> tensor_op = tensor_op_list_[*it]; | |||
| // apply python/C++ op | |||
| if (count == 1) { | |||
| (*tensor_op).Compute(input, output); | |||
| } else if (count % 2 == 0) { | |||
| (*tensor_op).Compute(*output, even_out_ptr); | |||
| } else { | |||
| (*tensor_op).Compute(even_out, output); | |||
| } | |||
| count++; | |||
| } | |||
| // copy the result to output if it is not in output | |||
| if (count == 1) { | |||
| *output = input; | |||
| } else if ((count % 2 == 1)) { | |||
| (*output).swap(even_out); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ | |||
| #define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ | |||
| #include <memory> | |||
| #include <random> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/kernels/py_func_op.h" | |||
| #include "pybind11/stl.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class UniformAugOp : public TensorOp { | |||
| public: | |||
| // Default number of Operations to be applied | |||
| static const int kDefNumOps; | |||
| // Constructor for UniformAugOp | |||
| // @param list op_list: list of candidate python operations | |||
| // @param list num_ops: number of augemtation operations to applied | |||
| UniformAugOp(py::list op_list, int32_t num_ops); | |||
| ~UniformAugOp() override = default; | |||
| void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } | |||
| // Overrides the base class compute function | |||
| // @return Status - The error code return | |||
| Status Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||
| std::vector<std::shared_ptr<Tensor>> *output) override; | |||
| private: | |||
| int32_t num_ops_; | |||
| std::vector<std::shared_ptr<TensorOp>> tensor_op_list_; | |||
| std::mt19937 rnd_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ | |||
| @@ -18,6 +18,9 @@ | |||
| #include "dataset/util/random.h" | |||
| #if defined(_WIN32) || defined(_WIn64) | |||
| #include <stdlib.h> | |||
| #endif | |||
| #include <limits> | |||
| #include <memory> | |||
| #include <random> | |||
| @@ -33,7 +36,9 @@ uint32_t GetSeed() { | |||
| uint32_t seed = GlobalContext::config_manager()->seed(); | |||
| if (seed == std::mt19937::default_seed) { | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| std::random_device random_device; | |||
| unsigned int number; | |||
| rand_s(&number); | |||
| std::mt19937 random_device{static_cast<uint32_t>(number)}; | |||
| #else | |||
| std::random_device random_device("/dev/urandom"); | |||
| #endif | |||
| @@ -18,6 +18,8 @@ | |||
| #include <limits.h> | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include <sys/syscall.h> | |||
| #else | |||
| #include <stdlib.h> | |||
| #endif | |||
| #include <unistd.h> | |||
| #include <random> | |||
| @@ -49,7 +51,9 @@ int Services::GetLWP() { return syscall(SYS_gettid); } | |||
| std::string Services::GetUniqueID() { | |||
| const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| std::mt19937 gen{std::random_device{}()}; | |||
| unsigned int number; | |||
| rand_s(&number); | |||
| std::mt19937 gen{static_cast<uint32_t>(number)}; | |||
| #else | |||
| std::mt19937 gen{std::random_device{"/dev/urandom"}()}; | |||
| #endif | |||
| @@ -22,7 +22,7 @@ | |||
| namespace mindspore { | |||
| constexpr char PARALLEL_STRATEGY[] = "strategy"; | |||
| void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); | |||
| void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); | |||
| } // namespace mindspore | |||
| @@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; | |||
| // get MindSpore Intermediate Representation Path | |||
| std::string GetMsIrPath(void) { | |||
| std::string path; | |||
| const char* path_ptr = getenv("MS_IR_PATH"); | |||
| const char *path_ptr = getenv("MS_IR_PATH"); | |||
| if (path_ptr != nullptr) { | |||
| path = path_ptr; | |||
| char real_path[PATH_MAX] = {0}; | |||
| @@ -62,13 +62,13 @@ std::string GetMsIrPath(void) { | |||
| return path; | |||
| } | |||
| std::string dump_obj(const py::object& obj, const std::string& path) { | |||
| std::string dump_obj(const py::object &obj, const std::string &path) { | |||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | |||
| py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); | |||
| return py::str(name); | |||
| } | |||
| py::object load_obj(const std::string& path) { | |||
| py::object load_obj(const std::string &path) { | |||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | |||
| py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); | |||
| return obj; | |||
| @@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) { | |||
| // ============================================= MindSpore IR Exporter ============================================= | |||
| std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||
| std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) { | |||
| abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); | |||
| TypePtr type = dyn_cast<Type>(nd->Type()); | |||
| std::ostringstream oss; | |||
| @@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const { | |||
| std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const { | |||
| std::string pkl_path = GetMsIrPath(); | |||
| // if not specified env 'MS_IR_PATH', do not create any files | |||
| if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { | |||
| @@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca | |||
| return file_prefix + file_name; | |||
| } | |||
| int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) { | |||
| int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { | |||
| if (func_graph == nullptr || param == nullptr) { | |||
| return -1; | |||
| } | |||
| @@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& | |||
| // try to find index of parameter for SymbolicKeyInstance from all exported graphs | |||
| // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different | |||
| int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { | |||
| int AnfExporter::GetParamIndexFromExported(const AnfNodePtr ¶m) { | |||
| if (param == nullptr) { | |||
| return -1; | |||
| } | |||
| int ret = -1; | |||
| for (const auto& item : exported) { | |||
| for (const auto &item : exported) { | |||
| auto pram_iter = item.second.find(param); | |||
| if (pram_iter != item.second.end()) { | |||
| return pram_iter->second; | |||
| @@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { | |||
| return ret; | |||
| } | |||
| std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) { | |||
| std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| return GetValueText(fg, node->value()); | |||
| } | |||
| std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) { | |||
| std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) { | |||
| auto py_funcs = mt_func_graph->GetPyFunctions(); | |||
| if (py_funcs.empty()) { | |||
| return ""; | |||
| @@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||
| oss << "{"; | |||
| bool is_first = true; | |||
| for (const auto& py_func : py_funcs) { | |||
| for (const auto &py_func : py_funcs) { | |||
| if (is_first) { | |||
| is_first = false; | |||
| } else { | |||
| @@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||
| * ├── GradOperation | |||
| * └── TupleAdd | |||
| */ | |||
| std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) { | |||
| std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) { | |||
| if (meta_func_graph == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_ | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||
| std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) { | |||
| std::ostringstream oss; | |||
| if (prim == nullptr) { | |||
| return oss.str(); | |||
| @@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||
| if (prim->isa<prim::DoSignaturePrimitive>()) { | |||
| auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim); | |||
| auto& func = do_signature->function(); | |||
| auto &func = do_signature->function(); | |||
| if (func->isa<Primitive>()) { | |||
| auto sig_prim = dyn_cast<Primitive>(func); | |||
| oss << sig_prim->GetAttrsText(); | |||
| @@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { | |||
| std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) { | |||
| std::ostringstream oss; | |||
| if (ns == nullptr) { | |||
| return oss.str(); | |||
| @@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, | |||
| const SymbolicKeyInstancePtr& sym_inst) { | |||
| std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, | |||
| const SymbolicKeyInstancePtr &sym_inst) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(sym_inst); | |||
| AnfNodePtr sym_node = sym_inst->node(); | |||
| @@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||
| std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| // output ValueList, ValueTuple | |||
| ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value); | |||
| @@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||
| std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>(); | |||
| oss << "{"; | |||
| bool first_flag = true; | |||
| for (const auto& elem : dict->value()) { | |||
| for (const auto &elem : dict->value()) { | |||
| if (first_flag) { | |||
| first_flag = false; | |||
| } else { | |||
| @@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { | |||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| if (check_integrity_) { | |||
| @@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||
| std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| bool is_null_ptr = (func_graph == nullptr || value == nullptr); | |||
| if (is_null_ptr) { | |||
| @@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu | |||
| } | |||
| // this function is used to output node in CNode's inputs | |||
| std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, int>& apply_map) { | |||
| std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, int> &apply_map) { | |||
| std::ostringstream oss; | |||
| if (func_graph == nullptr || node == nullptr) { | |||
| return oss.str(); | |||
| @@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An | |||
| return oss.str(); | |||
| } | |||
| void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map) { | |||
| void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) { | |||
| bool first_flag = true; | |||
| for (const AnfNodePtr& param : parameters) { | |||
| for (const AnfNodePtr ¶m : parameters) { | |||
| if (first_flag) { | |||
| first_flag = false; | |||
| ofs << " "; | |||
| @@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNode | |||
| } | |||
| } | |||
| void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& node) { | |||
| void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| // output type of each input argument | |||
| auto& inputs = node->inputs(); | |||
| auto &inputs = node->inputs(); | |||
| if (inputs.size() > 1) { | |||
| ofs << " #("; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| @@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod | |||
| ofs << " #scope: " << node->scope()->name(); | |||
| } | |||
| void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, | |||
| const FuncGraphPtr& func_graph) { | |||
| void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, | |||
| const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| int idx = 1; | |||
| std::map<AnfNodePtr, int> apply_map; | |||
| for (const AnfNodePtr& node : nodes) { | |||
| for (const AnfNodePtr &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| @@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr> | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto& inputs = cnode->inputs(); | |||
| auto &inputs = cnode->inputs(); | |||
| std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); | |||
| // non-return node | |||
| if (node != func_graph->get_return()) { | |||
| @@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr> | |||
| } | |||
| } | |||
| void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) { | |||
| void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun | |||
| ofs << "}\n"; | |||
| } | |||
| void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { | |||
| void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt | |||
| ofs.close(); | |||
| } | |||
| void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||
| void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||
| if (graphs.empty()) { | |||
| return; | |||
| } | |||
| @@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector | |||
| param_index = 1; | |||
| for (const auto& tagged_graph : graphs) { | |||
| for (const auto &tagged_graph : graphs) { | |||
| tagged_cnodes_ = tagged_graph.second; | |||
| ExportOneFuncGraph(ofs, tagged_graph.first); | |||
| tagged_cnodes_.clear(); | |||
| @@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) { | |||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap | |||
| ChangeFileMode(filename, S_IRUSR); | |||
| } | |||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||
| AnfExporter exporter("", false); | |||
| ChangeFileMode(filename, S_IRWXU); | |||
| exporter.ExportFuncGraph(filename, graphs); | |||
| @@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graph | |||
| ChangeFileMode(filename, S_IRUSR); | |||
| } | |||
| #else | |||
| void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { | |||
| void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { | |||
| << "please recompile source to enable it. See help of building script."; | |||
| } | |||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -732,7 +732,7 @@ enum Token : int { | |||
| TOK_ERROR // file read error | |||
| }; | |||
| std::map<Token, const char*> token_text = { | |||
| std::map<Token, const char *> token_text = { | |||
| {TOK_INVALID, "invalid"}, // invalid token | |||
| {TOK_LPARENTHESIS, "("}, // ( left parenthesis | |||
| {TOK_RPARENTHESIS, ")"}, // ) right parenthesis | |||
| @@ -761,14 +761,14 @@ std::map<Token, const char*> token_text = { | |||
| class Lexer { | |||
| public: | |||
| // filename is checked in ImportIR; | |||
| explicit Lexer(const char* filename) : fin(filename) {} | |||
| explicit Lexer(const char *filename) : fin(filename) {} | |||
| ~Lexer() { | |||
| try { | |||
| if (fin.is_open()) { | |||
| fin.close(); | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Exception when closing file"; | |||
| } catch (...) { | |||
| std::string exName(abi::__cxa_current_exception_type()->name()); | |||
| @@ -776,7 +776,7 @@ class Lexer { | |||
| } | |||
| } | |||
| bool IsSingleCharToken(char ch, Token* token_ptr) { | |||
| bool IsSingleCharToken(char ch, Token *token_ptr) { | |||
| // clang-format off | |||
| std::unordered_map<char, Token> char_to_token = { | |||
| {'(', TOK_LPARENTHESIS}, | |||
| @@ -806,7 +806,7 @@ class Lexer { | |||
| Token GetNextToken() { | |||
| #ifdef DEBUG | |||
| Token token = GetNextTokenInner(); | |||
| const char* str = token_text[token]; | |||
| const char *str = token_text[token]; | |||
| std::string text = (str == nullptr ? GetTokenText() : str); | |||
| MS_LOG(DEBUG) << "------Parse token] " << text; | |||
| return token; | |||
| @@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE; | |||
| class IrParser { | |||
| public: | |||
| explicit IrParser(const char* filename) : lexer_(filename) {} | |||
| explicit IrParser(const char *filename) : lexer_(filename) {} | |||
| ~IrParser() {} | |||
| py::object LoadObject(const std::string& file_name) const { | |||
| py::object LoadObject(const std::string &file_name) const { | |||
| std::string pkl_path = GetMsIrPath(); | |||
| py::object default_obj = load_obj(pkl_path + "/" + file_name); | |||
| return default_obj; | |||
| @@ -1087,7 +1087,7 @@ class IrParser { | |||
| MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); | |||
| } | |||
| Token ParseParent(FuncGraphPtr* const parent_ptr) { | |||
| Token ParseParent(FuncGraphPtr *const parent_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_IDENTIFIER) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1168,7 +1168,7 @@ class IrParser { | |||
| return func_graph; | |||
| } | |||
| FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) { | |||
| FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) { | |||
| Token tok = lexer_.SkipWhiteToken(); | |||
| while (tok == TOK_VARIABLE) { | |||
| if (ParseStatement(func_graph) == nullptr) { | |||
| @@ -1264,56 +1264,56 @@ class IrParser { | |||
| return func_graph; | |||
| } | |||
| void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const { | |||
| void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = dtype; | |||
| } | |||
| void SetTupleType(TypePtr* ptr) { | |||
| void SetTupleType(TypePtr *ptr) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<Tuple>(); | |||
| } | |||
| void SetTupleType(TypePtr* ptr, const TypePtrList& elems) { | |||
| void SetTupleType(TypePtr *ptr, const TypePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<Tuple>(elems); | |||
| } | |||
| void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector<int>&) { | |||
| void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<TensorType>(elem_type); | |||
| } | |||
| void SetListType(TypePtr* ptr) { | |||
| void SetListType(TypePtr *ptr) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<List>(); | |||
| } | |||
| void SetListType(TypePtr* ptr, const TypePtrList& elems) { | |||
| void SetListType(TypePtr *ptr, const TypePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<List>(elems); | |||
| } | |||
| void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) { | |||
| void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<JTagged>(elem); | |||
| } | |||
| void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const { | |||
| void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| @@ -1321,45 +1321,45 @@ class IrParser { | |||
| } | |||
| // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} | |||
| void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const { | |||
| void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractNone>(); | |||
| } | |||
| void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {} | |||
| void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {} | |||
| void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {} | |||
| void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {} | |||
| void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { | |||
| void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| // if one of elems is nullptr, just return | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractTuple>(elems); | |||
| } | |||
| void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector<int>& shape) { | |||
| void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape); | |||
| } | |||
| void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { | |||
| void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractList>(elems); | |||
| } | |||
| void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) { | |||
| void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| @@ -1367,7 +1367,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) { | |||
| Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) { | |||
| if (tok != TOK_LBRACKET) { | |||
| MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; | |||
| return tok; | |||
| @@ -1415,7 +1415,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { | |||
| Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| if (ptr != nullptr) { | |||
| SetBasicType(ptr, std::make_shared<TensorType>()); | |||
| @@ -1454,7 +1454,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| bool IsNumberType(const std::string& type, TypeId* typeid_ptr) { | |||
| bool IsNumberType(const std::string &type, TypeId *typeid_ptr) { | |||
| // clang-format off | |||
| static std::unordered_map<std::string, TypeId> basic_types = { | |||
| {"Bool", kNumberTypeBool}, | |||
| @@ -1486,7 +1486,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) { | |||
| void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) { | |||
| TypePtr dtype = nullptr; | |||
| std::unordered_map<int, TypePtr> type_map = { | |||
| @@ -1519,7 +1519,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) { | |||
| Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) { | |||
| if (type == "NoneType") { | |||
| SetBasicType(ptr, std::make_shared<TypeNone>()); | |||
| return lexer_.GetNextToken(); | |||
| @@ -1541,7 +1541,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { | |||
| Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { | |||
| if (tok != TOK_IDENTIFIER) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1588,11 +1588,11 @@ class IrParser { | |||
| } | |||
| } | |||
| Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) { | |||
| Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) { | |||
| return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); | |||
| } | |||
| Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { | |||
| Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { | |||
| Token tok = ParseAttribute(func_graph, prim); | |||
| while (tok == TOK_COMMA) { | |||
| tok = ParseAttribute(func_graph, prim); | |||
| @@ -1603,7 +1603,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { | |||
| Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { | |||
| Token tok = lexer_.GetNextToken(); | |||
| if (tok != TOK_IDENTIFIER) { | |||
| return TOK_ERROR; | |||
| @@ -1670,7 +1670,7 @@ class IrParser { | |||
| return tok == TOK_RPARENTHESIS ? func_graph : nullptr; | |||
| } | |||
| FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { | |||
| FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) { | |||
| Token tok = ParseArgument(func_graph, inputs_ptr); | |||
| while (tok == TOK_COMMA) { | |||
| tok = ParseArgument(func_graph, inputs_ptr); | |||
| @@ -1681,9 +1681,9 @@ class IrParser { | |||
| return func_graph; | |||
| } | |||
| AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) { | |||
| AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string ¶m_name) { | |||
| while (func_graph != nullptr) { | |||
| for (auto& ptr : func_graph->parameters()) { | |||
| for (auto &ptr : func_graph->parameters()) { | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| ParameterPtr param = ptr->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| @@ -1701,12 +1701,12 @@ class IrParser { | |||
| return nullptr; | |||
| } | |||
| bool Match(const std::string& str, const std::string& pattern) const { | |||
| bool Match(const std::string &str, const std::string &pattern) const { | |||
| return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; | |||
| } | |||
| template <typename T, typename V> | |||
| Token ParseScalar(ValuePtr* const val_ptr) { | |||
| Token ParseScalar(ValuePtr *const val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_NUMBER) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1725,7 +1725,7 @@ class IrParser { | |||
| } | |||
| template <typename VT, typename V, typename T> | |||
| Token ParseScalar(ValuePtr* const val_ptr, Token tok) { | |||
| Token ParseScalar(ValuePtr *const val_ptr, Token tok) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| *val_ptr = std::make_shared<T>(); | |||
| return tok; | |||
| @@ -1735,7 +1735,7 @@ class IrParser { | |||
| } | |||
| template <typename VT, typename V, typename T, const unsigned nbits> | |||
| Token ParseScalar(ValuePtr* const val_ptr, Token tok) { | |||
| Token ParseScalar(ValuePtr *const val_ptr, Token tok) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| *val_ptr = std::make_shared<T>(nbits); | |||
| return tok; | |||
| @@ -1745,7 +1745,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| T StringToScalar(const std::string& text) { | |||
| T StringToScalar(const std::string &text) { | |||
| std::stringstream ss; | |||
| T value; | |||
| ss << text; | |||
| @@ -1753,7 +1753,7 @@ class IrParser { | |||
| return value; | |||
| } | |||
| Token ParseTensor(ValuePtr* const val_ptr) { | |||
| Token ParseTensor(ValuePtr *const val_ptr) { | |||
| // parse type | |||
| TypeId type; | |||
| if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | |||
| @@ -1803,7 +1803,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParsePrimType(Token tok, PrimType* prim_type_ptr) { | |||
| Token ParsePrimType(Token tok, PrimType *prim_type_ptr) { | |||
| if (tok != TOK_LBRACE) { | |||
| return tok; | |||
| } | |||
| @@ -1830,7 +1830,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { | |||
| Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1855,7 +1855,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { | |||
| Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { | |||
| if (tok != TOK_LBRACE) { | |||
| return tok; | |||
| } | |||
| @@ -1868,7 +1868,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseBoolValue(const std::string& key, bool* val_ptr) { | |||
| Token ParseBoolValue(const std::string &key, bool *val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1892,7 +1892,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) { | |||
| Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_LBRACE) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1920,7 +1920,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) { | |||
| Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) { | |||
| if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1951,7 +1951,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) { | |||
| Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_AT_FILE) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1984,7 +1984,7 @@ class IrParser { | |||
| return next; | |||
| } | |||
| Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) { | |||
| Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) { | |||
| if (Match(id, "MultitypeFuncGraph::")) { | |||
| std::string name = id.substr(strlen("MultitypeFuncGraph::")); | |||
| auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name); | |||
| @@ -2024,8 +2024,8 @@ class IrParser { | |||
| } | |||
| } | |||
| Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr, | |||
| AnfNodePtr* const node_ptr = nullptr) { | |||
| Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr, | |||
| AnfNodePtr *const node_ptr = nullptr) { | |||
| if (id == "None") { | |||
| *val_ptr = std::make_shared<None>(); | |||
| return lexer_.GetNextToken(); | |||
| @@ -2075,9 +2075,9 @@ class IrParser { | |||
| } | |||
| } | |||
| Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid, | |||
| const std::vector<ValuePtr>& elems, const std::vector<AnfNodePtr>& nodes, | |||
| ValuePtr* const val_ptr, AnfNodePtr* node_ptr) { | |||
| Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid, | |||
| const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes, | |||
| ValuePtr *const val_ptr, AnfNodePtr *node_ptr) { | |||
| if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { | |||
| if (node_is_valid && node_ptr != nullptr) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -2097,8 +2097,8 @@ class IrParser { | |||
| } | |||
| } | |||
| Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, | |||
| AnfNodePtr* node_ptr = nullptr) { | |||
| Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, | |||
| AnfNodePtr *node_ptr = nullptr) { | |||
| Token left_tok = tok; | |||
| std::vector<ValuePtr> elems; | |||
| @@ -2138,7 +2138,7 @@ class IrParser { | |||
| return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); | |||
| } | |||
| Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) { | |||
| Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) { | |||
| // tuple or list | |||
| if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { | |||
| return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); | |||
| @@ -2152,7 +2152,7 @@ class IrParser { | |||
| return TOK_ERROR; | |||
| } | |||
| Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr, | |||
| Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr, | |||
| Token tok = TOK_INVALID) { | |||
| if (tok == TOK_INVALID) { | |||
| tok = lexer_.GetNextToken(); | |||
| @@ -2193,7 +2193,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseArgument(const FuncGraphPtr& func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { | |||
| Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) { | |||
| Token tok = lexer_.GetNextToken(); | |||
| if (tok == TOK_RPARENTHESIS) { | |||
| return tok; | |||
| @@ -2208,7 +2208,7 @@ class IrParser { | |||
| return tok; | |||
| } | |||
| const std::vector<FuncGraphPtr>& GetFuncGraphs() const { return func_graphs_; } | |||
| const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; } | |||
| private: | |||
| Lexer lexer_; | |||
| @@ -2226,14 +2226,14 @@ class IrParser { | |||
| std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter | |||
| }; | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string& filename) { | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string &filename) { | |||
| IrParser parser(filename.c_str()); | |||
| parser.ParseFile(); | |||
| return parser.GetFuncGraphs(); | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Func graph is nullptr"; | |||
| return; | |||
| @@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||
| return; | |||
| } | |||
| char real_path[PATH_MAX] = {0}; | |||
| char* real_path_ret = nullptr; | |||
| char *real_path_ret = nullptr; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); | |||
| #else | |||
| @@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||
| ChangeFileMode(file_path, S_IRUSR); | |||
| } | |||
| #else | |||
| void DumpIRProto(const FuncGraphPtr&, const std::string&) { | |||
| void DumpIRProto(const FuncGraphPtr &, const std::string &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -39,7 +39,7 @@ | |||
| namespace mindspore { | |||
| struct ParamPtrEqual { | |||
| bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { | |||
| bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const { | |||
| const ParameterPtr param1 = dyn_cast<Parameter>(t1); | |||
| const ParameterPtr param2 = dyn_cast<Parameter>(t2); | |||
| @@ -52,7 +52,7 @@ struct ParamPtrEqual { | |||
| }; | |||
| struct ParamPtrHasher { | |||
| std::size_t operator()(AnfNodePtr const& param) const { | |||
| std::size_t operator()(AnfNodePtr const ¶m) const { | |||
| const ParameterPtr parameter = dyn_cast<Parameter>(param); | |||
| if (parameter == nullptr) { | |||
| return 0; | |||
| @@ -64,39 +64,39 @@ struct ParamPtrHasher { | |||
| class AnfExporter { | |||
| public: | |||
| explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) | |||
| explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false) | |||
| : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { | |||
| func_graph_set.clear(); | |||
| exported.clear(); | |||
| } | |||
| virtual ~AnfExporter() {} | |||
| void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | |||
| void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); | |||
| void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); | |||
| void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs); | |||
| protected: | |||
| virtual std::string GetNodeType(const AnfNodePtr& nd); | |||
| int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); | |||
| int GetParamIndexFromExported(const AnfNodePtr& param); | |||
| std::string DumpObject(const py::object& obj, const std::string& category) const; | |||
| std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); | |||
| std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); | |||
| std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); | |||
| std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetPrimitiveText(const PrimitivePtr& prim); | |||
| std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetNameSpaceText(const parse::NameSpacePtr& ns); | |||
| std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); | |||
| std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, int>& apply_map); | |||
| void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); | |||
| void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map); | |||
| void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); | |||
| void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph); | |||
| virtual std::string GetNodeType(const AnfNodePtr &nd); | |||
| int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); | |||
| int GetParamIndexFromExported(const AnfNodePtr ¶m); | |||
| std::string DumpObject(const py::object &obj, const std::string &category) const; | |||
| std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); | |||
| std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); | |||
| std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); | |||
| std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetPrimitiveText(const PrimitivePtr &prim); | |||
| std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetNameSpaceText(const parse::NameSpacePtr &ns); | |||
| std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph); | |||
| std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, int> &apply_map); | |||
| void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); | |||
| void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map); | |||
| void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); | |||
| void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); | |||
| int param_index; | |||
| OrderedSet<FuncGraphPtr> func_graph_set{}; | |||
| @@ -108,16 +108,16 @@ class AnfExporter { | |||
| abstract::AnfNodeConfigPtr node_cfg_ = nullptr; | |||
| }; | |||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); | |||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs); | |||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph); | |||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs); | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string& filename); | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string &filename); | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); | |||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); | |||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | |||
| @@ -34,7 +34,7 @@ namespace draw { | |||
| namespace { | |||
| // Only for ValueNode | |||
| std::string ValueType(const ValueNodePtr& node) { | |||
| std::string ValueType(const ValueNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) { | |||
| return v->type_name(); | |||
| } | |||
| std::string ReplaceSpecialChar(const std::string& str) { | |||
| std::string ReplaceSpecialChar(const std::string &str) { | |||
| std::ostringstream oss; | |||
| for (size_t i = 0; i < str.size(); i++) { | |||
| if (str[i] == '<') { | |||
| @@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) { | |||
| } // namespace | |||
| // API of debug utils | |||
| void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs, | |||
| void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs, | |||
| bool is_user) { | |||
| if (sub_graphs == nullptr) { | |||
| return; | |||
| } | |||
| for (auto& nd : nodes) { | |||
| for (auto &nd : nodes) { | |||
| MS_EXCEPTION_IF_NULL(nd); | |||
| auto sub_graph = nd->func_graph(); | |||
| if (sub_graph != nullptr) { | |||
| @@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st | |||
| } | |||
| } | |||
| void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, | |||
| OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) { | |||
| void DrawValueNodes(const std::vector<AnfNodePtr> &nodes, | |||
| OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) { | |||
| if (sub_graphs == nullptr) { | |||
| return; | |||
| } | |||
| int dup_idx = 0; | |||
| for (auto& nd : nodes) { | |||
| for (auto& t : SuccIncoming(nd)) { | |||
| for (auto &nd : nodes) { | |||
| for (auto &t : SuccIncoming(nd)) { | |||
| MS_EXCEPTION_IF_NULL(t); | |||
| MS_EXCEPTION_IF_NULL(nd); | |||
| if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { | |||
| @@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, | |||
| } | |||
| } | |||
| void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) { | |||
| void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) { | |||
| if (digraph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD | |||
| } | |||
| // Draw edge | |||
| for (auto& nd : nodes) { | |||
| for (auto &nd : nodes) { | |||
| auto succs = SuccIncoming(nd); | |||
| auto num = succs.size(); | |||
| for (size_t i = 0; i < num; i++) { | |||
| auto& t = succs.at(i); | |||
| auto &t = succs.at(i); | |||
| MS_EXCEPTION_IF_NULL(t); | |||
| if (t->isa<ValueNode>() || t->isa<Parameter>()) { | |||
| if ((!is_user) || (i != 0)) { | |||
| @@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD | |||
| } | |||
| } | |||
| void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) { | |||
| void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use | |||
| DrawValueNodes(nodes, &sub_graphs); | |||
| // Draw subgraph | |||
| for (const auto& gsub : sub_graphs) { | |||
| for (const auto &gsub : sub_graphs) { | |||
| digraph->SubGraph(gsub.first, gsub.second); | |||
| } | |||
| @@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { | |||
| void Draw(const std::string &filename, const FuncGraphPtr &func_graph) { | |||
| const std::string dot_suffix = ".dot"; | |||
| std::string filename_with_suffix = | |||
| (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; | |||
| DrawByOpt(filename_with_suffix, func_graph, false); | |||
| } | |||
| void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { | |||
| void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { | |||
| DrawByOpt(filename, func_graph, true); | |||
| } | |||
| #else | |||
| void Draw(const std::string&, const FuncGraphPtr&) { | |||
| void Draw(const std::string &, const FuncGraphPtr &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) { | |||
| << "please recompile source to enable it. See help of building script."; | |||
| } | |||
| void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { | |||
| void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) { | |||
| return "plaintext"; | |||
| } | |||
| std::string Graphviz::Color(const AnfNodePtr& node) { | |||
| std::string Graphviz::Color(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -259,7 +259,7 @@ void BaseDigraph::Start() { | |||
| buffer_ << "compound=true" << std::endl; | |||
| } | |||
| void BaseDigraph::Head(const AnfNodePtr& node, int id) { | |||
| void BaseDigraph::Head(const AnfNodePtr &node, int id) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| @@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) { | |||
| } | |||
| } | |||
| void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { | |||
| void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| @@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { | |||
| buffer_ << ":" << idx; | |||
| } | |||
| void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { | |||
| void BaseDigraph::Tail(const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -304,12 +304,12 @@ void BaseDigraph::End() { | |||
| } | |||
| } | |||
| void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { | |||
| void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { | |||
| buffer_ << "parameters_" << key << "[shape=plaintext "; | |||
| buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>"; | |||
| buffer_ << "<tr><td>parameters</td></tr>"; | |||
| int count = 0; | |||
| for (auto& parameter : key->parameters()) { | |||
| for (auto ¶meter : key->parameters()) { | |||
| buffer_ << "<tr><td>"; | |||
| buffer_ << parameter->ToString(); | |||
| auto py_p = dyn_cast<Parameter>(parameter)->default_param(); | |||
| @@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { | |||
| buffer_ << "</table>>,];"; | |||
| } | |||
| void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) { | |||
| void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) { | |||
| if (key == nullptr || gsub == nullptr) { | |||
| return; | |||
| } | |||
| @@ -361,12 +361,12 @@ Digraph::~Digraph() { | |||
| if (fout_.is_open()) { | |||
| fout_.close(); | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Exception when closing file " << filename_; | |||
| } | |||
| } | |||
| static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { | |||
| static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) { | |||
| size_t start_pos = 0; | |||
| while ((start_pos = str.find(from, start_pos)) != std::string::npos) { | |||
| (void)str.replace(start_pos, from.length(), to); | |||
| @@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st | |||
| return str; | |||
| } | |||
| static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||
| static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph_obj); | |||
| graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node) | |||
| << "'>"; | |||
| @@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||
| graph_obj->buffer() << "</td></tr>"; | |||
| graph_obj->buffer() << "<tr><td align='left'>"; | |||
| int i = 0; | |||
| for (const auto& attr : attrs) { | |||
| for (const auto &attr : attrs) { | |||
| if (i != 0) { | |||
| graph_obj->buffer() << "<br/>"; | |||
| } | |||
| @@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||
| graph_obj->buffer() << "</table>>,"; | |||
| } | |||
| static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { | |||
| if (graph_obj == nullptr || node == nullptr) { | |||
| return; | |||
| } | |||
| @@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| } | |||
| } | |||
| static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) { | |||
| if (graph_obj == nullptr || node == nullptr || node->size() == 0) { | |||
| return; | |||
| } | |||
| @@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| } | |||
| graph_obj->buffer() << ">"; | |||
| int i = 0; | |||
| for (auto& attr : attrs) { | |||
| for (auto &attr : attrs) { | |||
| if (i != 0) { | |||
| graph_obj->buffer() << "<br/>"; | |||
| } | |||
| @@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() { | |||
| if (fout_.is_open()) { | |||
| fout_.close(); | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "exception when closing file " << filename_; | |||
| } | |||
| } | |||
| @@ -31,9 +31,9 @@ namespace parse = mindspore::parse; | |||
| class Graphviz { | |||
| public: | |||
| Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} | |||
| Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {} | |||
| explicit Graphviz(const std::string& name) : name_(name) {} | |||
| explicit Graphviz(const std::string &name) : name_(name) {} | |||
| virtual ~Graphviz() {} | |||
| @@ -41,8 +41,8 @@ class Graphviz { | |||
| virtual void End() {} | |||
| virtual std::string Shape(AnfNodePtr node); | |||
| std::string Color(const AnfNodePtr& node); | |||
| std::ostringstream& buffer() { return buffer_; } | |||
| std::string Color(const AnfNodePtr &node); | |||
| std::ostringstream &buffer() { return buffer_; } | |||
| std::ostringstream buffer_; | |||
| protected: | |||
| @@ -53,8 +53,8 @@ class Graphviz { | |||
| class BaseDigraph : public Graphviz { | |||
| public: | |||
| BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} | |||
| explicit BaseDigraph(const std::string& name) : Graphviz(name) {} | |||
| BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {} | |||
| explicit BaseDigraph(const std::string &name) : Graphviz(name) {} | |||
| ~BaseDigraph() override = default; | |||
| virtual void Node(AnfNodePtr node, int id = 0) = 0; | |||
| @@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz { | |||
| void Start() override; | |||
| void End() override; | |||
| virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); | |||
| void FuncGraphParameters(const FuncGraphPtr& key); | |||
| void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub); | |||
| void FuncGraphParameters(const FuncGraphPtr &key); | |||
| void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub); | |||
| const std::string& name() const { return name_; } | |||
| const std::string &name() const { return name_; } | |||
| protected: | |||
| void Head(const AnfNodePtr& node, int id = 0); | |||
| void Tail(const AnfNodePtr& node, int idx, int id = 0); | |||
| void Tail(const FuncGraphPtr& func_graph); | |||
| void Head(const AnfNodePtr &node, int id = 0); | |||
| void Tail(const AnfNodePtr &node, int idx, int id = 0); | |||
| void Tail(const FuncGraphPtr &func_graph); | |||
| }; | |||
| class Digraph : public BaseDigraph { | |||
| public: | |||
| Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} | |||
| explicit Digraph(const std::string& name) : BaseDigraph(name) {} | |||
| Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} | |||
| explicit Digraph(const std::string &name) : BaseDigraph(name) {} | |||
| ~Digraph() override; | |||
| void Node(AnfNodePtr node, int id = 0) override; | |||
| @@ -86,8 +86,8 @@ class Digraph : public BaseDigraph { | |||
| class ModelDigraph : public BaseDigraph { | |||
| public: | |||
| ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} | |||
| explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} | |||
| ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} | |||
| explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {} | |||
| ~ModelDigraph() override; | |||
| std::string Shape(AnfNodePtr node) override; | |||
| @@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph { | |||
| }; | |||
| // API to draw | |||
| void Draw(const std::string& filename, const FuncGraphPtr& func_graph); | |||
| void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | |||
| void Draw(const std::string &filename, const FuncGraphPtr &func_graph); | |||
| void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); | |||
| } // namespace draw | |||
| } // namespace mindspore | |||
| @@ -33,38 +33,38 @@ class ProtoExporter { | |||
| ProtoExporter() {} | |||
| ~ProtoExporter() {} | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||
| private: | |||
| void InitModelInfo(); | |||
| void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto); | |||
| std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr); | |||
| void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto); | |||
| void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto); | |||
| void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto); | |||
| void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto); | |||
| void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto); | |||
| void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto); | |||
| void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); | |||
| void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); | |||
| void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr); | |||
| void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto); | |||
| void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, std::map<AnfNodePtr, size_t>* const_map_ptr, | |||
| irpb::GraphProto* graph_proto); | |||
| void ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto); | |||
| void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto); | |||
| std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr); | |||
| void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto); | |||
| void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto); | |||
| void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto); | |||
| void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto); | |||
| void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto); | |||
| void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto); | |||
| void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); | |||
| void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); | |||
| void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr); | |||
| void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto); | |||
| void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr, | |||
| irpb::GraphProto *graph_proto); | |||
| void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto); | |||
| static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } | |||
| irpb::ModelProto model_; | |||
| }; | |||
| static irpb::DataType GetNumberDataType(const TypePtr& type) { | |||
| static irpb::DataType GetNumberDataType(const TypePtr &type) { | |||
| switch (type->type_id()) { | |||
| case kNumberTypeBool: | |||
| return irpb::DT_BOOL; | |||
| @@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) { | |||
| } | |||
| } | |||
| void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) { | |||
| void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { | |||
| if (type_proto == nullptr) { | |||
| return; | |||
| } | |||
| @@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||
| type_proto->set_data_type(irpb::DT_TENSOR); | |||
| if (shape != nullptr && shape->isa<abstract::Shape>()) { | |||
| abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); | |||
| for (const auto& elem : shape_info->shape()) { | |||
| for (const auto &elem : shape_info->shape()) { | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | |||
| } | |||
| } | |||
| } else if (type->isa<Tuple>()) { | |||
| TuplePtr tuple_type = dyn_cast<Tuple>(type); | |||
| type_proto->set_data_type(irpb::DT_TUPLE); | |||
| for (const auto& elem_type : tuple_type->elements()) { | |||
| for (const auto &elem_type : tuple_type->elements()) { | |||
| SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | |||
| } | |||
| } else if (type->isa<TypeType>()) { | |||
| @@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||
| } else if (type->isa<List>()) { | |||
| ListPtr list_type = dyn_cast<List>(type); | |||
| type_proto->set_data_type(irpb::DT_LIST); | |||
| for (const auto& elem_type : list_type->elements()) { | |||
| for (const auto &elem_type : list_type->elements()) { | |||
| SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | |||
| } | |||
| } else if (type->isa<TypeAnything>()) { | |||
| @@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||
| } | |||
| } | |||
| void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) { | |||
| void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) { | |||
| if (node == nullptr || type_proto == nullptr) { | |||
| return; | |||
| } | |||
| SetNodeOutputType(node->Type(), node->Shape(), type_proto); | |||
| } | |||
| void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| if (val->isa<StringImm>()) { | |||
| const StringImmPtr& value = dyn_cast<StringImm>(val); | |||
| const StringImmPtr &value = dyn_cast<StringImm>(val); | |||
| value_proto->set_dtype(irpb::DT_STRING); | |||
| value_proto->set_str_val(value->value()); | |||
| } else if (val->isa<Scalar>()) { | |||
| @@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value | |||
| } else if (val->isa<tensor::Tensor>()) { | |||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val); | |||
| value_proto->set_dtype(irpb::DT_TENSOR); | |||
| irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val(); | |||
| irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); | |||
| tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); | |||
| for (auto& elem : tensor_ptr->shape()) { | |||
| for (auto &elem : tensor_ptr->shape()) { | |||
| tensor_proto->add_dims(elem); | |||
| } | |||
| } else if (val->isa<TensorType>()) { | |||
| value_proto->set_dtype(irpb::DT_TYPE); | |||
| irpb::TypeProto* type_proto = value_proto->mutable_type_val(); | |||
| irpb::TypeProto *type_proto = value_proto->mutable_type_val(); | |||
| type_proto->set_data_type(irpb::DT_TENSOR); | |||
| TypePtr elem_type = dyn_cast<TensorType>(val)->element(); | |||
| type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); | |||
| @@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value | |||
| } | |||
| } | |||
| void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| if (val->isa<BoolImm>()) { | |||
| const BoolImmPtr& value = dyn_cast<BoolImm>(val); | |||
| const BoolImmPtr &value = dyn_cast<BoolImm>(val); | |||
| value_proto->set_dtype(irpb::DT_BOOL); | |||
| value_proto->set_bool_val(value->value()); | |||
| } else if (val->isa<Int8Imm>()) { | |||
| const Int8ImmPtr& value = dyn_cast<Int8Imm>(val); | |||
| const Int8ImmPtr &value = dyn_cast<Int8Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT8); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<Int16Imm>()) { | |||
| const Int16ImmPtr& value = dyn_cast<Int16Imm>(val); | |||
| const Int16ImmPtr &value = dyn_cast<Int16Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT16); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<Int32Imm>()) { | |||
| const Int32ImmPtr& value = dyn_cast<Int32Imm>(val); | |||
| const Int32ImmPtr &value = dyn_cast<Int32Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT32); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<Int64Imm>()) { | |||
| const Int64ImmPtr& value = dyn_cast<Int64Imm>(val); | |||
| const Int64ImmPtr &value = dyn_cast<Int64Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT64); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<UInt8Imm>()) { | |||
| const UInt8ImmPtr& value = dyn_cast<UInt8Imm>(val); | |||
| const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT8); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<UInt16Imm>()) { | |||
| const UInt16ImmPtr& value = dyn_cast<UInt16Imm>(val); | |||
| const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT16); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<UInt32Imm>()) { | |||
| const UInt32ImmPtr& value = dyn_cast<UInt32Imm>(val); | |||
| const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT32); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<UInt64Imm>()) { | |||
| const UInt64ImmPtr& value = dyn_cast<UInt64Imm>(val); | |||
| const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT64); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<FP32Imm>()) { | |||
| const FP32ImmPtr& value = dyn_cast<FP32Imm>(val); | |||
| const FP32ImmPtr &value = dyn_cast<FP32Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_FLOAT32); | |||
| value_proto->set_float_val(value->value()); | |||
| } else if (val->isa<FP64Imm>()) { | |||
| const FP64ImmPtr& value = dyn_cast<FP64Imm>(val); | |||
| const FP64ImmPtr &value = dyn_cast<FP64Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_FLOAT64); | |||
| value_proto->set_double_val(value->value()); | |||
| } else { | |||
| @@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val | |||
| } | |||
| } | |||
| void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| if (val->isa<ValueTuple>()) { | |||
| const ValueTuplePtr& value = dyn_cast<ValueTuple>(val); | |||
| const ValueTuplePtr &value = dyn_cast<ValueTuple>(val); | |||
| value_proto->set_dtype(irpb::DT_TUPLE); | |||
| for (const auto& item : value->value()) { | |||
| for (const auto &item : value->value()) { | |||
| SetValueToProto(item, value_proto->add_values()); | |||
| } | |||
| } else if (val->isa<ValueList>()) { | |||
| const ValueListPtr& value = dyn_cast<ValueList>(val); | |||
| const ValueListPtr &value = dyn_cast<ValueList>(val); | |||
| value_proto->set_dtype(irpb::DT_LIST); | |||
| for (const auto& item : value->value()) { | |||
| for (const auto &item : value->value()) { | |||
| SetValueToProto(item, value_proto->add_values()); | |||
| } | |||
| } | |||
| } | |||
| void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| value_proto->set_dtype(irpb::DT_DICT); | |||
| for (const auto& item : val->value()) { | |||
| irpb::NamedValueProto* named_val = value_proto->add_dict_val(); | |||
| for (const auto &item : val->value()) { | |||
| irpb::NamedValueProto *named_val = value_proto->add_dict_val(); | |||
| named_val->set_key(item.first); | |||
| SetValueToProto(item.second, named_val->mutable_value()); | |||
| } | |||
| } | |||
| void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) { | |||
| void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) { | |||
| if (node == nullptr || node_proto == nullptr) { | |||
| return; | |||
| } | |||
| @@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& | |||
| MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); | |||
| } | |||
| const PrimitivePtr& prim = GetValueNode<PrimitivePtr>(node); | |||
| const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node); | |||
| node_proto->set_op_type(prim->name()); | |||
| for (const auto& attr : prim->attrs()) { | |||
| irpb::AttributeProto* attr_proto = node_proto->add_attribute(); | |||
| for (const auto &attr : prim->attrs()) { | |||
| irpb::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name(attr.first); | |||
| SetValueToProto(attr.second, attr_proto->mutable_value()); | |||
| } | |||
| node_proto->set_scope(node->scope()->name()); | |||
| } | |||
| std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr) { | |||
| std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr) { | |||
| if (node == nullptr || const_map_ptr == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt | |||
| MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; | |||
| } | |||
| std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { | |||
| std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return ""; | |||
| } | |||
| InitModelInfo(); | |||
| irpb::GraphProto* graph_proto = model_.mutable_graph(); | |||
| irpb::GraphProto *graph_proto = model_.mutable_graph(); | |||
| ExportFuncGraph(func_graph, graph_proto); | |||
| return model_.SerializeAsString(); | |||
| } | |||
| void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { | |||
| if (func_graph == nullptr || graph_proto == nullptr) { | |||
| return; | |||
| } | |||
| @@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP | |||
| ExportValueNodes(const_map, graph_proto); | |||
| } | |||
| void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { | |||
| if (func_graph == nullptr || graph_proto == nullptr) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> parameters = func_graph->parameters(); | |||
| for (auto& param : parameters) { | |||
| irpb::ParameterProto* param_proto = graph_proto->add_parameters(); | |||
| for (auto ¶m : parameters) { | |||
| irpb::ParameterProto *param_proto = graph_proto->add_parameters(); | |||
| param_proto->set_name(param->ToString()); | |||
| SetNodeOutputType(param, param_proto->mutable_type()); | |||
| @@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph | |||
| } | |||
| } | |||
| void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr) { | |||
| void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr) { | |||
| if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { | |||
| return; | |||
| } | |||
| // topo sort nodes | |||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | |||
| std::map<AnfNodePtr, size_t> apply_map; | |||
| for (const AnfNodePtr& node : nodes) { | |||
| for (const AnfNodePtr &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| @@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt | |||
| } | |||
| } | |||
| void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) { | |||
| if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || | |||
| graph_proto == nullptr) { | |||
| return; | |||
| @@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||
| auto apply_idx = apply_map_ptr->size() + 1; | |||
| (*apply_map_ptr)[node] = apply_idx; | |||
| auto& inputs = node->inputs(); | |||
| auto &inputs = node->inputs(); | |||
| if (inputs.size() < 1) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr op = inputs[0]; | |||
| irpb::NodeProto* node_proto = graph_proto->add_node(); | |||
| irpb::NodeProto *node_proto = graph_proto->add_node(); | |||
| // CNode/ConstGraph/Const/Parameter | |||
| if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) { | |||
| @@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||
| // process OP inputs | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| irpb::InputProto* input_proto = node_proto->add_input(); | |||
| irpb::InputProto *input_proto = node_proto->add_input(); | |||
| input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); | |||
| std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); | |||
| input_proto->set_name(id); | |||
| @@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||
| } | |||
| } | |||
| void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) { | |||
| if (ret_node == nullptr || !ret_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Graph return node is illegal"; | |||
| } | |||
| @@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const | |||
| if (graph_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | |||
| } | |||
| irpb::OutputProto* output_proto = graph_proto->add_outputs(); | |||
| irpb::OutputProto *output_proto = graph_proto->add_outputs(); | |||
| if (output_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "output_proto is nullptr"; | |||
| } | |||
| @@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const | |||
| SetNodeOutputType(arg, output_proto->mutable_type()); | |||
| } | |||
| static bool CompareValue(const std::pair<AnfNodePtr, size_t>& x, const std::pair<AnfNodePtr, size_t>& y) { | |||
| static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) { | |||
| return x.second < y.second; | |||
| } | |||
| void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) { | |||
| std::vector<std::pair<AnfNodePtr, size_t>> nodes; | |||
| (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), | |||
| [](const std::pair<AnfNodePtr, size_t>& item) { return item; }); | |||
| [](const std::pair<AnfNodePtr, size_t> &item) { return item; }); | |||
| sort(nodes.begin(), nodes.end(), CompareValue); | |||
| for (auto& item : nodes) { | |||
| for (auto &item : nodes) { | |||
| if (graph_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | |||
| } | |||
| irpb::NamedValueProto* named_value = graph_proto->add_const_vals(); | |||
| irpb::NamedValueProto *named_value = graph_proto->add_const_vals(); | |||
| MS_EXCEPTION_IF_NULL(named_value); | |||
| named_value->set_key(GetConstNodeId(item.second)); | |||
| SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); | |||
| @@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_m | |||
| void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { | |||
| ProtoExporter exporter; | |||
| return exporter.GetFuncGraphProtoString(func_graph); | |||
| } | |||
| @@ -36,7 +36,7 @@ Dump::Dump() | |||
| dump_iter_(0), | |||
| cur_iter_(0) {} | |||
| bool Dump::IsKernelNeedDump(const std::string& kernel_name) { | |||
| bool Dump::IsKernelNeedDump(const std::string &kernel_name) { | |||
| if (dump_mode_ == 0) { | |||
| // Dump All Kernels mode | |||
| return true; | |||
| @@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) { | |||
| return false; | |||
| } | |||
| bool Dump::ParseDumpConfig(const std::string& dump_config_file) { | |||
| bool Dump::ParseDumpConfig(const std::string &dump_config_file) { | |||
| std::ifstream jsonFile(dump_config_file); | |||
| if (!jsonFile.is_open()) { | |||
| MS_LOG(ERROR) << dump_config_file << " open failed."; | |||
| @@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) { | |||
| return true; | |||
| } | |||
| bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { | |||
| bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) { | |||
| if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || | |||
| dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || | |||
| dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || | |||
| @@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { | |||
| return true; | |||
| } | |||
| bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { | |||
| bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { | |||
| auto trans_flag = dumpSettings.at("trans_flag"); | |||
| auto enable = dumpSettings.at("enable"); | |||
| auto mode = dumpSettings.at("mode"); | |||
| @@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { | |||
| dump_path_ = path; | |||
| dump_net_name_ = net_name; | |||
| dump_iter_ = iteration; | |||
| for (const auto& kernel : kernels) { | |||
| for (const auto &kernel : kernels) { | |||
| dump_kernels_.push_back(kernel); | |||
| } | |||
| return true; | |||
| } | |||
| bool Dump::SetDumpConfFromJsonFile() { | |||
| const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); | |||
| const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); | |||
| if (config_path_str != nullptr) { | |||
| MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; | |||
| } else { | |||
| @@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() { | |||
| return ParseDumpConfig(dump_config_file); | |||
| } | |||
| bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { | |||
| bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) { | |||
| if (filename.empty() || data == nullptr || len == 0) { | |||
| MS_LOG(ERROR) << "Incorrect parameter."; | |||
| return false; | |||
| @@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) | |||
| MS_LOG(ERROR) << "Open file " << realpath << " fail."; | |||
| return false; | |||
| } | |||
| (void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len)); | |||
| (void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len)); | |||
| fd.close(); | |||
| return true; | |||
| } | |||
| bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { | |||
| bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) { | |||
| MS_EXCEPTION_IF_NULL(outpath); | |||
| auto path_split_pos = inpath.find_last_of('/'); | |||
| if (path_split_pos == std::string::npos) { | |||
| @@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { | |||
| return true; | |||
| } | |||
| bool Dump::CreateNotExistDirs(const std::string& path) { | |||
| bool Dump::CreateNotExistDirs(const std::string &path) { | |||
| std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem(); | |||
| MS_EXCEPTION_IF_NULL(fs); | |||
| char temp_path[PATH_MAX] = {0}; | |||
| @@ -43,11 +43,11 @@ class Dump { | |||
| uint32_t cur_iter() const { return cur_iter_; } | |||
| bool IsKernelNeedDump(const std::string& kernel_name); | |||
| bool IsKernelNeedDump(const std::string &kernel_name); | |||
| bool SetDumpConfFromJsonFile(); | |||
| static bool DumpToFile(const std::string& filename, const void* data, size_t len); | |||
| static bool DumpToFile(const std::string &filename, const void *data, size_t len); | |||
| protected: | |||
| bool dump_enable_; | |||
| @@ -59,14 +59,14 @@ class Dump { | |||
| uint32_t cur_iter_; | |||
| std::vector<std::string> dump_kernels_; | |||
| static bool GetRealPath(const std::string& inpath, std::string* outpath); | |||
| static bool GetRealPath(const std::string &inpath, std::string *outpath); | |||
| static bool CreateNotExistDirs(const std::string& path); | |||
| static bool CreateNotExistDirs(const std::string &path); | |||
| private: | |||
| bool ParseDumpConfig(const std::string& dump_config_file); | |||
| bool IsConfigExist(const nlohmann::json& dumpSettings); | |||
| bool IsConfigValid(const nlohmann::json& dumpSettings); | |||
| bool ParseDumpConfig(const std::string &dump_config_file); | |||
| bool IsConfigExist(const nlohmann::json &dumpSettings); | |||
| bool IsConfigValid(const nlohmann::json &dumpSettings); | |||
| }; | |||
| using DumpConfPtr = std::shared_ptr<Dump>; | |||
| @@ -23,7 +23,7 @@ | |||
| #include "pipeline/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { | |||
| std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { | |||
| std::string temp_line = line; | |||
| if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && | |||
| tip != kSourceLineTipDiscard) { | |||
| @@ -101,14 +101,14 @@ DebugInfo::DebugInfo() { | |||
| name_ = ""; | |||
| } | |||
| DebugInfo::DebugInfo(const std::string& name) { | |||
| DebugInfo::DebugInfo(const std::string &name) { | |||
| InitValueFromContext(); | |||
| unique_id_ = gen_unique_id(); | |||
| debug_id_ = -1; | |||
| name_ = name; | |||
| } | |||
| DebugInfo::DebugInfo(const LocationPtr& loc) { | |||
| DebugInfo::DebugInfo(const LocationPtr &loc) { | |||
| InitValueFromContext(); | |||
| unique_id_ = gen_unique_id(); | |||
| debug_id_ = -1; | |||
| @@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() { | |||
| } | |||
| int64_t DebugInfo::unique_id_through_copy() const { | |||
| TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info(); | |||
| TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info(); | |||
| if (trace_info != nullptr) { | |||
| if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) { | |||
| return trace_info->debug_info()->unique_id_through_copy(); | |||
| @@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() { | |||
| } | |||
| return DebugInfo::location(); | |||
| } | |||
| void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } | |||
| void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } | |||
| TraceContextPtr TraceManager::CurrentContextInfo() { | |||
| if (!TraceManager::trace_context_stack_.empty()) { | |||
| @@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() { | |||
| return nullptr; | |||
| } | |||
| void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { | |||
| void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { | |||
| TraceContextPtr context = std::make_shared<TraceContext>(location); | |||
| context->set_func_name(func_name); | |||
| TraceManager::trace_context_stack_.push(context); | |||
| } | |||
| void TraceManager::DebugTrace(const LocationPtr& location) { | |||
| void TraceManager::DebugTrace(const LocationPtr &location) { | |||
| TraceContextPtr context = std::make_shared<TraceContext>(location); | |||
| TraceManager::trace_context_stack_.push(context); | |||
| } | |||
| void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { | |||
| void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { | |||
| if (trace_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | |||
| } | |||
| @@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { | |||
| TraceManager::trace_context_stack_.push(context); | |||
| } | |||
| void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { | |||
| void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { | |||
| if (trace_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | |||
| } | |||
| @@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou | |||
| // Location class record the location in source code. | |||
| class Location { | |||
| public: | |||
| Location(const std::string& file_name, int line, int column, int line_end, int column_end) | |||
| Location(const std::string &file_name, int line, int column, int line_end, int column_end) | |||
| : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} | |||
| Location(const Location& loc) | |||
| Location(const Location &loc) | |||
| : file_name_(loc.file_name_), | |||
| line_(loc.line_), | |||
| column_(loc.column_), | |||
| @@ -77,21 +77,21 @@ class TraceManager { | |||
| TraceManager() = default; | |||
| ~TraceManager() = default; | |||
| static TraceContextPtr CurrentContextInfo(); | |||
| static void DebugTrace(const std::string& func_name, const LocationPtr& location); | |||
| static void DebugTrace(const LocationPtr& location); | |||
| static void DebugTrace(const TraceInfoPtr& trace_info); | |||
| static void DebugTrace(const std::string &func_name, const LocationPtr &location); | |||
| static void DebugTrace(const LocationPtr &location); | |||
| static void DebugTrace(const TraceInfoPtr &trace_info); | |||
| // debug trace with a cloned trace info with debug_info | |||
| static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); | |||
| static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); | |||
| static void EndTrace(); | |||
| static std::stack<TraceContextPtr> trace_context_stack_; | |||
| }; | |||
| class TraceGuard { | |||
| public: | |||
| explicit TraceGuard(const std::string func_name, const LocationPtr& location) { | |||
| explicit TraceGuard(const std::string func_name, const LocationPtr &location) { | |||
| TraceManager::DebugTrace(func_name, location); | |||
| } | |||
| explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } | |||
| explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } | |||
| ~TraceGuard() { TraceManager::EndTrace(); } | |||
| }; | |||
| @@ -106,23 +106,23 @@ class TraceContext { | |||
| public: | |||
| ~TraceContext() = default; | |||
| explicit TraceContext(const LocationPtr& loc) { | |||
| explicit TraceContext(const LocationPtr &loc) { | |||
| ProcessAttributeFromContext(); | |||
| location_ = loc; | |||
| } | |||
| explicit TraceContext(const std::string& func_name) { | |||
| explicit TraceContext(const std::string &func_name) { | |||
| ProcessAttributeFromContext(); | |||
| func_name_ = func_name; | |||
| } | |||
| explicit TraceContext(const TraceInfoPtr& trace_info) { | |||
| explicit TraceContext(const TraceInfoPtr &trace_info) { | |||
| ProcessAttributeFromContext(); | |||
| trace_info_ = trace_info; | |||
| } | |||
| void set_location(const LocationPtr& loc) { location_ = loc; } | |||
| void set_location(const LocationPtr &loc) { location_ = loc; } | |||
| LocationPtr location() { return location_; } | |||
| void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } | |||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | |||
| TraceInfoPtr trace_info() { return trace_info_; } | |||
| void set_func_name(const std::string& func_name) { func_name_ = func_name; } | |||
| void set_func_name(const std::string &func_name) { func_name_ = func_name; } | |||
| std::string func_name() { return func_name_; } | |||
| }; | |||
| @@ -130,9 +130,9 @@ class DebugInfo : public Base { | |||
| public: | |||
| DebugInfo(); | |||
| explicit DebugInfo(const std::string& name); | |||
| explicit DebugInfo(const std::string &name); | |||
| explicit DebugInfo(const LocationPtr& loc); | |||
| explicit DebugInfo(const LocationPtr &loc); | |||
| virtual ~DebugInfo() = default; | |||
| MS_DECLARE_PARENT(DebugInfo, Base); | |||
| @@ -141,12 +141,12 @@ class DebugInfo : public Base { | |||
| int64_t unique_id_through_copy() const; | |||
| std::string get_id() { return std::to_string(debug_id()); } | |||
| void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } | |||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | |||
| TraceInfoPtr trace_info() { return trace_info_; } | |||
| void set_location(const LocationPtr& loc) { location_ = loc; } | |||
| void set_location(const LocationPtr &loc) { location_ = loc; } | |||
| virtual LocationPtr location() { return location_; } | |||
| std::string name() { return name_; } | |||
| void set_name(const std::string& name) { name_ = name; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| virtual std::string debug_name(); | |||
| virtual std::string get_python_func_belonged() { return ""; } | |||
| @@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo { | |||
| py_func_belonged_ = context_info->func_name(); | |||
| } | |||
| } | |||
| explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { | |||
| explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { | |||
| if (TraceManager::CurrentContextInfo() != nullptr) { | |||
| auto context_info = TraceManager::CurrentContextInfo(); | |||
| py_func_belonged_ = context_info->func_name(); | |||
| @@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo { | |||
| ~NodeDebugInfo() override = default; | |||
| std::string debug_name() override; | |||
| void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); } | |||
| void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); } | |||
| std::shared_ptr<AnfNode> get_node() const { return node_.lock(); } | |||
| void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } | |||
| void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } | |||
| std::string get_python_func_belonged() override { return py_func_belonged_; } | |||
| AnfNodeWeakPtr node_; | |||
| std::string py_func_belonged_; | |||
| @@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo { | |||
| } | |||
| } | |||
| explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { | |||
| explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { | |||
| if (TraceManager::CurrentContextInfo() != nullptr) { | |||
| auto context_info = TraceManager::CurrentContextInfo(); | |||
| py_func_name_ = context_info->func_name(); | |||
| @@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo { | |||
| std::string debug_name() override; | |||
| LocationPtr location() override; | |||
| LocationPtr deco_location() { return deco_loc_; } | |||
| void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } | |||
| void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } | |||
| FuncGraphPtr get_graph() const { return func_graph_.lock(); } | |||
| void set_full_name(const std::string& name) { full_name_ = name; } | |||
| void set_full_name(const std::string &name) { full_name_ = name; } | |||
| std::string get_full_name() { return full_name_; } | |||
| void set_deco_location(const LocationPtr& deco_list_loc); | |||
| void set_deco_location(const LocationPtr &deco_list_loc); | |||
| std::string get_python_func_belonged() override { return py_func_name_; } | |||
| FuncGraphWeakPtr func_graph_; | |||
| LocationPtr deco_loc_; | |||
| @@ -31,7 +31,7 @@ struct NameWithTrace { | |||
| std::string name; | |||
| std::vector<std::string> trace_labels; | |||
| }; | |||
| static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { | |||
| static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { | |||
| switch (trace_label) { | |||
| case TraceLabelType::kShortSymbol: | |||
| return trace_info->symbol(); | |||
| @@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t | |||
| } | |||
| } | |||
| NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||
| NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||
| NameWithTrace trace_name; | |||
| // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node | |||
| auto temp_info = debug_info; | |||
| @@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe | |||
| return trace_name; | |||
| } | |||
| std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) { | |||
| std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) { | |||
| std::string tags = ""; | |||
| for (auto& itr : trace_labels) { | |||
| for (auto &itr : trace_labels) { | |||
| std::string symbol = itr; | |||
| tags = tags + symbol; | |||
| } | |||
| @@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st | |||
| } | |||
| // get the label name of the node debug info | |||
| std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||
| std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||
| NameWithTrace trace_name = RootName(debug_info, trace_label); | |||
| return CombineTraceTypes(trace_name.name, trace_name.trace_labels); | |||
| } | |||
| std::string CombineUniqueID(const DebugInfoPtr& debug_info) { | |||
| std::string CombineUniqueID(const DebugInfoPtr &debug_info) { | |||
| auto temp_info = debug_info; | |||
| std::string label = ""; | |||
| while (temp_info != nullptr) { | |||
| @@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) { | |||
| } | |||
| // get trace with unique id chain | |||
| std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); } | |||
| std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); } | |||
| std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||
| std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||
| if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { | |||
| return LabelStringUnique(debug_info); | |||
| } | |||
| @@ -29,7 +29,7 @@ namespace label_manage { | |||
| enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId }; | |||
| TraceLabelType GetGlobalTraceLabelType(); | |||
| void SetGlobalTraceLabelType(TraceLabelType label_type); | |||
| std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); | |||
| std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); | |||
| } // namespace label_manage | |||
| } // namespace mindspore | |||
| @@ -37,7 +37,7 @@ | |||
| namespace mindspore { | |||
| // namespace to support debug trace infomation | |||
| namespace trace { | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) { | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) { | |||
| if (abs == nullptr) { | |||
| return "Null Abstract"; | |||
| } | |||
| @@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { | |||
| return debug_with_loc_vec; | |||
| } | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { | |||
| auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); | |||
| if (debug_with_loc_vec.size() > 0) { | |||
| return debug_with_loc_vec[0]; | |||
| @@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { | |||
| } | |||
| } | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { | |||
| if (info == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| // a trace info identifies a node transform, so we can trace the node transform through | |||
| // a link of trace info and debug info | |||
| std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) { | |||
| std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) { | |||
| if (info_vec.size() < 1) { | |||
| return ""; | |||
| } | |||
| @@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL | |||
| return traced_info; | |||
| } | |||
| std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { | |||
| if (info == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| return ""; | |||
| } | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { | |||
| std::ostringstream oss; | |||
| if (info == nullptr) { | |||
| return ""; | |||
| @@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So | |||
| return oss.str(); | |||
| } | |||
| std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { | |||
| std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) { | |||
| std::ostringstream oss; | |||
| oss << "graph:" << graph->ToString() << " with args["; | |||
| auto params = graph->parameters(); | |||
| @@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas | |||
| return oss.str(); | |||
| } | |||
| void DumpInferStack(std::ostringstream& oss) { | |||
| auto& infer_stack = GetCurrenGraphInferStack(); | |||
| void DumpInferStack(std::ostringstream &oss) { | |||
| auto &infer_stack = GetCurrenGraphInferStack(); | |||
| if (infer_stack.empty()) { | |||
| return; | |||
| } | |||
| @@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) { | |||
| } | |||
| std::reverse(infer_vec.begin(), infer_vec.end()); | |||
| int index = 0; | |||
| for (auto& item : infer_vec) { | |||
| for (auto &item : infer_vec) { | |||
| auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first); | |||
| if (graph_infer == nullptr) { | |||
| MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; | |||
| @@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) { | |||
| } | |||
| void TraceGraphInfer() { | |||
| auto& infer_stack = GetCurrenGraphInferStack(); | |||
| auto &infer_stack = GetCurrenGraphInferStack(); | |||
| std::ostringstream oss; | |||
| if (infer_stack.empty()) { | |||
| return; | |||
| @@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter { | |||
| AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} | |||
| ~AnalyzedFuncGraphExporter() override = default; | |||
| void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs); | |||
| void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs); | |||
| private: | |||
| std::string GetNodeType(const AnfNodePtr& nd) override; | |||
| std::string GetNodeType(const AnfNodePtr &nd) override; | |||
| }; | |||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { | |||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; | |||
| auto& list = GetCNodeDebugStack(); | |||
| auto &list = GetCNodeDebugStack(); | |||
| for (size_t i = 0; i < list.size(); ++i) { | |||
| auto node_cfg = list[i]; | |||
| auto fg = node_cfg->context()->func_graph(); | |||
| @@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() { | |||
| exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); | |||
| } | |||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { | |||
| if (node_cfg_ == nullptr) { | |||
| return AnfExporter::GetNodeType(node); | |||
| } | |||
| @@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||
| return oss.str(); | |||
| } | |||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||
| const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) { | |||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, | |||
| const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) { | |||
| if (node_cfgs.empty()) { | |||
| MS_LOG(DEBUG) << "Node configs is empty"; | |||
| return; | |||
| @@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||
| auto tagged_func_graphs = CalcTaggedFuncGraphs(); | |||
| // first output graph on the analysis stack | |||
| for (const auto& node_cfg : node_cfgs) { | |||
| for (const auto &node_cfg : node_cfgs) { | |||
| auto fg = node_cfg->context()->func_graph(); | |||
| // the graph is already output, skip it | |||
| if (exported.find(fg) != exported.end()) { | |||
| @@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||
| ofs.close(); | |||
| } | |||
| void GetInferStackInfo(std::ostringstream& oss) { | |||
| void GetInferStackInfo(std::ostringstream &oss) { | |||
| MS_LOG(INFO) << "Get graph analysis information begin"; | |||
| auto stack = GetCNodeDebugStack(); | |||
| if (stack.empty()) { | |||
| @@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) { | |||
| static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; | |||
| // trace the cnode infer debug info | |||
| static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { | |||
| if (eval == nullptr) { | |||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | |||
| } | |||
| @@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An | |||
| } | |||
| } | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { | |||
| if (eval == nullptr) { | |||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | |||
| } | |||
| @@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { | |||
| } | |||
| } | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||
| void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } | |||
| std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; } | |||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; } | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() { | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() { | |||
| return graph_infer_stack; | |||
| } | |||
| void ClearTraceStack() { | |||
| @@ -31,19 +31,19 @@ | |||
| namespace mindspore { | |||
| namespace trace { | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, | |||
| SourceLineTip tip = kSourceLineTipNextLine); | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); | |||
| void TraceGraphInfer(); | |||
| void GetInferStackInfo(std::ostringstream& oss); | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); | |||
| void GetInferStackInfo(std::ostringstream &oss); | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); | |||
| void TraceInferCNodeLeave(); | |||
| std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack(); | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack(); | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); | |||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack(); | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack(); | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); | |||
| void ClearTraceStack(); | |||
| } // namespace trace | |||
| } // namespace mindspore | |||
| @@ -23,7 +23,7 @@ | |||
| #include "pipeline/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { | |||
| std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { | |||
| if (info == nullptr) { | |||
| return ""; | |||
| } | |||