Browse Source

boilerplate code for future IR optimizer

add 2 test cases to IRNode Deepcopy()

address review cmts

fix ut

samplerObj copy

ci

fix ci

fix ci round III

address further review cmts

add a missing macro

fix merge conflict

fix complie err

fix lite compile err

fix compile err

fix lite compile round III

address an issue

fix minor comments
tags/v1.1.0
Nat Sutyanyong Zirui Wu 5 years ago
parent
commit
5e1bb0b697
86 changed files with 1922 additions and 602 deletions
  1. +4
    -4
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +10
    -4
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  3. +2
    -4
      mindspore/ccsrc/minddata/dataset/api/transforms.cc
  4. +44
    -19
      mindspore/ccsrc/minddata/dataset/api/vision.cc
  5. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt
  6. +17
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc
  7. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h
  8. +11
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc
  9. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h
  10. +24
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc
  11. +24
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h
  12. +23
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc
  13. +24
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h
  14. +23
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  15. +24
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h
  16. +93
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  17. +202
    -5
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  18. +67
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc
  19. +63
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h
  20. +23
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc
  21. +24
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h
  22. +24
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc
  23. +29
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h
  24. +8
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc
  25. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h
  26. +10
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc
  27. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h
  28. +20
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  29. +24
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  30. +85
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc
  31. +78
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h
  32. +11
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc
  33. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h
  34. +7
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc
  35. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h
  36. +11
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc
  37. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h
  38. +11
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc
  39. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h
  40. +11
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc
  41. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h
  42. +11
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc
  43. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h
  44. +12
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc
  45. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h
  46. +9
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc
  47. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h
  48. +12
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc
  49. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h
  50. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc
  51. +13
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h
  52. +14
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  53. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h
  54. +18
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc
  55. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h
  56. +17
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc
  57. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h
  58. +9
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc
  59. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h
  60. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc
  61. +16
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h
  62. +11
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc
  63. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h
  64. +17
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc
  65. +15
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h
  66. +9
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc
  67. +13
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h
  68. +10
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc
  69. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h
  70. +7
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc
  71. +12
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h
  72. +25
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc
  73. +26
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h
  74. +28
    -10
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc
  75. +23
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h
  76. +41
    -265
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
  77. +55
    -210
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.h
  78. +14
    -4
      mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc
  79. +0
    -4
      mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h
  80. +30
    -2
      mindspore/ccsrc/minddata/dataset/include/samplers.h
  81. +11
    -1
      mindspore/ccsrc/minddata/dataset/include/transforms.h
  82. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc
  83. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h
  84. +1
    -1
      mindspore/dataset/engine/samplers.py
  85. +1
    -0
      tests/ut/cpp/dataset/CMakeLists.txt
  86. +142
    -0
      tests/ut/cpp/dataset/c_api_dataset_ir_node_test.cc

+ 4
- 4
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -568,8 +568,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) { SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto vocab = std::make_shared<SentencePieceVocab>(); auto vocab = std::make_shared<SentencePieceVocab>();
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params);
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size,
character_coverage, model_type, params);


std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();
@@ -600,8 +600,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first) { const std::vector<std::string> &special_tokens, bool special_first) {
auto vocab = std::make_shared<Vocab>(); auto vocab = std::make_shared<Vocab>();
auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens,
special_first);


std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();


+ 10
- 4
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -190,13 +190,12 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
return sampler; return sampler;
} }


#ifndef ENABLE_ANDROID
// PreBuiltOperation // PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}


#ifndef ENABLE_ANDROID
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
: sp_minddataset_(std::move(sampler)) {}
#endif #endif


bool PreBuiltSamplerObj::ValidateParams() { return true; } bool PreBuiltSamplerObj::ValidateParams() { return true; }
@@ -207,6 +206,13 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif #endif


std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
#endif
return std::make_shared<PreBuiltSamplerObj>(sp_);
}

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object // runtime mindrecord sampler object


+ 2
- 4
mindspore/ccsrc/minddata/dataset/api/transforms.cc View File

@@ -30,8 +30,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


TensorOperation::TensorOperation() {}

/* ####################################### Validator Functions ############################################ */ /* ####################################### Validator Functions ############################################ */
Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) { Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) {
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) { if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
@@ -231,7 +229,7 @@ std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }


// RandomApplyOperation // RandomApplyOperation
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
: transforms_(transforms), prob_(prob) {}
: TensorOperation(true), transforms_(transforms), prob_(prob) {}


Status RandomApplyOperation::ValidateParams() { Status RandomApplyOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_)); RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_));
@@ -248,7 +246,7 @@ std::shared_ptr<TensorOp> RandomApplyOperation::Build() {


// RandomChoiceOperation // RandomChoiceOperation
RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms) RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
: transforms_(transforms) {}
: TensorOperation(true), transforms_(transforms) {}


Status RandomChoiceOperation::ValidateParams() { Status RandomChoiceOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_)); RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_));


+ 44
- 19
mindspore/ccsrc/minddata/dataset/api/vision.cc View File

@@ -734,7 +734,9 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> &degrees
scale_range_(scale_range), scale_range_(scale_range),
shear_ranges_(shear_ranges), shear_ranges_(shear_ranges),
interpolation_(interpolation), interpolation_(interpolation),
fill_value_(fill_value) {}
fill_value_(fill_value) {
random_op_ = true;
}


Status RandomAffineOperation::ValidateParams() { Status RandomAffineOperation::ValidateParams() {
// Degrees // Degrees
@@ -867,7 +869,7 @@ std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
} }


// RandomColorOperation. // RandomColorOperation.
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) { random_op_ = true; }


Status RandomColorOperation::ValidateParams() { Status RandomColorOperation::ValidateParams() {
// Do some input validation. // Do some input validation.
@@ -891,7 +893,9 @@ Status RandomColorOperation::ValidateParams() {
// RandomColorAdjustOperation. // RandomColorAdjustOperation.
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast, RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
std::vector<float> saturation, std::vector<float> hue) std::vector<float> saturation, std::vector<float> hue)
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {
random_op_ = true;
}


Status RandomColorAdjustOperation::ValidateParams() { Status RandomColorAdjustOperation::ValidateParams() {
// brightness // brightness
@@ -1012,11 +1016,14 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
// RandomCropOperation // RandomCropOperation
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed, RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
std::vector<uint8_t> fill_value, BorderType padding_mode) std::vector<uint8_t> fill_value, BorderType padding_mode)
: size_(size),
: TensorOperation(true),
size_(size),
padding_(padding), padding_(padding),
pad_if_needed_(pad_if_needed), pad_if_needed_(pad_if_needed),
fill_value_(fill_value), fill_value_(fill_value),
padding_mode_(padding_mode) {}
padding_mode_(padding_mode) {
random_op_ = true;
}


Status RandomCropOperation::ValidateParams() { Status RandomCropOperation::ValidateParams() {
// size // size
@@ -1083,7 +1090,12 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale, RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale,
std::vector<float> ratio, std::vector<float> ratio,
InterpolationMode interpolation, int32_t max_attempts) InterpolationMode interpolation, int32_t max_attempts)
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
: TensorOperation(true),
size_(size),
scale_(scale),
ratio_(ratio),
interpolation_(interpolation),
max_attempts_(max_attempts) {}


Status RandomCropDecodeResizeOperation::ValidateParams() { Status RandomCropDecodeResizeOperation::ValidateParams() {
// size // size
@@ -1176,7 +1188,8 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding, RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding,
bool pad_if_needed, std::vector<uint8_t> fill_value, bool pad_if_needed, std::vector<uint8_t> fill_value,
BorderType padding_mode) BorderType padding_mode)
: size_(size),
: TensorOperation(true),
size_(size),
padding_(padding), padding_(padding),
pad_if_needed_(pad_if_needed), pad_if_needed_(pad_if_needed),
fill_value_(fill_value), fill_value_(fill_value),
@@ -1245,7 +1258,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() {
} }


// RandomHorizontalFlipOperation // RandomHorizontalFlipOperation
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability)
: TensorOperation(true), probability_(probability) {}


Status RandomHorizontalFlipOperation::ValidateParams() { Status RandomHorizontalFlipOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_)); RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_));
@@ -1260,7 +1274,7 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {


// RandomHorizontalFlipWithBBoxOperation // RandomHorizontalFlipWithBBoxOperation
RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability) RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability)
: probability_(probability) {}
: TensorOperation(true), probability_(probability) {}


Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() { Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_)); RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_));
@@ -1275,7 +1289,8 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipWithBBoxOperation::Build() {
} }


// RandomPosterizeOperation // RandomPosterizeOperation
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range)
: TensorOperation(true), bit_range_(bit_range) {}


Status RandomPosterizeOperation::ValidateParams() { Status RandomPosterizeOperation::ValidateParams() {
if (bit_range_.size() != 2) { if (bit_range_.size() != 2) {
@@ -1309,7 +1324,7 @@ std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
} }


// RandomResizeOperation // RandomResizeOperation
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : size_(size) {}
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : TensorOperation(true), size_(size) {}


Status RandomResizeOperation::ValidateParams() { Status RandomResizeOperation::ValidateParams() {
// size // size
@@ -1343,7 +1358,8 @@ std::shared_ptr<TensorOp> RandomResizeOperation::Build() {
} }


// RandomResizeWithBBoxOperation // RandomResizeWithBBoxOperation
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) : size_(size) {}
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size)
: TensorOperation(true), size_(size) {}


Status RandomResizeWithBBoxOperation::ValidateParams() { Status RandomResizeWithBBoxOperation::ValidateParams() {
// size // size
@@ -1380,7 +1396,12 @@ std::shared_ptr<TensorOp> RandomResizeWithBBoxOperation::Build() {
RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale, RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale,
std::vector<float> ratio, InterpolationMode interpolation, std::vector<float> ratio, InterpolationMode interpolation,
int32_t max_attempts) int32_t max_attempts)
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
: TensorOperation(true),
size_(size),
scale_(scale),
ratio_(ratio),
interpolation_(interpolation),
max_attempts_(max_attempts) {}


Status RandomResizedCropOperation::ValidateParams() { Status RandomResizedCropOperation::ValidateParams() {
// size // size
@@ -1536,7 +1557,8 @@ std::shared_ptr<TensorOp> RandomResizedCropWithBBoxOperation::Build() {
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
bool expand, std::vector<float> center, bool expand, std::vector<float> center,
std::vector<uint8_t> fill_value) std::vector<uint8_t> fill_value)
: degrees_(degrees),
: TensorOperation(true),
degrees_(degrees),
interpolation_mode_(interpolation_mode), interpolation_mode_(interpolation_mode),
expand_(expand), expand_(expand),
center_(center), center_(center),
@@ -1603,7 +1625,7 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
// RandomSelectSubpolicyOperation. // RandomSelectSubpolicyOperation.
RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation(
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy) std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy)
: policy_(policy) {}
: TensorOperation(true), policy_(policy) {}


Status RandomSelectSubpolicyOperation::ValidateParams() { Status RandomSelectSubpolicyOperation::ValidateParams() {
if (policy_.empty()) { if (policy_.empty()) {
@@ -1650,7 +1672,8 @@ std::shared_ptr<TensorOp> RandomSelectSubpolicyOperation::Build() {
} }


// Function to create RandomSharpness. // Function to create RandomSharpness.
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees)
: TensorOperation(true), degrees_(degrees) {}


Status RandomSharpnessOperation::ValidateParams() { Status RandomSharpnessOperation::ValidateParams() {
if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) { if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) {
@@ -1674,7 +1697,8 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
} }


// RandomSolarizeOperation. // RandomSolarizeOperation.
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold)
: TensorOperation(true), threshold_(threshold) {}


Status RandomSolarizeOperation::ValidateParams() { Status RandomSolarizeOperation::ValidateParams() {
if (threshold_.size() != 2) { if (threshold_.size() != 2) {
@@ -1705,7 +1729,8 @@ std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
} }


// RandomVerticalFlipOperation // RandomVerticalFlipOperation
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability)
: TensorOperation(true), probability_(probability) {}


Status RandomVerticalFlipOperation::ValidateParams() { Status RandomVerticalFlipOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_)); RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_));
@@ -1720,7 +1745,7 @@ std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {


// RandomVerticalFlipWithBBoxOperation // RandomVerticalFlipWithBBoxOperation
RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability) RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability)
: probability_(probability) {}
: TensorOperation(true), probability_(probability) {}


Status RandomVerticalFlipWithBBoxOperation::ValidateParams() { Status RandomVerticalFlipWithBBoxOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_)); RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_));


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt View File

@@ -9,11 +9,13 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
build_sentence_piece_vocab_node.cc build_sentence_piece_vocab_node.cc
build_vocab_node.cc build_vocab_node.cc
concat_node.cc concat_node.cc
epoch_ctrl_node.cc
filter_node.cc filter_node.cc
map_node.cc map_node.cc
project_node.cc project_node.cc
rename_node.cc rename_node.cc
repeat_node.cc repeat_node.cc
root_node.cc
shuffle_node.cc shuffle_node.cc
skip_node.cc skip_node.cc
sync_wait_node.cc sync_wait_node.cc


+ 17
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc View File

@@ -43,14 +43,29 @@ BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, boo
batch_size_func_(batch_size_func), batch_size_func_(batch_size_func),
batch_map_func_(batch_map_func), batch_map_func_(batch_map_func),
pad_map_(pad_map) { pad_map_(pad_map) {
this->children.push_back(child);
this->AddChild(child);
} }
#endif #endif


// constructor #2, called by C++ API // constructor #2, called by C++ API
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder) BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder)
: batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) { : batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> BatchNode::Copy() {
#ifdef ENABLE_PYTHON
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_, pad_, in_col_names_, out_col_names_,
col_order_, batch_size_func_, batch_map_func_, pad_map_);
#else
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_);
#endif
return node;
}

void BatchNode::Print(std::ostream &out) const {
out << Name() + "(batch_size:" + std::to_string(batch_size_) +
" drop_remainder:" + (drop_remainder_ ? "true" : "false") + ")";
} }


Status BatchNode::ValidateParams() { Status BatchNode::ValidateParams() {


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h View File

@@ -44,6 +44,18 @@ class BatchNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~BatchNode() = default; ~BatchNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBatchNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 11
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc View File

@@ -41,7 +41,17 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
pad_info_(pad_info), pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary), pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder) { drop_remainder_(drop_remainder) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_);
return node;
}

void BucketBatchByLengthNode::Print(std::ostream &out) const {
out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)";
} }


std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h View File

@@ -40,6 +40,18 @@ class BucketBatchByLengthNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~BucketBatchByLengthNode() = default; ~BucketBatchByLengthNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBucketBatchByLengthNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 24
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc View File

@@ -22,6 +22,7 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h" #include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"


namespace mindspore { namespace mindspore {
@@ -38,7 +39,18 @@ BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> chil
character_coverage_(character_coverage), character_coverage_(character_coverage),
model_type_(model_type), model_type_(model_type),
params_(params) { params_(params) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> BuildSentenceVocabNode::Copy() {
auto node = std::make_shared<BuildSentenceVocabNode>(nullptr, vocab_, col_names_, vocab_size_, character_coverage_,
model_type_, params_);
return node;
}

void BuildSentenceVocabNode::Print(std::ostream &out) const {
out << Name() + "<vocab>," + "columns:" + PrintColumns(col_names_) + ",vocab_size:" + std::to_string(vocab_size_) +
",...)";
} }


// Function to build BuildSentenceVocabNode // Function to build BuildSentenceVocabNode
@@ -81,5 +93,16 @@ Status BuildSentenceVocabNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
}

// Visitor accepting method for NodePass
Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 24
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h View File

@@ -38,6 +38,18 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~BuildSentenceVocabNode() = default; ~BuildSentenceVocabNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBuildSentencePieceVocabNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -46,6 +58,18 @@ class BuildSentenceVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

private: private:
std::shared_ptr<SentencePieceVocab> vocab_; std::shared_ptr<SentencePieceVocab> vocab_;
std::vector<std::string> col_names_; std::vector<std::string> col_names_;


+ 23
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc View File

@@ -22,7 +22,7 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/build_vocab_op.h" #include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -36,7 +36,17 @@ BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_p
top_k_(top_k), top_k_(top_k),
special_tokens_(special_tokens), special_tokens_(special_tokens),
special_first_(special_first) { special_first_(special_first) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> BuildVocabNode::Copy() {
auto node =
std::make_shared<BuildVocabNode>(nullptr, vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_);
return node;
}

void BuildVocabNode::Print(std::ostream &out) const {
out << Name() + "(<vocab>," + "columns:" + PrintColumns(columns_) + ",...)";
} }


// Function to build BuildVocabNode // Function to build BuildVocabNode
@@ -78,5 +88,16 @@ Status BuildVocabNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


// Visitor accepting method for NodePass
Status BuildVocabNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<BuildVocabNode>(), modified);
}

// Visitor accepting method for NodePass
Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 24
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h View File

@@ -37,6 +37,18 @@ class BuildVocabNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~BuildVocabNode() = default; ~BuildVocabNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kBuildVocabNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -45,6 +57,18 @@ class BuildVocabNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

private: private:
std::shared_ptr<Vocab> vocab_; std::shared_ptr<Vocab> vocab_;
std::vector<std::string> columns_; std::vector<std::string> columns_;


+ 23
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -22,7 +22,7 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/concat_op.h" #include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -35,17 +35,25 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
: sampler_(sampler), : sampler_(sampler),
children_flag_and_nums_(children_flag_and_nums), children_flag_and_nums_(children_flag_and_nums),
children_start_end_index_(children_start_end_index) { children_start_end_index_(children_start_end_index) {
this->children = datasets;
for (auto const &child : datasets) AddChild(child);
}

std::shared_ptr<DatasetNode> ConcatNode::Copy() {
// create an empty vector to copy a concat
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>());
return node;
} }


void ConcatNode::Print(std::ostream &out) const { out << Name(); }

Status ConcatNode::ValidateParams() { Status ConcatNode::ValidateParams() {
if (children.size() < 2) {
if (children_.size() < 2) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified."; std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg; MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }


if (find(children.begin(), children.end(), nullptr) != children.end()) {
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null."; std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg; MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
@@ -73,5 +81,16 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
return node_ops; return node_ops;
} }


// Visitor accepting method for NodePass
Status ConcatNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ConcatNode>(), modified);
}

// Visitor accepting method for NodePass
Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 24
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h View File

@@ -38,6 +38,18 @@ class ConcatNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~ConcatNode() = default; ~ConcatNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kConcatNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -50,6 +62,18 @@ class ConcatNode : public DatasetNode {
std::shared_ptr<SamplerObj> sampler_; std::shared_ptr<SamplerObj> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_; std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_; std::vector<std::pair<int, int>> children_start_end_index_;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
}; };


} // namespace dataset } // namespace dataset


+ 93
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -233,14 +233,92 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this(); return shared_from_this();
} }


DatasetNode::DatasetNode() {
DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) {
// Fetch some default value from config manager // Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers(); num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer(); rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size(); connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size(); worker_connector_size_ = cfg->worker_connector_size();
build_status = Status::OK(); // remove me after changing return val of Build()
}

// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
std::shared_ptr<DatasetNode> new_node = this->Copy();
for (const auto &child : children_) {
new_node->AddChild(child->DeepCopy());
}
return new_node;
}

std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
std::string me;
if (columns.empty()) {
me = "<nil>";
} else {
me = "[";
auto i = 0;
for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
me += *it;
if (i < columns.size() - 1) {
me += ", ";
} else {
me += "]";
}
}
}
return me;
}

void DatasetNode::PrintTree(std::ostream &out) const {
int level = 0;
PrintNode(out, &level);
}

void DatasetNode::PrintNode(std::ostream &out, int *level) const {
const std::string prefix = "+-";
const std::string indent = " ";
out << prefix;
Print(out);
for (const auto &c : this->Children()) {
out << '\n';
++(*level);
for (auto i = 0; i < *level; i++) {
out << indent;
}
c->PrintNode(out, level);
--(*level);
}
}

// Add a node as a child, node's parent needs to be nullptr
// this function will allow child to be a nullptr, in which case it will simply skip
void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
if (child != nullptr && child->parent_ == nullptr) {
children_.push_back(child);
child->parent_ = this;
} else if (child != nullptr) {
MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr.";
}
}

// Remove this node from its parent. Add the child of this node to its parent.
// for now, this remove is limited to node with a single child or no child
Status DatasetNode::Remove() {
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent.");
CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child.");
if (children_.empty()) { // I am a leaf node, remove me from my parent's children list
parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()),
parent_->children_.end()); // removal using "erase remove idiom"
} else { // replace my position in my parent's children list with my single child
auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent
*itr = std::move(children_[0]); // replace me in my parent's children list with my single child
children_.clear(); // release my single child from my children list
}
parent_ = nullptr;
return Status::OK();
} }


// In DFS tree traversal, each node is visited twice. Accept is called on the first visit. // In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
@@ -255,13 +333,25 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one. // This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified); return p->VisitAfter(shared_from_this(), modified);
} }

Status DatasetNode::GetShardId(int32_t *shard_id) { Status DatasetNode::GetShardId(int32_t *shard_id) {
if (!Children().empty()) { if (!Children().empty()) {
// Get shard id from the child node // Get shard id from the child node
return Children()[0]->GetShardId(shard_id); return Children()[0]->GetShardId(shard_id);
} else { } else {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
} }
} }
// Visitor accepting method for NodePass
Status SourceNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<SourceNode>(), modified);
}

// Visitor accepting method for NodePass
Status SourceNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<SourceNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 202
- 5
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -42,6 +42,45 @@ class NodePass;
} \ } \
} while (false) } while (false)


// Names for non-leaf IR node
constexpr char kBatchNode[] = "Batch";
constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength";
constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab";
constexpr char kBuildVocabNode[] = "BuildVocab";
constexpr char kConcatNode[] = "Concat";
constexpr char kDatasetNode[] = "Dataset";
constexpr char kEpochCtrlNode[] = "EpochCtrl";
constexpr char kFilterNode[] = "Filter";
constexpr char kMapNode[] = "Map";
constexpr char kProjectNode[] = "Project";
constexpr char kRenameNode[] = "Rename";
constexpr char kRepeatNode[] = "Repeat";
constexpr char kRootNode[] = "Top";
constexpr char kShuffleNode[] = "Shuffle";
constexpr char kSkipNode[] = "Skip";
constexpr char kSyncWaitNode[] = "SyncWait";
constexpr char kTakeNode[] = "Take";
constexpr char kTransferNode[] = "Transfer";
constexpr char kZipNode[] = "Zip";

// Names for leaf IR node
constexpr char kAlbumNode[] = "AlbumDataset";
constexpr char kCelebANode[] = "CelebADataset";
constexpr char kCifar100Node[] = "Cifar100Dataset";
constexpr char kCifar10Node[] = "Cifar10Dataset";
constexpr char kCLUENode[] = "CLUEDataset";
constexpr char kCocoNode[] = "CocoDataset";
constexpr char kCSVNode[] = "CSVDataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";
constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";
constexpr char kRandomNode[] = "RandomDataset";
constexpr char kTextFileNode[] = "TextFileDataset";
constexpr char kTFRecordNode[] = "TFRecordDataset";
constexpr char kVOCNode[] = "VOCDataset";

Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op); int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op);


@@ -75,6 +114,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data
/// \return Shared pointer to the current Sampler. /// \return Shared pointer to the current Sampler.
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id); std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id);


// The base class of all IR nodes
class DatasetNode : public std::enable_shared_from_this<DatasetNode> { class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
public: public:
/// \brief Constructor /// \brief Constructor
@@ -87,6 +127,36 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Destructor /// \brief Destructor
~DatasetNode() = default; ~DatasetNode() = default;


/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;

/// \brief Pure virtual function to print the description
/// \param out - The output stream to write output to
virtual void Print(std::ostream &out) const = 0;

/// \brief Pure virtual function to make a new copy of the node
/// \return The new copy of the node
virtual std::shared_ptr<DatasetNode> Copy() = 0;

/// \brief Print the IR tree to output stream
/// \param out - The output stream to write output to
void PrintTree(std::ostream &out) const;

/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out - reference to the output stream being overloaded
/// \param dO - reference to the DatasetOp to display
/// \return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) {
node.PrintTree(out);
return out;
}

/// \brief Make a new copy of the tree from the current node
/// \return The new copy of the tree
std::shared_ptr<DatasetNode> DeepCopy();

/// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0; virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
@@ -95,17 +165,38 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
virtual Status ValidateParams() = 0; virtual Status ValidateParams() = 0;


const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; }

/// \brief Pure virtual function for derived class to get the shard id of specific node /// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully /// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id); virtual Status GetShardId(int32_t *shard_id);


/// \brief Getter function for child nodes
/// \return Child nodes
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }

/// \brief Establish the parent-child relationship between this node and its child.
void AddChild(std::shared_ptr<DatasetNode> child);

/// \brief detach this node from its parent, add its child (if any) to its parent
/// \return error code, return error if node has more than 1 children
Status Remove();

/// \brief Check if this node has cache
/// \return True if the data of this node will be cached
const bool IsCached() const { return (cache_ != nullptr); }

/// \brief Setter function for runtime number of workers /// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator /// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object /// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);


/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
/// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
/// \return A shared_ptr casted to the derived class
template <typename Derived>
std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this());
}

/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
/// visit on the way back up the tree after its descendants are visited. /// visit on the way back up the tree after its descendants are visited.
@@ -129,17 +220,123 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
Status BuildStatus() { return build_status; } Status BuildStatus() { return build_status; }


protected: protected:
std::vector<std::shared_ptr<DatasetNode>> children;
std::vector<std::shared_ptr<DatasetNode>> children_;
DatasetNode *parent_;
std::shared_ptr<DatasetCache> cache_; std::shared_ptr<DatasetCache> cache_;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);

int32_t num_workers_; int32_t num_workers_;
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
int32_t connector_que_size_; int32_t connector_que_size_;
int32_t worker_connector_size_; int32_t worker_connector_size_;
Status build_status; // remove me after changing return val of Build() Status build_status; // remove me after changing return val of Build()
std::string PrintColumns(const std::vector<std::string> &columns) const;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
void PrintNode(std::ostream &out, int *level) const;
}; };


// SourceNode represents the leaf nodes of a pipeline where the data is pulled into.
class SourceNode : public DatasetNode {
public:
/// \brief Constructor
SourceNode() : DatasetNode() {}

/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {}

/// \brief Destructor
~SourceNode() = default;

/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
/// \return True if the dataset represented by this node is a mappable dataset
const bool IsMappable() const { return mappable_; }

protected:
bool mappable_;
};

// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
class MappableSourceNode : public SourceNode {
public:
/// \brief Constructor
MappableSourceNode() : SourceNode() { mappable_ = true; }

/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
mappable_ = true;
}

/// \brief Destructor
~MappableSourceNode() = default;

/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};

// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
class NonMappableSourceNode : public SourceNode {
public:
/// \brief Constructor
NonMappableSourceNode() : SourceNode() { mappable_ = false; }

/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
mappable_ = false;
}

/// \brief Destructor
~NonMappableSourceNode() = default;

/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};

// NonLeafNode represents operations over data in a pipeline.
class NonLeafNode : public DatasetNode {
public:
/// \brief Constructor
NonLeafNode() = default;

/// \brief Destructor
~NonLeafNode() = default;

/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};

// SinkNode represents the end node of a pipeline where the data is pushed out
class SinkNode : public DatasetNode {
public:
/// \brief Constructor
SinkNode() = default;

/// \brief Destructor
~SinkNode() = default;

/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
};
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

+ 67
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc View File

@@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {

// Constructor for EpochCtrlNode
EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) {
// The root node's parent must set to null pointer.
this->AddChild(child);
}
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_);
return node;
}

void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; }

// Function to build the EpochCtrlOp
std::vector<std::shared_ptr<DatasetOp>> EpochCtrlNode::Build() {
// A dummy vector
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<EpochCtrlOp>(num_epochs_));
return node_ops;
}

// Function to validate the parameters for EpochCtrlNode
Status EpochCtrlNode::ValidateParams() {
if (num_epochs_ <= 0 && num_epochs_ != -1) {
std::string err_msg =
"EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (children_.size() != 1 || children_[0] == nullptr) {
std::string err_msg = "Internal error: epoch control node should have one child node";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 63
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h View File

@@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

class EpochCtrlNode : public DatasetNode {
public:
/// \brief Constructor
explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);

/// \brief Destructor
~EpochCtrlNode() = default;

/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kEpochCtrlNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
int32_t num_epochs_;
};

} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_

+ 23
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc View File

@@ -21,7 +21,7 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/filter_op.h" #include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"


namespace mindspore { namespace mindspore {
@@ -31,7 +31,16 @@ namespace dataset {
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
std::vector<std::string> input_columns) std::vector<std::string> input_columns)
: predicate_(predicate), input_columns_(input_columns) { : predicate_(predicate), input_columns_(input_columns) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> FilterNode::Copy() {
auto node = std::make_shared<FilterNode>(nullptr, predicate_, input_columns_);
return node;
}

void FilterNode::Print(std::ostream &out) const {
out << Name() + "(<predicate>," + "input_cols:" + PrintColumns(input_columns_) + ")";
} }


std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() { std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
@@ -54,5 +63,17 @@ Status FilterNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


// Visitor accepting method for NodePass
Status FilterNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<FilterNode>(), modified);
}

// Visitor accepting method for NodePass
Status FilterNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<FilterNode>(), modified);
}

} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 24
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h View File

@@ -35,6 +35,18 @@ class FilterNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~FilterNode() = default; ~FilterNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kFilterNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -43,6 +55,18 @@ class FilterNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

private: private:
std::shared_ptr<TensorOp> predicate_; std::shared_ptr<TensorOp> predicate_;
std::vector<std::string> input_columns_; std::vector<std::string> input_columns_;


+ 24
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc View File

@@ -22,6 +22,7 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/map_op/map_op.h" #include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/include/transforms.h" #include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
@@ -37,7 +38,18 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
project_columns_(project_columns), project_columns_(project_columns),
DatasetNode(std::move(cache)), DatasetNode(std::move(cache)),
callbacks_(callbacks) { callbacks_(callbacks) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> MapNode::Copy() {
auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_,
callbacks_);
return node;
}

void MapNode::Print(std::ostream &out) const {
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
",<project_cols>" + ",...)";
} }


std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
@@ -93,5 +105,16 @@ Status MapNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


// Visitor accepting method for NodePass
Status MapNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<MapNode>(), modified);
}

// Visitor accepting method for NodePass
Status MapNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<MapNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 29
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h View File

@@ -37,6 +37,18 @@ class MapNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~MapNode() = default; ~MapNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kMapNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -45,6 +57,23 @@ class MapNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Getter of tensor operations
/// \return Vector of operations the Map node will process
const auto &TensorOperations() const { return operations_; }
auto &TensorOperations() { return operations_; }

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

private: private:
std::vector<std::shared_ptr<TensorOperation>> operations_; std::vector<std::shared_ptr<TensorOperation>> operations_;
std::vector<std::string> input_columns_; std::vector<std::string> input_columns_;


+ 8
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc View File

@@ -29,9 +29,16 @@ namespace dataset {
// Function to build ProjectOp // Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
: columns_(columns) { : columns_(columns) {
this->children.push_back(child);
this->AddChild(child);
} }


std::shared_ptr<DatasetNode> ProjectNode::Copy() {
auto node = std::make_shared<ProjectNode>(nullptr, this->columns_);
return node;
}

void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; }

Status ProjectNode::ValidateParams() { Status ProjectNode::ValidateParams() {
if (columns_.empty()) { if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified."; std::string err_msg = "ProjectNode: No columns are specified.";


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h View File

@@ -34,6 +34,18 @@ class ProjectNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~ProjectNode() = default; ~ProjectNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kProjectNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 10
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc View File

@@ -30,7 +30,16 @@ namespace dataset {
RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns) const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) { : input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> RenameNode::Copy() {
auto node = std::make_shared<RenameNode>(nullptr, input_columns_, output_columns_);
return node;
}

void RenameNode::Print(std::ostream &out) const {
out << Name() + "(input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + ")";
} }


Status RenameNode::ValidateParams() { Status RenameNode::ValidateParams() {


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h View File

@@ -35,6 +35,18 @@ class RenameNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~RenameNode() = default; ~RenameNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRenameNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 20
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -21,15 +21,22 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> RepeatNode::Copy() {
auto node = std::make_shared<RepeatNode>(nullptr, this->repeat_count_);
return node;
} }


void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; }

std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
@@ -49,5 +56,16 @@ Status RepeatNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


// Visitor accepting method for NodePass
Status RepeatNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RepeatNode>(), modified);
}

// Visitor accepting method for NodePass
Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RepeatNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 24
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -36,6 +36,18 @@ class RepeatNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~RepeatNode() = default; ~RepeatNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRepeatNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -44,6 +56,18 @@ class RepeatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

private: private:
int32_t repeat_count_; int32_t repeat_count_;
}; };


+ 85
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc View File

@@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/root_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {

// Constructor for RootNode
RootNode::RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) {
// The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode)
AddChild(child);
}

std::shared_ptr<DatasetNode> RootNode::Copy() {
auto node = std::make_shared<RootNode>(nullptr, num_epochs_);
return node;
}

void RootNode::Print(std::ostream &out) const { out << Name(); }

std::vector<std::shared_ptr<DatasetOp>> RootNode::Build() {
// root node doesn't build a runtime Op. this function should return Status::Error when called.
return {};
}

// Function to validate the parameters for RootNode
Status RootNode::ValidateParams() {
if (num_epochs_ <= 0 && num_epochs_ != -1) {
std::string err_msg =
"RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (parent_ != nullptr) {
std::string err_msg = "Internal error: root node should not have a parent";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (children_.size() != 1) {
std::string err_msg = "Internal error: root node should have one child node";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (children_[0] == nullptr) {
std::string err_msg = "Internal error: root node's child is a null pointer";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

// Visitor accepting method for NodePass
Status RootNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RootNode>(), modified);
}

// Visitor accepting method for NodePass
Status RootNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RootNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

+ 78
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h View File

@@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

class RootNode : public DatasetNode {
public:
/// \brief Constructor
RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);

/// \brief Destructor
~RootNode() = default;

/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRootNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Getter of number of epochs
int32_t num_epochs() { return num_epochs_; }

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

private:
int32_t num_epochs_;
};

} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_

+ 11
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc View File

@@ -29,7 +29,17 @@ namespace dataset {
// Constructor for ShuffleNode // Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> ShuffleNode::Copy() {
auto node = std::make_shared<ShuffleNode>(nullptr, shuffle_size_, reset_every_epoch_);
return node;
}

void ShuffleNode::Print(std::ostream &out) const {
out << Name() + "(shuffle_size:" + std::to_string(shuffle_size_) +
",reset_every_epoch:" + (reset_every_epoch_ ? "true" : "false") + ")";
} }


// Function to build the ShuffleOp // Function to build the ShuffleOp


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h View File

@@ -34,6 +34,18 @@ class ShuffleNode : public DatasetNode {


~ShuffleNode() = default; ~ShuffleNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kShuffleNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


Status ValidateParams() override; Status ValidateParams() override;


+ 7
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc View File

@@ -27,10 +27,15 @@ namespace mindspore {
namespace dataset { namespace dataset {


// Constructor for SkipNode // Constructor for SkipNode
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) {
this->children.push_back(child);
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { this->AddChild(child); }

std::shared_ptr<DatasetNode> SkipNode::Copy() {
auto node = std::make_shared<SkipNode>(nullptr, skip_count_);
return node;
} }


void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + std::to_string(skip_count_) + ")"; }

// Function to build the SkipOp // Function to build the SkipOp
std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() { std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h View File

@@ -34,6 +34,18 @@ class SkipNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~SkipNode() = default; ~SkipNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kSkipNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 11
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc View File

@@ -32,13 +32,23 @@ namespace dataset {
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode, const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir), dataset_dir_(dataset_dir),
schema_path_(data_schema), schema_path_(data_schema),
column_names_(column_names), column_names_(column_names),
decode_(decode), decode_(decode),
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> AlbumNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
return node;
}

void AlbumNode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}

Status AlbumNode::ValidateParams() { Status AlbumNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_));




+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h View File

@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class AlbumNode : public DatasetNode {
class AlbumNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
AlbumNode(const std::string &dataset_dir, const std::string &data_schema, AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
@@ -36,6 +36,18 @@ class AlbumNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~AlbumNode() = default; ~AlbumNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kAlbumNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create a runtime dataset op object from this class /// \brief a base class override function to create a runtime dataset op object from this class
/// \return shared pointer to the newly created DatasetOp /// \return shared pointer to the newly created DatasetOp
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 11
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -31,13 +31,23 @@ namespace dataset {
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode, const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir), dataset_dir_(dataset_dir),
usage_(usage), usage_(usage),
sampler_(sampler), sampler_(sampler),
decode_(decode), decode_(decode),
extensions_(extensions) {} extensions_(extensions) {}


std::shared_ptr<DatasetNode> CelebANode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
return node;
}

void CelebANode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}

Status CelebANode::ValidateParams() { Status CelebANode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));




+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h View File

@@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class CelebANode : public DatasetNode {
class CelebANode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
@@ -37,6 +37,18 @@ class CelebANode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~CelebANode() = default; ~CelebANode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCelebANode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 11
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc View File

@@ -30,7 +30,17 @@ namespace dataset {
// Constructor for Cifar100Node // Constructor for Cifar100Node
Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
return node;
}

void Cifar100Node::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}


Status Cifar100Node::ValidateParams() { Status Cifar100Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));


+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h View File

@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class Cifar100Node : public DatasetNode {
class Cifar100Node : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
@@ -35,6 +35,18 @@ class Cifar100Node : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~Cifar100Node() = default; ~Cifar100Node() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCifar100Node; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 11
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc View File

@@ -30,7 +30,17 @@ namespace dataset {
// Constructor for Cifar10Node // Constructor for Cifar10Node
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache) std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
return node;
}

void Cifar10Node::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
}


Status Cifar10Node::ValidateParams() { Status Cifar10Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));


+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h View File

@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class Cifar10Node : public DatasetNode {
class Cifar10Node : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
@@ -35,6 +35,18 @@ class Cifar10Node : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~Cifar10Node() = default; ~Cifar10Node() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCifar10Node; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 12
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc View File

@@ -32,7 +32,7 @@ namespace dataset {
// Constructor for CLUENode // Constructor for CLUENode
CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(clue_files), dataset_files_(clue_files),
task_(task), task_(task),
usage_(usage), usage_(usage),
@@ -41,6 +41,17 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task,
num_shards_(num_shards), num_shards_(num_shards),
shard_id_(shard_id) {} shard_id_(shard_id) {}


std::shared_ptr<DatasetNode> CLUENode::Copy() {
auto node =
std::make_shared<CLUENode>(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}

void CLUENode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." +
",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")";
}

Status CLUENode::ValidateParams() { Status CLUENode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));




+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h View File

@@ -28,7 +28,7 @@ namespace dataset {


/// \class CLUENode /// \class CLUENode
/// \brief A Dataset derived class to represent CLUE dataset /// \brief A Dataset derived class to represent CLUE dataset
class CLUENode : public DatasetNode {
class CLUENode : public NonMappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
@@ -37,6 +37,18 @@ class CLUENode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~CLUENode() = default; ~CLUENode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCLUENode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 9
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc View File

@@ -30,13 +30,21 @@ namespace dataset {
// Constructor for CocoNode // Constructor for CocoNode
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir), dataset_dir_(dataset_dir),
annotation_file_(annotation_file), annotation_file_(annotation_file),
task_(task), task_(task),
decode_(decode), decode_(decode),
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> CocoNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
return node;
}

void CocoNode::Print(std::ostream &out) const { out << Name(); }

Status CocoNode::ValidateParams() { Status CocoNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_));




+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h View File

@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class CocoNode : public DatasetNode {
class CocoNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
@@ -35,6 +35,18 @@ class CocoNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~CocoNode() = default; ~CocoNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCocoNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 12
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc View File

@@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(csv_files), dataset_files_(csv_files),
field_delim_(field_delim), field_delim_(field_delim),
column_defaults_(column_defaults), column_defaults_(column_defaults),
@@ -43,6 +43,17 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
num_shards_(num_shards), num_shards_(num_shards),
shard_id_(shard_id) {} shard_id_(shard_id) {}


std::shared_ptr<DatasetNode> CSVNode::Copy() {
auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_,
shuffle_, num_shards_, shard_id_, cache_);
return node;
}

void CSVNode::Print(std::ostream &out) const {
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." +
",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")";
}

Status CSVNode::ValidateParams() { Status CSVNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_));




+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h View File

@@ -47,7 +47,7 @@ class CsvRecord : public CsvBase {
T value; T value;
}; };


class CSVNode : public DatasetNode {
class CSVNode : public NonMappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
CSVNode(const std::vector<std::string> &dataset_files, char field_delim, CSVNode(const std::vector<std::string> &dataset_files, char field_delim,
@@ -58,6 +58,18 @@ class CSVNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~CSVNode() = default; ~CSVNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCSVNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc View File

@@ -28,7 +28,19 @@ namespace dataset {


GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types) const std::vector<DataType> &column_types)
: generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {}
: MappableSourceNode(),
generator_function_(generator_function),
column_names_(column_names),
column_types_(column_types) {}

std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_);
return node;
}

void GeneratorNode::Print(std::ostream &out) const {
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)";
}


GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
: generator_function_(generator_function), schema_(schema) {} : generator_function_(generator_function), schema_(schema) {}


+ 13
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h View File

@@ -26,10 +26,9 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

/// \class GeneratorNode /// \class GeneratorNode
/// \brief A Dataset derived class to represent GeneratorNode dataset /// \brief A Dataset derived class to represent GeneratorNode dataset
class GeneratorNode : public DatasetNode {
class GeneratorNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
@@ -41,6 +40,18 @@ class GeneratorNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~GeneratorNode() = default; ~GeneratorNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kGeneratorNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 14
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -33,13 +33,24 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
bool recursive, std::set<std::string> extensions, bool recursive, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing, std::map<std::string, int32_t> class_indexing,
std::shared_ptr<DatasetCache> cache = nullptr) std::shared_ptr<DatasetCache> cache = nullptr)
: dataset_dir_(dataset_dir),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
decode_(decode), decode_(decode),
sampler_(sampler), sampler_(sampler),
recursive_(recursive), recursive_(recursive),
class_indexing_(class_indexing), class_indexing_(class_indexing),
exts_(extensions),
DatasetNode(std::move(cache)) {}
exts_(extensions) {}

std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node =
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
return node;
}

void ImageFolderNode::Print(std::ostream &out) const {
out << Name() + "(path:" + dataset_dir_ + ",decode:" + (decode_ ? "true" : "false") + ",...)";
}


Status ImageFolderNode::ValidateParams() { Status ImageFolderNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));


+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h View File

@@ -31,7 +31,7 @@ namespace dataset {


/// \class ImageFolderNode /// \class ImageFolderNode
/// \brief A Dataset derived class to represent ImageFolder dataset /// \brief A Dataset derived class to represent ImageFolder dataset
class ImageFolderNode : public DatasetNode {
class ImageFolderNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,
@@ -41,6 +41,18 @@ class ImageFolderNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~ImageFolderNode() = default; ~ImageFolderNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kImageFolderNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 18
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc View File

@@ -32,13 +32,30 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode, const std::map<std::string, int32_t> &class_indexing, bool decode,
std::shared_ptr<DatasetCache> cache) std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_file_(dataset_file), dataset_file_(dataset_file),
usage_(usage), usage_(usage),
decode_(decode), decode_(decode),
class_index_(class_indexing), class_index_(class_indexing),
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> ManifestNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
return node;
}

void ManifestNode::Print(std::ostream &out) const {
out << Name() + "(file:" + dataset_file_;
if (sampler_ != nullptr) {
out << ",sampler";
}
if (cache_ != nullptr) {
out << ",cache";
}
out << ")";
}

Status ManifestNode::ValidateParams() { Status ManifestNode::ValidateParams() {
std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
for (char c : dataset_file_) { for (char c : dataset_file_) {


+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h View File

@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class ManifestNode : public DatasetNode {
class ManifestNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
@@ -36,6 +36,18 @@ class ManifestNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~ManifestNode() = default; ~ManifestNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kManifestNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 17
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc View File

@@ -30,7 +30,8 @@ namespace dataset {


MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
: dataset_file_(std::string()),
: MappableSourceNode(),
dataset_file_(std::string()),
dataset_files_(dataset_files), dataset_files_(dataset_files),
search_for_pattern_(false), search_for_pattern_(false),
columns_list_(columns_list), columns_list_(columns_list),
@@ -41,7 +42,8 @@ MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const


MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list, MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
: dataset_file_(dataset_file),
: MappableSourceNode(),
dataset_file_(dataset_file),
dataset_files_({}), dataset_files_({}),
search_for_pattern_(true), search_for_pattern_(true),
columns_list_(columns_list), columns_list_(columns_list),
@@ -50,6 +52,19 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st
sample_bytes_({}), sample_bytes_({}),
num_padded_(num_padded) {} num_padded_(num_padded) {}


std::shared_ptr<DatasetNode> MindDataNode::Copy() {
std::shared_ptr<MindDataNode> node;
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
if (dataset_files_.empty()) {
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
} else {
node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_);
}
return node;
}

void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; }

Status MindDataNode::ValidateParams() { Status MindDataNode::ValidateParams() {
if (!search_for_pattern_ && dataset_files_.size() > 4096) { if (!search_for_pattern_ && dataset_files_.size() > 4096) {
std::string err_msg = std::string err_msg =


+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h View File

@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class MindDataNode : public DatasetNode {
class MindDataNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
@@ -40,6 +40,18 @@ class MindDataNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~MindDataNode() = default; ~MindDataNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kMindDataNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 9
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc View File

@@ -29,7 +29,15 @@ namespace dataset {


MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache) std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

std::shared_ptr<DatasetNode> MnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
return node;
}

void MnistNode::Print(std::ostream &out) const { out << Name(); }


Status MnistNode::ValidateParams() { Status MnistNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));


+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h View File

@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class MnistNode : public DatasetNode {
class MnistNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
@@ -35,6 +35,18 @@ class MnistNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~MnistNode() = default; ~MnistNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kMnistNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc View File

@@ -27,6 +27,18 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


std::shared_ptr<DatasetNode> RandomNode::Copy() {
std::shared_ptr<RandomNode> node;
if (schema_ != nullptr) {
node = std::make_shared<RandomNode>(total_rows_, schema_, columns_list_, cache_);
} else {
node = std::make_shared<RandomNode>(total_rows_, schema_path_, columns_list_, cache_);
}
return node;
}

void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + std::to_string(total_rows_) + ",...)"; }

// ValidateParams for RandomNode // ValidateParams for RandomNode
Status RandomNode::ValidateParams() { Status RandomNode::ValidateParams() {
if (total_rows_ < 0) { if (total_rows_ < 0) {


+ 16
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h View File

@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class RandomNode : public DatasetNode {
class RandomNode : public NonMappableSourceNode {
public: public:
// Some constants to provide limits to random generation. // Some constants to provide limits to random generation.
static constexpr int32_t kMaxNumColumns = 4; static constexpr int32_t kMaxNumColumns = 4;
@@ -37,7 +37,7 @@ class RandomNode : public DatasetNode {
/// \brief Constructor /// \brief Constructor
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
std::shared_ptr<DatasetCache> cache) std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
total_rows_(total_rows), total_rows_(total_rows),
schema_path_(""), schema_path_(""),
schema_(std::move(schema)), schema_(std::move(schema)),
@@ -46,14 +46,27 @@ class RandomNode : public DatasetNode {
/// \brief Constructor /// \brief Constructor
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
std::shared_ptr<DatasetCache> cache) std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
total_rows_(total_rows), total_rows_(total_rows),
schema_path_(schema_path), schema_path_(schema_path),
schema_(nullptr),
columns_list_(columns_list) {} columns_list_(columns_list) {}


/// \brief Destructor /// \brief Destructor
~RandomNode() = default; ~RandomNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kRandomNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 11
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc View File

@@ -31,13 +31,23 @@ namespace dataset {
// Constructor for TextFileNode // Constructor for TextFileNode
TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(dataset_files), dataset_files_(dataset_files),
num_samples_(num_samples), num_samples_(num_samples),
shuffle_(shuffle), shuffle_(shuffle),
num_shards_(num_shards), num_shards_(num_shards),
shard_id_(shard_id) {} shard_id_(shard_id) {}


std::shared_ptr<DatasetNode> TextFileNode::Copy() {
auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}

void TextFileNode::Print(std::ostream &out) const {
out << Name() + "(file:..." + ",num_shards:" + std::to_string(num_shards_) +
",shard_id:" + std::to_string(shard_id_) + ",cache:" + ((cache_ != nullptr) ? "true" : "false") + ",...)";
}

Status TextFileNode::ValidateParams() { Status TextFileNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));




+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h View File

@@ -28,7 +28,7 @@ namespace dataset {


/// \class TextFileNode /// \class TextFileNode
/// \brief A Dataset derived class to represent TextFile dataset /// \brief A Dataset derived class to represent TextFile dataset
class TextFileNode : public DatasetNode {
class TextFileNode : public NonMappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
@@ -37,6 +37,18 @@ class TextFileNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~TextFileNode() = default; ~TextFileNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTextFileNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 17
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc View File

@@ -30,6 +30,23 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


std::shared_ptr<DatasetNode> TFRecordNode::Copy() {
std::shared_ptr<TFRecordNode> node;
if (schema_obj_ != nullptr) {
node = std::make_shared<TFRecordNode>(dataset_files_, schema_obj_, columns_list_, num_samples_, shuffle_,
num_shards_, shard_id_, shard_equal_rows_, cache_);
} else {
node = std::make_shared<TFRecordNode>(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_,
num_shards_, shard_id_, shard_equal_rows_, cache_);
}
return node;
}

void TFRecordNode::Print(std::ostream &out) const {
out << Name() + "(num_samples:" + std::to_string(num_samples_) + ",num_shards:" + std::to_string(num_shards_) +
",shard_id:" + std::to_string(shard_id_) + ",...)";
}

// Validator for TFRecordNode // Validator for TFRecordNode
Status TFRecordNode::ValidateParams() { Status TFRecordNode::ValidateParams() {
if (dataset_files_.empty()) { if (dataset_files_.empty()) {


+ 15
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h View File

@@ -29,14 +29,14 @@ namespace dataset {


/// \class TFRecordNode /// \class TFRecordNode
/// \brief A Dataset derived class to represent TFRecord dataset /// \brief A Dataset derived class to represent TFRecord dataset
class TFRecordNode : public DatasetNode {
class TFRecordNode : public NonMappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
/// \note Parameter 'schema' is the path to the schema file /// \note Parameter 'schema' is the path to the schema file
TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(dataset_files), dataset_files_(dataset_files),
schema_path_(schema), schema_path_(schema),
columns_list_(columns_list), columns_list_(columns_list),
@@ -51,7 +51,7 @@ class TFRecordNode : public DatasetNode {
TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: NonMappableSourceNode(std::move(cache)),
dataset_files_(dataset_files), dataset_files_(dataset_files),
schema_obj_(schema), schema_obj_(schema),
columns_list_(columns_list), columns_list_(columns_list),
@@ -64,6 +64,18 @@ class TFRecordNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~TFRecordNode() = default; ~TFRecordNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTFRecordNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 9
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc View File

@@ -32,7 +32,7 @@ namespace dataset {
VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache) std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir), dataset_dir_(dataset_dir),
task_(task), task_(task),
usage_(usage), usage_(usage),
@@ -40,6 +40,14 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
decode_(decode), decode_(decode),
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> VOCNode::Copy() {
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
return node;
}

void VOCNode::Print(std::ostream &out) const { out << Name(); }

Status VOCNode::ValidateParams() { Status VOCNode::ValidateParams() {
Path dir(dataset_dir_); Path dir(dataset_dir_);




+ 13
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h View File

@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class VOCNode : public DatasetNode {
class VOCNode : public MappableSourceNode {
public: public:
/// \brief Constructor /// \brief Constructor
VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
@@ -37,6 +37,18 @@ class VOCNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~VOCNode() = default; ~VOCNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kVOCNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 10
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc View File

@@ -29,7 +29,16 @@ namespace dataset {
// Constructor for SyncWaitNode // Constructor for SyncWaitNode
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback) SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback)
: condition_name_(condition_name), callback_(callback) { : condition_name_(condition_name), callback_(callback) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> SyncWaitNode::Copy() {
auto node = std::make_shared<SyncWaitNode>(nullptr, condition_name_, callback_);
return node;
}

void SyncWaitNode::Print(std::ostream &out) const {
out << Name() + "(cond_name:" + condition_name_ + "<pyfunc>" + ")";
} }


// Function to build the BarrierOp // Function to build the BarrierOp


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h View File

@@ -36,6 +36,18 @@ class SyncWaitNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~SyncWaitNode() = default; ~SyncWaitNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kSyncWaitNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 7
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc View File

@@ -27,10 +27,15 @@ namespace mindspore {
namespace dataset { namespace dataset {


// Constructor for TakeNode // Constructor for TakeNode
TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) {
this->children.push_back(child);
TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { this->AddChild(child); }

std::shared_ptr<DatasetNode> TakeNode::Copy() {
auto node = std::make_shared<TakeNode>(nullptr, take_count_);
return node;
} }


void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + std::to_string(take_count_) + ")"; }

// Function to build the TakeOp // Function to build the TakeOp
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create


+ 12
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h View File

@@ -34,6 +34,18 @@ class TakeNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~TakeNode() = default; ~TakeNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTakeNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;


+ 25
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc View File

@@ -22,6 +22,7 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"


#include "utils/ms_context.h" #include "utils/ms_context.h"
@@ -39,7 +40,19 @@ TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue
total_batch_(total_batch), total_batch_(total_batch),
create_data_info_queue_(create_data_info_queue), create_data_info_queue_(create_data_info_queue),
device_id_(0) { device_id_(0) {
this->children.push_back(child);
this->AddChild(child);
}

std::shared_ptr<DatasetNode> TransferNode::Copy() {
auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_,
create_data_info_queue_);
return node;
}

void TransferNode::Print(std::ostream &out) const {
out << Name() + "(prefetch_size:" + std::to_string(prefetch_size_) +
",send_epoch_end:" + (send_epoch_end_ ? "true" : "false") + ",total_batch:" + std::to_string(total_batch_) +
")";
} }


// Validator for TransferNode // Validator for TransferNode
@@ -94,5 +107,16 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
return node_ops; return node_ops;
} }


// Visitor accepting method for NodePass
Status TransferNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<TransferNode>(), modified);
}

// Visitor accepting method for NodePass
Status TransferNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<TransferNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 26
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h View File

@@ -35,6 +35,18 @@ class TransferNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~TransferNode() = default; ~TransferNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kTransferNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps /// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -43,6 +55,20 @@ class TransferNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;

private: private:
std::string queue_name_; std::string queue_name_;
int32_t device_id_; int32_t device_id_;


+ 28
- 10
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc View File

@@ -21,30 +21,36 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) : datasets_(datasets) {
for (auto dataset : datasets_) {
this->children.push_back(dataset);
}
ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) {
for (auto const &child : datasets) AddChild(child);
} }


std::shared_ptr<DatasetNode> ZipNode::Copy() {
std::vector<std::shared_ptr<DatasetNode>> empty_vector;
empty_vector.clear();
auto node = std::make_shared<ZipNode>(empty_vector);
return node;
}

void ZipNode::Print(std::ostream &out) const { out << Name(); }

Status ZipNode::ValidateParams() { Status ZipNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "ZipNode: datasets to zip are not specified.";
if (children_.size() < 2) {
std::string err_msg = "ZipNode: input datasets are not specified.";
MS_LOG(ERROR) << err_msg; MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }


if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "ZipNode: zip datasets should not be null.";
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
std::string err_msg = "ZipNode: input datasets should not be null.";
MS_LOG(ERROR) << err_msg; MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }

return Status::OK(); return Status::OK();
} }


@@ -56,5 +62,17 @@ std::vector<std::shared_ptr<DatasetOp>> ZipNode::Build() {
return node_ops; return node_ops;
} }


// Visitor accepting method for NodePass
Status ZipNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ZipNode>(), modified);
}

// Visitor accepting method for NodePass
Status ZipNode::AcceptAfter(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ZipNode>(), modified);
}

} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 23
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h View File

@@ -34,6 +34,18 @@ class ZipNode : public DatasetNode {
/// \brief Destructor /// \brief Destructor
~ZipNode() = default; ~ZipNode() = default;


/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kZipNode; }

/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;

/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;

/// \brief a base class override function to create the required runtime dataset op objects for this class /// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps /// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override; std::vector<std::shared_ptr<DatasetOp>> Build() override;
@@ -42,8 +54,17 @@ class ZipNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


private:
std::vector<std::shared_ptr<DatasetNode>> datasets_;
/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(NodePass *p, bool *modified) override;
}; };


} // namespace dataset } // namespace dataset


+ 41
- 265
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc View File

@@ -22,10 +22,12 @@
#endif #endif
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h" #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h" #include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h" #include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h" #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h" #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@@ -34,34 +36,6 @@
#include "minddata/dataset/engine/ir/datasetops/take_node.h" #include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h" #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h" #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#endif
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#endif
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"


////////////////////////////////// //////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
@@ -113,7 +87,12 @@ namespace mindspore {
namespace dataset { namespace dataset {


// Driver method for TreePass // Driver method for TreePass
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
if (root_ir == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
}
return this->RunOnTree(root_ir, modified);
}


// Driver method for NodePass // Driver method for NodePass
Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
@@ -132,15 +111,23 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {


// Helper function to perform DFS visit // Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
RETURN_IF_NOT_OK(node_ir->Accept(this, modified));
bool m = false;

RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
*modified |= m;
for (const auto &c : node_ir->Children()) { for (const auto &c : node_ir->Children()) {
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m));
*modified |= m;
} }
return node_ir->AcceptAfter(this, modified);
RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m));
*modified |= m;
return Status::OK();
} }


// Helper function to perform BFS visit // Helper function to perform BFS visit
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
bool m = false;

// Initialize bfs queue with root // Initialize bfs queue with root
std::queue<std::shared_ptr<DatasetNode>> bfsQueue; std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
bfsQueue.push(node_ir); bfsQueue.push(node_ir);
@@ -152,7 +139,8 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
bfsQueue.pop(); bfsQueue.pop();


// Run node pass // Run node pass
RETURN_IF_NOT_OK(curNode->Accept(this, modified));
RETURN_IF_NOT_OK(curNode->Accept(this, &m));
*modified |= m;


// Push children into bfs queue // Push children into bfs queue
for (const auto &c : curNode->Children()) { for (const auto &c : curNode->Children()) {
@@ -162,331 +150,119 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
return Status::OK(); return Status::OK();
} }


// For datasetops IR
// For non-leaf IR node
Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#endif

Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

#ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#endif

Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) { Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

// For datasetops/source IR
Status NodePass::Visit(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::Visit(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CelebANode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CLUENode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CocoNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<CSVNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
Status NodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif

Status NodePass::Visit(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<MnistNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::Visit(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

Status NodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *modified) {
// Fallback to base class visitor by default
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}

#ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#endif #endif

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status NodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#endif #endif


Status NodePass::Visit(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
// For leaf IR Node
Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }

Status NodePass::VisitAfter(std::shared_ptr<VOCNode> node, bool *modified) {
// Fallback to base class visitor by default
Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }




+ 55
- 210
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h View File

@@ -26,123 +26,87 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Non-leaf IR node
class BatchNode; class BatchNode;
class BucketBatchByLengthNode; class BucketBatchByLengthNode;
#ifndef ENABLE_ANDROID
class BuildSentenceVocabNode;
#endif
class BuildVocabNode; class BuildVocabNode;
class ConcatNode; class ConcatNode;
class FilterNode;
class MapNode; class MapNode;
class ProjectNode; class ProjectNode;
class RenameNode; class RenameNode;
class RepeatNode; class RepeatNode;
class RootNode;
class ShuffleNode; class ShuffleNode;
class SkipNode; class SkipNode;
#ifdef ENABLE_PYTHON
class SyncWaitNode;
#endif
class TakeNode; class TakeNode;
class TransferNode; class TransferNode;
class ZipNode; class ZipNode;
#ifdef ENABLE_PYTHON
class SyncWaitNode;
#endif
#ifndef ENABLE_ANDROID
class BuildSentenceVocabNode;
#endif
// Leaf IR node
class AlbumNode; class AlbumNode;
class CelebANode; class CelebANode;
class Cifar100Node; class Cifar100Node;
class Cifar10Node; class Cifar10Node;
#ifndef ENABLE_ANDROID
class CLUENode;
#endif
class CocoNode; class CocoNode;
#ifndef ENABLE_ANDROID
class CSVNode;
#endif
#ifdef ENABLE_PYTHON
class GeneratorNode;
#endif
class ImageFolderNode; class ImageFolderNode;
class ManifestNode; class ManifestNode;
#ifndef ENABLE_ANDROID
class MindDataNode;
#endif
class MnistNode; class MnistNode;
class RandomNode; class RandomNode;
#ifndef ENABLE_ANDROID
class TextFileNode;
class VOCNode;
#ifdef ENABLE_PYTHON
class GeneratorNode;
#endif #endif
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
class CLUENode;
class CSVNode;
class MindDataNode;
class TextFileNode;
class TFRecordNode; class TFRecordNode;
#endif #endif
class VOCNode;


////////////////////////////////// //////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
class BatchOp; class BatchOp;

class MapOp; class MapOp;

class ProjectOp; class ProjectOp;

class RenameOp; class RenameOp;

class SkipOp; class SkipOp;

class ShuffleOp; class ShuffleOp;

class AlbumOp; class AlbumOp;

class RandomDataOp; class RandomDataOp;

class RepeatOp; class RepeatOp;

class TakeOp; class TakeOp;

class ZipOp; class ZipOp;

class DeviceQueueOp; class DeviceQueueOp;

class ImageFolderOp; class ImageFolderOp;

class MnistOp; class MnistOp;

class ManifestOp; class ManifestOp;

class CifarOp; class CifarOp;

class VOCOp; class VOCOp;

class CocoOp; class CocoOp;

class CelebAOp; class CelebAOp;

class EpochCtrlOp; class EpochCtrlOp;

class BuildVocabOp; class BuildVocabOp;

class ConcatOp; class ConcatOp;

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
class MindRecordOp; class MindRecordOp;

class TFReaderOp; class TFReaderOp;

class CacheOp; class CacheOp;

class CacheMergeOp; class CacheMergeOp;

class CacheLookupOp; class CacheLookupOp;

class BuildSentencePieceVocabOp; class BuildSentencePieceVocabOp;

class ClueOp; class ClueOp;

class CsvOp; class CsvOp;

class TextFileOp; class TextFileOp;
#endif #endif

#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
class FilterOp; class FilterOp;

class GeneratorOp; class GeneratorOp;
#endif #endif
////////////////////////////////// //////////////////////////////////
@@ -175,6 +139,13 @@ class TreePass : public Pass {
/// \param[inout] modified Indicate if the tree was modified /// \param[inout] modified Indicate if the tree was modified
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final; Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final;


/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }

////////////////////////////////// //////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
/// \brief Run the transformation pass against the execution tree. /// \brief Run the transformation pass against the execution tree.
@@ -191,8 +162,17 @@ class TreePass : public Pass {
////////////////////////////////// //////////////////////////////////
}; };


// NodePass is a basic Pass class which performs transformation on Node visiting.
// NodePass is a base Pass class which performs transformation on node visiting.
// NodePass implements Visitor design pattern. // NodePass implements Visitor design pattern.
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
// and the other when all the descending nodes are visited.
// Actual transformation is done by implementing a new derived class of NodePass.
// The derived class will implement the method Visit()/VisitAfter() passing specified node types
// it wants to action on them, overriding the ones defined in NodePass.
// If the derived class wants to perform the same action on all node types,
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
// to call the Visit()/VisitAfter() in this parent NodePass class.
class NodePass : public Pass { class NodePass : public Pass {
public: public:
// Tree traversal order // Tree traversal order
@@ -223,153 +203,57 @@ class NodePass : public Pass {
/// \return Status The error code return /// \return Status The error code return
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }


// For datasetops IR
// Visit method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare Visit of its own type and override "Accept" from DatasetNode.
// Visit()/VisitAfter() method to be overridden.
// These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here.
// Their implementation are in .cc file to avoid adding the include files of those derived classes.
// The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of
// the derived classes. With this technique, the transformation classes derived from NodePass needs only to
// implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes
// of DatasetNode in the same way.
// Note that virtual template functions are not permitted in C++.
//
// Non-leaf IR node
virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified); virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified);

// VisitAfter method to be overridden.
// Note that member template can not be virtual, any node which wants to work with NodePass
// should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode.
virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified); virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified); virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<FilterNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified); virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified); virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified); virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified); virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RootNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified); virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified); virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified);

#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified); virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified); virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified); virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified); virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified);

// For datasetops/source IR
virtual Status Visit(std::shared_ptr<AlbumNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<CelebANode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CelebANode> node, bool *modified);

virtual Status Visit(std::shared_ptr<Cifar100Node> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified);

virtual Status Visit(std::shared_ptr<Cifar10Node> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CLUENode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CLUENode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<CocoNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CocoNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CSVNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<CSVNode> node, bool *modified);
#endif

#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<ImageFolderNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<ManifestNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified);
#endif

virtual Status Visit(std::shared_ptr<MnistNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<MnistNode> node, bool *modified);

virtual Status Visit(std::shared_ptr<RandomNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *modified);

#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TextFileNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified);
#endif #endif

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified);
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
#endif #endif

virtual Status Visit(std::shared_ptr<VOCNode> node, bool *modified);

virtual Status VisitAfter(std::shared_ptr<VOCNode> node, bool *modified);
// Leaf IR node
virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified);
virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified);


////////////////////////////////// //////////////////////////////////
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
@@ -396,86 +280,47 @@ class NodePass : public Pass {
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
// of its own type and override "Accept" from DatasetOp. // of its own type and override "Accept" from DatasetOp.
virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified);

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified);
#endif #endif

#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);

virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);

virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
#endif #endif
////////////////////////////////// //////////////////////////////////


+ 14
- 4
mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc View File

@@ -18,6 +18,7 @@


#include "minddata/dataset/core/client.h" #include "minddata/dataset/core/client.h"
#include "minddata/dataset/include/datasets.h" #include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h"


@@ -119,11 +120,16 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::sha
return Status::OK(); return Status::OK();
} }


Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) {
num_epochs_ = num_epochs;
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
optimize_ = true; // Always ON (temporary) optimize_ = true; // Always ON (temporary)


RETURN_UNEXPECTED_IF_NULL(root_ir);
RETURN_UNEXPECTED_IF_NULL(input_ir);
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';

// Copy the input IR tree and insert under the root node
// Create a root node to host the input IR tree
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';


// Pre-pass of the IR tree // Pre-pass of the IR tree
RETURN_IF_NOT_OK(PrePass(root_ir)); RETURN_IF_NOT_OK(PrePass(root_ir));
@@ -136,11 +142,15 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
// Post-pass of the IR tree // Post-pass of the IR tree
RETURN_IF_NOT_OK(PostPass(root_ir)); RETURN_IF_NOT_OK(PostPass(root_ir));


MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n';

// This will evolve in the long run // This will evolve in the long run
tree_ = std::make_unique<ExecutionTree>(); tree_ = std::make_unique<ExecutionTree>();


// Build the Execution tree from the child of the root node
std::shared_ptr<DatasetOp> root_op; std::shared_ptr<DatasetOp> root_op;
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op));
// We will replace input_ir with root_ir->Children()[0] once IR optimizer is in
RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));


if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);


+ 0
- 4
mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h View File

@@ -67,10 +67,6 @@ class TreeAdapter {
// Optional optimizations status // Optional optimizations status
bool OptimizationEnabled() const { return optimize_; } bool OptimizationEnabled() const { return optimize_; }


// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }

private: private:
// This function runs a mandatory pass checking the syntax and semantics of the IR tree. // This function runs a mandatory pass checking the syntax and semantics of the IR tree.
Status PrePass(std::shared_ptr<DatasetNode> ir); Status PrePass(std::shared_ptr<DatasetNode> ir);


+ 30
- 2
mindspore/ccsrc/minddata/dataset/include/samplers.h View File

@@ -47,6 +47,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return Shared pointers to the newly created Sampler /// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> Build() = 0; virtual std::shared_ptr<SamplerRT> Build() = 0;


/// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
virtual std::shared_ptr<SamplerObj> Copy() = 0;

/// \brief Function for derived class to get the shard id of sampler /// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler /// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; } virtual int64_t ShardId() { return 0; }
@@ -132,6 +136,11 @@ class DistributedSamplerObj : public SamplerObj {


std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
even_dist_);
}

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
@@ -160,6 +169,10 @@ class PKSamplerObj : public SamplerObj {


std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
}

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
@@ -174,9 +187,8 @@ class PKSamplerObj : public SamplerObj {


class PreBuiltSamplerObj : public SamplerObj { class PreBuiltSamplerObj : public SamplerObj {
public: public:
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif #endif


@@ -188,6 +200,8 @@ class PreBuiltSamplerObj : public SamplerObj {
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif


std::shared_ptr<SamplerObj> Copy() override;

bool ValidateParams() override; bool ValidateParams() override;


private: private:
@@ -205,6 +219,8 @@ class RandomSamplerObj : public SamplerObj {


std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
@@ -224,6 +240,10 @@ class SequentialSamplerObj : public SamplerObj {


std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
}

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
@@ -243,6 +263,10 @@ class SubsetRandomSamplerObj : public SamplerObj {


std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
}

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif
@@ -262,6 +286,10 @@ class WeightedRandomSamplerObj : public SamplerObj {


std::shared_ptr<SamplerRT> Build() override; std::shared_ptr<SamplerRT> Build() override;


std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
}

bool ValidateParams() override; bool ValidateParams() override;


private: private:


+ 11
- 1
mindspore/ccsrc/minddata/dataset/include/transforms.h View File

@@ -32,7 +32,10 @@ class TensorOp;
class TensorOperation : public std::enable_shared_from_this<TensorOperation> { class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
public: public:
/// \brief Constructor /// \brief Constructor
TensorOperation();
TensorOperation() : random_op_(false) {}

/// \brief Constructor
explicit TensorOperation(bool random) : random_op_(random) {}


/// \brief Destructor /// \brief Destructor
~TensorOperation() = default; ~TensorOperation() = default;
@@ -42,6 +45,13 @@ class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
virtual std::shared_ptr<TensorOp> Build() = 0; virtual std::shared_ptr<TensorOp> Build() = 0;


virtual Status ValidateParams() = 0; virtual Status ValidateParams() = 0;

/// \brief Check whether the operation is deterministic.
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
bool IsRandomOp() const { return random_op_; }

protected:
bool random_op_;
}; };


// Helper function to validate fill value // Helper function to validate fill value


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc View File

@@ -427,7 +427,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
std::vector<dsize_t> cur_ind, size_t cur_dim) { std::vector<dsize_t> cur_ind, size_t cur_dim) {
if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
dst->CopyLastDimAt(src, cur_ind);
RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind));
} else { // not the last dimension, keep doing recursion } else { // not the last dimension, keep doing recursion
dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
for (dsize_t i = 0; i < min_ind; i++) { for (dsize_t i = 0; i < min_ind; i++) {


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h View File

@@ -57,7 +57,7 @@ class RandomCropOp : public TensorOp {
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;


// Function breaks out the compute function's image padding functionality and makes available to other Ops // Function breaks out the compute function's image padding functionality and makes available to other Ops
// Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op
// Using this class as a base - re-structured to allow for RandomCropWithBBox Augmentation Op
// @param input: Input is the original Image // @param input: Input is the original Image
// @param pad_image: Pointer to new Padded image // @param pad_image: Pointer to new Padded image
// @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required


+ 1
- 1
mindspore/dataset/engine/samplers.py View File

@@ -570,7 +570,7 @@ class WeightedRandomSampler(BuiltinSampler):
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).


Args: Args:
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
num_samples (int, optional): Number of elements to sample (default=None, all elements). num_samples (int, optional): Number of elements to sample (default=None, all elements).
replacement (bool): If True, put the sample ID back for the next draw (default=True). replacement (bool): If True, put the sample ID back for the next draw (default=True).




+ 1
- 0
tests/ut/cpp/dataset/CMakeLists.txt View File

@@ -17,6 +17,7 @@ SET(DE_UT_SRCS
c_api_dataset_coco_test.cc c_api_dataset_coco_test.cc
c_api_dataset_config_test.cc c_api_dataset_config_test.cc
c_api_dataset_csv_test.cc c_api_dataset_csv_test.cc
c_api_dataset_ir_node_test.cc
c_api_dataset_iterator_test.cc c_api_dataset_iterator_test.cc
c_api_dataset_manifest_test.cc c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc c_api_dataset_minddata_test.cc


+ 142
- 0
tests/ut/cpp/dataset/c_api_dataset_ir_node_test.cc View File

@@ -0,0 +1,142 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include <string>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pre/getter_pass.h"

using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::MsLogLevel::INFO;

class MindDataTestIRNodes : public UT::DatasetOpTesting {
public:
MindDataTestIRNodes() = default;
void SetUp() override { GlobalInit(); }

// compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code
// if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same)
Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) {
CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr.");
if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) {
std::string err_msg =
"Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!";
RETURN_STATUS_UNEXPECTED(err_msg);
}

size_t num_child = root1->Children().size();

CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(),
root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " +
std::to_string(root2->Children().size()) + " child.");

for (size_t ind = 0; ind < num_child; ind++) {
RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag));
}
return Status::OK();
}

// print the node's name in post order
Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) {
RETURN_UNEXPECTED_IF_NULL(ir);
for (auto child : ir->Children()) {
RETURN_IF_NOT_OK(PostOrderPrintTree(child, names));
}
names += (ir->Name() + "->");
return Status::OK();
}
};

TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy.";

auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode();

auto tree2 = tree1->DeepCopy();
std::string tree_1_names, tree_2_names;

ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));

// expected output for the 2 names:
// RandomDataset->Repeat->Project->Shuffle->Batch->
EXPECT_EQ(tree_1_names, tree_2_names);

ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));

// verify compare function is correct
EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError());
}

TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy.";

auto branch1 = RandomData(44)->Project({"label"});
auto branch2 = RandomData(44)->Shuffle(10);

auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode();

auto tree2 = tree1->DeepCopy();
std::string tree_1_names, tree_2_names;

ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));

// expected output for the 2 names:
// RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch->
EXPECT_EQ(tree_1_names, tree_2_names);

// verify the pointer within the same tree are the same
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
// verify two trees
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
}

TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) {
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove.";

auto branch1 = RandomData(44)->Project({"label"});
auto branch2 = ImageFolder("path");
auto tree = Zip({branch1, branch2})->IRNode();
/***
tree looks like this, we will remove node and test its functionalities
Zip
/ \
Project ImageFolder
/
RandomData
***/
auto tree_copy_1 = tree->DeepCopy();
ASSERT_EQ(tree_copy_1->Children().size(), 2);
// remove the project in the tree and test
ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree
ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false));
// remove the ImageFolder, a leaf node from the tree
std::string tree_1_names, tree_2_names;
ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names));
EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->");
auto tree_copy_2 = tree->DeepCopy();
ASSERT_EQ(tree_copy_2->Children().size(), 2);
tree_copy_2->Children()[1]->Remove();
ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names));
EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->");
}

Loading…
Cancel
Save