Browse Source

!56 Synchronization code423 to ms-incubator

Merge pull request !56 from changzherui/syn-code423
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
6844ea633d
100 changed files with 3689 additions and 915 deletions
  1. +2
    -2
      .clang-format
  2. +0
    -2
      CMakeLists.txt
  3. +1
    -1
      README.md
  4. +3
    -1
      build.bat
  5. +9
    -9
      build.sh
  6. +1
    -1
      cmake/dependency_graphengine.cmake
  7. +7
    -8
      cmake/external_libs/tvm_gpu.cmake
  8. +10
    -6
      cmake/utils.cmake
  9. +58
    -0
      example/alexnet_cifar10/README.md
  10. +46
    -0
      example/convert_to_mindrecord/README.md
  11. +0
    -0
      example/convert_to_mindrecord/imagenet/__init__.py
  12. +122
    -0
      example/convert_to_mindrecord/imagenet/mr_api.py
  13. +8
    -0
      example/convert_to_mindrecord/run_imagenet.sh
  14. +6
    -0
      example/convert_to_mindrecord/run_template.sh
  15. +0
    -0
      example/convert_to_mindrecord/template/__init__.py
  16. +73
    -0
      example/convert_to_mindrecord/template/mr_api.py
  17. +152
    -0
      example/convert_to_mindrecord/writer.py
  18. +63
    -0
      example/lenet_mnist/README.md
  19. +1
    -1
      graphengine
  20. +9
    -4
      mindspore/_akg/gpu/__init__.py
  21. +4
    -4
      mindspore/_akg/gpu/hsigmoid.py
  22. +4
    -4
      mindspore/_akg/gpu/hsigmoid_grad.py
  23. +4
    -4
      mindspore/_akg/gpu/hswish.py
  24. +5
    -5
      mindspore/_akg/gpu/hswish_grad.py
  25. +40
    -0
      mindspore/_akg/gpu/less_equal.py
  26. +40
    -0
      mindspore/_akg/gpu/logical_and.py
  27. +40
    -0
      mindspore/_akg/gpu/logical_not.py
  28. +40
    -0
      mindspore/_akg/gpu/logical_or.py
  29. +40
    -0
      mindspore/_akg/gpu/sub.py
  30. +54
    -0
      mindspore/_akg/ops/math/less_equal.py
  31. +41
    -0
      mindspore/_akg/ops/math/logical_and.py
  32. +32
    -0
      mindspore/_akg/ops/math/logical_not.py
  33. +41
    -0
      mindspore/_akg/ops/math/logical_or.py
  34. +161
    -31
      mindspore/_checkparam.py
  35. +1
    -1
      mindspore/_extends/__init__.py
  36. +2
    -1
      mindspore/_extends/parse/resources.py
  37. +2
    -2
      mindspore/_extends/parse/trope.py
  38. +0
    -44
      mindspore/_extends/pynative_helper.py
  39. +0
    -4
      mindspore/ccsrc/CMakeLists.txt
  40. +33
    -11
      mindspore/ccsrc/common/trans.cc
  41. +6
    -3
      mindspore/ccsrc/common/trans.h
  42. +1
    -1
      mindspore/ccsrc/common/utils.cc
  43. +6
    -6
      mindspore/ccsrc/common/utils.h
  44. +3
    -3
      mindspore/ccsrc/dataset/CMakeLists.txt
  45. +97
    -4
      mindspore/ccsrc/dataset/api/de_pipeline.cc
  46. +9
    -3
      mindspore/ccsrc/dataset/api/de_pipeline.h
  47. +43
    -16
      mindspore/ccsrc/dataset/api/python_bindings.cc
  48. +2
    -0
      mindspore/ccsrc/dataset/core/client.h
  49. +1
    -1
      mindspore/ccsrc/dataset/core/tensor.cc
  50. +2
    -0
      mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
  51. +235
    -0
      mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc
  52. +172
    -0
      mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h
  53. +253
    -0
      mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
  54. +181
    -0
      mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
  55. +0
    -22
      mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
  56. +0
    -4
      mindspore/ccsrc/dataset/engine/datasetops/map_op.h
  57. +6
    -1
      mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc
  58. +15
    -21
      mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
  59. +1
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt
  60. +3
    -2
      mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
  61. +2
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
  62. +1
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt
  63. +85
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
  64. +58
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
  65. +0
    -3
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
  66. +1
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
  67. +5
    -1
      mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc
  68. +459
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
  69. +263
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
  70. +54
    -6
      mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
  71. +6
    -4
      mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
  72. +9
    -9
      mindspore/ccsrc/dataset/engine/datasetops/zip_op.h
  73. +3
    -3
      mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt
  74. +4
    -3
      mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc
  75. +1
    -0
      mindspore/ccsrc/dataset/kernels/image/cut_out_op.h
  76. +4
    -4
      mindspore/ccsrc/dataset/kernels/image/decode_op.h
  77. +0
    -117
      mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc
  78. +0
    -72
      mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h
  79. +7
    -73
      mindspore/ccsrc/dataset/kernels/image/image_utils.cc
  80. +2
    -8
      mindspore/ccsrc/dataset/kernels/image/image_utils.h
  81. +3
    -3
      mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc
  82. +87
    -0
      mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
  83. +60
    -0
      mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h
  84. +6
    -1
      mindspore/ccsrc/dataset/util/random.cc
  85. +5
    -1
      mindspore/ccsrc/dataset/util/services.cc
  86. +1
    -1
      mindspore/ccsrc/debug/anf_ir_dump.h
  87. +108
    -108
      mindspore/ccsrc/debug/anf_ir_utils.cc
  88. +33
    -33
      mindspore/ccsrc/debug/anf_ir_utils.h
  89. +32
    -32
      mindspore/ccsrc/debug/draw.cc
  90. +18
    -18
      mindspore/ccsrc/debug/draw.h
  91. +85
    -85
      mindspore/ccsrc/debug/dump_proto.cc
  92. +10
    -10
      mindspore/ccsrc/debug/e2e_dump.cc
  93. +7
    -7
      mindspore/ccsrc/debug/e2e_dump.h
  94. +9
    -9
      mindspore/ccsrc/debug/info.cc
  95. +26
    -26
      mindspore/ccsrc/debug/info.h
  96. +8
    -8
      mindspore/ccsrc/debug/label.cc
  97. +1
    -1
      mindspore/ccsrc/debug/label.h
  98. +24
    -24
      mindspore/ccsrc/debug/trace.cc
  99. +10
    -10
      mindspore/ccsrc/debug/trace.h
  100. +1
    -1
      mindspore/ccsrc/debug/trace_info.cc

+ 2
- 2
.clang-format View File

@@ -52,7 +52,7 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 2
Cpp11BracedListStyle: true
DerivePointerAlignment: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
@@ -94,7 +94,7 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
PointerAlignment: Right
RawStringFormats:
- Language: Cpp
Delimiters:


+ 0
- 2
CMakeLists.txt View File

@@ -1,8 +1,6 @@
cmake_minimum_required(VERSION 3.14)
project (MindSpore)

include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")

if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")


+ 1
- 1
README.md View File

@@ -179,7 +179,7 @@ Check out how MindSpore Open Governance [works](https://gitee.com/mindspore/comm

- [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers.
- IRC channel at `#mindspore` (only for meeting minutes logging purpose)
- Video Conferencing: meet.jit.si
- Video Conferencing: https://meet.jit.si
- Mailing-list: https://mailweb.mindspore.cn/postorius/lists

## Contributing


+ 3
- 1
build.bat View File

@@ -31,6 +31,7 @@ cd %CD%/mindspore
cmake -DCMAKE_BUILD_TYPE=Release -DENABLE_CPU=ON -DENABLE_MINDDATA=ON -DUSE_GLOG=ON -G "CodeBlocks - MinGW Makefiles" ../..
IF NOT %errorlevel% == 0 (
echo "cmake fail."
goto run_fail
)
@@ -40,6 +41,7 @@ IF "%1%" == "" (
cmake --build . --target package -- -j%1%
)
IF NOT %errorlevel% == 0 (
echo "build fail."
goto run_fail
)

@@ -49,6 +51,6 @@ goto run_eof

:run_fail
cd %BASEPATH%
echo "build fail."
set errorlevel=1

:run_eof

+ 9
- 9
build.sh View File

@@ -23,30 +23,30 @@ export BUILD_PATH="${BASEPATH}/build/"
usage()
{
echo "Usage:"
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge|cpu] [-m infer|train] \\"
echo " [-a on|off] [-g on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\"
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\"
echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\"
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K]"
echo ""
echo "Options:"
echo " -d Debug mode"
echo " -r Release mode, default mode"
echo " -v Display build command"
echo " -c Enable code coverage switch, default off"
echo " -t Run testcases switch, default on"
echo " -c Enable code coverage, default off"
echo " -t Run testcases, default on"
echo " -g Use glog to output log, default on"
echo " -h Print usage"
echo " -b Select other backend, available: \\"
echo " ge:graph engine, cpu"
echo " -m Select mode, available: infer, train, default is infer "
echo " ge:graph engine"
echo " -m Select graph engine backend mode, available: infer, train, default is infer"
echo " -a Enable ASAN, default off"
echo " -p Enable pipeline profile, default off"
echo " -p Enable pipeline profile, print to stdout, default off"
echo " -R Enable pipeline profile, record to json, default off"
echo " -i Enable increment building, default off"
echo " -L Enable load ANF-IR as input of 'infer', default off"
echo " -R Enable the time_line record, default off"
echo " -j[n] Set the threads when building (Default: -j8)"
echo " -e Use gpu, d or cpu"
echo " -P Enable dump anf graph to file in ProtoBuffer format, default on"
echo " -Q Enable dump end to end, default off"
echo " -Q Enable dump memory, default off"
echo " -D Enable dumping of function graph ir, default on"
echo " -z Compile dataset & mindrecord, default on"
echo " -M Enable MPI and NCCL for GPU training, default on"


+ 1
- 1
cmake/dependency_graphengine.cmake View File

@@ -64,7 +64,7 @@ set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
# force __FILE__ to show relative path of file, from source directory
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"' -Wno-builtin-macro-redefined")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined")
add_subdirectory(${GE_SOURCE_DIR}/src/common/graph)
if(ENABLE_D)
add_subdirectory(${GE_SOURCE_DIR}/src/ge/common)


+ 7
- 8
cmake/external_libs/tvm_gpu.cmake View File

@@ -1,16 +1,15 @@
set(incubator_tvm_gpu_CFLAGS "-pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2")
set(incubator_tvm_gpu_CXXFLAGS "-std=c++11 -pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2")
set(USE_CUDA "ON")
set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(incubator_tvm_gpu
VER 0.6.0
LIBS tvm
URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz
MD5 9cbbd32545a776023acabbba270449fe
CUSTOM_CMAKE ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/
SUBMODULES ${dlpack_DIRPATH} ${dmlc-core_DIRPATH} ${rang_DIRPATH}
SOURCEMODULES topi/python/topi python/tvm
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/find_library.patch
${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch
${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch
CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON)
include_directories(${incubator_tvm_gpu_INC})
add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm)
${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch
${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch
CMAKE_OPTION " ")
add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm)

+ 10
- 6
cmake/utils.cmake View File

@@ -205,7 +205,7 @@ set(MS_FIND_NO_DEFAULT_PATH ${MS_FIND_NO_DEFAULT_PATH} PARENT_SCOPE)
function(mindspore_add_pkg pkg_name )

set(options )
set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH)
set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH CUSTOM_CMAKE)
set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES SUBMODULES SOURCEMODULES)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )

@@ -281,10 +281,6 @@ function(mindspore_add_pkg pkg_name )
file(GLOB ${pkg_name}_INSTALL_SUBMODULE ${_SUBMODULE_FILE}/*)
file(COPY ${${pkg_name}_INSTALL_SUBMODULE} DESTINATION ${${pkg_name}_SOURCE_DIR}/3rdparty/${_SUBMODENAME})
endforeach (_SUBMODULE_FILE)
foreach(_SOURCE_DIR ${PKG_SOURCEMODULES})
file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*)
file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/)
endforeach (_SUBMODULE_FILE)
else()
set(${pkg_name}_SOURCE_DIR ${PKG_DIR})
endif ()
@@ -304,12 +300,20 @@ function(mindspore_add_pkg pkg_name )
message(FATAL_ERROR "Failed patch: ${_LF_PATCH_FILE}")
endif()
endforeach(_PATCH_FILE)
foreach(_SOURCE_DIR ${PKG_SOURCEMODULES})
file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*)
file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/)
endforeach (_SUBMODULE_FILE)
file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600)
if(NOT ${pkg_name}_LOCK_RET EQUAL "0")
message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}")
endif()

if (PKG_CUSTOM_CMAKE)
file(GLOB ${pkg_name}_cmake ${PKG_CUSTOM_CMAKE}/CMakeLists.txt)
file(COPY ${${pkg_name}_cmake} DESTINATION ${${pkg_name}_SOURCE_DIR})
endif ()

if(${pkg_name}_SOURCE_DIR)
if (PKG_HEAD_ONLY)
file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*)


+ 58
- 0
example/alexnet_cifar10/README.md View File

@@ -0,0 +1,58 @@
# AlexNet Example

## Description

Training AlexNet with CIFAR-10 dataset in MindSpore.

This is the simple tutorial for training AlexNet in MindSpore.

## Requirements

- Install [MindSpore](https://www.mindspore.cn/install/en).

- Download the CIFAR-10 dataset at <http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz>. The directory structure is as follows:

```
├─cifar-10-batches-bin
└─cifar-10-verify-bin
```

## Running the example

```python
# train AlexNet, hyperparameter setting in config.py
python train.py --data_path cifar-10-batches-bin
```

You can get loss with each step similar to this:

```bash
epoch: 1 step: 1, loss is 2.2791853
...
epoch: 1 step: 1536, loss is 1.9366643
epoch: 1 step: 1537, loss is 1.6983616
epoch: 1 step: 1538, loss is 1.0221305
...
```

Then, test AlexNet according to network model
```python
# test AlexNet, 1 epoch training accuracy is up to 51.1%; 10 epoch training accuracy is up to 81.2%
python eval.py --data_path cifar-10-verify-bin --mode test --ckpt_path checkpoint_alexnet-1_1562.ckpt
```

## Note
There are some optional arguments:

```bash
-h, --help show this help message and exit
--device_target {Ascend,GPU}
device where the code will be implemented (default: Ascend)
--data_path DATA_PATH
path where the dataset is saved
--dataset_sink_mode DATASET_SINK_MODE
dataset_sink_mode is False or True
```

You can run ```python train.py -h``` or ```python eval.py -h``` to get more information.

+ 46
- 0
example/convert_to_mindrecord/README.md View File

@@ -0,0 +1,46 @@
# MindRecord generating guidelines

<!-- TOC -->

- [MindRecord generating guidelines](#mindrecord-generating-guidelines)
- [Create work space](#create-work-space)
- [Implement data generator](#implement-data-generator)
- [Run data generator](#run-data-generator)

<!-- /TOC -->

## Create work space

Assume the dataset name is 'xyz'
* Create work space from template
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
cp -r template xyz
```

## Implement data generator

Edit dictionary data generator
* Edit file
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
vi xyz/mr_api.py
```

Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()


Tricky for parallel run
- For imagenet, one directory can be a task.
- For TFRecord with multiple files, each file can be a task.
- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K)


## Run data generator
* run python script
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
python writer.py --mindrecord_script imagenet [...]
```

+ 0
- 0
example/convert_to_mindrecord/imagenet/__init__.py View File


+ 122
- 0
example/convert_to_mindrecord/imagenet/mr_api.py View File

@@ -0,0 +1,122 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord writer.
Two API must be implemented,
1. mindrecord_task_number()
# Return number of parallel tasks. return 1 if no parallel
2. mindrecord_dict_data(task_id)
# Yield data for one task
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
"""
import argparse
import os
import pickle

######## mindrecord_schema begin ##########
mindrecord_schema = {"label": {"type": "int64"},
"data": {"type": "bytes"},
"file_name": {"type": "string"}}
######## mindrecord_schema end ##########

######## Frozen code begin ##########
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle:
ARG_LIST = pickle.load(mindrecord_argument_file_handle)
######## Frozen code end ##########

parser = argparse.ArgumentParser(description='Mind record imagenet example')
parser.add_argument('--label_file', type=str, default="", help='label file')
parser.add_argument('--image_dir', type=str, default="", help='images directory')

######## Frozen code begin ##########
args = parser.parse_args(ARG_LIST)
print(args)
######## Frozen code end ##########


def _user_defined_private_func():
"""
Internal function for tasks list

Return:
tasks list
"""
if not os.path.exists(args.label_file):
raise IOError("map file {} not exists".format(args.label_file))

label_dict = {}
with open(args.label_file) as file_handle:
line = file_handle.readline()
while line:
labels = line.split(" ")
label_dict[labels[1]] = labels[0]
line = file_handle.readline()
# get all the dir which are n02087046, n02094114, n02109525
dir_paths = {}
for item in label_dict:
real_path = os.path.join(args.image_dir, label_dict[item])
if not os.path.isdir(real_path):
print("{} dir is not exist".format(real_path))
continue
dir_paths[item] = real_path

if not dir_paths:
print("not valid image dir in {}".format(args.image_dir))
return {}, {}

dir_list = []
for label in dir_paths:
dir_list.append(label)
return dir_list, dir_paths


dir_list_global, dir_paths_global = _user_defined_private_func()

def mindrecord_task_number():
"""
Get task size.

Return:
number of tasks
"""
return len(dir_list_global)


def mindrecord_dict_data(task_id):
"""
Get data dict.

Yields:
data (dict): data row which is dict.
"""

# get the filename, label and image binary as a dict
label = dir_list_global[task_id]
for item in os.listdir(dir_paths_global[label]):
file_name = os.path.join(dir_paths_global[label], item)
if not item.endswith("JPEG") and not item.endswith(
"jpg") and not item.endswith("jpeg"):
print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name))
continue
data = {}
data["file_name"] = str(file_name)
data["label"] = int(label)

# get the image data
image_file = open(file_name, "rb")
image_bytes = image_file.read()
image_file.close()
data["data"] = image_bytes
yield data

+ 8
- 0
example/convert_to_mindrecord/run_imagenet.sh View File

@@ -0,0 +1,8 @@
#!/bin/bash
rm /tmp/imagenet/mr/*

python writer.py --mindrecord_script imagenet \
--mindrecord_file "/tmp/imagenet/mr/m" \
--mindrecord_partitions 16 \
--label_file "/tmp/imagenet/label.txt" \
--image_dir "/tmp/imagenet/jpeg"

+ 6
- 0
example/convert_to_mindrecord/run_template.sh View File

@@ -0,0 +1,6 @@
#!/bin/bash
rm /tmp/template/*

python writer.py --mindrecord_script template \
--mindrecord_file "/tmp/template/m" \
--mindrecord_partitions 4

+ 0
- 0
example/convert_to_mindrecord/template/__init__.py View File


+ 73
- 0
example/convert_to_mindrecord/template/mr_api.py View File

@@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord writer.
Two API must be implemented,
1. mindrecord_task_number()
# Return number of parallel tasks. return 1 if no parallel
2. mindrecord_dict_data(task_id)
# Yield data for one task
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
"""
import argparse
import pickle

# ## Parse argument

with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line
ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line
parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line

# ## Your arguments below
# parser.add_argument(...)

args = parser.parse_args(ARG_LIST) # Do NOT change this line
print(args) # Do NOT change this line


# ## Default mindrecord vars. Comment them unless default value has to be changed.
# mindrecord_index_fields = ['label']
# mindrecord_header_size = 1 << 24
# mindrecord_page_size = 1 << 25


# define global vars here if necessary


# ####### Your code below ##########
mindrecord_schema = {"label": {"type": "int32"}}

def mindrecord_task_number():
"""
Get task size.

Return:
number of tasks
"""
return 1


def mindrecord_dict_data(task_id):
"""
Get data dict.

Yields:
data (dict): data row which is dict.
"""
print("task is {}".format(task_id))
for i in range(256):
data = {}
data['label'] = i
yield data

+ 152
- 0
example/convert_to_mindrecord/writer.py View File

@@ -0,0 +1,152 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
######################## write mindrecord example ########################
Write mindrecord by data dictionary:
python writer.py --mindrecord_script /YourScriptPath ...
"""
import argparse
import os
import pickle
import time
from importlib import import_module
from multiprocessing import Pool

from mindspore.mindrecord import FileWriter


def _exec_task(task_id, parallel_writer=True):
"""
Execute task with specified task id
"""
print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
imagenet_iter = mindrecord_dict_data(task_id)
batch_size = 2048
transform_count = 0
while True:
data_list = []
try:
for _ in range(batch_size):
data_list.append(imagenet_iter.__next__())
transform_count += 1
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
except StopIteration:
if data_list:
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
break


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Mind record writer')
parser.add_argument('--mindrecord_script', type=str, default="template",
help='path where script is saved')

parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord",
help='written file name prefix')

parser.add_argument('--mindrecord_partitions', type=int, default=1,
help='number of written files')

parser.add_argument('--mindrecord_workers', type=int, default=8,
help='number of parallel workers')

args = parser.parse_known_args()

args, other_args = parser.parse_known_args()

print(args)
print(other_args)

with open('mr_argument.pickle', 'wb') as file_handle:
pickle.dump(other_args, file_handle)

try:
mr_api = import_module(args.mindrecord_script + '.mr_api')
except ModuleNotFoundError:
raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))

num_tasks = mr_api.mindrecord_task_number()

print("Write mindrecord ...")

mindrecord_dict_data = mr_api.mindrecord_dict_data

# get number of files
writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)

start_time = time.time()

# set the header size
try:
header_size = mr_api.mindrecord_header_size
writer.set_header_size(header_size)
except AttributeError:
print("Default header size: {}".format(1 << 24))

# set the page size
try:
page_size = mr_api.mindrecord_page_size
writer.set_page_size(page_size)
except AttributeError:
print("Default page size: {}".format(1 << 25))

# get schema
try:
mindrecord_schema = mr_api.mindrecord_schema
except AttributeError:
raise RuntimeError("mindrecord_schema is not defined in mr_api.py.")

# create the schema
writer.add_schema(mindrecord_schema, "mindrecord_schema")

# add the index
try:
index_fields = mr_api.mindrecord_index_fields
writer.add_index(index_fields)
except AttributeError:
print("Default index fields: all simple fields are indexes.")

writer.open_and_set_header()

task_list = list(range(num_tasks))

# set number of workers
num_workers = args.mindrecord_workers

if num_tasks < 1:
num_tasks = 1

if num_workers > num_tasks:
num_workers = num_tasks

if os.name == 'nt':
for window_task_id in task_list:
_exec_task(window_task_id, False)
elif num_tasks > 1:
with Pool(num_workers) as p:
p.map(_exec_task, task_list)
else:
_exec_task(0, False)

ret = writer.commit()

os.remove("{}".format("mr_argument.pickle"))

end_time = time.time()
print("--------------------------------------------")
print("END. Total time: {}".format(end_time - start_time))
print("--------------------------------------------")

+ 63
- 0
example/lenet_mnist/README.md View File

@@ -0,0 +1,63 @@
# LeNet Example

## Description

Training LeNet with MNIST dataset in MindSpore.

This is the simple and basic tutorial for constructing a network in MindSpore.

## Requirements

- Install [MindSpore](https://www.mindspore.cn/install/en).

- Download the MNIST dataset at <http://yann.lecun.com/exdb/mnist/>. The directory structure is as follows:

```
└─MNIST_Data
├─test
│ t10k-images.idx3-ubyte
│ t10k-labels.idx1-ubyte
└─train
train-images.idx3-ubyte
train-labels.idx1-ubyte
```

## Running the example

```python
# train LeNet, hyperparameter setting in config.py
python train.py --data_path MNIST_Data
```

You can get loss with each step similar to this:

```bash
epoch: 1 step: 1, loss is 2.3040335
...
epoch: 1 step: 1739, loss is 0.06952668
epoch: 1 step: 1740, loss is 0.05038793
epoch: 1 step: 1741, loss is 0.05018193
...
```

Then, test LeNet according to network model
```python
# test LeNet, after 1 epoch training, the accuracy is up to 96.5%
python eval.py --data_path MNIST_Data --mode test --ckpt_path checkpoint_lenet-1_1875.ckpt
```

## Note
There are some optional arguments:

```bash
-h, --help show this help message and exit
--device_target {Ascend,GPU,CPU}
device where the code will be implemented (default: Ascend)
--data_path DATA_PATH
path where the dataset is saved
--dataset_sink_mode DATASET_SINK_MODE
dataset_sink_mode is False or True
```

You can run ```python train.py -h``` or ```python eval.py -h``` to get more information.

+ 1
- 1
graphengine

@@ -1 +1 @@
Subproject commit 70bb745b459ff9a0e7fc1008d15fe4b510f03da7
Subproject commit 43a715bc461fd70b7837051a2f47f0a1b19c5859

+ 9
- 4
mindspore/_akg/gpu/__init__.py View File

@@ -26,7 +26,12 @@ from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad
from .mean import SimpleMean, gpu_schedule_SimpleMean
from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad
from .mul import Mul, gpu_schedule_Mul
from .hsigmoid import Hsigmoid, gpu_schedule_Hsigmoid
from .hsigmoid_grad import HsigmoidGrad, gpu_schedule_HsigmoidGrad
from .hswish import Hswish, gpu_schedule_Hswish
from .hswish_grad import HswishGrad, gpu_schedule_HswishGrad
from .hsigmoid import HSigmoid, gpu_schedule_HSigmoid
from .hsigmoid_grad import HSigmoidGrad, gpu_schedule_HSigmoidGrad
from .hswish import HSwish, gpu_schedule_HSwish
from .hswish_grad import HSwishGrad, gpu_schedule_HSwishGrad
from .logical_or import LogicalOr, gpu_schedule_LogicalOr
from .logical_not import LogicalNot, gpu_schedule_LogicalNot
from .logical_and import LogicalAnd, gpu_schedule_LogicalAnd
from .sub import Sub, gpu_schedule_Sub
from .less_equal import LessEqual, gpu_schedule_LessEqual

+ 4
- 4
mindspore/_akg/gpu/hsigmoid.py View File

@@ -33,9 +33,9 @@ def topi_nn_hsigmoid(x):
(x(*i) + 3) / 6)))


def Hsigmoid(x):
def HSigmoid(x):
"""
Hsigmoid
HSigmoid
Args:
x:

@@ -45,9 +45,9 @@ def Hsigmoid(x):
return topi_nn_hsigmoid(x)


def gpu_schedule_Hsigmoid(outs):
def gpu_schedule_HSigmoid(outs):
"""
gpu schedule Hsigmoid
gpu schedule HSigmoid
Args:
outs:



+ 4
- 4
mindspore/_akg/gpu/hsigmoid_grad.py View File

@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hsigmoid grad"""
"""HSigmoid grad"""
import _akg.topi as topi
import _akg.tvm as tvm


def HsigmoidGrad(y_grad, x):
def HSigmoidGrad(y_grad, x):
"""
HsigmoidGrad
HSigmoidGrad
Args:
y_grad:
x:
@@ -32,7 +32,7 @@ def HsigmoidGrad(y_grad, x):
y_grad(*i) / 6)))


def gpu_schedule_HsigmoidGrad(outs):
def gpu_schedule_HSigmoidGrad(outs):
"""
gpu schedule ReLU6Grad
Args:


+ 4
- 4
mindspore/_akg/gpu/hswish.py View File

@@ -33,9 +33,9 @@ def topi_nn_hswish(x):
x(*i) * (x(*i) + 3) / 6)))


def Hswish(x):
def HSwish(x):
"""
Hswish
HSwish
Args:
x:

@@ -45,9 +45,9 @@ def Hswish(x):
return topi_nn_hswish(x)


def gpu_schedule_Hswish(outs):
def gpu_schedule_HSwish(outs):
"""
gpu schedule Hswish
gpu schedule HSwish
Args:
outs:



+ 5
- 5
mindspore/_akg/gpu/hswish_grad.py View File

@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""HswishGrad"""
"""HSwishGrad"""
import _akg.topi as topi
import _akg.tvm as tvm


def HswishGrad(y_grad, x):
def HSwishGrad(y_grad, x):
"""
HswishGrad
HSwishGrad
Args:
y_grad:
x:
@@ -34,9 +34,9 @@ def HswishGrad(y_grad, x):
return res6


def gpu_schedule_HswishGrad(outs):
def gpu_schedule_HSwishGrad(outs):
"""
gpu schedule HswishGrad
gpu schedule HSwishGrad
Args:
outs:



+ 40
- 0
mindspore/_akg/gpu/less_equal.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""less_equal"""
import _akg.tvm
from _akg.ops.math import less_equal
from _akg.topi.generic import schedule_elemwise

def LessEqual(x, y):
"""LessEqual."""
return less_equal.less_equal(x, y)


def gpu_schedule_LessEqual(outs):
"""
GPU schedule for LessEqual.

Args:
outs (tvm.tensor.Tensor): Outputs of compute.

Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = _akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with _akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch

+ 40
- 0
mindspore/_akg/gpu/logical_and.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""logical_and"""
import _akg.tvm
from _akg.ops.math import logical_and
from _akg.topi.generic import schedule_elemwise

def LogicalAnd(x, y):
"""LogicalAnd."""
return logical_and.logical_and(x, y)


def gpu_schedule_LogicalAnd(outs):
"""
GPU schedule for LogicalAnd.

Args:
outs (tvm.tensor.Tensor): outputs of compute.

Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = _akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with _akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch

+ 40
- 0
mindspore/_akg/gpu/logical_not.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""logical_not"""
import _akg.tvm
from _akg.ops.math import logical_not
from _akg.topi.generic import schedule_elemwise

def LogicalNot(x):
"""LogicalNot."""
return logical_not.logical_not(x)


def gpu_schedule_LogicalNot(outs):
"""
GPU schedule for LogicalNot.

Args:
outs (tvm.tensor.Tensor): outputs of compute.

Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = _akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with _akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch

+ 40
- 0
mindspore/_akg/gpu/logical_or.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""logical_or"""
import _akg.tvm
from _akg.ops.math import logical_or
from _akg.topi.generic import schedule_elemwise

def LogicalOr(x, y):
"""LogicalOr."""
return logical_or.logical_or(x, y)


def gpu_schedule_LogicalOr(outs):
"""
GPU schedule for LogicalOr.

Args:
outs (tvm.tensor.Tensor): outputs of compute.

Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = _akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with _akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch

+ 40
- 0
mindspore/_akg/gpu/sub.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""sub"""
import _akg.tvm
from _akg.ops.math import sub
from _akg.topi.generic import schedule_elemwise

def Sub(x, y):
"""Sub."""
return sub.sub(x, y)


def gpu_schedule_Sub(outs):
"""
GPU schedule for Sub.

Args:
outs (tvm.tensor.Tensor): outputs of compute.

Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = _akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with _akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch

+ 54
- 0
mindspore/_akg/ops/math/less_equal.py View File

@@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""operator dsl function: lessequal"""
import _akg.tvm
import _akg.topi
from _akg.utils.dsl_create import produce_shapes
from _akg.utils import validation_check as vc_util


@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
def less_equal(input1, input2):
"""
Check whether input1 lessequals to input2.

Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.

Returns:
tvm.tensor.Tensor. If input1 lessequal to input2 return True, else return False.
"""
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)

shape1, shape2, shape = produce_shapes(shape1, shape2)

vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
dtype = input1.dtype

# get lessequal compute
t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T")
f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F")

input1_bro = _akg.topi.broadcast_to(input1, shape)
input2_bro = _akg.topi.broadcast_to(input2, shape)
c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] <= input2_bro[indice],
t_value[indice], f_value[indice]), name="C")
res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")

return res

+ 41
- 0
mindspore/_akg/ops/math/logical_and.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""operator dsl function: logical_and"""
import _akg.tvm
import _akg.topi
from _akg.utils import validation_check as vc_util

@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
def logical_and(input1, input2):
"""
Compute logical_and of input1 and input2.

Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.

Returns:
tvm.tensor.Tensor. LogicalAnd of input1 and input2.
"""

vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)

shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)

res = _akg.topi.logical_and(input1, input2)
return res

+ 32
- 0
mindspore/_akg/ops/math/logical_not.py View File

@@ -0,0 +1,32 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""operator dsl function: logical_not"""
import _akg.tvm
import _akg.topi
from _akg.utils import validation_check as vc_util

@vc_util.check_input_type(_akg.tvm.tensor.Tensor)
def logical_not(input1):
"""
Compute logical_not of input1.

Args:
input1 (tvm.tensor.Tensor): Tensor.

Returns:
tvm.tensor.Tensor.
"""
res = _akg.topi.logical_not(input1)
return res

+ 41
- 0
mindspore/_akg/ops/math/logical_or.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""operator dsl function: logical_or"""
import _akg.tvm
import _akg.topi
from _akg.utils import validation_check as vc_util

@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
def logical_or(input1, input2):
"""
Compute logical_or of input1 and input2.

Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.

Returns:
tvm.tensor.Tensor. LogicalOr of input1 and input2.
"""

vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)

shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)

res = _akg.topi.logical_or(input1, input2)
return res

+ 161
- 31
mindspore/_checkparam.py View File

@@ -14,10 +14,12 @@
# ============================================================================
"""Check parameters."""
import re
import inspect
import math
from enum import Enum
from functools import reduce
from functools import reduce, wraps
from itertools import repeat
from collections import Iterable
from collections.abc import Iterable

import numpy as np
from mindspore import log as logger
@@ -98,7 +100,7 @@ class Validator:
"""validator for checking input parameters"""

@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
"""
Method for judging relation between two int values or list/tuple made up of ints.

@@ -108,18 +110,29 @@ class Validator:
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For {prim_name} the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')

@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
f' but got {arg_value}.')
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
f' with type `{type(arg_value).__name__}`.')
return arg_value

@staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
return arg_value

@staticmethod
@@ -127,15 +140,53 @@ class Validator:
"""Method for checking whether an int value is in some range."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got {arg_value}.')
raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got `{arg_value}` with type `{type(arg_value).__name__}`.')
return arg_value

@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether a numeric value is in some range."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value

@staticmethod
def check_string(arg_name, arg_value, valid_values, prim_name):
"""Checks whether a string is in some value list"""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')

@staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name):
"""Validates value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding

@staticmethod
def check_float_positive(arg_name, arg_value, prim_name):
"""Float type judgment."""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")

@staticmethod
def check_subclass(arg_name, type_, template_type, prim_name):
"""Check whether some type is sublcass of another type"""
"""Checks whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
@@ -144,32 +195,51 @@ class Validator:
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')

@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""check whether the element types of input tensors are the same."""
def check_const_input(arg_name, arg_value, prim_name):
"""Checks valid value."""
if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')

@staticmethod
def check_type_same(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
elem_type = arg_val.element_type()
elem_type = arg_val
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
type_names = []
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ", ".join(type_names) + ']'
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
f' but got {elem_type}.')
return (arg_key, elem_type)

def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
return arg1

elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)

@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""Checks whether the element types of input tensors are the same."""
tensor_types = [mstype.tensor_type(t) for t in valid_values]
Validator.check_type_same(args, tensor_types, prim_name)

@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
"""
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.

If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
"""

def _check_argument_type(arg):
arg_key, arg_val = arg
if isinstance(arg_val, type(mstype.tensor)):
@@ -182,16 +252,19 @@ class Validator:
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
excp_flag = False
except_flag = False
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
arg1_type = arg1_type.element_type()
arg2_type = arg2_type.element_type()
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
pass
elif allow_mix:
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
else:
excp_flag = True
except_flag = True

if excp_flag or arg1_type != arg2_type:
if except_flag or arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
@@ -199,13 +272,15 @@ class Validator:

@staticmethod
def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Check whether a values is instance of some types."""
"""Checks whether a value is instance of some types."""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)

def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
f'{"one of " if num_types > 1 else ""}'
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')

# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
@@ -216,6 +291,34 @@ class Validator:
return arg_value
raise_error_msg()

@staticmethod
def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)

def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)

if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')

@staticmethod
def check_float_legal_value(arg_name, arg_value, prim_name):
"""Checks whether a legal value of float type"""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.")
return arg_value
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")


class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@@ -268,9 +371,9 @@ class ParamValidator:

@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isintance of classes"""
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value

@staticmethod
@@ -284,7 +387,7 @@ class ParamValidator:

@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is sublcass of another type"""
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
@@ -302,9 +405,9 @@ class ParamValidator:

@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isintance of bool"""
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value

@staticmethod
@@ -671,3 +774,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True


def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""

def type_check(func):
sig = inspect.signature(func)
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments

@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_types
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_types:
bound_types = bound_types["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)

return wrapper

return type_check

+ 1
- 1
mindspore/_extends/__init__.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
Extension functions.
Extension functions.

Python functions that will be called in the c++ parts of MindSpore.
"""


+ 2
- 1
mindspore/_extends/parse/resources.py View File

@@ -83,6 +83,7 @@ convert_object_map = {
T.mul: multitype_ops.mul,
T.truediv: multitype_ops.div,
T.getitem: multitype_ops.getitem,
T.setitem: multitype_ops.setitem,
T.floordiv: multitype_ops.floordiv,
T.mod: multitype_ops.mod,
T.pow: multitype_ops.pow_,
@@ -113,12 +114,12 @@ convert_object_map = {
T.map: C.HyperMap(),
T.partial: F.partial,
T.zip: C.zip_operation,
T.print: F.print_,

# custom define operation
T.iter: M.ms_iter,
T.next: M.ms_next,
T.hasnext: M.hasnext,
T.setitem: M.setitem,

T.make_tuple: F.make_tuple,
T.make_dict: F.make_dict,


+ 2
- 2
mindspore/_extends/parse/trope.py View File

@@ -27,7 +27,7 @@ from operator import ( # noqa

# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
)

# support functools
@@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial',
'partial', 'print',
'exp', 'log', 'sin', 'cos', 'tan']




+ 0
- 44
mindspore/_extends/pynative_helper.py View File

@@ -1,44 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pynative mode help module."""
from inspect import signature
from functools import wraps


def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""

def type_check(func):
sig = signature(func)
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments

@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_types
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_types:
bound_types = bound_types["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)

return wrapper

return type_check

+ 0
- 4
mindspore/ccsrc/CMakeLists.txt View File

@@ -394,10 +394,6 @@ if(USE_GLOG)
target_link_libraries(_c_expression PRIVATE mindspore::glog)
endif()

if(ENABLE_GPU)
target_link_libraries(_c_expression PRIVATE mindspore::tvm)
endif()

if(ENABLE_DUMP_PROTO)
target_link_libraries(_c_expression PRIVATE mindspore::protobuf)
endif()


+ 33
- 11
mindspore/ccsrc/common/trans.cc View File

@@ -103,17 +103,39 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{

template <typename SrcT, typename DstT>
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
auto src_id = TypeIdSize(args.src_type);
auto dst_id = TypeIdSize(args.dst_type);
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
}
for (size_t idx = 0; idx != data_size; idx++) {
SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
}
}

template <typename SrcT>
void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
auto src_id = TypeIdSize(args.src_type);
auto dst_id = TypeIdSize(args.dst_type);
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
}
auto src_data = static_cast<const SrcT *>(args.data);
auto half_data = static_cast<Eigen::half *>(dst);
for (size_t i = 0; i < data_size; i++) {
half_data[i] = Eigen::half(src_data[i]);
}
}

bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
switch (mode) {
case FROM_FLOAT_TO_FLOAT16:
device::FloatToHalf(dst, args.data, data_size);
break;
case FROM_INT32_TO_FLOAT16:
TransDataSrc2Fp16<int32_t>(args, dst, data_size);
break;
case FROM_FLOAT16_TO_FLOAT:
device::HalfToFloat(dst, args.data, data_size);
break;
@@ -372,27 +394,27 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
}

bool TransDataType(const TypeIdArgs &args, void *result) {
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
<< TypeIdLabel(args.device_data_type);
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type);
MS_EXCEPTION_IF_NULL(result);
std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
std::pair<TypeId, TypeId> type_info(args.src_type, args.dst_type);
auto iter = mode_map.find(type_info);
if (iter == mode_map.end()) {
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
<< ", dst_type:" << TypeIdLabel(args.device_data_type);
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type)
<< ", dst_type:" << TypeIdLabel(args.dst_type);
return false;
}
auto trans_mode = iter->second;
auto type_size = TypeIdSize(args.device_data_type);
if (type_size < 1) {
MS_LOG(ERROR) << "Invalid host data type.";
auto src_id = TypeIdSize(args.src_type);
auto dst_id = TypeIdSize(args.dst_type);
if (src_id < 1 || dst_id < 1) {
MS_LOG(ERROR) << "Invalid src or dst data type.";
return false;
}
if (args.host_shape_size < 1) {
MS_LOG(ERROR) << "Invalid host data size.";
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) {
MS_LOG(ERROR) << "Invalid src or dst data size.";
return false;
}
if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) {
MS_LOG(ERROR) << "Failed to trans datatype..";
return false;
}


+ 6
- 3
mindspore/ccsrc/common/trans.h View File

@@ -31,9 +31,12 @@ namespace mindspore {
namespace trans {
struct TypeIdArgs {
const void *data;
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
TypeId host_data_type;
TypeId device_data_type;
size_t src_size;
size_t dst_size;
TypeId src_type;
TypeId dst_type;
size_t src_shape_size;
size_t dst_shape_size;
};

struct FormatArgs {


+ 1
- 1
mindspore/ccsrc/common/utils.cc View File

@@ -23,7 +23,7 @@ namespace common {
const int CACHED_STR_NUM = 1 << 8;
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
const char* SafeCStr(const std::string&& str) {
const char *SafeCStr(const std::string &&str) {
static std::atomic<uint32_t> index{0};
uint32_t cur_index = index++;
cur_index = cur_index & CACHED_STR_MASK;


+ 6
- 6
mindspore/ccsrc/common/utils.h View File

@@ -21,16 +21,16 @@
#include <string>

#define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType&) = delete; \
ClassType& operator=(const ClassType&) = delete;
ClassType(const ClassType &) = delete; \
ClassType &operator=(const ClassType &) = delete;

namespace mindspore {
namespace common {
inline const char* SafeCStr(const std::string& str) { return str.c_str(); }
const char* SafeCStr(const std::string&& str);
inline const char *SafeCStr(const std::string &str) { return str.c_str(); }
const char *SafeCStr(const std::string &&str);

static inline std::string GetEnv(const std::string& envvar) {
const char* value = ::getenv(envvar.c_str());
static inline std::string GetEnv(const std::string &envvar) {
const char *value = ::getenv(envvar.c_str());

if (value == nullptr) {
return std::string();


+ 3
- 3
mindspore/ccsrc/dataset/CMakeLists.txt View File

@@ -12,10 +12,10 @@ endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes")

if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--image-base -Wl,0x10000000")
endif()
############################# Options ################################
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
add_definitions(-D _CRT_RAND_S)
endif ()
if (ENABLE_GPUQUE)
add_definitions(-D ENABLE_GPUQUE)
message(STATUS "GPU queue is enabled")


+ 97
- 4
mindspore/ccsrc/dataset/api/de_pipeline.cc View File

@@ -28,10 +28,11 @@
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"

#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
@@ -45,7 +46,9 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
@@ -61,7 +64,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kVoc, &DEPipeline::ParseVOCOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp}};
{kCelebA, &DEPipeline::ParseCelebAOp},
{kTextFile, &DEPipeline::ParseTextFileOp}};

DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
@@ -501,6 +505,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return Status::OK();
}

Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<FilterOp::Builder> builder = std::make_shared<FilterOp::Builder>();

if (args["predicate"].is_none()) {
RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n");
}

for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "predicate") {
py::handle op = args["predicate"];
if (!py::isinstance<py::function>(op)) {
RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc).");
}
py::function predicate_func = op.cast<py::function>();
(void)builder->SetPredicateFunc(std::move(predicate_func));
} else if (key == "input_columns") {
std::vector<std::string> in_col_names = ToStringVector(args["input_columns"]);
(void)builder->SetInColNames(in_col_names);
} else {
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
}
}
}

std::shared_ptr<FilterOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}

Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set.";
@@ -589,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}

Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
// Right now barrier should only take num_rows_per_buffer = 1
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
(void)builder->SetRowsPerBuffer(1);
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "condition_name") {
(void)builder->SetConditionName(ToString(value));
} else if (key == "condition_func") {
(void)builder->SetConditionFunc(value.cast<py::function>());
}
}
}

std::shared_ptr<BarrierOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}

Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
int32_t prefetch_size = 0;
if (args.contains("prefetch_size")) {
@@ -670,8 +733,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return Status::OK();
}

DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); }

Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
@@ -985,5 +1046,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
*ptr = op;
return Status::OK();
}

Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
if (!args["dataset_files"].is_none()) {
(void)builder->SetTextFilesList(ToStringVector(args["dataset_files"]));
} else {
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_shards") {
(void)builder->SetNumDevices(ToInt(value));
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
}
}
}
std::shared_ptr<TextFileOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 9
- 3
mindspore/ccsrc/dataset/api/de_pipeline.h View File

@@ -40,6 +40,7 @@ enum OpName {
kShuffle,
kMindrecord,
kBatch,
kBarrier,
kCache,
kRepeat,
kSkip,
@@ -58,7 +59,8 @@ enum OpName {
kVoc,
kCifar10,
kCifar100,
kCelebA
kCelebA,
kTextFile
};

// The C++ binder class that we expose to the python script.
@@ -106,12 +108,16 @@ class DEPipeline {

Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
@@ -120,8 +126,6 @@ class DEPipeline {

Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

DsOpPtr ParseFilterOp(const py::dict &args) const;

Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
@@ -148,6 +152,8 @@ class DEPipeline {

Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;


+ 43
- 16
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -24,7 +24,6 @@
#endif
#include "dataset/kernels/image/cut_out_op.h"
#include "dataset/kernels/image/decode_op.h"
#include "dataset/kernels/image/distort_bounding_box_crop_op.h"
#include "dataset/kernels/image/hwc_to_chw_op.h"
#include "dataset/kernels/image/image_utils.h"
#include "dataset/kernels/image/normalize_op.h"
@@ -40,6 +39,7 @@
#include "dataset/kernels/image/rescale_op.h"
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
@@ -53,11 +53,14 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
@@ -150,9 +153,14 @@ void bindDatasetOps(py::module *m) {
});

(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
.def_static("get_num_rows", [](const std::string &path) {
.def_static("get_num_rows", [](const std::string &path, const py::object &sampler) {
int64_t count = 0;
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count));
std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "_create_for_minddataset")) {
auto create = sampler.attr("_create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
}
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count));
return count;
});

@@ -176,6 +184,17 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
return count;
});

(void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
.def_static("get_num_rows", [](const py::list &files) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
!file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
}
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
return count;
});
}
void bindTensor(py::module *m) {
(void)py::class_<GlobalContext>(*m, "GlobalContext")
@@ -251,6 +270,10 @@ void bindTensorOps1(py::module *m) {
.def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);

(void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
*m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
.def(py::init<py::list, int32_t>(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps);

(void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
*m, "ResizeBilinearOp",
"Tensor operation to resize an image using "
@@ -345,18 +368,6 @@ void bindTensorOps3(py::module *m) {
}

void bindTensorOps4(py::module *m) {
(void)py::class_<DistortBoundingBoxCropOp, TensorOp, std::shared_ptr<DistortBoundingBoxCropOp>>(
*m, "DistortBoundingBoxCropOp",
"Tensor operator to crop an image randomly as long as the cropped image has sufficient "
"overlap with any one bounding box associated with original image"
"Takes aspect ratio of the generated crop box, the intersection ratio of crop box and bounding box,"
"crop ratio lower and upper bounds"
"Optional parameters: number of attempts for crop, number of attempts of crop box generation")
.def(py::init<float, float, float, float, int32_t, int32_t>(), py::arg("aspect_ratio"), py::arg("intersect_ratio"),
py::arg("crop_ratio_lower_bound"), py::arg("crop_ratio_upper_bound"),
py::arg("max_attempts") = DistortBoundingBoxCropOp::kDefMaxAttempts,
py::arg("box_gen_attempts") = DistortBoundingBoxCropOp::kDefBoxGenAttempts);

(void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
*m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
.def(py::init<DataType>(), py::arg("data_type"))
@@ -415,16 +426,30 @@ void bindSamplerOps(py::module *m) {

(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>());

(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));

(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
*m, "MindrecordSubsetRandomSampler")
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
*m, "MindrecordPkSampler")
.def(py::init([](int64_t kVal, bool shuffle) {
if (shuffle == true) {
return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(),
GetSeed());
} else {
return std::make_shared<mindrecord::ShardPkSample>("label", kVal);
}
}));

(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
py::arg("replacement"));

(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
.def(py::init<py::object>(), py::arg("pySampler"));
}

void bindInfoObjects(py::module *m) {
@@ -443,6 +468,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("STORAGE", OpName::kStorage)
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
.value("REPEAT", OpName::kRepeat)
@@ -463,7 +489,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("VOC", OpName::kVoc)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("CELEBA", OpName::kCelebA);
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile);

(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
.value("DE_INTER_LINEAR", InterpolationMode::kLinear)


+ 2
- 0
mindspore/ccsrc/dataset/core/client.h View File

@@ -25,12 +25,14 @@
#include "dataset/core/tensor_shape.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/barrier_op.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
#include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"


+ 1
- 1
mindspore/ccsrc/dataset/core/tensor.cc View File

@@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c
DS_ASSERT(data_);

switch (type_.value()) {
CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t);
CASE_PRINT_HEX(DataType::DE_BOOL, bool);

CASE_PRINT_HEX(DataType::DE_INT8, int8_t);



+ 2
- 0
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt View File

@@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT
dataset_op.cc
parallel_op.cc
pipeline_op.cc
barrier_op.cc
batch_op.cc
device_queue_op.cc
map_op.cc
@@ -14,5 +15,6 @@ add_library(engine-datasetops OBJECT
take_op.cc
shuffle_op.cc
zip_op.cc
filter_op.cc
)


+ 235
- 0
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc View File

@@ -0,0 +1,235 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/barrier_op.h"
#include <utility>
#include "dataset/core/constants.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace dataset {
BarrierOp::Builder::Builder() {
// Some arguments to the BarrierOp constructor have a default argument that is taken
// from the client config.
// The user may choose to change these values for the construction of the BarrierOp by
// using the various builder set methods.

std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
builder_op_connector_size_ = cfg->op_connector_size();
}

Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); }

Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_,
builder_condition_func_);
return Status::OK();
}

// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions
BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
py::function condition_func)
: PipelineOp(op_connector_size),
rows_per_buffer_(rows_per_buffer),
buffer_id_(0),
clean_up_(false),
eof_(false),
condition_name_(condition_name),
condition_function_(condition_func) {}

// destructor
BarrierOp::~BarrierOp() {}

// Entry point for Barrier, called by launch()
Status BarrierOp::operator()() {
// The children_num_ parameter needs to be put here
// Synchronize with TaskManager once the thread is created.
TaskManager::FindMe()->Post();

// create child iterator, right now this barrier is a pipeline operator
int32_t worker_id = 0;
int32_t child_idx = 0;
child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);

// Loop until eof is true
while (!eof_) {
// Create new table to put the new tensor rows
std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
RETURN_IF_NOT_OK(prepare(curr_table.get()));

// If an eof got picked up during the above prepare, then we're done
if (eof_) {
break;
}

// we have to output new buffer with possibly different buffer size, possibly one row
while (!clean_up_) {
// 1. If a previous loop iteration sent the current table out, then create a new one.

if (curr_table == nullptr) {
curr_table = std::make_unique<TensorQTable>();
}

// 2 fill the table. Note: clean_up mode might get turned on if epoch is finished
RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));

// 3 create and update buffer and send it to the out connector
if (!curr_table->empty()) {
std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
curr_buffer->set_tensor_table(std::move(curr_table));
curr_buffer->set_column_name_map(col_name_id_map_);
MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
<< curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << ".";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
buffer_id_++;
}
}

// 4 handle drain state.
if (clean_up_) {
MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal.";
// Send the eoe up.
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
}
}
// 5 handle eof
// propagate eof here.
MS_LOG(INFO) << "Barrier operator got EOF, propagating.";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
return Status::OK();
}

// Handles preprocessing of the main loop, used when starting new epoch
Status BarrierOp::prepare(TensorQTable *const table) {
MS_LOG(DEBUG) << "Barrier operator prepares for new epoch.";
clean_up_ = false;
buffer_id_ = 0;
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table.");
}
// fill initial row
TensorRow new_row = {};
// use iterator to get next row and invoke pyfunc wait
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));

// If the first row fetching resulted in eof, then we are done.
if (eof_) {
return Status::OK();
}
if (new_row.empty()) {
// This epoch is empty
return Status::OK();
}
// Pack this first row into our tensor table
// first row we also have to check if we should block
RETURN_IF_NOT_OK(blockCond());

table->push_back(std::move(new_row));
// At this point we have 1 row produced, we take the old column map id and use it in the new table
// Initializing col_name_id_map_ from the first data buffer.
col_name_id_map_ = child_iterator_->col_name_id_map();
// the update code below shouldn't do anything bad if the column name already exists.
return Status::OK();
}

// fillBuffer always expects a new table to fill
Status BarrierOp::fillBuffer(TensorQTable *const table) {
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
}
TensorRow new_row = {};
while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
// Early exit the loop if we got empty row from any of our child iterations
if (new_row.empty()) {
return Status::OK();
}
// else we got a row so pack it into the tensor table.
RETURN_IF_NOT_OK(blockCond());

table->push_back(std::move(new_row));
}
return Status::OK();
}

// function executes a py_func and blocks until condition becomes true.
Status BarrierOp::blockCond() {
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
// we have condition name, however the flexibility is in python today
try {
// Invoke python function
py::object ret_py_obj = condition_function_();
// Process the return value
if (!py::isinstance<py::bool_>(ret_py_obj)) {
return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false");
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
}
return Status::OK();
}

// fetches next Barrier buffer row
Status BarrierOp::getNextTensorRow(TensorRow *new_row) {
// iterate over all iterators and generate a row
RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row));
// add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
if (new_row->empty()) {
// If we did not get a row from any of the children, then it's the end of an epoch and we can move
// to drain state.
MS_LOG(INFO) << "Barrier operator child iterator produced empty row.";
clean_up_ = true;
// If we picked up an eof here, then we are completely done.
if ((child_iterator_)->eof_handled()) {
MS_LOG(INFO) << "Barrier operator iterator got EOF.";
eof_ = true;
}
return Status::OK();
}
return Status::OK();
}

// A function that prints info about the Operator
void BarrierOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first
PipelineOp::Print(out, show_all);
out << "\nBarrierOp:\n"
<< "\nCondition " << condition_name_ << "\n\n";
}

// overwrite function and handle eof
Status BarrierOp::EofReceived(int32_t) {
MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now.";
return Status::OK();
}

// overwrite function and handle eoe
Status BarrierOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 172
- 0
mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h View File

@@ -0,0 +1,172 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_

#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/pipeline_op.h"
#include "dataset/kernels/tensor_op.h"

namespace mindspore {
namespace dataset {
// Forward declare
class DataBuffer;
class ExecutionTree;

// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has
// been received. This signal is given from python layer. The current barrier design respects the
// rows per buffer design and will only output a buffer with rows once it has received rows per buffer
// signals from python.

class BarrierOp : public PipelineOp {
public:
// The nested builder class inside of the BarrierOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.

class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();

// Default destructor
~Builder() = default;

// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}

// Setter method.
// @param int32_t op_connector_size
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}

// Setter method.
// @param const std::string & condition_name
// @return Builder setter method returns reference to the builder.
Builder &SetConditionName(const std::string &condition_name) {
builder_condition_name_ = condition_name;
return *this;
}

// Setter method.
// @param py::function condition_func - blocking condition function
// @return Builder setter method returns reference to the builder.
Builder &SetConditionFunc(py::function condition_func) {
builder_condition_func_ = condition_func;
return *this;
}

// The builder "build" method creates the BarrierOp dataset Operator.
// @return shared_ptr to the new BarrierOp object
Status Build(std::shared_ptr<BarrierOp> *);

private:
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::string builder_condition_name_;
py::function builder_condition_func_;

Status SanityCheck() const;
};

// Constructor for BarrierOp
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size
// @param condition_name - the condition name associated with this operator
// @param condition_func - the blocking condition check per row
// @note - currently rows_per_buffer should = 1 for barrier.
// The reason for this is having other values would complicate how the pipeline behaves with other operators
// One example of such case is having batch after barrier. Batch would be waiting for data and having
// rows per buffer in this case can result in hanging
BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
py::function condition_func);

// Destructor
~BarrierOp();

Status EofReceived(int32_t) override;

Status EoeReceived(int32_t) override;

// Print function for Barrier
// @param out - output stream to print to
// @param show_all - if it should print everything
void Print(std::ostream &out, bool show_all) const override;

// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) {
bo.Print(out, false);
return out;
}

// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status operator()() override;

// Handles preprocessing of the main loop, used when starting new epoch
// @param table - a table of tensors to be moved into a buffer
Status prepare(TensorQTable *const table);

// This function calls takes a table repeatedly adds rows to it.
// @param table - a table of tensors to be moved into a buffer
Status fillBuffer(TensorQTable *const table);

// Gets next tensor row and sets control signals
Status getNextTensorRow(TensorRow *new_row);

// This function runs the wait function on condition
Status blockCond();

private:
// clean up variable to return imcomplete buffer
bool clean_up_;
// end of file state, we stop reading data and shut down
bool eof_;
// rows per buffer
int32_t rows_per_buffer_;
// buffer_id
int32_t buffer_id_;
// local variable to keep track of the buffer information
std::unordered_map<std::string, int32_t> col_name_id_map_;
// iterator to pull new rows, we only have one child
std::unique_ptr<ChildIterator> child_iterator_;
// condition name, to support multiple barriers
std::string condition_name_;
// Function pointer of blocking function
py::function condition_function_;
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_

+ 253
- 0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc View File

@@ -0,0 +1,253 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/filter_op.h"
#include <algorithm>
#include <cstring>
#include <iostream>
#include <memory>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "dataset/util/task_manager.h"

namespace mindspore {
namespace dataset {

Status FilterOp::Builder::SanityCheck() {
std::string err;
err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : "";
err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
}

FilterOp::Builder::Builder() {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_op_connector_size_ = cfg->op_connector_size();
}

Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<FilterOp>(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_,
builder_predicate_func_);
return Status::OK();
}

FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
py::function predicate_func)
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {}

Status FilterOp::operator()() {
// The operator class just starts off threads by calling the tree_ function.
RETURN_UNEXPECTED_IF_NULL(tree_);
// Synchronize with TaskManager.
TaskManager::FindMe()->Post();
filter_queues_.Init(num_workers_, oc_queue_size_);
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(Collector());
return Status::OK();
}

Status FilterOp::EofReceived(int32_t) { return Status::OK(); }

Status FilterOp::EoeReceived(int32_t) { return Status::OK(); }

// Validating if each of the input_columns exists in the DataBuffer.
Status FilterOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map,
const std::vector<std::string> *input_columns) {
for (const auto &inCol : *input_columns) {
bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false;
if (!found) {
std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
return Status::OK();
}

// A print method typically used for debugging.
void FilterOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first.
ParallelOp::Print(out, show_all);

// Then display our own stuff.
out << "\nFilterOp:";
out << "\n Input column names:";
for (size_t i = 0; i < in_columns_.size(); i++) {
out << " " << in_columns_[i];
}
}

Status FilterOp::WorkerEntry(int32_t worker_id) {
// Handshake with TaskManager that thread creation is successful.
TaskManager::FindMe()->Post();
std::unique_ptr<DataBuffer> in_buffer;
bool worker_stop = false;
while (worker_stop == false) {
// Getting a databuffer to work on.
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id));
if (in_buffer->eoe()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe));
continue;
} else if (in_buffer->eof()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof));
worker_stop = true;
continue;
}

RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_));

// if the databuffer was all filtered, it is marked as kFilterEmpty.
// if the databuffer was partially filtered, it is marked as kFilterPartial.
// if the databuffer was not filtered, it is marked as kFilterFull.
int32_t num_rows = in_buffer->NumRows();
std::unique_ptr<TensorQTable> new_tensor_table;
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table));

if (new_tensor_table->empty()) {
RETURN_IF_NOT_OK(
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty)));
} else if (new_tensor_table->size() == num_rows) {
in_buffer->set_tensor_table(std::move(new_tensor_table));
RETURN_IF_NOT_OK(
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull)));
} else { // kFilterPartial
in_buffer->set_tensor_table(std::move(new_tensor_table));
RETURN_IF_NOT_OK(
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial)));
}
}
return Status::OK();
}

Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out) {
*out = std::make_unique<TensorQTable>();
int32_t num_rows = in_buffer->NumRows();
for (int32_t i = 0; i < num_rows; i++) {
TensorRow to_process;
TensorRow cur_row;
RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row));
if (in_columns_.empty() == true) {
MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table.";
to_process = cur_row;
} else {
std::unordered_map<std::string, int32_t> col_map = in_buffer->column_name_map();
(void)std::transform(
in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process),
[&cur_row, &col_map](const auto &it) -> std::shared_ptr<Tensor> { return cur_row[col_map[it]]; });
}
bool predicate = true;
RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate));
if (predicate) {
(*out)->push_back(std::move(cur_row));
}
}
return Status::OK();
}

// if the filtered DataBuffer is written directly to out_connector_,
// the thread fetching data will block in a queue.
// Collector function will reorder the DataBuffer in order.
// for example in two work queues:
// int filter_queues_:
// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof)
// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe)
// after reorder in out_connector_:
// queue1: DB(data2) DB(data4) DB(eof)
// queue2: DB(eoe) DB(eoe)
Status FilterOp::Collector() {
bool collector_stop = false;
uint64_t task_id_cnt = 0;
uint64_t out_id_cnt = 0;
std::pair<std::unique_ptr<DataBuffer>, filterCtrl> in_pair;
while (collector_stop == false) {
uint32_t w_id = task_id_cnt % num_workers_;
RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair));
if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial ||
in_pair.second == filterCtrl::kFilterEoe) {
uint32_t out_task_id = out_id_cnt % num_workers_;
RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first)));
out_id_cnt++;
task_id_cnt++;
} else if (in_pair.second == filterCtrl::kFilterEof) {
uint32_t out_task_id = out_id_cnt % num_workers_;
RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first)));
collector_stop = true;
} else { // kFilterEmpty
task_id_cnt++;
}
}
return Status::OK();
}

// Private function for checking the column legality.
Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector<std::string> *input_columns) {
int32_t num_rows = in_buf->NumRows();
int32_t num_cols = in_buf->NumCols();
if (num_rows == 0 || num_cols == 0) {
RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer.");
}
std::unordered_map<std::string, int32_t> col_name_id_map = in_buf->column_name_map();
// Check if there is invalid column name in the inColumns.
RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns));
return Status::OK();
}

Status FilterOp::CheckInput(const TensorRow &input) const {
for (auto &item : input) {
if (item == nullptr) {
RETURN_STATUS_UNEXPECTED("input is null.");
}
}
return Status::OK();
}

Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) {
RETURN_IF_NOT_OK(CheckInput(input));
// Acquire Python GIL.
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
// Transform input tensor vector into numpy array vector.
py::tuple input_args(input.size());
for (size_t i = 0; i < input.size(); i++) {
py::array new_data;
RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
input_args[i] = new_data;
}
// Invoke python function.
py::object ret_py_obj = predicate_func_(*input_args);
*out_predicate = ret_py_obj.cast<py::bool_>();
} catch (const py::error_already_set &e) {
std::stringstream ss;
ss << e.what() << std::endl;
ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool.";
return Status(StatusCode::kPyFuncException, ss.str());
}
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
}
} // namespace dataset
} // namespace mindspore

+ 181
- 0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h View File

@@ -0,0 +1,181 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_

#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/queue.h"

namespace mindspore {
namespace dataset {

class FilterOp : public ParallelOp {
public:
// The nested builder class inside of the FilterOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args.
// @return This is a constructor.
Builder();

// Default destructor
~Builder() = default;

// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetPredicateFunc(py::function func) {
builder_predicate_func_ = std::move(func);
return *this;
}

// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetInColNames(const std::vector<std::string> &in_col_names) {
build_in_col_names_ = in_col_names;
return *this;
}

// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}

// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
builder_op_connector_size_ = connector_size;
return *this;
}

// The builder "build" method creates the final object.
// @param ptr The shared_ptr to the new FilterOp object.
// @return Status.
Status Build(std::shared_ptr<FilterOp> *ptr);

private:
// Sanity check for builder class args.
// @return Status - The error code return.
Status SanityCheck();
std::vector<std::string> build_in_col_names_;
py::function builder_predicate_func_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
};

enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 };

// Constructor of FilterOp
// @note The builder class should be used to call it.
// @param in_col_names A list of input column names,when it is empty the predicate will be
// applied all columns in the dataset.
// @param num_workers The number of worker threads.
// @param op_connector_size The size of each queue in the connector.
// @param predicate_func python callable which returns a boolean value.
FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
py::function predicate_func);

// Destructor
~FilterOp() = default;

// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status The error code return
Status operator()() override;

// @param int32_t workerId.
// @return Status - The error code return.
Status EofReceived(int32_t) override;

// @param int32_t workerId.
// @return Status - The error code return.
Status EoeReceived(int32_t) override;

// A print method typically used for debugging.
// @param out The output stream to write output to.
// @param show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;

private:
// predicate_func python callable which returns a boolean value.
py::function predicate_func_;

// Variable to store the column name that will feed to predicate function.
std::vector<std::string> in_columns_;

// Internal queue for filter.
QueueList<std::pair<std::unique_ptr<DataBuffer>, filterCtrl>> filter_queues_;

// Private function for worker/thread to loop continuously. It comprises the main
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
// applying predicate to each of the data, filter the data when predicate result is false.
// @param worker_id The id assigned to this thread/worker upon creation.
// @return Status The error code return.
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_

// Filter the data by predicate function .
// @param in_buffer input data buffer.
// @param to_proess_indices Indices of columns to be processed.
// @param out data buffer that are filtered by predicate.
// @return Status The error code return.
Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out);

// Collector databuffer.
// @return Status The error code return.
Status Collector();

// @param input tensor vector.
// @return Status - The error code return.
Status CheckInput(const TensorRow &input) const;

// Invoke python func.
// @param input tensor vector.
// @param the result of predicate.
// @return Status - The error code return.
Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate);

// Private function for validating if each of the user specified input column names
// exist in the DataBuffer.
// @param col_name_id_map The column name to index mapping obtained from DataBuffer.
// @param input_columns The vector of input column names used in the current thread.
// @return Status The error code return.
Status ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map,
const std::vector<std::string> *input_columns);

// Private function for checking the column legality
// @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory
// and is not shared with other threads.
// @param[out] to_process_indices Indices of columns that will feed to predicate.
// @param input_columns The vector of input column names used in the current thread.
Status CheckColumns(const DataBuffer *in_buf, const std::vector<std::string> *input_columns);
};

} // namespace dataset
} // namespace mindspore
#endif

+ 0
- 22
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc View File

@@ -65,9 +65,6 @@ MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std
tfuncs_(std::move(tensor_funcs)),
in_columns_(in_col_names),
out_columns_(out_col_names),
#if defined(_WIN32) || defined(_WIN64)
eof_worker_id_(0),
#endif
perf_mode_(perf_mode) {
// If caller didn't specify the out_col_names, assume they are same as the in_columns.
if (out_columns_.empty() || out_columns_[0].empty()) {
@@ -123,17 +120,6 @@ Status MapOp::operator()() {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
is_eof = buff->eof();
RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff)));
#if defined(_WIN32) || defined(_WIN64)
if (is_eof) {
eof_worker_id_ = que_id;
for (int32_t id = 0; id < num_workers_; id++) {
if (id != eof_worker_id_) {
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(local_queues_[id]->Add(std::move(eof_buffer)));
}
}
}
#endif
que_id = (que_id + 1) % num_workers_;
}
}
@@ -173,14 +159,6 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
continue;
} else if (in_buffer->eof()) {
// Calling base class EofReceived to forward eof buffer.
#if defined(_WIN32) || defined(_Win64)
if (perf_mode_) {
if (eof_worker_id_ == worker_id) {
RETURN_IF_NOT_OK(EofReceived(worker_id));
}
break;
}
#endif
RETURN_IF_NOT_OK(EofReceived(worker_id));
break;
}


+ 0
- 4
mindspore/ccsrc/dataset/engine/datasetops/map_op.h View File

@@ -193,10 +193,6 @@ class MapOp : public ParallelOp {
// cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order.
bool perf_mode_;

#if defined(_WIN32) || defined(_WIN64)
// EOF worker id is only work on Performance mode, to record the worker id of queue which gets EOF
int32_t eof_worker_id_;
#endif
// Private function for worker/thread to loop continuously. It comprises the main
// logic of MapOp: getting the data from previous Op, validating user specified column names,
// applying a list of TensorOps to each of the data, process the results and then


+ 6
- 1
mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc View File

@@ -13,6 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(_WIN32) || defined(_WIN64)
#include <stdlib.h>
#endif
#include <securec.h>
#include <algorithm>
#include <chrono>
@@ -86,7 +89,9 @@ Status ShuffleOp::SelfReset() {
rng_ = std::mt19937_64(shuffle_seed_);
} else {
#if defined(_WIN32) || defined(_WIN64)
std::random_device random_device;
unsigned int number;
rand_s(&number);
std::mt19937 random_device{static_cast<uint32_t>(number)};
#else
std::random_device random_device("/dev/urandom");
#endif


+ 15
- 21
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc View File

@@ -67,9 +67,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
}

std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));

// Drop first max_skips_ rows
while (skip_count_ < max_skips_) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
if (buf->eoe() || buf->eof()) {
break;
}
@@ -77,31 +78,24 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
// Consider the rows of buffer more than 1
TensorRow drop_row;
int row_num = buf->NumRows();
for (int i = 0; i < row_num; i++) {
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
skip_count_ += drop_num;
for (int i = 0; i < drop_num; i++) {
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
if (++skip_count_ == max_skips_) {
break;
}
}
}

// If buffer is none or the rows of buffer is 0,
// then get a buffer from child.
if (!buf || buf->NumRows() == 0) {
if (buf && buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
if (buf->NumRows() == 0) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}

// Handling eoe and eof
if (buf->eoe() || buf->eof()) {
// Handling eoe
if (buf->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
if (state_ == OpState::kDeOpIdle) {
*p_buffer = std::move(buf);
return Status::OK();
}
}

// Handling eof
if (buf->eof()) {
RETURN_IF_NOT_OK(EofReceived(worker_id));
}

*p_buffer = std::move(buf);
@@ -125,7 +119,7 @@ Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is a

// Base-class override for handling cases when an eof is received.
Status SkipOp::EofReceived(int32_t worker_id) {
MS_LOG(INFO) << "Skip operator EOF received, do nothing now.";
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
return Status::OK();
}
} // namespace dataset


+ 1
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt View File

@@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT
manifest_op.cc
cifar_op.cc
celeba_op.cc
text_file_op.cc
)

add_dependencies(engine-datasetops-source mindspore::protobuf)

+ 3
- 2
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc View File

@@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
return Status::OK();
}

Status MindRecordOp::CountTotalRows(const std::string dataset_path, int64_t *count) {
Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
int64_t *count) {
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, count);
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count);
if (rc == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
}


+ 2
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h View File

@@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp {
int32_t num_rows() const { return num_rows_; }

// Getter method
static Status CountTotalRows(const std::string dataset_path, int64_t *count);
static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
int64_t *count);

// Getter method
int32_t rows_per_buffer() const { return rows_per_buffer_; }


+ 1
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt View File

@@ -1,6 +1,7 @@
add_library(engine-datasetops-source-sampler OBJECT
distributed_sampler.cc
pk_sampler.cc
python_sampler.cc
random_sampler.cc
sampler.cc
sequential_sampler.cc


+ 85
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc View File

@@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"

#include <memory>

namespace mindspore {
namespace dataset {

PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}

Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
std::shared_ptr<Tensor> sample_ids;
{
py::gil_scoped_acquire gil_acquire;
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
py::object py_ret = py_sampler_instance.attr("_get_indices")();
py::array np_sample_ids = py_ret.cast<py::array>();
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
} catch (const py::cast_error &e) {
return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index");
}
}
TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
need_to_reset_ = true;
}
return Status::OK();
}

Status PythonSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0");
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
py_sampler_instance.attr("_handshake")(num_rows_, num_samples_);
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
}
return Status::OK();
}

Status PythonSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch");
need_to_reset_ = false;
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
py_sampler_instance.attr("reset")();
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 58
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h View File

@@ -0,0 +1,58 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_

#include <limits>
#include <memory>

#include "dataset/engine/datasetops/source/sampler/sampler.h"

namespace mindspore {
namespace dataset {
class PythonSampler : public Sampler {
public:
// Constructor
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());

// Destructor.
~PythonSampler() = default;

// Initialize the sampler.
// @return Status
Status InitSampler() override;

// for next epoch of sampleIds
// @return - The error code return
Status Reset() override;

// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;

private:
bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer()

py::object py_sampler_instance; // The handle to the py_sampler python object
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_

+ 0
- 3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc View File

@@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids;

// check samples_per_buffer is properly set and doesn't overflow
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid");

// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK(GetNextBuffer(&db));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch


+ 1
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc View File

@@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
}

Status SequentialSampler::InitSampler() {
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();


+ 5
- 1
mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc View File

@@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
std::ifstream in(schemaFile);
nlohmann::json js;
in >> js;
num_rows = js.value("numRows", 0);
if (js.find("numRows") == js.end()) {
num_rows = MAX_INTEGER_INT32;
} else {
num_rows = js.value("numRows", 0);
}
if (num_rows == 0) {
std::string err_msg =
"Storage client has not properly done dataset "


+ 459
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc View File

@@ -0,0 +1,459 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <utility>

#include "common/utils.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/core/config_manager.h"
#include "dataset/util/task_manager.h"
#include "dataset/util/wait_post.h"
#include "dataset/util/random.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/execution_tree.h"

namespace mindspore {
namespace dataset {
TextFileOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
builder_worker_connector_size_ = config_manager->worker_connector_size();
}

Status TextFileOp::Builder::ValidateInputs() const {
std::string err_msg;
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : "";
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
RETURN_IF_NOT_OK(ValidateInputs());

// Throttle the number of workers if we have more workers than files!
if (static_cast<size_t>(builder_num_workers_) > builder_text_files_list_.size()) {
builder_num_workers_ = builder_text_files_list_.size();
MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
}

builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));

std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(text_file_op->Init());
*op = std::move(text_file_op);

return Status::OK();
}

TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
num_samples_(num_samples),
text_files_list_(std::move(text_files_list)),
shuffle_files_(shuffle_files),
data_schema_(std::move(schema)),
all_num_rows_(0),
num_rows_per_shard_(0),
filename_index_(std::make_unique<StringIndex>()),
finished_reading_dataset_(false),
load_io_block_queue_(true),
load_jagged_connector_(true) {
worker_connector_size_ = worker_connector_size;
}

Status TextFileOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_));

int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1);
io_block_queues_.Init(num_workers_, safe_queue_size);

for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
col_name_map_[data_schema_->column(i).name()] = i;
}

RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));

jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
return Status::OK();
}

Status TextFileOp::Reset() {
load_jagged_connector_ = true;
load_io_block_queue_ = true;

RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}

Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
TensorRow tRow(1, nullptr);
(*tensor_table)->push_back(std::move(tRow));

std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(),
TensorShape(std::vector<dsize_t>(1, line.size())), data_schema_->column(0).type(),
const_cast<unsigned char *>(reinterpret_cast<const unsigned char *>(common::SafeCStr(line)))));
(**tensor_table)[row][0] = std::move(tensor);
return Status::OK();
}

Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
std::ifstream handle(file);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Failed to open file " + file);
}

int64_t rows_each_buffer = 0;
int64_t rows_total = 0;
std::string line;
std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
cur_buffer->set_column_name_map(col_name_map_);
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();

while (getline(handle, line)) {
// If read to the end offset of this file, break.
if (rows_total >= end_offset) {
break;
}
// Skip line before start offset.
if (rows_total < start_offset) {
rows_total++;
continue;
}

RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer));
rows_each_buffer++;
rows_total++;
if (rows_each_buffer == rows_per_buffer_) {
cur_buffer->set_tensor_table(std::move(tensor_table));
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));

cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
cur_buffer->set_column_name_map(col_name_map_);
tensor_table = std::make_unique<TensorQTable>();
rows_each_buffer = 0;
}
}

if (rows_each_buffer > 0) {
cur_buffer->set_tensor_table(std::move(tensor_table));
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
}

return Status::OK();
}

Status TextFileOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();

std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}

RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}

// Pops an element from a queue in io_block_queues
Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));

return Status::OK();
}

// Pushes an element to a queue in io_block_queues
Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status TextFileOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status TextFileOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}

return Status::OK();
}

static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}

bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}

int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

return push;
}

Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
int64_t start_offset = 0;
int64_t end_offset = 0;
bool finish = false;
while (!finish) {
std::vector<std::pair<std::string, int64_t>> file_index;
if (!i_keys.empty()) {
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
auto file_it = filename_index_->Search(*it);
file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it));
}
} else {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
}
}
for (auto file_info : file_index) {
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
auto ioBlock =
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
queue_index = (queue_index + 1) % num_workers_;
}

pre_count += filename_numrows_[file_info.first];
}

if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
finish = false;
} else {
finish = true;
}
}

RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
return Status::OK();
}

Status TextFileOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();

std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();

if (finished_reading_dataset_) {
break;
}

if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}

void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }

Status TextFileOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());

// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this)));

// Read data from disk into buffers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1)));

// must be called after launching workers.
TaskManager::FindMe()->Post();

io_block_queue_wait_post_.Register(tree_->AllTasks());
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;

while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}

std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));

if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
}

std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

RETURN_IF_NOT_OK(PostEndOfData());

return Status::OK();
}

int64_t TextFileOp::CountTotalRows(const std::string &file) {
std::ifstream handle(file);
if (!handle.is_open()) {
MS_LOG(ERROR) << "Failed to open file: " << file;
return 0;
}

std::string line;
int64_t count = 0;
while (getline(handle, line)) {
count++;
}

return count;
}

Status TextFileOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count;
all_num_rows_ += count;
}
if (all_num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED("Number of rows can not be zero");
}

num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}

Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
std::shared_ptr<TextFileOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op));
for (auto file : files) {
*count += op->CountTotalRows(file);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 263
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h View File

@@ -0,0 +1,263 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_

#include <memory>
#include <map>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>

#include "dataset/util/status.h"
#include "dataset/util/auto_index.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/queue.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/jagged_connector.h"

namespace mindspore {
namespace dataset {
using StringIndex = AutoIndexObj<std::string>;

class TextFileOp : public ParallelOp {
public:
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();

// Default destructor
~Builder() = default;

// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status ValidateInputs() const;

// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status Build(std::shared_ptr<TextFileOp> *op);

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumDevices(int64_t num_dev) {
builder_num_devices_ = num_dev;
return *this;
}

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetDeviceId(int64_t dev_id) {
builder_device_id_ = dev_id;
return *this;
}

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetTextFilesList(const std::vector<std::string> &files_list) {
builder_text_files_list_ = files_list;
return *this;
}

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetShuffleFiles(bool shuffle_files) {
builder_shuffle_files_ = shuffle_files;
return *this;
}

// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}

private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
int32_t builder_worker_connector_size_;
std::vector<std::string> builder_text_files_list_;
bool builder_shuffle_files_;
std::unique_ptr<DataSchema> builder_schema_;
};

// Constructor of TextFileOp
// @note The builder class should be used to call this constructor.
// @param num_workers - number of worker threads reading data from tf_file files.
// @param rows_per_buffer - number of rows that a full buffer will contain.
// @param total_num_rows - number of rows to read
// @param dataset_files_list - list of filepaths for the dataset files.
// @param data_schema - the data schema object.
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);

// Default destructor
~TextFileOp() = default;

// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();

// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;

// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;

// Get total rows in files.
// @param files - all text files.
// @param count - number of rows.
// @return Status - the error coed returned.
static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count);

private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;

// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row);

// Reads a text file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);

// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();

// Count number of rows in each file.
// @param filename - text file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);

// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();

// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();

// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);

// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);

// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);

// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);

int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t num_samples_;
std::vector<std::string> text_files_list_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_;
int64_t all_num_rows_;
int64_t num_rows_per_shard_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
bool finished_reading_dataset_;
bool load_io_block_queue_;
bool load_jagged_connector_;
std::unordered_map<std::string, int32_t> col_name_map_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_

+ 54
- 6
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc View File

@@ -42,6 +42,7 @@
#include "dataset/util/status.h"
#include "dataset/util/task_manager.h"
#include "dataset/util/wait_post.h"
#include "utils/system/crc32c.h"

namespace mindspore {
namespace dataset {
@@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder()
builder_data_schema_ = std::make_unique<DataSchema>();
}

bool ValidateFirstRowCrc(const std::string &filename) {
std::ifstream reader;
reader.open(filename);
if (!reader) {
return false;
}

// read data
int64_t record_length = 0;
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));

// read crc from file
uint32_t masked_crc = 0;
(void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));

// generate crc from data
uint32_t generated_crc =
system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));

return masked_crc == generated_crc;
}

Status TFReaderOp::Builder::ValidateInputs() const {
std::string err_msg;
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : "";
if (!builder_equal_rows_per_shard_) {
err_msg += builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)
? "No enough tf_file files provided\n"
: "";

if (builder_num_workers_ <= 0) {
err_msg += "Number of parallel workers is smaller or equal to 0\n";
}

if (!builder_equal_rows_per_shard_ &&
builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) {
err_msg += "Not enough tfrecord files provided\n";
}

if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) {
err_msg += "Wrong sharding configs\n";
}

std::vector<std::string> invalid_files(builder_dataset_files_list_.size());
auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(),
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); });
invalid_files.resize(std::distance(invalid_files.begin(), it));

if (!invalid_files.empty()) {
err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n";

std::string accumulated_filenames = std::accumulate(
invalid_files.begin(), invalid_files.end(), std::string(""),
[](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; });
err_msg += accumulated_filenames;
}
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

@@ -119,6 +163,9 @@ Status TFReaderOp::Init() {
if (total_rows_ == 0) {
total_rows_ = data_schema_->num_rows();
}
if (total_rows_ < 0) {
RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0");
}

// Build the index with our files such that each file corresponds to a key id.
RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
@@ -523,6 +570,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read));
rows_read++;
}

// ignore crc footer
(void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
rows_total++;


+ 6
- 4
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc View File

@@ -67,7 +67,7 @@ Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
if (take_count_ == max_takes_) {
if (state_ == OpState::kDeOpRunning) {
MS_LOG(INFO) << "meet max count and push-back eoe buffer.";
MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer.";
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
*p_buffer = std::move(eoe_buffer);
state_ = OpState::kDeOpIdle;
@@ -80,11 +80,13 @@ Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
}
} else {
MS_LOG(INFO) << "meet max count and push-back eof buffer.";
} else if (state_ == OpState::kDeOpIdle) {
MS_LOG(DEBUG) << "Meet max count and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
*p_buffer = std::move(eof_buffer);
take_count_ = 0;
} else {
MS_LOG(WARNING) << "Invalid OpState: " << state_;
}
return Status::OK();
}
@@ -116,7 +118,7 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
*data_buffer = std::move(*buffer);
take_count_ = take_count_ + buffer_size;
} else {
MS_LOG(INFO) << "In last buffer: Push one buffer.";
MS_LOG(DEBUG) << "In last buffer: Push one buffer.";
std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>();
while (take_count_ < max_takes_) {
TensorRow new_row;


+ 9
- 9
mindspore/ccsrc/dataset/engine/datasetops/zip_op.h View File

@@ -34,7 +34,7 @@ class DataBuffer;

class ZipOp : public PipelineOp {
public:
// The nested builder class inside of the BatchOp is used to help manage all of
// The nested builder class inside of the ZipOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
@@ -76,8 +76,8 @@ class ZipOp : public PipelineOp {
};

// Constructor for ZipOp
// @param rows_per_buffer number of rows in output buffer
// @param op_connector_size connector
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size
ZipOp(int32_t rows_per_buffer, int32_t op_connector_size);

// Destructor
@@ -88,8 +88,8 @@ class ZipOp : public PipelineOp {
Status EoeReceived(int32_t) override;

// Print function for Zip
// @param out output stream to print to
// @param show_all if it should print everything
// @param out - output stream to print to
// @param show_all - if it should print everything
void Print(std::ostream &out, bool show_all) const override;

// Provide stream operator for displaying it
@@ -113,14 +113,14 @@ class ZipOp : public PipelineOp {
Status fillBuffer(TensorQTable *const table);

// Special handle case where an empty row has been received from child iterator
// @note we need to drain eoe signals from all children connectors.
// @details when this function is called, then we encountered eoe at child iterator
// @note - we need to drain eoe signals from all children connectors.
// @details - when this function is called, then we encountered eoe at child iterator
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
Status drainPipeline();

// Merges 1 row from each childIterator together
// @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true
// @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true
// @details merge rows from iterator together. This is the main functionality for ZipOp
// this function takes one row and fills it with tensors from rows fetched
// from childIterators.


+ 3
- 3
mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt View File

@@ -3,7 +3,6 @@ if (WIN32)
center_crop_op.cc
cut_out_op.cc
decode_op.cc
distort_bounding_box_crop_op.cc
hwc_to_chw_op.cc
image_utils.cc
normalize_op.cc
@@ -19,6 +18,7 @@ if (WIN32)
rescale_op.cc
resize_bilinear_op.cc
resize_op.cc
uniform_aug_op.cc
)
else()
add_library(kernels-image OBJECT
@@ -26,7 +26,6 @@ else()
change_mode_op.cc
cut_out_op.cc
decode_op.cc
distort_bounding_box_crop_op.cc
hwc_to_chw_op.cc
image_utils.cc
normalize_op.cc
@@ -42,5 +41,6 @@ else()
rescale_op.cc
resize_bilinear_op.cc
resize_op.cc
uniform_aug_op.cc
)
endif()
endif()

+ 4
- 3
mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc View File

@@ -33,7 +33,8 @@ const uint8_t CutOutOp::kDefFillB = 0;
// constructor
CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r,
uint8_t fill_g, uint8_t fill_b)
: box_height_(box_height),
: rnd_(GetSeed()),
box_height_(box_height),
box_width_(box_width),
num_patches_(num_patches),
random_color_(random_color),
@@ -46,8 +47,8 @@ Status CutOutOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
IO_CHECK(input, output);
std::shared_ptr<CVTensor> inputCV = CVTensor::AsCVTensor(input);
// cut out will clip the erasing area if the box is near the edge of the image and the boxes are black
RETURN_IF_NOT_OK(
Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, fill_r_, fill_g_, fill_b_));
RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_,
fill_g_, fill_b_));
return Status::OK();
}
} // namespace dataset


+ 1
- 0
mindspore/ccsrc/dataset/kernels/image/cut_out_op.h View File

@@ -62,6 +62,7 @@ class CutOutOp : public TensorOp {
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;

private:
std::mt19937 rnd_;
int32_t box_height_;
int32_t box_width_;
int32_t num_patches_;


+ 4
- 4
mindspore/ccsrc/dataset/kernels/image/decode_op.h View File

@@ -34,11 +34,11 @@ class DecodeOp : public TensorOp {

~DecodeOp() = default;

Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;

void Print(std::ostream& out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;
void Print(std::ostream &out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;

private:
bool is_rgb_format_ = true;


+ 0
- 117
mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc View File

@@ -1,117 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/image/distort_bounding_box_crop_op.h"
#include <random>
#include "dataset/core/cv_tensor.h"
#include "dataset/kernels/image/image_utils.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"

namespace mindspore {
namespace dataset {
const int32_t DistortBoundingBoxCropOp::kDefMaxAttempts = 100;
const int32_t DistortBoundingBoxCropOp::kDefBoxGenAttempts = 10;

DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb,
float crop_ratio_ub, int32_t max_attempts, int32_t box_gen_attempts)
: max_attempts_(max_attempts),
box_gen_attempts_(box_gen_attempts),
aspect_ratio_(aspect_ratio),
intersect_ratio_(intersect_ratio),
crop_ratio_lb_(crop_ratio_lb),
crop_ratio_ub_(crop_ratio_ub) {
seed_ = GetSeed();
rnd_.seed(seed_);
}

Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input,
std::vector<std::shared_ptr<Tensor>>* output) {
IO_CHECK_VECTOR(input, output);
if (input.size() != NumInput())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5");

CHECK_FAIL_RETURN_UNEXPECTED(input[1]->shape().Size() >= 1, "The shape of the second tensor is abnormal");
int64_t num_boxes = 0;
for (uint64_t i = 1; i < input.size(); i++) {
if (i == 1) num_boxes = input[i]->shape()[0];
if (num_boxes != input[i]->shape()[0])
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Numbers of boxes do not match");

if (input[i]->type() != DataType::DE_FLOAT32)
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "boxes' type is not DE_FLOAT21");
}

// assume input Tensor vector in the order of [img, bbox_y_min, bbox_y_max, bbox_x_min, bbox_x_max]
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of the first tensor is abnormal");
int h_in = input[0]->shape()[0];
int w_in = input[0]->shape()[1];

std::vector<cv::Rect> bounding_boxes;
for (int64_t i = 0; i < num_boxes; ++i) {
// bbox coordinates are floats relative to the image width and height
float y_min, y_max, x_min, x_max;
RETURN_IF_NOT_OK(input[1]->GetItemAt<float>(&y_min, {i}));
RETURN_IF_NOT_OK(input[2]->GetItemAt<float>(&y_max, {i}));
RETURN_IF_NOT_OK(input[3]->GetItemAt<float>(&x_min, {i}));
RETURN_IF_NOT_OK(input[4]->GetItemAt<float>(&x_max, {i}));
bounding_boxes.emplace_back(static_cast<int>(x_min * w_in), static_cast<int>(y_min * h_in),
static_cast<int>((x_max - x_min) * w_in), static_cast<int>((y_max - y_min) * h_in));
}
cv::Rect output_box;
bool should_crop = false;

// go over iterations, if no satisfying box found we return the original image
for (int32_t t = 0; t < max_attempts_; ++t) {
// try to generate random box
RETURN_IF_NOT_OK(GenerateRandomCropBox(h_in, w_in, aspect_ratio_, crop_ratio_lb_, crop_ratio_ub_,
box_gen_attempts_, // int maxIter, should not be needed here
&output_box, seed_));
RETURN_IF_NOT_OK(CheckOverlapConstraint(output_box,
bounding_boxes, // have to change, should take tensor or add bbox logic
intersect_ratio_, &should_crop));
if (should_crop) {
// found a box to crop
break;
}
}
// essentially we have to check this again at the end to return original tensor
if (should_crop) {
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(Crop(input[0], &out, output_box.x, output_box.y, output_box.width, output_box.height));
output->push_back(out);
} else {
output->push_back(input[0]);
}
return Status::OK();
}

Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs,
std::vector<TensorShape>& outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
TensorShape out = TensorShape{-1, -1};
if (inputs[0].Rank() == 2) outputs.emplace_back(out);
if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2]));
if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
}
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
outputs[0] = inputs[0];
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 72
mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h View File

@@ -1,72 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_
#define DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_

#include <memory>
#include <random>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"

namespace mindspore {
namespace dataset {
class DistortBoundingBoxCropOp : public TensorOp {
public:
// Default values, also used by python_bindings.cc
static const int32_t kDefMaxAttempts;
static const int32_t kDefBoxGenAttempts;

// Constructor for DistortBoundingBoxCropOp
// @param max_attempts tries before the crop happens
// @param box_gen_attempts crop box generation attempts
// @param aspect_ratio aspect ratio of the generated crop box
// @param intersect_ratio area overlap ratio, condition for crop only if area over lap between the generated
// crop box has sufficient overlap with any 1 bounding box
// @param crop_ratio_lb the crop ratio lower bound
// @param crop_ratio_ub the crop ratio upper bound
// @param seed
DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, float crop_ratio_ub,
int32_t max_attempts = kDefMaxAttempts, int32_t box_gen_attempts = kDefBoxGenAttempts);

~DistortBoundingBoxCropOp() override = default;

void Print(std::ostream& out) const override {
out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_;
}

Status Compute(const std::vector<std::shared_ptr<Tensor>>& input,
std::vector<std::shared_ptr<Tensor>>* output) override;

uint32_t NumInput() override { return 5; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;

private:
int32_t max_attempts_;
int32_t box_gen_attempts_;
float aspect_ratio_;
float intersect_ratio_;
float crop_ratio_lb_;
float crop_ratio_ub_;
std::mt19937 rnd_;
uint32_t seed_;
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_

+ 7
- 73
mindspore/ccsrc/dataset/kernels/image/image_utils.cc View File

@@ -636,76 +636,10 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
return Status::OK();
}

Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr,
cv::Rect *crop_box, uint32_t seed) {
try {
std::mt19937 rnd;
rnd.seed(GetSeed());
if (input_height <= 0 || input_width <= 0 || ratio <= 0.0 || lb <= 0.0 || lb > ub) {
RETURN_STATUS_UNEXPECTED("Invalid inputs GenerateRandomCropBox");
}
std::uniform_real_distribution<float> rd_crop_ratio(lb, ub);
float crop_ratio;
int crop_width, crop_height;
bool crop_success = false;
int64_t input_area = input_height * input_width;
for (auto i = 0; i < max_itr; i++) {
crop_ratio = rd_crop_ratio(rnd);
crop_width = static_cast<int32_t>(std::round(std::sqrt(input_area * static_cast<double>(crop_ratio) / ratio)));
crop_height = static_cast<int32_t>(std::round(crop_width * ratio));
if (crop_width <= input_width && crop_height <= input_height) {
crop_success = true;
break;
}
}
if (crop_success == false) {
ratio = static_cast<float>(input_height) / input_width;
crop_ratio = rd_crop_ratio(rnd);
crop_width = static_cast<int>(std::lround(std::sqrt(input_area * static_cast<double>(crop_ratio) / ratio)));
crop_height = static_cast<int>(std::lround(crop_width * ratio));
crop_height = (crop_height > input_height) ? input_height : crop_height;
crop_width = (crop_width > input_width) ? input_width : crop_width;
}
std::uniform_int_distribution<> rd_x(0, input_width - crop_width);
std::uniform_int_distribution<> rd_y(0, input_height - crop_height);
*crop_box = cv::Rect(rd_x(rnd), rd_y(rnd), crop_width, crop_height);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("error in GenerateRandomCropBox.");
}
}

Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Rect> &bounding_boxes,
float min_intersect_ratio, bool *is_satisfied) {
try {
// not satisfied if the crop box contains no pixel
if (crop_box.area() < 1.0) {
*is_satisfied = false;
}
for (const auto &b_box : bounding_boxes) {
const float b_box_area = b_box.area();
// not satisfied if the bounding box contains no pixel
if (b_box_area < 1.0) {
continue;
}
const float intersect_ratio = (crop_box & b_box).area() / b_box_area;
if (intersect_ratio >= min_intersect_ratio) {
*is_satisfied = true;
break;
}
}
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("error in CheckOverlapConstraint.");
}
}

Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height,
int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r, uint8_t fill_g,
uint8_t fill_b) {
int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r,
uint8_t fill_g, uint8_t fill_b) {
try {
std::mt19937 rnd;
rnd.seed(GetSeed());
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (input_cv->mat().data == nullptr || (input_cv->Rank() != 3 && input_cv->shape()[2] != 3)) {
RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase");
@@ -731,8 +665,8 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp
// rows in cv mat refers to the height of the cropped box
// we determine h_start and w_start using two different distributions as erasing is used by two different
// image augmentations. The bounds are also different in each case.
int32_t h_start = (bounded) ? height_distribution_bound(rnd) : (height_distribution_unbound(rnd) - box_height);
int32_t w_start = (bounded) ? width_distribution_bound(rnd) : (width_distribution_unbound(rnd) - box_width);
int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height);
int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width);

int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width;
int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height;
@@ -744,9 +678,9 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp
for (int x = h_start; x < max_height; x++) {
if (random_color) {
// fill each box with a random value
input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = static_cast<int32_t>(normal_distribution(rnd));
input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = static_cast<int32_t>(normal_distribution(rnd));
input_img.at<cv::Vec3b>(cv::Point(y, x))[2] = static_cast<int32_t>(normal_distribution(rnd));
input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = static_cast<int32_t>(normal_distribution(*rnd));
input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = static_cast<int32_t>(normal_distribution(*rnd));
input_img.at<cv::Vec3b>(cv::Point(y, x))[2] = static_cast<int32_t>(normal_distribution(*rnd));
} else {
input_img.at<cv::Vec3b>(cv::Point(y, x))[0] = fill_r;
input_img.at<cv::Vec3b>(cv::Point(y, x))[1] = fill_g;


+ 2
- 8
mindspore/ccsrc/dataset/kernels/image/image_utils.h View File

@@ -196,12 +196,6 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
// @param output: Adjusted image of same shape and type.
Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &hue);

Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr,
cv::Rect *crop_box, uint32_t seed = std::mt19937::default_seed);

Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Rect> &bounding_boxes,
float min_intersect_ratio, bool *is_satisfied);

// Masks out a random section from the image with set dimension
// @param input: input Tensor
// @param output: cutOut Tensor
@@ -214,8 +208,8 @@ Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector<cv::Re
// @param fill_g: green fill value for erase
// @param fill_b: blue fill value for erase.
Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t box_height,
int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r = 0,
uint8_t fill_g = 0, uint8_t fill_b = 0);
int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd,
uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0);

// Pads the input image and puts the padded image in the output
// @param input: input Tensor


+ 3
- 3
mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc View File

@@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
rnd_.seed(GetSeed());
}

Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) {
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal");

@@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std:
(void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_);
}
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) {
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
TensorShape out = TensorShape{target_height_, target_width_};
@@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs
if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
}
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) {
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {
double scale, aspect;
*crop_width = w_in;
*crop_height = h_in;


+ 87
- 0
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc View File

@@ -0,0 +1,87 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"

namespace mindspore {
namespace dataset {
const int UniformAugOp::kDefNumOps = 2;

UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) {
std::shared_ptr<TensorOp> tensor_op;
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
for (auto op : op_list) {
if (py::isinstance<py::function>(op)) {
// python op
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
} else if (py::isinstance<TensorOp>(op)) {
// C++ op
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
}
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
}

rnd_.seed(GetSeed());
}
// compute method to apply uniformly random selected augmentations from a list
Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) {
IO_CHECK_VECTOR(input, output);

// variables to generate random number to select ops from the list
std::vector<int> random_indexes;

// variables to copy the result to output if it is not already
std::vector<std::shared_ptr<Tensor>> even_out;
std::vector<std::shared_ptr<Tensor>> *even_out_ptr = &even_out;
int count = 1;

// select random indexes for candidates to be applied
for (int i = 0; i < num_ops_; ++i) {
random_indexes.insert(random_indexes.end(),
std::uniform_int_distribution<int>(0, tensor_op_list_.size() - 1)(rnd_));
}

for (auto it = random_indexes.begin(); it != random_indexes.end(); ++it) {
// Do NOT apply the op, if second random generator returned zero
if (std::uniform_int_distribution<int>(0, 1)(rnd_)) {
continue;
}
std::shared_ptr<TensorOp> tensor_op = tensor_op_list_[*it];

// apply python/C++ op
if (count == 1) {
(*tensor_op).Compute(input, output);
} else if (count % 2 == 0) {
(*tensor_op).Compute(*output, even_out_ptr);
} else {
(*tensor_op).Compute(even_out, output);
}
count++;
}

// copy the result to output if it is not in output
if (count == 1) {
*output = input;
} else if ((count % 2 == 1)) {
(*output).swap(even_out);
}

return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 60
- 0
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h View File

@@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_
#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_

#include <memory>
#include <random>
#include <string>
#include <vector>

#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
#include "dataset/kernels/py_func_op.h"

#include "pybind11/stl.h"

namespace mindspore {
namespace dataset {
class UniformAugOp : public TensorOp {
public:
// Default number of Operations to be applied
static const int kDefNumOps;

// Constructor for UniformAugOp
// @param list op_list: list of candidate python operations
// @param list num_ops: number of augemtation operations to applied
UniformAugOp(py::list op_list, int32_t num_ops);

~UniformAugOp() override = default;

void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; }

// Overrides the base class compute function
// @return Status - The error code return
Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) override;

private:
int32_t num_ops_;
std::vector<std::shared_ptr<TensorOp>> tensor_op_list_;
std::mt19937 rnd_;
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_

+ 6
- 1
mindspore/ccsrc/dataset/util/random.cc View File

@@ -18,6 +18,9 @@

#include "dataset/util/random.h"

#if defined(_WIN32) || defined(_WIn64)
#include <stdlib.h>
#endif
#include <limits>
#include <memory>
#include <random>
@@ -33,7 +36,9 @@ uint32_t GetSeed() {
uint32_t seed = GlobalContext::config_manager()->seed();
if (seed == std::mt19937::default_seed) {
#if defined(_WIN32) || defined(_WIN64)
std::random_device random_device;
unsigned int number;
rand_s(&number);
std::mt19937 random_device{static_cast<uint32_t>(number)};
#else
std::random_device random_device("/dev/urandom");
#endif


+ 5
- 1
mindspore/ccsrc/dataset/util/services.cc View File

@@ -18,6 +18,8 @@
#include <limits.h>
#if !defined(_WIN32) && !defined(_WIN64)
#include <sys/syscall.h>
#else
#include <stdlib.h>
#endif
#include <unistd.h>
#include <random>
@@ -49,7 +51,9 @@ int Services::GetLWP() { return syscall(SYS_gettid); }
std::string Services::GetUniqueID() {
const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789";
#if defined(_WIN32) || defined(_WIN64)
std::mt19937 gen{std::random_device{}()};
unsigned int number;
rand_s(&number);
std::mt19937 gen{static_cast<uint32_t>(number)};
#else
std::mt19937 gen{std::random_device{"/dev/urandom"}()};
#endif


+ 1
- 1
mindspore/ccsrc/debug/anf_ir_dump.h View File

@@ -22,7 +22,7 @@

namespace mindspore {
constexpr char PARALLEL_STRATEGY[] = "strategy";
void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false);
void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false);

} // namespace mindspore



+ 108
- 108
mindspore/ccsrc/debug/anf_ir_utils.cc View File

@@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
// get MindSpore Intermediate Representation Path
std::string GetMsIrPath(void) {
std::string path;
const char* path_ptr = getenv("MS_IR_PATH");
const char *path_ptr = getenv("MS_IR_PATH");
if (path_ptr != nullptr) {
path = path_ptr;
char real_path[PATH_MAX] = {0};
@@ -62,13 +62,13 @@ std::string GetMsIrPath(void) {
return path;
}

std::string dump_obj(const py::object& obj, const std::string& path) {
std::string dump_obj(const py::object &obj, const std::string &path) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
return py::str(name);
}

py::object load_obj(const std::string& path) {
py::object load_obj(const std::string &path) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
return obj;
@@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) {

// ============================================= MindSpore IR Exporter =============================================

std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
TypePtr type = dyn_cast<Type>(nd->Type());
std::ostringstream oss;
@@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
return oss.str();
}

std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const {
std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const {
std::string pkl_path = GetMsIrPath();
// if not specified env 'MS_IR_PATH', do not create any files
if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
@@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca
return file_prefix + file_name;
}

int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) {
int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp) {
if (func_graph == nullptr || param == nullptr) {
return -1;
}
@@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&

// try to find index of parameter for SymbolicKeyInstance from all exported graphs
// NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) {
int AnfExporter::GetParamIndexFromExported(const AnfNodePtr &param) {
if (param == nullptr) {
return -1;
}

int ret = -1;
for (const auto& item : exported) {
for (const auto &item : exported) {
auto pram_iter = item.second.find(param);
if (pram_iter != item.second.end()) {
return pram_iter->second;
@@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) {
return ret;
}

std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) {
std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return GetValueText(fg, node->value());
}

std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) {
std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) {
auto py_funcs = mt_func_graph->GetPyFunctions();
if (py_funcs.empty()) {
return "";
@@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap

oss << "{";
bool is_first = true;
for (const auto& py_func : py_funcs) {
for (const auto &py_func : py_funcs) {
if (is_first) {
is_first = false;
} else {
@@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
* ├── GradOperation
* └── TupleAdd
*/
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) {
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
if (meta_func_graph == nullptr) {
return "";
}
@@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_
return oss.str();
}

std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) {
std::ostringstream oss;
if (prim == nullptr) {
return oss.str();
@@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {

if (prim->isa<prim::DoSignaturePrimitive>()) {
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
auto& func = do_signature->function();
auto &func = do_signature->function();
if (func->isa<Primitive>()) {
auto sig_prim = dyn_cast<Primitive>(func);
oss << sig_prim->GetAttrsText();
@@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
return oss.str();
}

std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) {
std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) {
std::ostringstream oss;
if (ns == nullptr) {
return oss.str();
@@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) {
return oss.str();
}

std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph,
const SymbolicKeyInstancePtr& sym_inst) {
std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph,
const SymbolicKeyInstancePtr &sym_inst) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(sym_inst);
AnfNodePtr sym_node = sym_inst->node();
@@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra
return oss.str();
}

std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss;
// output ValueList, ValueTuple
ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
@@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V
return oss.str();
}

std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss;
ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
oss << "{";
bool first_flag = true;
for (const auto& elem : dict->value()) {
for (const auto &elem : dict->value()) {
if (first_flag) {
first_flag = false;
} else {
@@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
return oss.str();
}

std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) {
std::ostringstream oss;

if (check_integrity_) {
@@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr&
return oss.str();
}

std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss;
bool is_null_ptr = (func_graph == nullptr || value == nullptr);
if (is_null_ptr) {
@@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu
}

// this function is used to output node in CNode's inputs
std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
const std::map<AnfNodePtr, int>& apply_map) {
std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int> &apply_map) {
std::ostringstream oss;
if (func_graph == nullptr || node == nullptr) {
return oss.str();
@@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
return oss.str();
}

void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map) {
void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) {
bool first_flag = true;
for (const AnfNodePtr& param : parameters) {
for (const AnfNodePtr &param : parameters) {
if (first_flag) {
first_flag = false;
ofs << " ";
@@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNode
}
}

void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& node) {
void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) {
if (node == nullptr) {
return;
}

// output type of each input argument
auto& inputs = node->inputs();
auto &inputs = node->inputs();
if (inputs.size() > 1) {
ofs << " #(";
for (size_t i = 1; i < inputs.size(); ++i) {
@@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod
ofs << " #scope: " << node->scope()->name();
}

void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes,
const FuncGraphPtr& func_graph) {
void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}

int idx = 1;
std::map<AnfNodePtr, int> apply_map;
for (const AnfNodePtr& node : nodes) {
for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
@@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
}

auto cnode = node->cast<CNodePtr>();
auto& inputs = cnode->inputs();
auto &inputs = cnode->inputs();
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
// non-return node
if (node != func_graph->get_return()) {
@@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
}
}

void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) {
void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun
ofs << "}\n";
}

void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) {
void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt
ofs.close();
}

void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
if (graphs.empty()) {
return;
}
@@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector

param_index = 1;

for (const auto& tagged_graph : graphs) {
for (const auto &tagged_graph : graphs) {
tagged_cnodes_ = tagged_graph.second;
ExportOneFuncGraph(ofs, tagged_graph.first);
tagged_cnodes_.clear();
@@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
}

#ifdef ENABLE_DUMP_IR
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) {
void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap
ChangeFileMode(filename, S_IRUSR);
}

void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
AnfExporter exporter("", false);
ChangeFileMode(filename, S_IRWXU);
exporter.ExportFuncGraph(filename, graphs);
@@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graph
ChangeFileMode(filename, S_IRUSR);
}
#else
void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) {
void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
@@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script.";
}

void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
static bool already_printed = false;
if (already_printed) {
return;
@@ -732,7 +732,7 @@ enum Token : int {
TOK_ERROR // file read error
};

std::map<Token, const char*> token_text = {
std::map<Token, const char *> token_text = {
{TOK_INVALID, "invalid"}, // invalid token
{TOK_LPARENTHESIS, "("}, // ( left parenthesis
{TOK_RPARENTHESIS, ")"}, // ) right parenthesis
@@ -761,14 +761,14 @@ std::map<Token, const char*> token_text = {
class Lexer {
public:
// filename is checked in ImportIR;
explicit Lexer(const char* filename) : fin(filename) {}
explicit Lexer(const char *filename) : fin(filename) {}

~Lexer() {
try {
if (fin.is_open()) {
fin.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file";
} catch (...) {
std::string exName(abi::__cxa_current_exception_type()->name());
@@ -776,7 +776,7 @@ class Lexer {
}
}

bool IsSingleCharToken(char ch, Token* token_ptr) {
bool IsSingleCharToken(char ch, Token *token_ptr) {
// clang-format off
std::unordered_map<char, Token> char_to_token = {
{'(', TOK_LPARENTHESIS},
@@ -806,7 +806,7 @@ class Lexer {
Token GetNextToken() {
#ifdef DEBUG
Token token = GetNextTokenInner();
const char* str = token_text[token];
const char *str = token_text[token];
std::string text = (str == nullptr ? GetTokenText() : str);
MS_LOG(DEBUG) << "------Parse token] " << text;
return token;
@@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE;

class IrParser {
public:
explicit IrParser(const char* filename) : lexer_(filename) {}
explicit IrParser(const char *filename) : lexer_(filename) {}

~IrParser() {}

py::object LoadObject(const std::string& file_name) const {
py::object LoadObject(const std::string &file_name) const {
std::string pkl_path = GetMsIrPath();
py::object default_obj = load_obj(pkl_path + "/" + file_name);
return default_obj;
@@ -1087,7 +1087,7 @@ class IrParser {
MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
}

Token ParseParent(FuncGraphPtr* const parent_ptr) {
Token ParseParent(FuncGraphPtr *const parent_ptr) {
if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
return TOK_ERROR;
}
@@ -1168,7 +1168,7 @@ class IrParser {
return func_graph;
}

FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) {
FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) {
Token tok = lexer_.SkipWhiteToken();
while (tok == TOK_VARIABLE) {
if (ParseStatement(func_graph) == nullptr) {
@@ -1264,56 +1264,56 @@ class IrParser {
return func_graph;
}

void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const {
void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const {
if (ptr == nullptr) {
return;
}
*ptr = dtype;
}

void SetTupleType(TypePtr* ptr) {
void SetTupleType(TypePtr *ptr) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<Tuple>();
}

void SetTupleType(TypePtr* ptr, const TypePtrList& elems) {
void SetTupleType(TypePtr *ptr, const TypePtrList &elems) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<Tuple>(elems);
}

void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector<int>&) {
void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<TensorType>(elem_type);
}

void SetListType(TypePtr* ptr) {
void SetListType(TypePtr *ptr) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<List>();
}

void SetListType(TypePtr* ptr, const TypePtrList& elems) {
void SetListType(TypePtr *ptr, const TypePtrList &elems) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<List>(elems);
}

void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) {
void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<JTagged>(elem);
}

void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const {
void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const {
if (ptr == nullptr) {
return;
}
@@ -1321,45 +1321,45 @@ class IrParser {
}

// void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const {
void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<abstract::AbstractNone>();
}

void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {}
void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {}
void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {}
void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {}

void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) {
void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
if (ptr == nullptr) {
return;
}
// if one of elems is nullptr, just return
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) {
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
return;
}
*ptr = std::make_shared<abstract::AbstractTuple>(elems);
}

void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector<int>& shape) {
void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
}

void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) {
void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
if (ptr == nullptr) {
return;
}
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) {
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
return;
}
*ptr = std::make_shared<abstract::AbstractList>(elems);
}

void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) {
void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) {
if (ptr == nullptr) {
return;
}
@@ -1367,7 +1367,7 @@ class IrParser {
}

template <typename T>
Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) {
Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) {
if (tok != TOK_LBRACKET) {
MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
return tok;
@@ -1415,7 +1415,7 @@ class IrParser {
}

template <typename T>
Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) {
Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
if (tok != TOK_LPARENTHESIS) {
if (ptr != nullptr) {
SetBasicType(ptr, std::make_shared<TensorType>());
@@ -1454,7 +1454,7 @@ class IrParser {
return lexer_.GetNextToken();
}

bool IsNumberType(const std::string& type, TypeId* typeid_ptr) {
bool IsNumberType(const std::string &type, TypeId *typeid_ptr) {
// clang-format off
static std::unordered_map<std::string, TypeId> basic_types = {
{"Bool", kNumberTypeBool},
@@ -1486,7 +1486,7 @@ class IrParser {
}

template <typename T>
void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) {
void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) {
TypePtr dtype = nullptr;

std::unordered_map<int, TypePtr> type_map = {
@@ -1519,7 +1519,7 @@ class IrParser {
}

template <typename T>
Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) {
Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) {
if (type == "NoneType") {
SetBasicType(ptr, std::make_shared<TypeNone>());
return lexer_.GetNextToken();
@@ -1541,7 +1541,7 @@ class IrParser {
}

template <typename T>
Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) {
Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
if (tok != TOK_IDENTIFIER) {
return TOK_ERROR;
}
@@ -1588,11 +1588,11 @@ class IrParser {
}
}

Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) {
Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) {
return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
}

Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) {
Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
Token tok = ParseAttribute(func_graph, prim);
while (tok == TOK_COMMA) {
tok = ParseAttribute(func_graph, prim);
@@ -1603,7 +1603,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) {
Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
Token tok = lexer_.GetNextToken();
if (tok != TOK_IDENTIFIER) {
return TOK_ERROR;
@@ -1670,7 +1670,7 @@ class IrParser {
return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
}

FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr>* const inputs_ptr) {
FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
Token tok = ParseArgument(func_graph, inputs_ptr);
while (tok == TOK_COMMA) {
tok = ParseArgument(func_graph, inputs_ptr);
@@ -1681,9 +1681,9 @@ class IrParser {
return func_graph;
}

AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) {
AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string &param_name) {
while (func_graph != nullptr) {
for (auto& ptr : func_graph->parameters()) {
for (auto &ptr : func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(ptr);
ParameterPtr param = ptr->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
@@ -1701,12 +1701,12 @@ class IrParser {
return nullptr;
}

bool Match(const std::string& str, const std::string& pattern) const {
bool Match(const std::string &str, const std::string &pattern) const {
return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
}

template <typename T, typename V>
Token ParseScalar(ValuePtr* const val_ptr) {
Token ParseScalar(ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_NUMBER) {
return TOK_ERROR;
}
@@ -1725,7 +1725,7 @@ class IrParser {
}

template <typename VT, typename V, typename T>
Token ParseScalar(ValuePtr* const val_ptr, Token tok) {
Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
if (tok != TOK_LPARENTHESIS) {
*val_ptr = std::make_shared<T>();
return tok;
@@ -1735,7 +1735,7 @@ class IrParser {
}

template <typename VT, typename V, typename T, const unsigned nbits>
Token ParseScalar(ValuePtr* const val_ptr, Token tok) {
Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
if (tok != TOK_LPARENTHESIS) {
*val_ptr = std::make_shared<T>(nbits);
return tok;
@@ -1745,7 +1745,7 @@ class IrParser {
}

template <typename T>
T StringToScalar(const std::string& text) {
T StringToScalar(const std::string &text) {
std::stringstream ss;
T value;
ss << text;
@@ -1753,7 +1753,7 @@ class IrParser {
return value;
}

Token ParseTensor(ValuePtr* const val_ptr) {
Token ParseTensor(ValuePtr *const val_ptr) {
// parse type
TypeId type;
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
@@ -1803,7 +1803,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParsePrimType(Token tok, PrimType* prim_type_ptr) {
Token ParsePrimType(Token tok, PrimType *prim_type_ptr) {
if (tok != TOK_LBRACE) {
return tok;
}
@@ -1830,7 +1830,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) {
Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
if (tok != TOK_LPARENTHESIS) {
return TOK_ERROR;
}
@@ -1855,7 +1855,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) {
Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
if (tok != TOK_LBRACE) {
return tok;
}
@@ -1868,7 +1868,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParseBoolValue(const std::string& key, bool* val_ptr) {
Token ParseBoolValue(const std::string &key, bool *val_ptr) {
if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
return TOK_ERROR;
}
@@ -1892,7 +1892,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) {
Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_LBRACE) {
return TOK_ERROR;
}
@@ -1920,7 +1920,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) {
Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) {
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
return TOK_ERROR;
}
@@ -1951,7 +1951,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) {
Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_AT_FILE) {
return TOK_ERROR;
}
@@ -1984,7 +1984,7 @@ class IrParser {
return next;
}

Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) {
Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) {
if (Match(id, "MultitypeFuncGraph::")) {
std::string name = id.substr(strlen("MultitypeFuncGraph::"));
auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
@@ -2024,8 +2024,8 @@ class IrParser {
}
}

Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr,
AnfNodePtr* const node_ptr = nullptr) {
Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr,
AnfNodePtr *const node_ptr = nullptr) {
if (id == "None") {
*val_ptr = std::make_shared<None>();
return lexer_.GetNextToken();
@@ -2075,9 +2075,9 @@ class IrParser {
}
}

Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid,
const std::vector<ValuePtr>& elems, const std::vector<AnfNodePtr>& nodes,
ValuePtr* const val_ptr, AnfNodePtr* node_ptr) {
Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid,
const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes,
ValuePtr *const val_ptr, AnfNodePtr *node_ptr) {
if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
if (node_is_valid && node_ptr != nullptr) {
MS_EXCEPTION_IF_NULL(func_graph);
@@ -2097,8 +2097,8 @@ class IrParser {
}
}

Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr,
AnfNodePtr* node_ptr = nullptr) {
Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr,
AnfNodePtr *node_ptr = nullptr) {
Token left_tok = tok;

std::vector<ValuePtr> elems;
@@ -2138,7 +2138,7 @@ class IrParser {
return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
}

Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) {
Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) {
// tuple or list
if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
@@ -2152,7 +2152,7 @@ class IrParser {
return TOK_ERROR;
}

Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr,
Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr,
Token tok = TOK_INVALID) {
if (tok == TOK_INVALID) {
tok = lexer_.GetNextToken();
@@ -2193,7 +2193,7 @@ class IrParser {
return lexer_.GetNextToken();
}

Token ParseArgument(const FuncGraphPtr& func_graph, std::vector<AnfNodePtr>* const inputs_ptr) {
Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
Token tok = lexer_.GetNextToken();
if (tok == TOK_RPARENTHESIS) {
return tok;
@@ -2208,7 +2208,7 @@ class IrParser {
return tok;
}

const std::vector<FuncGraphPtr>& GetFuncGraphs() const { return func_graphs_; }
const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; }

private:
Lexer lexer_;
@@ -2226,14 +2226,14 @@ class IrParser {
std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
};

std::vector<FuncGraphPtr> ImportIR(const std::string& filename) {
std::vector<FuncGraphPtr> ImportIR(const std::string &filename) {
IrParser parser(filename.c_str());
parser.ParseFile();
return parser.GetFuncGraphs();
}

#ifdef ENABLE_DUMP_IR
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func graph is nullptr";
return;
@@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
return;
}
char real_path[PATH_MAX] = {0};
char* real_path_ret = nullptr;
char *real_path_ret = nullptr;
#if defined(_WIN32) || defined(_WIN64)
real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX);
#else
@@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
ChangeFileMode(file_path, S_IRUSR);
}
#else
void DumpIRProto(const FuncGraphPtr&, const std::string&) {
void DumpIRProto(const FuncGraphPtr &, const std::string &) {
static bool already_printed = false;
if (already_printed) {
return;


+ 33
- 33
mindspore/ccsrc/debug/anf_ir_utils.h View File

@@ -39,7 +39,7 @@
namespace mindspore {

struct ParamPtrEqual {
bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const {
bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const {
const ParameterPtr param1 = dyn_cast<Parameter>(t1);
const ParameterPtr param2 = dyn_cast<Parameter>(t2);

@@ -52,7 +52,7 @@ struct ParamPtrEqual {
};

struct ParamPtrHasher {
std::size_t operator()(AnfNodePtr const& param) const {
std::size_t operator()(AnfNodePtr const &param) const {
const ParameterPtr parameter = dyn_cast<Parameter>(param);
if (parameter == nullptr) {
return 0;
@@ -64,39 +64,39 @@ struct ParamPtrHasher {

class AnfExporter {
public:
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false)
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
func_graph_set.clear();
exported.clear();
}
virtual ~AnfExporter() {}

void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs);

protected:
virtual std::string GetNodeType(const AnfNodePtr& nd);
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr& param);
std::string DumpObject(const py::object& obj, const std::string& category) const;
std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst);
std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetPrimitiveText(const PrimitivePtr& prim);
std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetNameSpaceText(const parse::NameSpacePtr& ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
const std::map<AnfNodePtr, int>& apply_map);
void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph);
void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map);
void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node);
void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph);
virtual std::string GetNodeType(const AnfNodePtr &nd);
int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr &param);
std::string DumpObject(const py::object &obj, const std::string &category) const;
std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst);
std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetPrimitiveText(const PrimitivePtr &prim);
std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetNameSpaceText(const parse::NameSpacePtr &ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int> &apply_map);
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map);
void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);

int param_index;
OrderedSet<FuncGraphPtr> func_graph_set{};
@@ -108,16 +108,16 @@ class AnfExporter {
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
};

void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs);
void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph);
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs);

std::vector<FuncGraphPtr> ImportIR(const std::string& filename);
std::vector<FuncGraphPtr> ImportIR(const std::string &filename);

std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);

void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);

std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
} // namespace mindspore

#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_

+ 32
- 32
mindspore/ccsrc/debug/draw.cc View File

@@ -34,7 +34,7 @@ namespace draw {

namespace {
// Only for ValueNode
std::string ValueType(const ValueNodePtr& node) {
std::string ValueType(const ValueNodePtr &node) {
if (node == nullptr) {
return "";
}
@@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) {
return v->type_name();
}

std::string ReplaceSpecialChar(const std::string& str) {
std::string ReplaceSpecialChar(const std::string &str) {
std::ostringstream oss;
for (size_t i = 0; i < str.size(); i++) {
if (str[i] == '<') {
@@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) {
} // namespace

// API of debug utils
void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs,
void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
bool is_user) {
if (sub_graphs == nullptr) {
return;
}
for (auto& nd : nodes) {
for (auto &nd : nodes) {
MS_EXCEPTION_IF_NULL(nd);
auto sub_graph = nd->func_graph();
if (sub_graph != nullptr) {
@@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st
}
}

void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) {
void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
if (sub_graphs == nullptr) {
return;
}

int dup_idx = 0;

for (auto& nd : nodes) {
for (auto& t : SuccIncoming(nd)) {
for (auto &nd : nodes) {
for (auto &t : SuccIncoming(nd)) {
MS_EXCEPTION_IF_NULL(t);
MS_EXCEPTION_IF_NULL(nd);
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
@@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
}
}

void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) {
void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
if (digraph == nullptr) {
return;
}
@@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
}

// Draw edge
for (auto& nd : nodes) {
for (auto &nd : nodes) {
auto succs = SuccIncoming(nd);
auto num = succs.size();
for (size_t i = 0; i < num; i++) {
auto& t = succs.at(i);
auto &t = succs.at(i);
MS_EXCEPTION_IF_NULL(t);
if (t->isa<ValueNode>() || t->isa<Parameter>()) {
if ((!is_user) || (i != 0)) {
@@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
}
}

void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) {
void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) {
if (func_graph == nullptr) {
return;
}
@@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
DrawValueNodes(nodes, &sub_graphs);

// Draw subgraph
for (const auto& gsub : sub_graphs) {
for (const auto &gsub : sub_graphs) {
digraph->SubGraph(gsub.first, gsub.second);
}

@@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
}

#ifdef ENABLE_DUMP_IR
void Draw(const std::string& filename, const FuncGraphPtr& func_graph) {
void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
const std::string dot_suffix = ".dot";
std::string filename_with_suffix =
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
DrawByOpt(filename_with_suffix, func_graph, false);
}

void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) {
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
DrawByOpt(filename, func_graph, true);
}
#else
void Draw(const std::string&, const FuncGraphPtr&) {
void Draw(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
@@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script.";
}

void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) {
void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
@@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) {
return "plaintext";
}

std::string Graphviz::Color(const AnfNodePtr& node) {
std::string Graphviz::Color(const AnfNodePtr &node) {
if (node == nullptr) {
return "";
}
@@ -259,7 +259,7 @@ void BaseDigraph::Start() {
buffer_ << "compound=true" << std::endl;
}

void BaseDigraph::Head(const AnfNodePtr& node, int id) {
void BaseDigraph::Head(const AnfNodePtr &node, int id) {
if (node == nullptr) {
return;
}
@@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) {
}
}

void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
if (node == nullptr) {
return;
}
@@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
buffer_ << ":" << idx;
}

void BaseDigraph::Tail(const FuncGraphPtr& func_graph) {
void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@@ -304,12 +304,12 @@ void BaseDigraph::End() {
}
}

void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << "parameters_" << key << "[shape=plaintext ";
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
buffer_ << "<tr><td>parameters</td></tr>";
int count = 0;
for (auto& parameter : key->parameters()) {
for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>";
buffer_ << parameter->ToString();
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
@@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
buffer_ << "</table>>,];";
}

void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) {
void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
if (key == nullptr || gsub == nullptr) {
return;
}
@@ -361,12 +361,12 @@ Digraph::~Digraph() {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file " << filename_;
}
}

static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) {
static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
(void)str.replace(start_pos, from.length(), to);
@@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st
return str;
}

static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph_obj);
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
<< "'>";
@@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td align='left'>";
int i = 0;
for (const auto& attr : attrs) {
for (const auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
@@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</table>>,";
}

static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) {
return;
}
@@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
}
}

static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
return;
}
@@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
}
graph_obj->buffer() << ">";
int i = 0;
for (auto& attr : attrs) {
for (auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
@@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "exception when closing file " << filename_;
}
}


+ 18
- 18
mindspore/ccsrc/debug/draw.h View File

@@ -31,9 +31,9 @@ namespace parse = mindspore::parse;

class Graphviz {
public:
Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {}
Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {}

explicit Graphviz(const std::string& name) : name_(name) {}
explicit Graphviz(const std::string &name) : name_(name) {}

virtual ~Graphviz() {}

@@ -41,8 +41,8 @@ class Graphviz {
virtual void End() {}

virtual std::string Shape(AnfNodePtr node);
std::string Color(const AnfNodePtr& node);
std::ostringstream& buffer() { return buffer_; }
std::string Color(const AnfNodePtr &node);
std::ostringstream &buffer() { return buffer_; }
std::ostringstream buffer_;

protected:
@@ -53,8 +53,8 @@ class Graphviz {

class BaseDigraph : public Graphviz {
public:
BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string& name) : Graphviz(name) {}
BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string &name) : Graphviz(name) {}
~BaseDigraph() override = default;

virtual void Node(AnfNodePtr node, int id = 0) = 0;
@@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz {
void Start() override;
void End() override;
virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start);
void FuncGraphParameters(const FuncGraphPtr& key);
void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub);
void FuncGraphParameters(const FuncGraphPtr &key);
void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub);

const std::string& name() const { return name_; }
const std::string &name() const { return name_; }

protected:
void Head(const AnfNodePtr& node, int id = 0);
void Tail(const AnfNodePtr& node, int idx, int id = 0);
void Tail(const FuncGraphPtr& func_graph);
void Head(const AnfNodePtr &node, int id = 0);
void Tail(const AnfNodePtr &node, int idx, int id = 0);
void Tail(const FuncGraphPtr &func_graph);
};

class Digraph : public BaseDigraph {
public:
Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string& name) : BaseDigraph(name) {}
Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string &name) : BaseDigraph(name) {}
~Digraph() override;

void Node(AnfNodePtr node, int id = 0) override;
@@ -86,8 +86,8 @@ class Digraph : public BaseDigraph {

class ModelDigraph : public BaseDigraph {
public:
ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {}
ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {}
~ModelDigraph() override;

std::string Shape(AnfNodePtr node) override;
@@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph {
};

// API to draw
void Draw(const std::string& filename, const FuncGraphPtr& func_graph);
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void Draw(const std::string &filename, const FuncGraphPtr &func_graph);
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);

} // namespace draw
} // namespace mindspore


+ 85
- 85
mindspore/ccsrc/debug/dump_proto.cc View File

@@ -33,38 +33,38 @@ class ProtoExporter {
ProtoExporter() {}
~ProtoExporter() {}

std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);

private:
void InitModelInfo();
void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto);
std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
const std::map<AnfNodePtr, size_t>& apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr);
void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto);
void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto);
void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto);
void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto);
void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto);
void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto);
void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto);
void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto);
void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto,
std::map<AnfNodePtr, size_t>* const_map_ptr);
void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* apply_map_ptr,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto);
void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node,
const std::map<AnfNodePtr, size_t>& apply_map, std::map<AnfNodePtr, size_t>* const_map_ptr,
irpb::GraphProto* graph_proto);
void ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto);
void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto);
std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t> *const_map_ptr);
void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto);
void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto);
void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto);
void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto);
void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto);
void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
std::map<AnfNodePtr, size_t> *const_map_ptr);
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto);
void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
irpb::GraphProto *graph_proto);
void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto);

static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }

irpb::ModelProto model_;
};

static irpb::DataType GetNumberDataType(const TypePtr& type) {
static irpb::DataType GetNumberDataType(const TypePtr &type) {
switch (type->type_id()) {
case kNumberTypeBool:
return irpb::DT_BOOL;
@@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) {
}
}

void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) {
void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
if (type_proto == nullptr) {
return;
}
@@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
type_proto->set_data_type(irpb::DT_TENSOR);
if (shape != nullptr && shape->isa<abstract::Shape>()) {
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
for (const auto& elem : shape_info->shape()) {
for (const auto &elem : shape_info->shape()) {
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
}
}
} else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE);
for (const auto& elem_type : tuple_type->elements()) {
for (const auto &elem_type : tuple_type->elements()) {
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
}
} else if (type->isa<TypeType>()) {
@@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
} else if (type->isa<List>()) {
ListPtr list_type = dyn_cast<List>(type);
type_proto->set_data_type(irpb::DT_LIST);
for (const auto& elem_type : list_type->elements()) {
for (const auto &elem_type : list_type->elements()) {
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
}
} else if (type->isa<TypeAnything>()) {
@@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
}
}

void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) {
void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) {
if (node == nullptr || type_proto == nullptr) {
return;
}
SetNodeOutputType(node->Type(), node->Shape(), type_proto);
}

void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}

if (val->isa<StringImm>()) {
const StringImmPtr& value = dyn_cast<StringImm>(val);
const StringImmPtr &value = dyn_cast<StringImm>(val);
value_proto->set_dtype(irpb::DT_STRING);
value_proto->set_str_val(value->value());
} else if (val->isa<Scalar>()) {
@@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
} else if (val->isa<tensor::Tensor>()) {
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
value_proto->set_dtype(irpb::DT_TENSOR);
irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val();
irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
for (auto& elem : tensor_ptr->shape()) {
for (auto &elem : tensor_ptr->shape()) {
tensor_proto->add_dims(elem);
}
} else if (val->isa<TensorType>()) {
value_proto->set_dtype(irpb::DT_TYPE);

irpb::TypeProto* type_proto = value_proto->mutable_type_val();
irpb::TypeProto *type_proto = value_proto->mutable_type_val();
type_proto->set_data_type(irpb::DT_TENSOR);
TypePtr elem_type = dyn_cast<TensorType>(val)->element();
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
@@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
}
}

void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}

if (val->isa<BoolImm>()) {
const BoolImmPtr& value = dyn_cast<BoolImm>(val);
const BoolImmPtr &value = dyn_cast<BoolImm>(val);
value_proto->set_dtype(irpb::DT_BOOL);
value_proto->set_bool_val(value->value());
} else if (val->isa<Int8Imm>()) {
const Int8ImmPtr& value = dyn_cast<Int8Imm>(val);
const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
value_proto->set_dtype(irpb::DT_INT8);
value_proto->set_int_val(value->value());
} else if (val->isa<Int16Imm>()) {
const Int16ImmPtr& value = dyn_cast<Int16Imm>(val);
const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
value_proto->set_dtype(irpb::DT_INT16);
value_proto->set_int_val(value->value());
} else if (val->isa<Int32Imm>()) {
const Int32ImmPtr& value = dyn_cast<Int32Imm>(val);
const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
value_proto->set_dtype(irpb::DT_INT32);
value_proto->set_int_val(value->value());
} else if (val->isa<Int64Imm>()) {
const Int64ImmPtr& value = dyn_cast<Int64Imm>(val);
const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
value_proto->set_dtype(irpb::DT_INT64);
value_proto->set_int_val(value->value());
} else if (val->isa<UInt8Imm>()) {
const UInt8ImmPtr& value = dyn_cast<UInt8Imm>(val);
const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
value_proto->set_dtype(irpb::DT_UINT8);
value_proto->set_uint_val(value->value());
} else if (val->isa<UInt16Imm>()) {
const UInt16ImmPtr& value = dyn_cast<UInt16Imm>(val);
const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
value_proto->set_dtype(irpb::DT_UINT16);
value_proto->set_uint_val(value->value());
} else if (val->isa<UInt32Imm>()) {
const UInt32ImmPtr& value = dyn_cast<UInt32Imm>(val);
const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
value_proto->set_dtype(irpb::DT_UINT32);
value_proto->set_uint_val(value->value());
} else if (val->isa<UInt64Imm>()) {
const UInt64ImmPtr& value = dyn_cast<UInt64Imm>(val);
const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
value_proto->set_dtype(irpb::DT_UINT64);
value_proto->set_uint_val(value->value());
} else if (val->isa<FP32Imm>()) {
const FP32ImmPtr& value = dyn_cast<FP32Imm>(val);
const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
value_proto->set_dtype(irpb::DT_FLOAT32);
value_proto->set_float_val(value->value());
} else if (val->isa<FP64Imm>()) {
const FP64ImmPtr& value = dyn_cast<FP64Imm>(val);
const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
value_proto->set_dtype(irpb::DT_FLOAT64);
value_proto->set_double_val(value->value());
} else {
@@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val
}
}

void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}

if (val->isa<ValueTuple>()) {
const ValueTuplePtr& value = dyn_cast<ValueTuple>(val);
const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
value_proto->set_dtype(irpb::DT_TUPLE);
for (const auto& item : value->value()) {
for (const auto &item : value->value()) {
SetValueToProto(item, value_proto->add_values());
}
} else if (val->isa<ValueList>()) {
const ValueListPtr& value = dyn_cast<ValueList>(val);
const ValueListPtr &value = dyn_cast<ValueList>(val);
value_proto->set_dtype(irpb::DT_LIST);
for (const auto& item : value->value()) {
for (const auto &item : value->value()) {
SetValueToProto(item, value_proto->add_values());
}
}
}

void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}

value_proto->set_dtype(irpb::DT_DICT);
for (const auto& item : val->value()) {
irpb::NamedValueProto* named_val = value_proto->add_dict_val();
for (const auto &item : val->value()) {
irpb::NamedValueProto *named_val = value_proto->add_dict_val();
named_val->set_key(item.first);
SetValueToProto(item.second, named_val->mutable_value());
}
}

void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) {
void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) {
if (node == nullptr || node_proto == nullptr) {
return;
}
@@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr&
MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
}

const PrimitivePtr& prim = GetValueNode<PrimitivePtr>(node);
const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
node_proto->set_op_type(prim->name());
for (const auto& attr : prim->attrs()) {
irpb::AttributeProto* attr_proto = node_proto->add_attribute();
for (const auto &attr : prim->attrs()) {
irpb::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first);
SetValueToProto(attr.second, attr_proto->mutable_value());
}
node_proto->set_scope(node->scope()->name());
}

std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node,
const std::map<AnfNodePtr, size_t>& apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr) {
std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t> *const_map_ptr) {
if (node == nullptr || const_map_ptr == nullptr) {
return "";
}
@@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt
MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
}

std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) {
std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return "";
}

InitModelInfo();
irpb::GraphProto* graph_proto = model_.mutable_graph();
irpb::GraphProto *graph_proto = model_.mutable_graph();
ExportFuncGraph(func_graph, graph_proto);
return model_.SerializeAsString();
}

void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || graph_proto == nullptr) {
return;
}
@@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP
ExportValueNodes(const_map, graph_proto);
}

void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || graph_proto == nullptr) {
return;
}

std::vector<AnfNodePtr> parameters = func_graph->parameters();
for (auto& param : parameters) {
irpb::ParameterProto* param_proto = graph_proto->add_parameters();
for (auto &param : parameters) {
irpb::ParameterProto *param_proto = graph_proto->add_parameters();
param_proto->set_name(param->ToString());

SetNodeOutputType(param, param_proto->mutable_type());
@@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph
}
}

void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto,
std::map<AnfNodePtr, size_t>* const_map_ptr) {
void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
std::map<AnfNodePtr, size_t> *const_map_ptr) {
if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
return;
}
// topo sort nodes
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
std::map<AnfNodePtr, size_t> apply_map;
for (const AnfNodePtr& node : nodes) {
for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
@@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt
}
}

void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* apply_map_ptr,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *apply_map_ptr,
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
graph_proto == nullptr) {
return;
@@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
auto apply_idx = apply_map_ptr->size() + 1;
(*apply_map_ptr)[node] = apply_idx;

auto& inputs = node->inputs();
auto &inputs = node->inputs();
if (inputs.size() < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr op = inputs[0];
irpb::NodeProto* node_proto = graph_proto->add_node();
irpb::NodeProto *node_proto = graph_proto->add_node();

// CNode/ConstGraph/Const/Parameter
if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
@@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&

// process OP inputs
for (size_t i = 1; i < inputs.size(); ++i) {
irpb::InputProto* input_proto = node_proto->add_input();
irpb::InputProto *input_proto = node_proto->add_input();
input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
input_proto->set_name(id);
@@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
}
}

void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node,
const std::map<AnfNodePtr, size_t>& apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
if (ret_node == nullptr || !ret_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Graph return node is illegal";
}
@@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
if (graph_proto == nullptr) {
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
}
irpb::OutputProto* output_proto = graph_proto->add_outputs();
irpb::OutputProto *output_proto = graph_proto->add_outputs();
if (output_proto == nullptr) {
MS_LOG(EXCEPTION) << "output_proto is nullptr";
}
@@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
SetNodeOutputType(arg, output_proto->mutable_type());
}

static bool CompareValue(const std::pair<AnfNodePtr, size_t>& x, const std::pair<AnfNodePtr, size_t>& y) {
static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
return x.second < y.second;
}

void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) {
std::vector<std::pair<AnfNodePtr, size_t>> nodes;
(void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
[](const std::pair<AnfNodePtr, size_t>& item) { return item; });
[](const std::pair<AnfNodePtr, size_t> &item) { return item; });

sort(nodes.begin(), nodes.end(), CompareValue);

for (auto& item : nodes) {
for (auto &item : nodes) {
if (graph_proto == nullptr) {
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
}
irpb::NamedValueProto* named_value = graph_proto->add_const_vals();
irpb::NamedValueProto *named_value = graph_proto->add_const_vals();
MS_EXCEPTION_IF_NULL(named_value);
named_value->set_key(GetConstNodeId(item.second));
SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
@@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_m

void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); }

std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) {
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
ProtoExporter exporter;
return exporter.GetFuncGraphProtoString(func_graph);
}


+ 10
- 10
mindspore/ccsrc/debug/e2e_dump.cc View File

@@ -36,7 +36,7 @@ Dump::Dump()
dump_iter_(0),
cur_iter_(0) {}

bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
bool Dump::IsKernelNeedDump(const std::string &kernel_name) {
if (dump_mode_ == 0) {
// Dump All Kernels mode
return true;
@@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
return false;
}

bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
bool Dump::ParseDumpConfig(const std::string &dump_config_file) {
std::ifstream jsonFile(dump_config_file);
if (!jsonFile.is_open()) {
MS_LOG(ERROR) << dump_config_file << " open failed.";
@@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
return true;
}

bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) {
if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() ||
dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() ||
dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() ||
@@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
return true;
}

bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto trans_flag = dumpSettings.at("trans_flag");
auto enable = dumpSettings.at("enable");
auto mode = dumpSettings.at("mode");
@@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
dump_path_ = path;
dump_net_name_ = net_name;
dump_iter_ = iteration;
for (const auto& kernel : kernels) {
for (const auto &kernel : kernels) {
dump_kernels_.push_back(kernel);
}
return true;
}

bool Dump::SetDumpConfFromJsonFile() {
const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
if (config_path_str != nullptr) {
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
} else {
@@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() {
return ParseDumpConfig(dump_config_file);
}

bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) {
bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) {
if (filename.empty() || data == nullptr || len == 0) {
MS_LOG(ERROR) << "Incorrect parameter.";
return false;
@@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len)
MS_LOG(ERROR) << "Open file " << realpath << " fail.";
return false;
}
(void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len));
(void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len));
fd.close();
return true;
}

bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) {
MS_EXCEPTION_IF_NULL(outpath);
auto path_split_pos = inpath.find_last_of('/');
if (path_split_pos == std::string::npos) {
@@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
return true;
}

bool Dump::CreateNotExistDirs(const std::string& path) {
bool Dump::CreateNotExistDirs(const std::string &path) {
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
char temp_path[PATH_MAX] = {0};


+ 7
- 7
mindspore/ccsrc/debug/e2e_dump.h View File

@@ -43,11 +43,11 @@ class Dump {

uint32_t cur_iter() const { return cur_iter_; }

bool IsKernelNeedDump(const std::string& kernel_name);
bool IsKernelNeedDump(const std::string &kernel_name);

bool SetDumpConfFromJsonFile();

static bool DumpToFile(const std::string& filename, const void* data, size_t len);
static bool DumpToFile(const std::string &filename, const void *data, size_t len);

protected:
bool dump_enable_;
@@ -59,14 +59,14 @@ class Dump {
uint32_t cur_iter_;
std::vector<std::string> dump_kernels_;

static bool GetRealPath(const std::string& inpath, std::string* outpath);
static bool GetRealPath(const std::string &inpath, std::string *outpath);

static bool CreateNotExistDirs(const std::string& path);
static bool CreateNotExistDirs(const std::string &path);

private:
bool ParseDumpConfig(const std::string& dump_config_file);
bool IsConfigExist(const nlohmann::json& dumpSettings);
bool IsConfigValid(const nlohmann::json& dumpSettings);
bool ParseDumpConfig(const std::string &dump_config_file);
bool IsConfigExist(const nlohmann::json &dumpSettings);
bool IsConfigValid(const nlohmann::json &dumpSettings);
};

using DumpConfPtr = std::shared_ptr<Dump>;


+ 9
- 9
mindspore/ccsrc/debug/info.cc View File

@@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h"

namespace mindspore {
std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) {
std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {
std::string temp_line = line;
if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) &&
tip != kSourceLineTipDiscard) {
@@ -101,14 +101,14 @@ DebugInfo::DebugInfo() {
name_ = "";
}

DebugInfo::DebugInfo(const std::string& name) {
DebugInfo::DebugInfo(const std::string &name) {
InitValueFromContext();
unique_id_ = gen_unique_id();
debug_id_ = -1;
name_ = name;
}

DebugInfo::DebugInfo(const LocationPtr& loc) {
DebugInfo::DebugInfo(const LocationPtr &loc) {
InitValueFromContext();
unique_id_ = gen_unique_id();
debug_id_ = -1;
@@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() {
}

int64_t DebugInfo::unique_id_through_copy() const {
TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info();
TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info();
if (trace_info != nullptr) {
if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) {
return trace_info->debug_info()->unique_id_through_copy();
@@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() {
}
return DebugInfo::location();
}
void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; }
void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; }

TraceContextPtr TraceManager::CurrentContextInfo() {
if (!TraceManager::trace_context_stack_.empty()) {
@@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() {
return nullptr;
}

void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) {
void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location);
context->set_func_name(func_name);
TraceManager::trace_context_stack_.push(context);
}

void TraceManager::DebugTrace(const LocationPtr& location) {
void TraceManager::DebugTrace(const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location);
TraceManager::trace_context_stack_.push(context);
}

void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
}
@@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
TraceManager::trace_context_stack_.push(context);
}

void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) {
void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
}


+ 26
- 26
mindspore/ccsrc/debug/info.h View File

@@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou
// Location class record the location in source code.
class Location {
public:
Location(const std::string& file_name, int line, int column, int line_end, int column_end)
Location(const std::string &file_name, int line, int column, int line_end, int column_end)
: file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
Location(const Location& loc)
Location(const Location &loc)
: file_name_(loc.file_name_),
line_(loc.line_),
column_(loc.column_),
@@ -77,21 +77,21 @@ class TraceManager {
TraceManager() = default;
~TraceManager() = default;
static TraceContextPtr CurrentContextInfo();
static void DebugTrace(const std::string& func_name, const LocationPtr& location);
static void DebugTrace(const LocationPtr& location);
static void DebugTrace(const TraceInfoPtr& trace_info);
static void DebugTrace(const std::string &func_name, const LocationPtr &location);
static void DebugTrace(const LocationPtr &location);
static void DebugTrace(const TraceInfoPtr &trace_info);
// debug trace with a cloned trace info with debug_info
static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info);
static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info);
static void EndTrace();
static std::stack<TraceContextPtr> trace_context_stack_;
};

class TraceGuard {
public:
explicit TraceGuard(const std::string func_name, const LocationPtr& location) {
explicit TraceGuard(const std::string func_name, const LocationPtr &location) {
TraceManager::DebugTrace(func_name, location);
}
explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); }
explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
~TraceGuard() { TraceManager::EndTrace(); }
};

@@ -106,23 +106,23 @@ class TraceContext {

public:
~TraceContext() = default;
explicit TraceContext(const LocationPtr& loc) {
explicit TraceContext(const LocationPtr &loc) {
ProcessAttributeFromContext();
location_ = loc;
}
explicit TraceContext(const std::string& func_name) {
explicit TraceContext(const std::string &func_name) {
ProcessAttributeFromContext();
func_name_ = func_name;
}
explicit TraceContext(const TraceInfoPtr& trace_info) {
explicit TraceContext(const TraceInfoPtr &trace_info) {
ProcessAttributeFromContext();
trace_info_ = trace_info;
}
void set_location(const LocationPtr& loc) { location_ = loc; }
void set_location(const LocationPtr &loc) { location_ = loc; }
LocationPtr location() { return location_; }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; }
void set_func_name(const std::string& func_name) { func_name_ = func_name; }
void set_func_name(const std::string &func_name) { func_name_ = func_name; }
std::string func_name() { return func_name_; }
};

@@ -130,9 +130,9 @@ class DebugInfo : public Base {
public:
DebugInfo();

explicit DebugInfo(const std::string& name);
explicit DebugInfo(const std::string &name);

explicit DebugInfo(const LocationPtr& loc);
explicit DebugInfo(const LocationPtr &loc);

virtual ~DebugInfo() = default;
MS_DECLARE_PARENT(DebugInfo, Base);
@@ -141,12 +141,12 @@ class DebugInfo : public Base {
int64_t unique_id_through_copy() const;
std::string get_id() { return std::to_string(debug_id()); }

void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; }
void set_location(const LocationPtr& loc) { location_ = loc; }
void set_location(const LocationPtr &loc) { location_ = loc; }
virtual LocationPtr location() { return location_; }
std::string name() { return name_; }
void set_name(const std::string& name) { name_ = name; }
void set_name(const std::string &name) { name_ = name; }
virtual std::string debug_name();

virtual std::string get_python_func_belonged() { return ""; }
@@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo {
py_func_belonged_ = context_info->func_name();
}
}
explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) {
explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo();
py_func_belonged_ = context_info->func_name();
@@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo {
~NodeDebugInfo() override = default;

std::string debug_name() override;
void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); }
void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); }
std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; }
void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; }
std::string get_python_func_belonged() override { return py_func_belonged_; }
AnfNodeWeakPtr node_;
std::string py_func_belonged_;
@@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo {
}
}

explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) {
explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo();
py_func_name_ = context_info->func_name();
@@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo {
std::string debug_name() override;
LocationPtr location() override;
LocationPtr deco_location() { return deco_loc_; }
void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
FuncGraphPtr get_graph() const { return func_graph_.lock(); }
void set_full_name(const std::string& name) { full_name_ = name; }
void set_full_name(const std::string &name) { full_name_ = name; }
std::string get_full_name() { return full_name_; }
void set_deco_location(const LocationPtr& deco_list_loc);
void set_deco_location(const LocationPtr &deco_list_loc);
std::string get_python_func_belonged() override { return py_func_name_; }
FuncGraphWeakPtr func_graph_;
LocationPtr deco_loc_;


+ 8
- 8
mindspore/ccsrc/debug/label.cc View File

@@ -31,7 +31,7 @@ struct NameWithTrace {
std::string name;
std::vector<std::string> trace_labels;
};
static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) {
static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
switch (trace_label) {
case TraceLabelType::kShortSymbol:
return trace_info->symbol();
@@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t
}
}

NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name;
// find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
auto temp_info = debug_info;
@@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe
return trace_name;
}

std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) {
std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) {
std::string tags = "";
for (auto& itr : trace_labels) {
for (auto &itr : trace_labels) {
std::string symbol = itr;
tags = tags + symbol;
}
@@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st
}

// get the label name of the node debug info
std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name = RootName(debug_info, trace_label);
return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
}

std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
auto temp_info = debug_info;
std::string label = "";
while (temp_info != nullptr) {
@@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
}

// get trace with unique id chain
std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); }
std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }

std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
return LabelStringUnique(debug_info);
}


+ 1
- 1
mindspore/ccsrc/debug/label.h View File

@@ -29,7 +29,7 @@ namespace label_manage {
enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId };
TraceLabelType GetGlobalTraceLabelType();
void SetGlobalTraceLabelType(TraceLabelType label_type);
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
} // namespace label_manage
} // namespace mindspore



+ 24
- 24
mindspore/ccsrc/debug/trace.cc View File

@@ -37,7 +37,7 @@
namespace mindspore {
// namespace to support debug trace infomation
namespace trace {
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) {
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) {
if (abs == nullptr) {
return "Null Abstract";
}
@@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) {
return debug_with_loc_vec;
}

DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) {
auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info);
if (debug_with_loc_vec.size() > 0) {
return debug_with_loc_vec[0];
@@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
}
}

std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) {
return "";
}
@@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {

// a trace info identifies a node transform, so we can trace the node transform through
// a link of trace info and debug info
std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) {
std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) {
if (info_vec.size() < 1) {
return "";
}
@@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL
return traced_info;
}

std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) {
return "";
}
@@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
return "";
}

std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) {
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) {
std::ostringstream oss;
if (info == nullptr) {
return "";
@@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So
return oss.str();
}

std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) {
std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) {
std::ostringstream oss;
oss << "graph:" << graph->ToString() << " with args[";
auto params = graph->parameters();
@@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas
return oss.str();
}

void DumpInferStack(std::ostringstream& oss) {
auto& infer_stack = GetCurrenGraphInferStack();
void DumpInferStack(std::ostringstream &oss) {
auto &infer_stack = GetCurrenGraphInferStack();
if (infer_stack.empty()) {
return;
}
@@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) {
}
std::reverse(infer_vec.begin(), infer_vec.end());
int index = 0;
for (auto& item : infer_vec) {
for (auto &item : infer_vec) {
auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first);
if (graph_infer == nullptr) {
MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator";
@@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) {
}

void TraceGraphInfer() {
auto& infer_stack = GetCurrenGraphInferStack();
auto &infer_stack = GetCurrenGraphInferStack();
std::ostringstream oss;
if (infer_stack.empty()) {
return;
@@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
~AnalyzedFuncGraphExporter() override = default;

void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs);

private:
std::string GetNodeType(const AnfNodePtr& nd) override;
std::string GetNodeType(const AnfNodePtr &nd) override;
};

std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
auto& list = GetCNodeDebugStack();
auto &list = GetCNodeDebugStack();
for (size_t i = 0; i < list.size(); ++i) {
auto node_cfg = list[i];
auto fg = node_cfg->context()->func_graph();
@@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() {
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
}

std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
if (node_cfg_ == nullptr) {
return AnfExporter::GetNodeType(node);
}
@@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
return oss.str();
}

void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) {
if (node_cfgs.empty()) {
MS_LOG(DEBUG) << "Node configs is empty";
return;
@@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
auto tagged_func_graphs = CalcTaggedFuncGraphs();

// first output graph on the analysis stack
for (const auto& node_cfg : node_cfgs) {
for (const auto &node_cfg : node_cfgs) {
auto fg = node_cfg->context()->func_graph();
// the graph is already output, skip it
if (exported.find(fg) != exported.end()) {
@@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
ofs.close();
}

void GetInferStackInfo(std::ostringstream& oss) {
void GetInferStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Get graph analysis information begin";
auto stack = GetCNodeDebugStack();
if (stack.empty()) {
@@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) {
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
// trace the cnode infer debug info
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) {
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
}
@@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An
}
}

void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) {
if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
}
@@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
}
}

void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }

void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); }

std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; }
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }

std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() {
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() {
return graph_infer_stack;
}
void ClearTraceStack() {


+ 10
- 10
mindspore/ccsrc/debug/trace.h View File

@@ -31,19 +31,19 @@

namespace mindspore {
namespace trace {
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix,
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
void TraceGraphInfer();
void GetInferStackInfo(std::ostringstream& oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg);
void GetInferStackInfo(std::ostringstream &oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
void TraceInferCNodeLeave();
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs);
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
void ClearTraceStack();
} // namespace trace
} // namespace mindspore


+ 1
- 1
mindspore/ccsrc/debug/trace_info.cc View File

@@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h"

namespace mindspore {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {
if (info == nullptr) {
return "";
}


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save