You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 67 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
optimize the comment and log description 修改: ops/operations/_inner_ops.py 修改: ops/operations/_quant_ops.py 修改: ops/operations/array_ops.py 修改: ops/operations/comm_ops.py 修改: ops/operations/math_ops.py 修改: ops/operations/quantum_ops.py 修改: ops/operations/rl_ops.py 修改: ops/operations/sponge_ops.py 修改: ops/operations/sponge_update_ops.py 修改: train/__init__.py 修改: common/tensor.py 修改: train/serialization.py 修改: ccsrc/pipeline/jit/parse/parse.h 修改: explainer/benchmark/_attribution/metric.py 修改: ops/composite/multitype_ops/_constexpr_utils.py 修改: ops/operations/comm_ops.py 修改: RELEASE.md 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/concat_offset_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_shape_cpu_kernel.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc 修改: mindspore/ccsrc/frontend/parallel/strategy.h 修改: mindspore/common/tensor.py 修改: mindspore/core/abstract/prim_arrays.cc 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/core/ops/logical_and.h 修改: mindspore/core/ops/logical_not.h 修改: mindspore/core/ops/logical_or.h 修改: mindspore/core/ops/reduce_all.h 修改: mindspore/core/ops/reduce_any.h 修改: mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc 修改: mindspore/nn/layer/quant.py 修改: mindspore/nn/optim/sgd.py 修改: mindspore/nn/sparse/sparse.py 修改: mindspore/numpy/array_creations.py 修改: mindspore/numpy/array_ops.py 修改: mindspore/numpy/logic_ops.py 修改: mindspore/numpy/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/train/_utils.py 修改: tests/ut/python/model/test_lenet_core_after_exception.py 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ctcloss_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h 修改: mindspore/ccsrc/fl/server/server.cc 修改: mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc 修改: mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h 修改: mindspore/ccsrc/frontend/optimizer/irpass/inline.h 修改: mindspore/ccsrc/minddata/dataset/core/device_tensor.cc 修改: mindspore/ccsrc/minddata/dataset/core/tensor.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/emnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc 修改: mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc 修改: mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc 修改: mindspore/ccsrc/pipeline/jit/action.cc 修改: mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc 修改: mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_adapter.cc 修改: mindspore/compression/quant/quant_utils.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/dataset/engine/validators.py 修改: mindspore/lite/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.cc 修改: mindspore/lite/micro/coder/opcoders/nnacl/int8/affine_int8_coder.cc 修改: mindspore/lite/src/runtime/kernel/ascend310/src/custom_kernel.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/lite/tools/optimizer/fisson/fisson_util.cc 修改: mindspore/ops/composite/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/math_ops.py 修改: mindspore/ops/operations/other_ops.py 修改: mindspore/boost/boost_cell_wrapper.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/common/trans.cc 修改: mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/lite/src/common/log_util.h 修改: mindspore/nn/wrap/loss_scale.py 修改: mindspore/parallel/nn/moe.py 修改: tests/mindspore_test_framework/mindspore_test.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/core/ops/conv2d.cc 修改: tests/ut/python/model/test_lenet_core_after_exception.py
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
5 years ago
5 years ago
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License foNtest_resr the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Built-in validators.
  17. """
  18. import inspect as ins
  19. import os
  20. import re
  21. from functools import wraps
  22. import numpy as np
  23. from mindspore._c_expression import typing
  24. from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
  25. INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
  26. validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \
  27. check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id
  28. from . import datasets
  29. from . import samplers
  30. from . import cache_client
  31. def check_imagefolderdataset(method):
  32. """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset)."""
  33. @wraps(method)
  34. def new_method(self, *args, **kwargs):
  35. _, param_dict = parse_user_args(method, *args, **kwargs)
  36. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  37. nreq_param_bool = ['shuffle', 'decode']
  38. nreq_param_list = ['extensions']
  39. nreq_param_dict = ['class_indexing']
  40. dataset_dir = param_dict.get('dataset_dir')
  41. check_dir(dataset_dir)
  42. validate_dataset_param_value(nreq_param_int, param_dict, int)
  43. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  44. validate_dataset_param_value(nreq_param_list, param_dict, list)
  45. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  46. check_sampler_shuffle_shard_options(param_dict)
  47. cache = param_dict.get('cache')
  48. check_cache_option(cache)
  49. return method(self, *args, **kwargs)
  50. return new_method
  51. def check_mnist_cifar_dataset(method):
  52. """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
  53. @wraps(method)
  54. def new_method(self, *args, **kwargs):
  55. _, param_dict = parse_user_args(method, *args, **kwargs)
  56. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  57. nreq_param_bool = ['shuffle']
  58. dataset_dir = param_dict.get('dataset_dir')
  59. check_dir(dataset_dir)
  60. usage = param_dict.get('usage')
  61. if usage is not None:
  62. check_valid_str(usage, ["train", "test", "all"], "usage")
  63. validate_dataset_param_value(nreq_param_int, param_dict, int)
  64. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  65. check_sampler_shuffle_shard_options(param_dict)
  66. cache = param_dict.get('cache')
  67. check_cache_option(cache)
  68. return method(self, *args, **kwargs)
  69. return new_method
  70. def check_photo_tour_dataset(method):
  71. """A wrapper that wraps a parameter checker around the original Dataset(PhotoTourDataset)."""
  72. @wraps(method)
  73. def new_method(self, *args, **kwargs):
  74. _, param_dict = parse_user_args(method, *args, **kwargs)
  75. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  76. nreq_param_bool = ['shuffle']
  77. dataset_dir = param_dict.get('dataset_dir')
  78. check_dir(dataset_dir)
  79. usage = param_dict.get('usage')
  80. if usage is not None:
  81. check_valid_str(usage, ["train", "test"], "usage")
  82. name = param_dict.get('name')
  83. check_valid_str(name, ["notredame", "yosemite", "liberty", "notredame_harris",
  84. "yosemite_harris", "liberty_harris"], "name")
  85. validate_dataset_param_value(nreq_param_int, param_dict, int)
  86. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  87. check_sampler_shuffle_shard_options(param_dict)
  88. cache = param_dict.get('cache')
  89. check_cache_option(cache)
  90. return method(self, *args, **kwargs)
  91. return new_method
  92. def check_places365_dataset(method):
  93. """A wrapper that wraps a parameter checker around the original Dataset(Places365Dataset)."""
  94. @wraps(method)
  95. def new_method(self, *args, **kwargs):
  96. _, param_dict = parse_user_args(method, *args, **kwargs)
  97. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  98. nreq_param_bool = ['shuffle', 'small', 'decode']
  99. dataset_dir = param_dict.get('dataset_dir')
  100. check_dir(dataset_dir)
  101. usage = param_dict.get('usage')
  102. if usage is not None:
  103. check_valid_str(usage, ["train-standard", "train-challenge", "val"], "usage")
  104. validate_dataset_param_value(nreq_param_int, param_dict, int)
  105. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  106. check_sampler_shuffle_shard_options(param_dict)
  107. cache = param_dict.get('cache')
  108. check_cache_option(cache)
  109. return method(self, *args, **kwargs)
  110. return new_method
  111. def check_qmnist_dataset(method):
  112. """A wrapper that wraps a parameter checker around the original Dataset(QMnistDataset)."""
  113. @wraps(method)
  114. def new_method(self, *args, **kwargs):
  115. _, param_dict = parse_user_args(method, *args, **kwargs)
  116. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  117. nreq_param_bool = ['shuffle', 'compat']
  118. dataset_dir = param_dict.get('dataset_dir')
  119. check_dir(dataset_dir)
  120. usage = param_dict.get('usage')
  121. if usage is not None:
  122. check_valid_str(usage, ["train", "test", "test10k", "test50k", "nist", "all"], "usage")
  123. validate_dataset_param_value(nreq_param_int, param_dict, int)
  124. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  125. check_sampler_shuffle_shard_options(param_dict)
  126. cache = param_dict.get('cache')
  127. check_cache_option(cache)
  128. return method(self, *args, **kwargs)
  129. return new_method
  130. def check_manifestdataset(method):
  131. """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
  132. @wraps(method)
  133. def new_method(self, *args, **kwargs):
  134. _, param_dict = parse_user_args(method, *args, **kwargs)
  135. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  136. nreq_param_bool = ['shuffle', 'decode']
  137. nreq_param_str = ['usage']
  138. nreq_param_dict = ['class_indexing']
  139. dataset_file = param_dict.get('dataset_file')
  140. check_file(dataset_file)
  141. validate_dataset_param_value(nreq_param_int, param_dict, int)
  142. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  143. validate_dataset_param_value(nreq_param_str, param_dict, str)
  144. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  145. check_sampler_shuffle_shard_options(param_dict)
  146. cache = param_dict.get('cache')
  147. check_cache_option(cache)
  148. return method(self, *args, **kwargs)
  149. return new_method
  150. def check_sbu_dataset(method):
  151. """A wrapper that wraps a parameter checker around the original Dataset(SBUDataset)."""
  152. @wraps(method)
  153. def new_method(self, *args, **kwargs):
  154. _, param_dict = parse_user_args(method, *args, **kwargs)
  155. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  156. nreq_param_bool = ['shuffle', 'decode']
  157. dataset_dir = param_dict.get('dataset_dir')
  158. check_dir(dataset_dir)
  159. check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt"))
  160. check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt"))
  161. check_dir(os.path.join(dataset_dir, "sbu_images"))
  162. validate_dataset_param_value(nreq_param_int, param_dict, int)
  163. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  164. check_sampler_shuffle_shard_options(param_dict)
  165. cache = param_dict.get('cache')
  166. check_cache_option(cache)
  167. return method(self, *args, **kwargs)
  168. return new_method
  169. def check_tfrecorddataset(method):
  170. """A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
  171. @wraps(method)
  172. def new_method(self, *args, **kwargs):
  173. _, param_dict = parse_user_args(method, *args, **kwargs)
  174. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  175. nreq_param_list = ['columns_list']
  176. nreq_param_bool = ['shard_equal_rows']
  177. dataset_files = param_dict.get('dataset_files')
  178. if not isinstance(dataset_files, (str, list)):
  179. raise TypeError("dataset_files should be type str or a list of strings.")
  180. validate_dataset_param_value(nreq_param_int, param_dict, int)
  181. validate_dataset_param_value(nreq_param_list, param_dict, list)
  182. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  183. check_sampler_shuffle_shard_options(param_dict)
  184. cache = param_dict.get('cache')
  185. check_cache_option(cache)
  186. return method(self, *args, **kwargs)
  187. return new_method
  188. def check_usps_dataset(method):
  189. """A wrapper that wraps a parameter checker around the original Dataset(USPSDataset)."""
  190. @wraps(method)
  191. def new_method(self, *args, **kwargs):
  192. _, param_dict = parse_user_args(method, *args, **kwargs)
  193. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  194. dataset_dir = param_dict.get('dataset_dir')
  195. check_dir(dataset_dir)
  196. usage = param_dict.get('usage')
  197. if usage is not None:
  198. check_valid_str(usage, ["train", "test", "all"], "usage")
  199. validate_dataset_param_value(nreq_param_int, param_dict, int)
  200. check_sampler_shuffle_shard_options(param_dict)
  201. cache = param_dict.get('cache')
  202. check_cache_option(cache)
  203. return method(self, *args, **kwargs)
  204. return new_method
  205. def check_vocdataset(method):
  206. """A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
  207. @wraps(method)
  208. def new_method(self, *args, **kwargs):
  209. _, param_dict = parse_user_args(method, *args, **kwargs)
  210. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  211. nreq_param_bool = ['shuffle', 'decode']
  212. nreq_param_dict = ['class_indexing']
  213. dataset_dir = param_dict.get('dataset_dir')
  214. check_dir(dataset_dir)
  215. task = param_dict.get('task')
  216. type_check(task, (str,), "task")
  217. usage = param_dict.get('usage')
  218. type_check(usage, (str,), "usage")
  219. dataset_dir = os.path.realpath(dataset_dir)
  220. if task == "Segmentation":
  221. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
  222. if param_dict.get('class_indexing') is not None:
  223. raise ValueError("class_indexing is not supported in Segmentation task.")
  224. elif task == "Detection":
  225. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
  226. else:
  227. raise ValueError("Invalid task : " + task + ".")
  228. check_file(imagesets_file)
  229. validate_dataset_param_value(nreq_param_int, param_dict, int)
  230. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  231. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  232. check_sampler_shuffle_shard_options(param_dict)
  233. cache = param_dict.get('cache')
  234. check_cache_option(cache)
  235. return method(self, *args, **kwargs)
  236. return new_method
  237. def check_cocodataset(method):
  238. """A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
  239. @wraps(method)
  240. def new_method(self, *args, **kwargs):
  241. _, param_dict = parse_user_args(method, *args, **kwargs)
  242. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  243. nreq_param_bool = ['shuffle', 'decode']
  244. dataset_dir = param_dict.get('dataset_dir')
  245. check_dir(dataset_dir)
  246. annotation_file = param_dict.get('annotation_file')
  247. check_file(annotation_file)
  248. task = param_dict.get('task')
  249. type_check(task, (str,), "task")
  250. if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
  251. raise ValueError("Invalid task type: " + task + ".")
  252. validate_dataset_param_value(nreq_param_int, param_dict, int)
  253. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  254. sampler = param_dict.get('sampler')
  255. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  256. raise ValueError("CocoDataset doesn't support PKSampler.")
  257. check_sampler_shuffle_shard_options(param_dict)
  258. cache = param_dict.get('cache')
  259. check_cache_option(cache)
  260. return method(self, *args, **kwargs)
  261. return new_method
  262. def check_celebadataset(method):
  263. """A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
  264. @wraps(method)
  265. def new_method(self, *args, **kwargs):
  266. _, param_dict = parse_user_args(method, *args, **kwargs)
  267. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  268. nreq_param_bool = ['shuffle', 'decode']
  269. nreq_param_list = ['extensions']
  270. nreq_param_str = ['dataset_type']
  271. dataset_dir = param_dict.get('dataset_dir')
  272. check_dir(dataset_dir)
  273. validate_dataset_param_value(nreq_param_int, param_dict, int)
  274. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  275. validate_dataset_param_value(nreq_param_list, param_dict, list)
  276. validate_dataset_param_value(nreq_param_str, param_dict, str)
  277. usage = param_dict.get('usage')
  278. if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
  279. raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.")
  280. check_sampler_shuffle_shard_options(param_dict)
  281. sampler = param_dict.get('sampler')
  282. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  283. raise ValueError("CelebADataset doesn't support PKSampler.")
  284. cache = param_dict.get('cache')
  285. check_cache_option(cache)
  286. return method(self, *args, **kwargs)
  287. return new_method
  288. def check_lj_speech_dataset(method):
  289. """A wrapper that wraps a parameter checker around the original Dataset(LJSpeechDataset)."""
  290. @wraps(method)
  291. def new_method(self, *args, **kwargs):
  292. _, param_dict = parse_user_args(method, *args, **kwargs)
  293. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  294. nreq_param_bool = ['shuffle']
  295. dataset_dir = param_dict.get('dataset_dir')
  296. check_dir(dataset_dir)
  297. validate_dataset_param_value(nreq_param_int, param_dict, int)
  298. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  299. check_sampler_shuffle_shard_options(param_dict)
  300. cache = param_dict.get('cache')
  301. check_cache_option(cache)
  302. return method(self, *args, **kwargs)
  303. return new_method
  304. def check_save(method):
  305. """A wrapper that wraps a parameter checker around the saved operator."""
  306. @wraps(method)
  307. def new_method(self, *args, **kwargs):
  308. _, param_dict = parse_user_args(method, *args, **kwargs)
  309. nreq_param_int = ['num_files']
  310. nreq_param_str = ['file_name', 'file_type']
  311. validate_dataset_param_value(nreq_param_int, param_dict, int)
  312. if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
  313. raise ValueError("num_files should between 0 and 1000.")
  314. validate_dataset_param_value(nreq_param_str, param_dict, str)
  315. if param_dict.get('file_type') != 'mindrecord':
  316. raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
  317. return method(self, *args, **kwargs)
  318. return new_method
  319. def check_tuple_iterator(method):
  320. """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
  321. @wraps(method)
  322. def new_method(self, *args, **kwargs):
  323. [columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
  324. nreq_param_bool = ['output_numpy']
  325. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  326. if num_epochs is not None:
  327. type_check(num_epochs, (int,), "num_epochs")
  328. check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
  329. if columns is not None:
  330. check_columns(columns, "column_names")
  331. return method(self, *args, **kwargs)
  332. return new_method
  333. def check_dict_iterator(method):
  334. """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
  335. @wraps(method)
  336. def new_method(self, *args, **kwargs):
  337. [num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
  338. nreq_param_bool = ['output_numpy']
  339. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  340. if num_epochs is not None:
  341. type_check(num_epochs, (int,), "num_epochs")
  342. check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
  343. return method(self, *args, **kwargs)
  344. return new_method
  345. def check_minddataset(method):
  346. """A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
  347. @wraps(method)
  348. def new_method(self, *args, **kwargs):
  349. _, param_dict = parse_user_args(method, *args, **kwargs)
  350. nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
  351. nreq_param_list = ['columns_list']
  352. nreq_param_dict = ['padded_sample']
  353. dataset_file = param_dict.get('dataset_file')
  354. if isinstance(dataset_file, list):
  355. if len(dataset_file) > 4096:
  356. raise ValueError("length of dataset_file should be less than or equal to {}.".format(4096))
  357. for f in dataset_file:
  358. check_file(f)
  359. else:
  360. check_file(dataset_file)
  361. validate_dataset_param_value(nreq_param_int, param_dict, int)
  362. validate_dataset_param_value(nreq_param_list, param_dict, list)
  363. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  364. check_sampler_shuffle_shard_options(param_dict)
  365. check_padding_options(param_dict)
  366. return method(self, *args, **kwargs)
  367. return new_method
  368. def check_generatordataset(method):
  369. """A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
  370. @wraps(method)
  371. def new_method(self, *args, **kwargs):
  372. _, param_dict = parse_user_args(method, *args, **kwargs)
  373. source = param_dict.get('source')
  374. if not callable(source):
  375. try:
  376. iter(source)
  377. except TypeError:
  378. raise TypeError("Input `source` function of GeneratorDataset should be callable, iterable or random"
  379. " accessible, commonly it should implement one of the method like yield, __getitem__ or"
  380. " __next__(__iter__).")
  381. column_names = param_dict.get('column_names')
  382. if column_names is not None:
  383. check_columns(column_names, "column_names")
  384. schema = param_dict.get('schema')
  385. if column_names is None and schema is None:
  386. raise ValueError("Neither columns_names nor schema are provided.")
  387. if schema is not None:
  388. if not isinstance(schema, (datasets.Schema, str)):
  389. raise ValueError("schema should be a path to schema file or a schema object.")
  390. # check optional argument
  391. nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"]
  392. validate_dataset_param_value(nreq_param_int, param_dict, int)
  393. nreq_param_list = ["column_types"]
  394. validate_dataset_param_value(nreq_param_list, param_dict, list)
  395. nreq_param_bool = ["shuffle"]
  396. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  397. num_shards = param_dict.get("num_shards")
  398. shard_id = param_dict.get("shard_id")
  399. check_dataset_num_shards_shard_id(num_shards, shard_id)
  400. sampler = param_dict.get("sampler")
  401. if sampler is not None:
  402. if isinstance(sampler, samplers.PKSampler):
  403. raise ValueError("GeneratorDataset doesn't support PKSampler.")
  404. if not isinstance(sampler, samplers.BuiltinSampler):
  405. try:
  406. iter(sampler)
  407. except TypeError:
  408. raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.")
  409. if sampler is not None and not hasattr(source, "__getitem__"):
  410. raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.")
  411. if num_shards is not None and not hasattr(source, "__getitem__"):
  412. raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.")
  413. return method(self, *args, **kwargs)
  414. return new_method
  415. def check_random_dataset(method):
  416. """A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
  417. @wraps(method)
  418. def new_method(self, *args, **kwargs):
  419. _, param_dict = parse_user_args(method, *args, **kwargs)
  420. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
  421. nreq_param_bool = ['shuffle']
  422. nreq_param_list = ['columns_list']
  423. validate_dataset_param_value(nreq_param_int, param_dict, int)
  424. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  425. validate_dataset_param_value(nreq_param_list, param_dict, list)
  426. check_sampler_shuffle_shard_options(param_dict)
  427. cache = param_dict.get('cache')
  428. check_cache_option(cache)
  429. return method(self, *args, **kwargs)
  430. return new_method
  431. def check_pad_info(key, val):
  432. """check the key and value pair of pad_info in batch"""
  433. type_check(key, (str,), "key in pad_info")
  434. if val is not None:
  435. if len(val) != 2:
  436. raise ValueError("value of pad_info should be a tuple of size 2.")
  437. type_check(val, (tuple,), "value in pad_info")
  438. if val[0] is not None:
  439. type_check(val[0], (list,), "shape in pad_info")
  440. for dim in val[0]:
  441. if dim is not None:
  442. check_pos_int32(dim, "dim of shape in pad_info")
  443. if val[1] is not None:
  444. type_check(val[1], (int, float, str, bytes), "pad_value")
  445. def check_bucket_batch_by_length(method):
  446. """check the input arguments of bucket_batch_by_length."""
  447. @wraps(method)
  448. def new_method(self, *args, **kwargs):
  449. [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
  450. pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
  451. nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
  452. type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
  453. nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
  454. type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
  455. # check column_names: must be list of string.
  456. check_columns(column_names, "column_names")
  457. if element_length_function is None and len(column_names) != 1:
  458. raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
  459. if element_length_function is not None and not callable(element_length_function):
  460. raise TypeError("element_length_function object is not callable.")
  461. # check bucket_boundaries: must be list of int, positive and strictly increasing
  462. if not bucket_boundaries:
  463. raise ValueError("bucket_boundaries cannot be empty.")
  464. all_int = all(isinstance(item, int) for item in bucket_boundaries)
  465. if not all_int:
  466. raise TypeError("bucket_boundaries should be a list of int.")
  467. all_non_negative = all(item > 0 for item in bucket_boundaries)
  468. if not all_non_negative:
  469. raise ValueError("bucket_boundaries must only contain positive numbers.")
  470. for i in range(len(bucket_boundaries) - 1):
  471. if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
  472. raise ValueError("bucket_boundaries should be strictly increasing.")
  473. # check bucket_batch_sizes: must be list of int and positive
  474. if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
  475. raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
  476. all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
  477. if not all_int:
  478. raise TypeError("bucket_batch_sizes should be a list of int.")
  479. all_non_negative = all(item > 0 for item in bucket_batch_sizes)
  480. if not all_non_negative:
  481. raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
  482. if pad_info is not None:
  483. type_check(pad_info, (dict,), "pad_info")
  484. for k, v in pad_info.items():
  485. check_pad_info(k, v)
  486. return method(self, *args, **kwargs)
  487. return new_method
  488. def check_batch(method):
  489. """check the input arguments of batch."""
  490. @wraps(method)
  491. def new_method(self, *args, **kwargs):
  492. [batch_size, drop_remainder, num_parallel_workers, per_batch_map,
  493. input_columns, output_columns, column_order, pad_info,
  494. python_multiprocessing, max_rowsize], param_dict = parse_user_args(method, *args, **kwargs)
  495. if not (isinstance(batch_size, int) or (callable(batch_size))):
  496. raise TypeError("batch_size should either be an int or a callable.")
  497. if callable(batch_size):
  498. sig = ins.signature(batch_size)
  499. if len(sig.parameters) != 1:
  500. raise ValueError("callable batch_size should take one parameter (BatchInfo).")
  501. else:
  502. check_pos_int32(int(batch_size), "batch_size")
  503. if num_parallel_workers is not None:
  504. check_num_parallel_workers(num_parallel_workers)
  505. type_check(drop_remainder, (bool,), "drop_remainder")
  506. type_check(max_rowsize, (int,), "max_rowsize")
  507. if (pad_info is not None) and (per_batch_map is not None):
  508. raise ValueError("pad_info and per_batch_map can't both be set.")
  509. if pad_info is not None:
  510. type_check(param_dict["pad_info"], (dict,), "pad_info")
  511. for k, v in param_dict.get('pad_info').items():
  512. check_pad_info(k, v)
  513. if (per_batch_map is None) != (input_columns is None):
  514. # These two parameters appear together.
  515. raise ValueError("per_batch_map and input_columns need to be passed in together.")
  516. if input_columns is not None:
  517. check_columns(input_columns, "input_columns")
  518. if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
  519. raise ValueError("The signature of per_batch_map should match with input columns.")
  520. if output_columns is not None:
  521. check_columns(output_columns, "output_columns")
  522. if column_order is not None:
  523. check_columns(column_order, "column_order")
  524. if python_multiprocessing is not None:
  525. type_check(python_multiprocessing, (bool,), "python_multiprocessing")
  526. return method(self, *args, **kwargs)
  527. return new_method
  528. def check_sync_wait(method):
  529. """check the input arguments of sync_wait."""
  530. @wraps(method)
  531. def new_method(self, *args, **kwargs):
  532. [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
  533. type_check(condition_name, (str,), "condition_name")
  534. type_check(num_batch, (int,), "num_batch")
  535. return method(self, *args, **kwargs)
  536. return new_method
  537. def check_shuffle(method):
  538. """check the input arguments of shuffle."""
  539. @wraps(method)
  540. def new_method(self, *args, **kwargs):
  541. [buffer_size], _ = parse_user_args(method, *args, **kwargs)
  542. type_check(buffer_size, (int,), "buffer_size")
  543. check_value(buffer_size, [2, INT32_MAX], "buffer_size")
  544. return method(self, *args, **kwargs)
  545. return new_method
  546. def check_map(method):
  547. """check the input arguments of map."""
  548. @wraps(method)
  549. def new_method(self, *args, **kwargs):
  550. from mindspore.dataset.callback import DSCallback
  551. [_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache,
  552. callbacks, max_rowsize, offload], _ = \
  553. parse_user_args(method, *args, **kwargs)
  554. nreq_param_columns = ['input_columns', 'output_columns', 'column_order']
  555. if column_order is not None:
  556. type_check(column_order, (list,), "column_order")
  557. if num_parallel_workers is not None:
  558. check_num_parallel_workers(num_parallel_workers)
  559. type_check(python_multiprocessing, (bool,), "python_multiprocessing")
  560. check_cache_option(cache)
  561. type_check(max_rowsize, (int,), "max_rowsize")
  562. if offload is not None:
  563. type_check(offload, (bool,), "offload")
  564. if callbacks is not None:
  565. if isinstance(callbacks, (list, tuple)):
  566. type_check_list(callbacks, (DSCallback,), "callbacks")
  567. else:
  568. type_check(callbacks, (DSCallback,), "callbacks")
  569. for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]):
  570. if param is not None:
  571. check_columns(param, param_name)
  572. if callbacks is not None:
  573. type_check(callbacks, (list, DSCallback), "callbacks")
  574. return method(self, *args, **kwargs)
  575. return new_method
  576. def check_filter(method):
  577. """"check the input arguments of filter."""
  578. @wraps(method)
  579. def new_method(self, *args, **kwargs):
  580. [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
  581. if not callable(predicate):
  582. raise TypeError("Predicate should be a Python function or a callable Python object.")
  583. if num_parallel_workers is not None:
  584. check_num_parallel_workers(num_parallel_workers)
  585. if input_columns is not None:
  586. check_columns(input_columns, "input_columns")
  587. return method(self, *args, **kwargs)
  588. return new_method
  589. def check_repeat(method):
  590. """check the input arguments of repeat."""
  591. @wraps(method)
  592. def new_method(self, *args, **kwargs):
  593. [count], _ = parse_user_args(method, *args, **kwargs)
  594. type_check(count, (int, type(None)), "repeat")
  595. if isinstance(count, int):
  596. if (count <= 0 and count != -1) or count > INT32_MAX:
  597. raise ValueError("count should be either -1 or positive integer, range[1, INT32_MAX].")
  598. return method(self, *args, **kwargs)
  599. return new_method
  600. def check_skip(method):
  601. """check the input arguments of skip."""
  602. @wraps(method)
  603. def new_method(self, *args, **kwargs):
  604. [count], _ = parse_user_args(method, *args, **kwargs)
  605. type_check(count, (int,), "count")
  606. check_value(count, (0, INT32_MAX), "count")
  607. return method(self, *args, **kwargs)
  608. return new_method
  609. def check_take(method):
  610. """check the input arguments of take."""
  611. @wraps(method)
  612. def new_method(self, *args, **kwargs):
  613. [count], _ = parse_user_args(method, *args, **kwargs)
  614. type_check(count, (int,), "count")
  615. if (count <= 0 and count != -1) or count > INT32_MAX:
  616. raise ValueError("count should be either -1 or within the required interval of ({}, {}], got {}."
  617. .format(0, INT32_MAX, count))
  618. return method(self, *args, **kwargs)
  619. return new_method
  620. def check_positive_int32(method):
  621. """check whether the input argument is positive and int, only works for functions with one input."""
  622. @wraps(method)
  623. def new_method(self, *args, **kwargs):
  624. [count], param_dict = parse_user_args(method, *args, **kwargs)
  625. para_name = None
  626. for key in list(param_dict.keys()):
  627. if key not in ['self', 'cls']:
  628. para_name = key
  629. # Need to get default value of param
  630. if count is not None:
  631. check_pos_int32(count, para_name)
  632. return method(self, *args, **kwargs)
  633. return new_method
  634. def check_device_send(method):
  635. """check the input argument for to_device and device_que."""
  636. @wraps(method)
  637. def new_method(self, *args, **kwargs):
  638. [send_epoch_end, create_data_info_queue], _ = parse_user_args(method, *args, **kwargs)
  639. type_check(send_epoch_end, (bool,), "send_epoch_end")
  640. type_check(create_data_info_queue, (bool,), "create_data_info_queue")
  641. return method(self, *args, **kwargs)
  642. return new_method
  643. def check_zip(method):
  644. """check the input arguments of zip."""
  645. @wraps(method)
  646. def new_method(*args, **kwargs):
  647. [ds], _ = parse_user_args(method, *args, **kwargs)
  648. type_check(ds, (tuple,), "datasets")
  649. return method(*args, **kwargs)
  650. return new_method
  651. def check_zip_dataset(method):
  652. """check the input arguments of zip method in `Dataset`."""
  653. @wraps(method)
  654. def new_method(self, *args, **kwargs):
  655. [ds], _ = parse_user_args(method, *args, **kwargs)
  656. type_check(ds, (tuple, datasets.Dataset), "datasets")
  657. return method(self, *args, **kwargs)
  658. return new_method
  659. def check_concat(method):
  660. """check the input arguments of concat method in `Dataset`."""
  661. @wraps(method)
  662. def new_method(self, *args, **kwargs):
  663. [ds], _ = parse_user_args(method, *args, **kwargs)
  664. type_check(ds, (list, datasets.Dataset), "datasets")
  665. if isinstance(ds, list):
  666. type_check_list(ds, (datasets.Dataset,), "dataset")
  667. return method(self, *args, **kwargs)
  668. return new_method
  669. def check_rename(method):
  670. """check the input arguments of rename."""
  671. @wraps(method)
  672. def new_method(self, *args, **kwargs):
  673. values, _ = parse_user_args(method, *args, **kwargs)
  674. req_param_columns = ['input_columns', 'output_columns']
  675. for param_name, param in zip(req_param_columns, values):
  676. check_columns(param, param_name)
  677. input_size, output_size = 1, 1
  678. input_columns, output_columns = values
  679. if isinstance(input_columns, list):
  680. input_size = len(input_columns)
  681. if isinstance(output_columns, list):
  682. output_size = len(output_columns)
  683. if input_size != output_size:
  684. raise ValueError("Number of column in input_columns and output_columns is not equal.")
  685. return method(self, *args, **kwargs)
  686. return new_method
  687. def check_project(method):
  688. """check the input arguments of project."""
  689. @wraps(method)
  690. def new_method(self, *args, **kwargs):
  691. [columns], _ = parse_user_args(method, *args, **kwargs)
  692. check_columns(columns, 'columns')
  693. return method(self, *args, **kwargs)
  694. return new_method
  695. def check_schema(method):
  696. """check the input arguments of Schema.__init__."""
  697. @wraps(method)
  698. def new_method(self, *args, **kwargs):
  699. [schema_file], _ = parse_user_args(method, *args, **kwargs)
  700. if schema_file is not None:
  701. check_file(schema_file)
  702. return method(self, *args, **kwargs)
  703. return new_method
  704. def check_add_column(method):
  705. """check the input arguments of add_column."""
  706. @wraps(method)
  707. def new_method(self, *args, **kwargs):
  708. [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
  709. type_check(name, (str,), "name")
  710. if not name:
  711. raise TypeError("Expected non-empty string for column name.")
  712. if de_type is not None:
  713. if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
  714. raise TypeError("Unknown column type: {}.".format(de_type))
  715. else:
  716. raise TypeError("Expected non-empty string for de_type.")
  717. if shape is not None:
  718. type_check(shape, (list,), "shape")
  719. type_check_list(shape, (int,), "shape")
  720. return method(self, *args, **kwargs)
  721. return new_method
  722. def check_cluedataset(method):
  723. """A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
  724. @wraps(method)
  725. def new_method(self, *args, **kwargs):
  726. _, param_dict = parse_user_args(method, *args, **kwargs)
  727. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  728. dataset_files = param_dict.get('dataset_files')
  729. type_check(dataset_files, (str, list), "dataset files")
  730. # check task
  731. task_param = param_dict.get('task')
  732. if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
  733. raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.")
  734. # check usage
  735. usage_param = param_dict.get('usage')
  736. if usage_param not in ['train', 'test', 'eval']:
  737. raise ValueError("usage should be 'train', 'test' or 'eval'.")
  738. validate_dataset_param_value(nreq_param_int, param_dict, int)
  739. check_sampler_shuffle_shard_options(param_dict)
  740. cache = param_dict.get('cache')
  741. check_cache_option(cache)
  742. return method(self, *args, **kwargs)
  743. return new_method
  744. def check_csvdataset(method):
  745. """A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
  746. @wraps(method)
  747. def new_method(self, *args, **kwargs):
  748. _, param_dict = parse_user_args(method, *args, **kwargs)
  749. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  750. # check dataset_files; required argument
  751. dataset_files = param_dict.get('dataset_files')
  752. type_check(dataset_files, (str, list), "dataset files")
  753. # check field_delim
  754. field_delim = param_dict.get('field_delim')
  755. if field_delim is not None:
  756. type_check(field_delim, (str,), 'field delim')
  757. if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
  758. raise ValueError("field_delim is invalid.")
  759. # check column_defaults
  760. column_defaults = param_dict.get('column_defaults')
  761. if column_defaults is not None:
  762. if not isinstance(column_defaults, list):
  763. raise TypeError("column_defaults should be type of list.")
  764. for item in column_defaults:
  765. if not isinstance(item, (str, int, float)):
  766. raise TypeError("column type in column_defaults is invalid.")
  767. # check column_names: must be list of string.
  768. column_names = param_dict.get("column_names")
  769. if column_names is not None:
  770. all_string = all(isinstance(item, str) for item in column_names)
  771. if not all_string:
  772. raise TypeError("column_names should be a list of str.")
  773. validate_dataset_param_value(nreq_param_int, param_dict, int)
  774. check_sampler_shuffle_shard_options(param_dict)
  775. cache = param_dict.get('cache')
  776. check_cache_option(cache)
  777. return method(self, *args, **kwargs)
  778. return new_method
  779. def check_flowers102dataset(method):
  780. """A wrapper that wraps a parameter checker around the original Dataset(Flowers102Dataset)."""
  781. @wraps(method)
  782. def new_method(self, *args, **kwargs):
  783. _, param_dict = parse_user_args(method, *args, **kwargs)
  784. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  785. nreq_param_bool = ['shuffle', 'decode']
  786. dataset_dir = param_dict.get('dataset_dir')
  787. check_dir(dataset_dir)
  788. check_dir(os.path.join(dataset_dir, "jpg"))
  789. check_file(os.path.join(dataset_dir, "imagelabels.mat"))
  790. check_file(os.path.join(dataset_dir, "setid.mat"))
  791. usage = param_dict.get('usage')
  792. if usage is not None:
  793. check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
  794. task = param_dict.get('task')
  795. if task is not None:
  796. check_valid_str(task, ["Classification", "Segmentation"], "task")
  797. if task == "Segmentation":
  798. check_dir(os.path.join(dataset_dir, "segmim"))
  799. validate_dataset_param_value(nreq_param_int, param_dict, int)
  800. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  801. check_sampler_shuffle_shard_options(param_dict)
  802. return method(self, *args, **kwargs)
  803. return new_method
  804. def check_textfiledataset(method):
  805. """A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
  806. @wraps(method)
  807. def new_method(self, *args, **kwargs):
  808. _, param_dict = parse_user_args(method, *args, **kwargs)
  809. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  810. dataset_files = param_dict.get('dataset_files')
  811. type_check(dataset_files, (str, list), "dataset files")
  812. validate_dataset_param_value(nreq_param_int, param_dict, int)
  813. check_sampler_shuffle_shard_options(param_dict)
  814. cache = param_dict.get('cache')
  815. check_cache_option(cache)
  816. return method(self, *args, **kwargs)
  817. return new_method
  818. def check_split(method):
  819. """check the input arguments of split."""
  820. @wraps(method)
  821. def new_method(self, *args, **kwargs):
  822. [sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
  823. type_check(sizes, (list,), "sizes")
  824. type_check(randomize, (bool,), "randomize")
  825. # check sizes: must be list of float or list of int
  826. if not sizes:
  827. raise ValueError("sizes cannot be empty.")
  828. all_int = all(isinstance(item, int) for item in sizes)
  829. all_float = all(isinstance(item, float) for item in sizes)
  830. if not (all_int or all_float):
  831. raise ValueError("sizes should be list of int or list of float.")
  832. if all_int:
  833. all_positive = all(item > 0 for item in sizes)
  834. if not all_positive:
  835. raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
  836. if all_float:
  837. all_valid_percentages = all(0 < item <= 1 for item in sizes)
  838. if not all_valid_percentages:
  839. raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
  840. epsilon = 0.00001
  841. if not abs(sum(sizes) - 1) < epsilon:
  842. raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
  843. return method(self, *args, **kwargs)
  844. return new_method
  845. def check_hostname(hostname):
  846. if not hostname or len(hostname) > 255:
  847. return False
  848. if hostname[-1] == ".":
  849. hostname = hostname[:-1] # strip exactly one dot from the right, if present
  850. allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
  851. return all(allowed.match(x) for x in hostname.split("."))
  852. def check_gnn_graphdata(method):
  853. """check the input arguments of graphdata."""
  854. @wraps(method)
  855. def new_method(self, *args, **kwargs):
  856. [dataset_file, num_parallel_workers, working_mode, hostname,
  857. port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
  858. check_file(dataset_file)
  859. if num_parallel_workers is not None:
  860. check_num_parallel_workers(num_parallel_workers)
  861. type_check(hostname, (str,), "hostname")
  862. if check_hostname(hostname) is False:
  863. raise ValueError("The hostname is illegal")
  864. type_check(working_mode, (str,), "working_mode")
  865. if working_mode not in {'local', 'client', 'server'}:
  866. raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'.")
  867. type_check(port, (int,), "port")
  868. check_value(port, (1024, 65535), "port")
  869. type_check(num_client, (int,), "num_client")
  870. check_value(num_client, (1, 255), "num_client")
  871. type_check(auto_shutdown, (bool,), "auto_shutdown")
  872. return method(self, *args, **kwargs)
  873. return new_method
  874. def check_gnn_get_all_nodes(method):
  875. """A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""
  876. @wraps(method)
  877. def new_method(self, *args, **kwargs):
  878. [node_type], _ = parse_user_args(method, *args, **kwargs)
  879. type_check(node_type, (int,), "node_type")
  880. return method(self, *args, **kwargs)
  881. return new_method
  882. def check_gnn_get_all_edges(method):
  883. """A wrapper that wraps a parameter checker around the GNN `get_all_edges` function."""
  884. @wraps(method)
  885. def new_method(self, *args, **kwargs):
  886. [edge_type], _ = parse_user_args(method, *args, **kwargs)
  887. type_check(edge_type, (int,), "edge_type")
  888. return method(self, *args, **kwargs)
  889. return new_method
  890. def check_gnn_get_nodes_from_edges(method):
  891. """A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function."""
  892. @wraps(method)
  893. def new_method(self, *args, **kwargs):
  894. [edge_list], _ = parse_user_args(method, *args, **kwargs)
  895. check_gnn_list_or_ndarray(edge_list, "edge_list")
  896. return method(self, *args, **kwargs)
  897. return new_method
  898. def check_gnn_get_edges_from_nodes(method):
  899. """A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function."""
  900. @wraps(method)
  901. def new_method(self, *args, **kwargs):
  902. [node_list], _ = parse_user_args(method, *args, **kwargs)
  903. check_gnn_list_of_pair_or_ndarray(node_list, "node_list")
  904. return method(self, *args, **kwargs)
  905. return new_method
  906. def check_gnn_get_all_neighbors(method):
  907. """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""
  908. @wraps(method)
  909. def new_method(self, *args, **kwargs):
  910. [node_list, neighbour_type, _], _ = parse_user_args(method, *args, **kwargs)
  911. check_gnn_list_or_ndarray(node_list, 'node_list')
  912. type_check(neighbour_type, (int,), "neighbour_type")
  913. return method(self, *args, **kwargs)
  914. return new_method
  915. def check_gnn_get_sampled_neighbors(method):
  916. """A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function."""
  917. @wraps(method)
  918. def new_method(self, *args, **kwargs):
  919. [node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs)
  920. check_gnn_list_or_ndarray(node_list, 'node_list')
  921. check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
  922. if not neighbor_nums or len(neighbor_nums) > 6:
  923. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
  924. 'neighbor_nums', len(neighbor_nums)))
  925. check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
  926. if not neighbor_types or len(neighbor_types) > 6:
  927. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
  928. 'neighbor_types', len(neighbor_types)))
  929. if len(neighbor_nums) != len(neighbor_types):
  930. raise ValueError(
  931. "The number of members of neighbor_nums and neighbor_types is inconsistent.")
  932. return method(self, *args, **kwargs)
  933. return new_method
  934. def check_gnn_get_neg_sampled_neighbors(method):
  935. """A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function."""
  936. @wraps(method)
  937. def new_method(self, *args, **kwargs):
  938. [node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs)
  939. check_gnn_list_or_ndarray(node_list, 'node_list')
  940. type_check(neg_neighbor_num, (int,), "neg_neighbor_num")
  941. type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
  942. return method(self, *args, **kwargs)
  943. return new_method
  944. def check_gnn_random_walk(method):
  945. """A wrapper that wraps a parameter checker around the GNN `random_walk` function."""
  946. @wraps(method)
  947. def new_method(self, *args, **kwargs):
  948. [target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args,
  949. **kwargs)
  950. check_gnn_list_or_ndarray(target_nodes, 'target_nodes')
  951. check_gnn_list_or_ndarray(meta_path, 'meta_path')
  952. type_check(step_home_param, (float,), "step_home_param")
  953. type_check(step_away_param, (float,), "step_away_param")
  954. type_check(default_node, (int,), "default_node")
  955. check_value(default_node, (-1, INT32_MAX), "default_node")
  956. return method(self, *args, **kwargs)
  957. return new_method
  958. def check_aligned_list(param, param_name, member_type):
  959. """Check whether the structure of each member of the list is the same."""
  960. type_check(param, (list,), "param")
  961. if not param:
  962. raise TypeError(
  963. "Parameter {0} or its members are empty".format(param_name))
  964. member_have_list = None
  965. list_len = None
  966. for member in param:
  967. if isinstance(member, list):
  968. check_aligned_list(member, param_name, member_type)
  969. if member_have_list not in (None, True):
  970. raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
  971. param_name))
  972. if list_len is not None and len(member) != list_len:
  973. raise TypeError("The size of each member of parameter {0} is inconsistent.".format(
  974. param_name))
  975. member_have_list = True
  976. list_len = len(member)
  977. else:
  978. type_check(member, (member_type,), param_name)
  979. if member_have_list not in (None, False):
  980. raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
  981. param_name))
  982. member_have_list = False
  983. def check_gnn_get_node_feature(method):
  984. """A wrapper that wraps a parameter checker around the GNN `get_node_feature` function."""
  985. @wraps(method)
  986. def new_method(self, *args, **kwargs):
  987. [node_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
  988. type_check(node_list, (list, np.ndarray), "node_list")
  989. if isinstance(node_list, list):
  990. check_aligned_list(node_list, 'node_list', int)
  991. elif isinstance(node_list, np.ndarray):
  992. if not node_list.dtype == np.int32:
  993. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  994. node_list, node_list.dtype))
  995. check_gnn_list_or_ndarray(feature_types, 'feature_types')
  996. return method(self, *args, **kwargs)
  997. return new_method
  998. def check_gnn_get_edge_feature(method):
  999. """A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function."""
  1000. @wraps(method)
  1001. def new_method(self, *args, **kwargs):
  1002. [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
  1003. type_check(edge_list, (list, np.ndarray), "edge_list")
  1004. if isinstance(edge_list, list):
  1005. check_aligned_list(edge_list, 'edge_list', int)
  1006. elif isinstance(edge_list, np.ndarray):
  1007. if not edge_list.dtype == np.int32:
  1008. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  1009. edge_list, edge_list.dtype))
  1010. check_gnn_list_or_ndarray(feature_types, 'feature_types')
  1011. return method(self, *args, **kwargs)
  1012. return new_method
  1013. def check_numpyslicesdataset(method):
  1014. """A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
  1015. @wraps(method)
  1016. def new_method(self, *args, **kwargs):
  1017. _, param_dict = parse_user_args(method, *args, **kwargs)
  1018. data = param_dict.get("data")
  1019. column_names = param_dict.get("column_names")
  1020. type_check(data, (list, tuple, dict, np.ndarray), "data")
  1021. if data is None or len(data) == 0: # pylint: disable=len-as-condition
  1022. raise ValueError("Argument data cannot be empty")
  1023. if isinstance(data, tuple):
  1024. type_check(data[0], (list, np.ndarray), "data[0]")
  1025. # check column_names
  1026. if column_names is not None:
  1027. check_columns(column_names, "column_names")
  1028. # check num of input column in column_names
  1029. column_num = 1 if isinstance(column_names, str) else len(column_names)
  1030. if isinstance(data, dict):
  1031. data_column = len(list(data.keys()))
  1032. if column_num != data_column:
  1033. raise ValueError("Num of input column names is {0}, but required is {1}."
  1034. .format(column_num, data_column))
  1035. elif isinstance(data, tuple):
  1036. if column_num != len(data):
  1037. raise ValueError("Num of input column names is {0}, but required is {1}."
  1038. .format(column_num, len(data)))
  1039. else:
  1040. if column_num != 1:
  1041. raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
  1042. .format(column_num, 1))
  1043. return method(self, *args, **kwargs)
  1044. return new_method
  1045. def check_paddeddataset(method):
  1046. """A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
  1047. @wraps(method)
  1048. def new_method(self, *args, **kwargs):
  1049. _, param_dict = parse_user_args(method, *args, **kwargs)
  1050. padded_samples = param_dict.get("padded_samples")
  1051. if not padded_samples:
  1052. raise ValueError("padded_samples cannot be empty.")
  1053. type_check(padded_samples, (list,), "padded_samples")
  1054. type_check(padded_samples[0], (dict,), "padded_element")
  1055. return method(self, *args, **kwargs)
  1056. return new_method
  1057. def check_cache_option(cache):
  1058. """Sanity check for cache parameter"""
  1059. if cache is not None:
  1060. type_check(cache, (cache_client.DatasetCache,), "cache")
  1061. def check_to_device_send(method):
  1062. """Check the input arguments of send function for TransferDataset."""
  1063. @wraps(method)
  1064. def new_method(self, *args, **kwargs):
  1065. [num_epochs], _ = parse_user_args(method, *args, **kwargs)
  1066. if num_epochs is not None:
  1067. type_check(num_epochs, (int,), "num_epochs")
  1068. check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
  1069. return method(self, *args, **kwargs)
  1070. return new_method
  1071. def check_emnist_dataset(method):
  1072. """A wrapper that wraps a parameter checker emnist dataset"""
  1073. @wraps(method)
  1074. def new_method(self, *args, **kwargs):
  1075. _, param_dict = parse_user_args(method, *args, **kwargs)
  1076. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1077. nreq_param_bool = ['shuffle']
  1078. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1079. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1080. dataset_dir = param_dict.get('dataset_dir')
  1081. check_dir(dataset_dir)
  1082. name = param_dict.get('name')
  1083. check_valid_str(name, ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"], "name")
  1084. usage = param_dict.get('usage')
  1085. if usage is not None:
  1086. check_valid_str(usage, ["train", "test", "all"], "usage")
  1087. check_sampler_shuffle_shard_options(param_dict)
  1088. cache = param_dict.get('cache')
  1089. check_cache_option(cache)
  1090. return method(self, *args, **kwargs)
  1091. return new_method
  1092. def check_flickr_dataset(method):
  1093. """A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k)."""
  1094. @wraps(method)
  1095. def new_method(self, *args, **kwargs):
  1096. _, param_dict = parse_user_args(method, *args, **kwargs)
  1097. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1098. nreq_param_bool = ['shuffle', 'decode']
  1099. dataset_dir = param_dict.get('dataset_dir')
  1100. annotation_file = param_dict.get('annotation_file')
  1101. check_dir(dataset_dir)
  1102. check_file(annotation_file)
  1103. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1104. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1105. check_sampler_shuffle_shard_options(param_dict)
  1106. cache = param_dict.get('cache')
  1107. check_cache_option(cache)
  1108. return method(self, *args, **kwargs)
  1109. return new_method
  1110. def check_sb_dataset(method):
  1111. """A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""
  1112. @wraps(method)
  1113. def new_method(self, *args, **kwargs):
  1114. _, param_dict = parse_user_args(method, *args, **kwargs)
  1115. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1116. nreq_param_bool = ['shuffle', 'decode']
  1117. dataset_dir = param_dict.get('dataset_dir')
  1118. check_dir(dataset_dir)
  1119. usage = param_dict.get('usage')
  1120. if usage is not None:
  1121. check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage")
  1122. task = param_dict.get('task')
  1123. if task is not None:
  1124. check_valid_str(task, ["Boundaries", "Segmentation"], "task")
  1125. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1126. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1127. check_sampler_shuffle_shard_options(param_dict)
  1128. return method(self, *args, **kwargs)
  1129. return new_method
  1130. def check_speech_commands_dataset(method):
  1131. """A wrapper that wraps a parameter checker around the original Dataset(SpeechCommandsDataset)."""
  1132. @wraps(method)
  1133. def new_method(self, *args, **kwargs):
  1134. _, param_dict = parse_user_args(method, *args, **kwargs)
  1135. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1136. nreq_param_bool = ['shuffle']
  1137. dataset_dir = param_dict.get('dataset_dir')
  1138. check_dir(dataset_dir)
  1139. usage = param_dict.get('usage')
  1140. if usage is not None:
  1141. check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
  1142. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1143. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1144. check_sampler_shuffle_shard_options(param_dict)
  1145. cache = param_dict.get('cache')
  1146. check_cache_option(cache)
  1147. return method(self, *args, **kwargs)
  1148. return new_method
  1149. def check_cityscapes_dataset(method):
  1150. """A wrapper that wraps a parameter checker around the original CityScapesDataset."""
  1151. @wraps(method)
  1152. def new_method(self, *args, **kwargs):
  1153. _, param_dict = parse_user_args(method, *args, **kwargs)
  1154. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1155. nreq_param_bool = ['shuffle', 'decode']
  1156. dataset_dir = param_dict.get('dataset_dir')
  1157. check_dir(dataset_dir)
  1158. task = param_dict.get('task')
  1159. check_valid_str(task, ["instance", "semantic", "polygon", "color"], "task")
  1160. quality_mode = param_dict.get('quality_mode')
  1161. check_valid_str(quality_mode, ["fine", "coarse"], "quality_mode")
  1162. usage = param_dict.get('usage')
  1163. if quality_mode == "fine":
  1164. valid_strings = ["train", "test", "val", "all"]
  1165. else:
  1166. valid_strings = ["train", "train_extra", "val", "all"]
  1167. check_valid_str(usage, valid_strings, "usage")
  1168. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1169. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1170. check_sampler_shuffle_shard_options(param_dict)
  1171. return method(self, *args, **kwargs)
  1172. return new_method
  1173. def check_div2k_dataset(method):
  1174. """A wrapper that wraps a parameter checker around the original DIV2KDataset."""
  1175. @wraps(method)
  1176. def new_method(self, *args, **kwargs):
  1177. _, param_dict = parse_user_args(method, *args, **kwargs)
  1178. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1179. nreq_param_bool = ['shuffle', 'decode']
  1180. dataset_dir = param_dict.get('dataset_dir')
  1181. check_dir(dataset_dir)
  1182. usage = param_dict.get('usage')
  1183. check_valid_str(usage, ['train', 'valid', 'all'], "usage")
  1184. downgrade = param_dict.get('downgrade')
  1185. check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade')
  1186. validate_dataset_param_value(['scale'], param_dict, int)
  1187. scale = param_dict.get('scale')
  1188. scale_values = [2, 3, 4, 8]
  1189. if scale not in scale_values:
  1190. raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values)))
  1191. if scale == 8 and downgrade != "bicubic":
  1192. raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.")
  1193. downgrade_2018 = ["mild", "difficult", "wild"]
  1194. if downgrade in downgrade_2018 and scale != 4:
  1195. raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade))
  1196. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1197. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1198. check_sampler_shuffle_shard_options(param_dict)
  1199. return method(self, *args, **kwargs)
  1200. return new_method
  1201. def check_fake_image_dataset(method):
  1202. """A wrapper that wraps a parameter checker around the original Dataset(FakeImageDataset)."""
  1203. @wraps(method)
  1204. def new_method(self, *args, **kwargs):
  1205. _, param_dict = parse_user_args(method, *args, **kwargs)
  1206. nreq_param_int = ['num_images', 'num_classes', 'base_seed', 'num_samples',
  1207. 'num_parallel_workers', 'num_shards', 'shard_id']
  1208. nreq_param_bool = ['shuffle']
  1209. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1210. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1211. num_images = param_dict.get("num_images")
  1212. check_pos_int32(num_images, "num_images")
  1213. image_size = param_dict.get("image_size")
  1214. type_check(image_size, (list, tuple), "image_size")
  1215. if len(image_size) != 3:
  1216. raise ValueError("image_size should be a list or tuple of length 3, but got {0}".format(len(image_size)))
  1217. for i, value in enumerate(image_size):
  1218. check_pos_int32(value, "image_size[{0}]".format(i))
  1219. num_classes = param_dict.get("num_classes")
  1220. check_pos_int32(num_classes, "num_classes")
  1221. check_sampler_shuffle_shard_options(param_dict)
  1222. cache = param_dict.get('cache')
  1223. check_cache_option(cache)
  1224. return method(self, *args, **kwargs)
  1225. return new_method
  1226. def check_ag_news_dataset(method):
  1227. """A wrapper that wraps a parameter checker around the original Dataset(AGNewsDataset)."""
  1228. @wraps(method)
  1229. def new_method(self, *args, **kwargs):
  1230. _, param_dict = parse_user_args(method, *args, **kwargs)
  1231. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1232. # check dataset_files; required argument
  1233. dataset_dir = param_dict.get('dataset_dir')
  1234. check_dir(dataset_dir)
  1235. # check usage
  1236. usage = param_dict.get('usage')
  1237. if usage is not None:
  1238. check_valid_str(usage, ["train", "test", "all"], "usage")
  1239. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1240. check_sampler_shuffle_shard_options(param_dict)
  1241. cache = param_dict.get('cache')
  1242. check_cache_option(cache)
  1243. return method(self, *args, **kwargs)
  1244. return new_method
  1245. def check_dbpedia_dataset(method):
  1246. """A wrapper that wraps a parameter checker around the original DBpediaDataset."""
  1247. @wraps(method)
  1248. def new_method(self, *args, **kwargs):
  1249. _, param_dict = parse_user_args(method, *args, **kwargs)
  1250. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1251. dataset_dir = param_dict.get('dataset_dir')
  1252. check_dir(dataset_dir)
  1253. usage = param_dict.get('usage')
  1254. if usage is not None:
  1255. check_valid_str(usage, ["train", "test", "all"], "usage")
  1256. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1257. check_sampler_shuffle_shard_options(param_dict)
  1258. cache = param_dict.get('cache')
  1259. check_cache_option(cache)
  1260. return method(self, *args, **kwargs)
  1261. return new_method
  1262. def check_yes_no_dataset(method):
  1263. """A wrapper that wraps a parameter checker around the original Dataset(YesNoDataset)."""
  1264. @wraps(method)
  1265. def new_method(self, *args, **kwargs):
  1266. _, param_dict = parse_user_args(method, *args, **kwargs)
  1267. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  1268. nreq_param_bool = ['shuffle']
  1269. dataset_dir = param_dict.get('dataset_dir')
  1270. check_dir(dataset_dir)
  1271. validate_dataset_param_value(nreq_param_int, param_dict, int)
  1272. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  1273. check_sampler_shuffle_shard_options(param_dict)
  1274. cache = param_dict.get('cache')
  1275. check_cache_option(cache)
  1276. return method(self, *args, **kwargs)
  1277. return new_method