Browse Source

Remove api namespace

tags/v1.1.0
hesham 5 years ago
parent
commit
5169fb4c42
100 changed files with 490 additions and 565 deletions
  1. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/config.cc
  2. +14
    -207
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  3. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/execute.cc
  4. +4
    -6
      mindspore/ccsrc/minddata/dataset/api/iterator.cc
  5. +22
    -22
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
  6. +24
    -24
      mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc
  7. +13
    -15
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  8. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/text.cc
  9. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/transforms.cc
  10. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/vision.cc
  11. +6
    -6
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  12. +6
    -10
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  13. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
  14. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
  15. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc
  16. +5
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h
  17. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
  18. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
  19. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
  20. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h
  21. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc
  22. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h
  23. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
  24. +6
    -6
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
  25. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc
  26. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h
  27. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc
  28. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h
  29. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc
  30. +4
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h
  31. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc
  32. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h
  33. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
  34. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h
  35. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc
  36. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h
  37. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc
  38. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h
  39. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
  40. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
  41. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
  42. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h
  43. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc
  44. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h
  45. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc
  46. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h
  47. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc
  48. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h
  49. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  50. +5
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h
  51. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc
  52. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h
  53. +7
    -7
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc
  54. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h
  55. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc
  56. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h
  57. +14
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc
  58. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h
  59. +7
    -7
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
  60. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h
  61. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
  62. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
  63. +9
    -9
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
  64. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
  65. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc
  66. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h
  67. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc
  68. +5
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h
  69. +3
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
  70. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h
  71. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h
  72. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc
  73. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h
  74. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc
  75. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h
  76. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc
  77. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h
  78. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc
  79. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h
  80. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc
  81. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h
  82. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  83. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h
  84. +177
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  85. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  86. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc
  87. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h
  88. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc
  89. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h
  90. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc
  91. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h
  92. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  93. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  94. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc
  95. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h
  96. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc
  97. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h
  98. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc
  99. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h
  100. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/config.cc View File

@@ -21,7 +21,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Config operations for setting and getting the configuration. // Config operations for setting and getting the configuration.
namespace config { namespace config {
@@ -104,6 +103,5 @@ bool load(std::string file) {
} }


} // namespace config } // namespace config
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 14
- 207
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -21,36 +21,14 @@
#include <utility> #include <utility>
#include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h" #include "minddata/dataset/include/transforms.h"
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"

#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#endif #endif
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"


// Sampler headers (in alphabetical order) // Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"


// IR non-leaf nodes // IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h" #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
@@ -99,7 +77,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Function to create the iterator, which will build and launch the execution tree. // Function to create the iterator, which will build and launch the execution tree.
std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) { std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
@@ -317,7 +294,7 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) {
return schema->init() ? schema : nullptr; return schema->init() ? schema : nullptr;
} }


// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// FUNCTIONS TO CREATE DATASETS FOR LEAF CLASSES
// (In alphabetical order) // (In alphabetical order)


// Function to create a AlbumDataset. // Function to create a AlbumDataset.
@@ -466,7 +443,7 @@ std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::strin
} }
#endif #endif


// Function to create a ZipNode.
// Function to create a ZipDatset.
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
auto ds = std::make_shared<ZipDataset>(datasets); auto ds = std::make_shared<ZipDataset>(datasets);
return ds; return ds;
@@ -639,7 +616,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
return nullptr; return nullptr;
} }


@@ -647,15 +624,15 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
BuildVocabConsumer *bv_consumer = consumer.get(); BuildVocabConsumer *bv_consumer = consumer.get();
rc = consumer->Init(ds); rc = consumer->Init(ds);
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc;
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
return nullptr; return nullptr;
} }
runtime_context->AssignConsumer(std::move(consumer)); runtime_context->AssignConsumer(std::move(consumer));


// Run tree here to starting building vocab
// Run tree here to starting building SentencePieceVocab
rc = bv_consumer->Start(); rc = bv_consumer->Start();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc;
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc;
return nullptr; return nullptr;
} }
return vocab; return vocab;
@@ -671,7 +648,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
return nullptr; return nullptr;
} }


@@ -679,7 +656,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
BuildVocabConsumer *bv_consumer = consumer.get(); BuildVocabConsumer *bv_consumer = consumer.get();
rc = consumer->Init(ds); rc = consumer->Init(ds);
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc;
MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
return nullptr; return nullptr;
} }
runtime_context->AssignConsumer(std::move(consumer)); runtime_context->AssignConsumer(std::move(consumer));
@@ -687,11 +664,14 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
// Run tree here to starting building vocab // Run tree here to starting building vocab
rc = bv_consumer->Start(); rc = bv_consumer->Start();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc;
MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc;
return nullptr; return nullptr;
} }
return vocab; return vocab;
} }
std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
}
#endif #endif
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}


@@ -856,162 +836,6 @@ bool SchemaObj::from_json(nlohmann::json json_obj) {


// OTHER FUNCTIONS // OTHER FUNCTIONS


// Helper function to compute a default shuffle size
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int64_t *shuffle_size) {
const int64_t average_files_multiplier = 4;
const int64_t shuffle_max = 10000;
int64_t avg_rows_per_file = 0;

// Adjust the num rows per shard if sharding was given
if (num_devices > 0) {
if (num_rows % num_devices == 0) {
num_rows = num_rows / num_devices;
} else {
num_rows = (num_rows / num_devices) + 1;
}
}

// Cap based on total rows directive. Some ops do not have this and give value of 0.
if (total_rows > 0) {
num_rows = std::min(num_rows, total_rows);
}

// get the average per file
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0.");
avg_rows_per_file = num_rows / num_files;

*shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
return Status::OK();
}

// Helper function to inject a shuffle operator over top of current operator being built
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) {
std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr;
int64_t shuffle_size = 0;
RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
// Add the shuffle op
*shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer);
return Status::OK();
}

// Helper function to validate dataset directory parameter
Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
if (dataset_dir.empty()) {
std::string err_msg = dataset_name + ": dataset_dir is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

Path dir(dataset_dir);
if (!dir.IsDirectory()) {
std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (access(dataset_dir.c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset files parameter
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
if (dataset_files.empty()) {
std::string err_msg = dataset_name + ": dataset_files is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (access(dataset_file.toString().c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}

return Status::OK();
}

// Helper function to validate dataset num_shards and shard_id parameters
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
if (num_shards <= 0) {
std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (shard_id < 0 || shard_id >= num_shards) {
// num_shards;
std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
", num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
if (sampler == nullptr) {
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings) {
if (valid_strings.find(str) == valid_strings.end()) {
std::string mode;
mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode,
[](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

// Helper function to validate dataset input/output column parameter
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
const std::vector<std::string> &columns) {
if (columns.empty()) {
std::string err_msg = dataset_name + ":" + column_param + " should not be empty string";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (uint32_t i = 0; i < columns.size(); ++i) {
if (columns[i].empty()) {
std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
std::set<std::string> columns_set(columns.begin(), columns.end());
if (columns_set.size() != columns.size()) {
std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID


std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill,
@@ -1153,22 +977,5 @@ TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files,
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
#endif #endif
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
if (shuffle) {
if (num_shards > 1) {
// If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle enabled, sharding disabled, use random sampler
return RandomSampler(num_samples >= 0, num_samples);
}
if (num_shards > 1) {
// If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(0, num_samples);
}
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/execute.cc View File

@@ -26,7 +26,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {}


@@ -54,6 +53,5 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS
return std::make_shared<tensor::DETensor>(std::move(de_output)); return std::make_shared<tensor::DETensor>(std::move(de_output));
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 4
- 6
mindspore/ccsrc/minddata/dataset/api/iterator.cc View File

@@ -20,7 +20,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Get the next row from the data pipeline. // Get the next row from the data pipeline.
bool Iterator::GetNextRow(TensorMap *row) { bool Iterator::GetNextRow(TensorMap *row) {
@@ -45,19 +44,18 @@ bool Iterator::GetNextRow(TensorVec *row) {
} }


// Shut down the data pipeline. // Shut down the data pipeline.
void Iterator::Stop() { runtime_context->Terminate(); }
void Iterator::Stop() { runtime_context_->Terminate(); }


// Function to build and launch the execution tree. // Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
runtime_context = std::make_unique<RuntimeContext>();
RETURN_IF_NOT_OK(runtime_context->Init());
runtime_context_ = std::make_unique<RuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>(); auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get(); consumer_ = consumer.get();
RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
runtime_context->AssignConsumer(std::move(consumer));
runtime_context_->AssignConsumer(std::move(consumer));
return Status::OK(); return Status::OK();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 22
- 22
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc View File

@@ -27,59 +27,59 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) {
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
.def("set_num_rows", .def("set_num_rows",
[](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples", .def("set_num_samples",
[](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices", .def("get_indices",
[](Sampler &self) {
[](SamplerRT &self) {
py::array ret; py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret; return ret;
}) })
.def("add_child", [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) {
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
THROW_IF_ERROR(self->AddChild(child)); THROW_IF_ERROR(self->AddChild(child));
}); });
})); }));


PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) {
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
*m, "DistributedSampler") *m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>()); .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
})); }));


PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) {
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
.def(py::init<int64_t, int64_t, bool>()); .def(py::init<int64_t, int64_t, bool>());
})); }));


PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) {
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
.def(py::init<int64_t, py::object>()); .def(py::init<int64_t, py::object>());
})); }));


PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) {
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
.def(py::init<int64_t, bool, bool>()); .def(py::init<int64_t, bool, bool>());
})); }));


PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) {
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m,
"SequentialSampler")
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
*m, "SequentialSampler")
.def(py::init<int64_t, int64_t>()); .def(py::init<int64_t, int64_t>());
})); }));


PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(
PYBIND_REGISTER(SubsetRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerRT, SamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
*m, "SubsetRandomSampler") *m, "SubsetRandomSampler")
.def(py::init<int64_t, std::vector<int64_t>>()); .def(py::init<int64_t, std::vector<int64_t>>());
})); }));


PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
*m, "WeightedRandomSampler") *m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>()); .def(py::init<int64_t, std::vector<double>, bool>());
})); }));


+ 24
- 24
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc View File

@@ -1140,7 +1140,7 @@ Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp
if (!value.is_none()) { if (!value.is_none()) {
if (key == "sampler") { if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} }
if (key == "children_flag_and_nums") { if (key == "children_flag_and_nums") {
@@ -1164,7 +1164,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
// Required arguments // Required arguments
std::vector<std::string> files_list; std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr; std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0; int num_workers = 0;
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
if (!args["dataset_files"].is_none()) { if (!args["dataset_files"].is_none()) {
@@ -1210,7 +1210,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
cache_client = value.cast<std::shared_ptr<CacheClient>>(); cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
} }
} }
} }
@@ -1234,7 +1234,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
} else if (cache_client) { } else if (cache_client) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} }


@@ -1308,7 +1308,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "extensions") { } else if (key == "extensions") {
(void)builder->SetExtensions(ToStringSet(value)); (void)builder->SetExtensions(ToStringSet(value));
@@ -1363,7 +1363,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "class_indexing") { } else if (key == "class_indexing") {
(void)builder->SetClassIndex(ToStringMap(value)); (void)builder->SetClassIndex(ToStringMap(value));
@@ -1416,7 +1416,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") { } else if (key == "decode") {
(void)builder->SetDecode(ToBool(value)); (void)builder->SetDecode(ToBool(value));
@@ -1478,7 +1478,7 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp>
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") { } else if (key == "decode") {
(void)builder->SetDecode(ToBool(value)); (void)builder->SetDecode(ToBool(value));
@@ -1529,7 +1529,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") { } else if (key == "usage") {
(void)builder->SetUsage(ToString(value)); (void)builder->SetUsage(ToString(value));
@@ -1583,7 +1583,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") { } else if (key == "usage") {
(void)builder->SetUsage(ToString(value)); (void)builder->SetUsage(ToString(value));
@@ -1618,7 +1618,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
// Required arguments // Required arguments
RandomDataOp::Builder builder; RandomDataOp::Builder builder;
std::shared_ptr<CacheClient> cache_client = nullptr; std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0; int num_workers = 0;


if (args["total_rows"].is_none()) { if (args["total_rows"].is_none()) {
@@ -1646,7 +1646,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
cache_client = value.cast<std::shared_ptr<CacheClient>>(); cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
} }
} }
} }
@@ -1670,7 +1670,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
} else if (cache_client) { } else if (cache_client) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder.SetSampler(std::move(sampler)); (void)builder.SetSampler(std::move(sampler));
} }


@@ -1715,7 +1715,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") { } else if (key == "usage") {
(void)builder->SetUsage(ToString(value)); (void)builder->SetUsage(ToString(value));
@@ -1768,7 +1768,7 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetNumWorkers(num_workers); (void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") { } else if (key == "decode") {
(void)builder->SetDecode(ToBool(value)); (void)builder->SetDecode(ToBool(value));
@@ -1806,7 +1806,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
// Required arguments // Required arguments
std::vector<std::string> files_list; std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr; std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0; int num_workers = 0;
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
if (!args["dataset_files"].is_none()) { if (!args["dataset_files"].is_none()) {
@@ -1840,7 +1840,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
cache_client = value.cast<std::shared_ptr<CacheClient>>(); cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
} }
} }
} }
@@ -1855,7 +1855,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
} else if (cache_client) { } else if (cache_client) {
int64_t num_samples = 0; int64_t num_samples = 0;
int64_t start_index = 0; int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} }


@@ -1991,7 +1991,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
std::shared_ptr<DatasetOp> *bottom) { std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list; std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr; std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0; int num_workers = 0;


std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
@@ -2036,7 +2036,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
cache_client = value.cast<std::shared_ptr<CacheClient>>(); cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
} }
} }
} }
@@ -2051,7 +2051,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
} else if (cache_client) { } else if (cache_client) {
int64_t num_samples = 0; int64_t num_samples = 0;
int64_t start_index = 0; int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} }


@@ -2116,7 +2116,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std::shared_ptr<DatasetOp> *bottom) { std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list; std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr; std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0; int num_workers = 0;
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
if (!args["dataset_files"].is_none()) { if (!args["dataset_files"].is_none()) {
@@ -2173,7 +2173,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
cache_client = value.cast<std::shared_ptr<CacheClient>>(); cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") { } else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create"); auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
} }
} }
} }
@@ -2188,7 +2188,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
} else if (cache_client) { } else if (cache_client) {
int64_t num_samples = 0; int64_t num_samples = 0;
int64_t start_index = 0; int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler)); (void)builder->SetSampler(std::move(sampler));
} }




+ 13
- 15
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -35,7 +35,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


#define RETURN_NULL_IF_ERROR(_s) \ #define RETURN_NULL_IF_ERROR(_s) \
do { \ do { \
@@ -151,10 +150,10 @@ bool DistributedSamplerObj::ValidateParams() {
return true; return true;
} }


std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
return sampler; return sampler;
} }


@@ -184,9 +183,9 @@ bool PKSamplerObj::ValidateParams() {
return true; return true;
} }


std::shared_ptr<Sampler> PKSamplerObj::Build() {
std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_);
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);


return sampler; return sampler;
} }
@@ -218,10 +217,10 @@ bool RandomSamplerObj::ValidateParams() {
return true; return true;
} }


std::shared_ptr<Sampler> RandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
// runtime sampler object // runtime sampler object
bool reshuffle_each_epoch = true; bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch);
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);


return sampler; return sampler;
} }
@@ -255,9 +254,9 @@ bool SequentialSamplerObj::ValidateParams() {
return true; return true;
} }


std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_);
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);


return sampler; return sampler;
} }
@@ -284,9 +283,9 @@ bool SubsetRandomSamplerObj::ValidateParams() {
return true; return true;
} }


std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_);
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);


return sampler; return sampler;
} }
@@ -330,11 +329,10 @@ bool WeightedRandomSamplerObj::ValidateParams() {
return true; return true;
} }


std::shared_ptr<Sampler> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSampler>(num_samples_, weights_, replacement_);
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
return sampler; return sampler;
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/text.cc View File

@@ -22,7 +22,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Transform operations for text. // Transform operations for text.
namespace text { namespace text {
@@ -130,6 +129,5 @@ std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
} }


} // namespace text } // namespace text
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/transforms.cc View File

@@ -22,7 +22,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


TensorOperation::TensorOperation() {} TensorOperation::TensorOperation() {}


@@ -94,6 +93,5 @@ Status TypeCastOperation::ValidateParams() {
std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); } std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }


} // namespace transforms } // namespace transforms
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/vision.cc View File

@@ -65,7 +65,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Transform operations for computer vision. // Transform operations for computer vision.
namespace vision { namespace vision {
@@ -1702,6 +1701,5 @@ std::shared_ptr<TensorOp> UniformAugOperation::Build() {
#endif #endif


} // namespace vision } // namespace vision
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 6
- 6
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -34,11 +34,11 @@ namespace mindspore::dataset {
// TreeConsumer // TreeConsumer
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }


Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }


// IteratorConsumer // IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
} }


@@ -74,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
} }


// ToDevice // ToDevice
Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) {
Status ToDevice::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
} }


@@ -385,8 +385,8 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal
tree_adapter_ = std::make_unique<TreeAdapter>(); tree_adapter_ = std::make_unique<TreeAdapter>();
} }


Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d), 1);
if (!s.IsError()) { if (!s.IsError()) {
init_flag_ = true; init_flag_ = true;
} }
@@ -464,7 +464,7 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK(); return Status::OK();
} }
Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), 1); return tree_adapter_->BuildAndPrepare(std::move(d), 1);
} }
Status BuildVocabConsumer::Start() { Status BuildVocabConsumer::Start() {


+ 6
- 10
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -29,10 +29,7 @@
namespace mindspore::dataset { namespace mindspore::dataset {
// Forward declare // Forward declare
class TreeAdapter; class TreeAdapter;

namespace api {
class DatasetNode; class DatasetNode;
}


/// A base class for tree consumers which would fetch rows from the tree pipeline /// A base class for tree consumers which would fetch rows from the tree pipeline
class TreeConsumer { class TreeConsumer {
@@ -42,7 +39,7 @@ class TreeConsumer {
/// Initializes the consumer, this involves constructing and preparing the tree. /// Initializes the consumer, this involves constructing and preparing the tree.
/// \param d The dataset node that represent the root of the IR tree. /// \param d The dataset node that represent the root of the IR tree.
/// \return Status error code. /// \return Status error code.
virtual Status Init(std::shared_ptr<api::DatasetNode> d);
virtual Status Init(std::shared_ptr<DatasetNode> d);


Status Terminate(); Status Terminate();


@@ -61,7 +58,7 @@ class IteratorConsumer : public TreeConsumer {
/// \param num_epochs number of epochs. Default to -1 (infinite epochs). /// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}


Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;


/// Returns the next row in a vector format /// Returns the next row in a vector format
/// \param[out] out std::vector of Tensors /// \param[out] out std::vector of Tensors
@@ -133,7 +130,7 @@ class ToDevice : public TreeConsumer {
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}


Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;


/// Send the data to device /// Send the data to device
/// \return Status error code /// \return Status error code
@@ -162,7 +159,7 @@ class ToDevice : public TreeConsumer {
class TreeGetters : public TreeConsumer { class TreeGetters : public TreeConsumer {
public: public:
TreeGetters(); TreeGetters();
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status GetDatasetSize(int64_t *size); Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types); Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes); Status GetOutputShapes(std::vector<TensorShape> *shapes);
@@ -185,10 +182,9 @@ class BuildVocabConsumer : public TreeConsumer {
/// BuildVocabConsumer Constructor which will call the base class default constructor. /// BuildVocabConsumer Constructor which will call the base class default constructor.
BuildVocabConsumer() = default; BuildVocabConsumer() = default;


Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;


/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
/// would be written to disk)
/// Start consuming
/// \return Status error code /// \return Status error code
Status Start(); Status Start();




+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc View File

@@ -46,7 +46,7 @@ Status CacheBase::Reset() {
return Status::OK(); return Status::OK();
} }
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
row_cnt_(0), row_cnt_(0),
num_cache_miss_(0), num_cache_miss_(0),


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h View File

@@ -46,7 +46,7 @@ class CacheBase : public ParallelOp {
/// \param cache_client CacheClient for communication to the CacheServer /// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory /// \param sampler Sampler which is mandatory
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor /// \brief Destructor
~CacheBase(); ~CacheBase();




+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc View File

@@ -87,7 +87,7 @@ Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
leaf_op_wp_.Set(); leaf_op_wp_.Set();
return Status::OK(); return Status::OK();
} }
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); }
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
std::vector<row_id_type> cache_miss; std::vector<row_id_type> cache_miss;


+ 5
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h View File

@@ -28,7 +28,7 @@ namespace dataset {
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. /// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
/// \note For non-mappable dataset, please see CacheOp /// \note For non-mappable dataset, please see CacheOp
/// \see CacheOp /// \see CacheOp
class CacheLookupOp : public CacheBase, public Sampler {
class CacheLookupOp : public CacheBase, public SamplerRT {
public: public:
class Builder { class Builder {
public: public:
@@ -62,7 +62,7 @@ class CacheLookupOp : public CacheBase, public Sampler {


/// \brief Setter method. /// \brief Setter method.
/// \return Builder setter method returns reference to the builder. /// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler); build_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -77,7 +77,7 @@ class CacheLookupOp : public CacheBase, public Sampler {
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
int32_t build_op_connector_size_; int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_; std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;


// Check if the required parameters are set by the builder. // Check if the required parameters are set by the builder.
// \return Status The error code return // \return Status The error code return
@@ -87,8 +87,8 @@ class CacheLookupOp : public CacheBase, public Sampler {
/// \note It takes the same argument as the base class. /// \note It takes the same argument as the base class.
/// \see CacheBase /// \see CacheBase
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), SamplerRT(*(sampler.get())) {}
~CacheLookupOp() = default; ~CacheLookupOp() = default;
// As a parallel op, we override these two functions // As a parallel op, we override these two functions
Status operator()() override; Status operator()() override;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc View File

@@ -46,7 +46,7 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
} }


CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler), : ParallelOp(numWorkers, opConnectorSize, sampler),
num_cleaners_(numCleaners), num_cleaners_(numCleaners),
cache_client_(std::move(cache_client)), cache_client_(std::move(cache_client)),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h View File

@@ -110,7 +110,7 @@ class CacheMergeOp : public ParallelOp {
/// \brief Setter method /// \brief Setter method
/// \param sampler /// \param sampler
/// \return Builder setter method returns reference to the builder. /// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler); build_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -133,7 +133,7 @@ class CacheMergeOp : public ParallelOp {
int32_t build_op_connector_size_; int32_t build_op_connector_size_;
int32_t build_num_cleaners_; int32_t build_num_cleaners_;
std::shared_ptr<CacheClient> build_cache_client_; std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;


/// Check if the required parameters are set by the builder. /// Check if the required parameters are set by the builder.
/// \return Status The error code return /// \return Status The error code return
@@ -147,7 +147,7 @@ class CacheMergeOp : public ParallelOp {
/// \param cache_client CacheClient to commmunicate with the Cache server /// \param cache_client CacheClient to commmunicate with the Cache server
/// \param sampler as a derived class of ParallelOp /// \param sampler as a derived class of ParallelOp
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler);
~CacheMergeOp(); ~CacheMergeOp();
void Print(std::ostream &out, bool show_all) const override; void Print(std::ostream &out, bool show_all) const override;
std::string Name() const override { return kCacheMergeOp; } std::string Name() const override { return kCacheMergeOp; }


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc View File

@@ -68,7 +68,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {


// Constructor of CacheOp // Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)),
num_guys_in_(0), num_guys_in_(0),
phase_(Phase::kBuildPhase) {} phase_(Phase::kBuildPhase) {}


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h View File

@@ -81,7 +81,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \brief Setter method /// \brief Setter method
/// \param sampler /// \param sampler
/// \return Builder setter method returns reference to the builder. /// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler); build_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -96,7 +96,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
int32_t build_op_connector_size_; int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_; std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;


/// \brief Check if the required parameters are set by the builder. /// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return /// \return Status The error code return
@@ -108,7 +108,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \param num_workers The number of worker threads. /// \param num_workers The number of worker threads.
/// \param op_connector_size The size of each queue in the connector. /// \param op_connector_size The size of each queue in the connector.
CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler);


// Destructor // Destructor
~CacheOp(); ~CacheOp();


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc View File

@@ -36,7 +36,7 @@ ConcatOp::Builder::Builder() {
// The builder "build" method creates the final object. // The builder "build" method creates the final object.
Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<DistributedSampler>(0, 1, 0, false);
builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false);
} }
*ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_, *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_,
children_start_end_index_); children_start_end_index_);
@@ -44,7 +44,7 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
} }


// Constructor of the ConcatOp. // Constructor of the ConcatOp.
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums, std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index) std::vector<std::pair<int, int>> children_start_end_index)
: PipelineOp(op_connector_size), : PipelineOp(op_connector_size),
@@ -80,7 +80,7 @@ Status ConcatOp::operator()() {
bool is_not_mappable = true; bool is_not_mappable = true;
int num_shard = 1; int num_shard = 1;
int shard_index = 0; int shard_index = 0;
std::shared_ptr<DistributedSampler> distribute_sampler = std::dynamic_pointer_cast<DistributedSampler>(sampler_);
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_);
if (distribute_sampler != nullptr) { if (distribute_sampler != nullptr) {
num_shard = distribute_sampler->GetDeviceNum(); num_shard = distribute_sampler->GetDeviceNum();
shard_index = distribute_sampler->GetDeviceID(); shard_index = distribute_sampler->GetDeviceID();


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h View File

@@ -44,7 +44,7 @@ class ConcatOp : public PipelineOp {
// The builder "build" method creates the final object. // The builder "build" method creates the final object.
// @return shared_ptr to the new ConcatOp object // @return shared_ptr to the new ConcatOp object
Status Build(std::shared_ptr<ConcatOp> *); Status Build(std::shared_ptr<ConcatOp> *);
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -61,7 +61,7 @@ class ConcatOp : public PipelineOp {


private: private:
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_; std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_; std::vector<std::pair<int, int>> children_start_end_index_;
}; };
@@ -70,7 +70,7 @@ class ConcatOp : public PipelineOp {
// @note The builder class should be used to call it // @note The builder class should be used to call it
// @param op_connector_size - connector size // @param op_connector_size - connector size
explicit ConcatOp(int32_t op_connector_size); explicit ConcatOp(int32_t op_connector_size);
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums, std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index); std::vector<std::pair<int, int>> children_start_end_index);


@@ -123,7 +123,7 @@ class ConcatOp : public PipelineOp {
std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name
std::vector<DataType> data_type_; std::vector<DataType> data_type_;
std::vector<dsize_t> data_rank_; std::vector<dsize_t> data_rank_;
std::shared_ptr<Sampler> sampler_;
std::shared_ptr<SamplerRT> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_; std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_; std::vector<std::pair<int, int>> children_start_end_index_;
}; };


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc View File

@@ -40,7 +40,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Constructor // Constructor
DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: oc_queue_size_(op_connector_size), : oc_queue_size_(op_connector_size),
sampler_(sampler), sampler_(sampler),
operator_id_(kInvalidOperatorId), operator_id_(kInvalidOperatorId),
@@ -409,7 +409,7 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
} }


// Getter for the sampler, and it also removes the sampler from the op // Getter for the sampler, and it also removes the sampler from the op
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) {
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler) {
*sampler = sampler_; // It's okay if it sampler_ points to nullptr *sampler = sampler_; // It's okay if it sampler_ points to nullptr
sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler
return Status::OK(); return Status::OK();


+ 6
- 6
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h View File

@@ -62,7 +62,7 @@ class DataBuffer;


class NodePass; class NodePass;


class Sampler;
class SamplerRT;


/// \brief The base class DatasetOp is the main tree node. It is an abstract class, so /// \brief The base class DatasetOp is the main tree node. It is an abstract class, so
/// the actual implementation of the operators will be derived from here. /// the actual implementation of the operators will be derived from here.
@@ -80,7 +80,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// Constructor /// Constructor
/// \param op_connector_size - The size for the output connector of this operator. /// \param op_connector_size - The size for the output connector of this operator.
/// \param sampler - The sampler for the op /// \param sampler - The sampler for the op
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler);
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler);


/// Destructor /// Destructor
virtual ~DatasetOp() { tree_ = nullptr; } virtual ~DatasetOp() { tree_ = nullptr; }
@@ -347,12 +347,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {


/// Getter for the sampler /// Getter for the sampler
/// \return Shared pointer to the sampler (may return nullptr) /// \return Shared pointer to the sampler (may return nullptr)
std::shared_ptr<Sampler> sampler() { return sampler_; }
std::shared_ptr<SamplerRT> sampler() { return sampler_; }


/// \brief Getter for the sampler, and it also removes the sampler from the op /// \brief Getter for the sampler, and it also removes the sampler from the op
/// \param[out] sampler A pointer to the output sampler that was removed /// \param[out] sampler A pointer to the output sampler that was removed
/// \return Status error code /// \return Status error code
Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler);
Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler);


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Computes a CRC value for the operator // Computes a CRC value for the operator
@@ -368,7 +368,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
} }


/// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one.
void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; }
void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; }


/// \brief Checks if this is a leaf node (0 children) /// \brief Checks if this is a leaf node (0 children)
/// \return boolean returns true if it's a leaf /// \return boolean returns true if it's a leaf
@@ -409,7 +409,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {


std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler
int32_t oc_queue_size_; // Capacity for each out_connector_ int32_t oc_queue_size_; // Capacity for each out_connector_
int32_t operator_id_; // Generated id for the node int32_t operator_id_; // Generated id for the node
ExecutionTree *tree_; // Back pointer to our tree. ExecutionTree *tree_; // Back pointer to our tree.


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc View File

@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Constructor // Constructor
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: DatasetOp(op_connector_size, sampler), : DatasetOp(op_connector_size, sampler),
num_workers_(num_workers), num_workers_(num_workers),
num_producers_(num_workers), num_producers_(num_workers),


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h View File

@@ -41,7 +41,7 @@ class ParallelOp : public DatasetOp {
// @param num_workers // @param num_workers
// @param op_connector_size - size of the output connector for this operator // @param op_connector_size - size of the output connector for this operator
// @param sampler - The sampler for the op // @param sampler - The sampler for the op
ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr);
ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr);


// Destructor // Destructor
~ParallelOp() = default; ~ParallelOp() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc View File

@@ -20,7 +20,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Constructor // Constructor
PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: DatasetOp(op_connector_size, sampler) {} : DatasetOp(op_connector_size, sampler) {}


// A print method typically used for debugging // A print method typically used for debugging


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h View File

@@ -34,7 +34,7 @@ class PipelineOp : public DatasetOp {
// @param op_connector_size - size of the output connector // @param op_connector_size - size of the output connector
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
// @param sampler - The sampler for the op // @param sampler - The sampler for the op
explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr);
explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr);


// Destructor // Destructor
~PipelineOp() = default; ~PipelineOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc View File

@@ -42,7 +42,7 @@ Status AlbumOp::Builder::Build(std::shared_ptr<AlbumOp> *ptr) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data
const int64_t start_index = 0; const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }


builder_schema_ = std::make_unique<DataSchema>(); builder_schema_ = std::make_unique<DataSchema>();
@@ -73,7 +73,7 @@ Status AlbumOp::Builder::SanityCheck() {


AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_wkrs, queue_size, std::move(sampler)), : ParallelOp(num_wkrs, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir), folder_path_(file_dir),


+ 4
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h View File

@@ -100,7 +100,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
/// \brief Setter method /// \brief Setter method
/// \param[in] sampler /// \param[in] sampler
/// \return Builder setter method returns reference to the builder /// \return Builder setter method returns reference to the builder
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -147,7 +147,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
int32_t builder_rows_per_buffer_; int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_; std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_; std::unique_ptr<DataSchema> builder_schema_;
}; };


@@ -161,7 +161,8 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
/// \param[in] data_schema - schema of dataset /// \param[in] data_schema - schema of dataset
/// \param[in] sampler - sampler tells AlbumOp what to read /// \param[in] sampler - sampler tells AlbumOp what to read
AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler);


/// \brief Destructor. /// \brief Destructor.
~AlbumOp() = default; ~AlbumOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc View File

@@ -46,7 +46,7 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }


builder_schema_ = std::make_unique<DataSchema>(); builder_schema_ = std::make_unique<DataSchema>();
@@ -79,7 +79,7 @@ Status CelebAOp::Builder::SanityCheck() {


CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
bool decode, const std::string &usage, const std::set<std::string> &exts, bool decode, const std::string &usage, const std::set<std::string> &exts,
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)), : ParallelOp(num_workers, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
folder_path_(dir), folder_path_(dir),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h View File

@@ -95,7 +95,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -131,7 +131,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
int32_t builder_rows_per_buffer_; int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_; std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_; std::unique_ptr<DataSchema> builder_schema_;
std::string builder_usage_; std::string builder_usage_;
}; };
@@ -144,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode,
const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
std::shared_ptr<Sampler> sampler);
std::shared_ptr<SamplerRT> sampler);


~CelebAOp() override = default; ~CelebAOp() override = default;




+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc View File

@@ -50,7 +50,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
if (sampler_ == nullptr) { if (sampler_ == nullptr) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }
schema_ = std::make_unique<DataSchema>(); schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
@@ -88,7 +88,7 @@ Status CifarOp::Builder::SanityCheck() {


CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_works, queue_size, std::move(sampler)), : ParallelOp(num_works, queue_size, std::move(sampler)),
cifar_type_(type), cifar_type_(type),
usage_(usage), usage_(usage),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h View File

@@ -75,7 +75,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
sampler_ = std::move(sampler); sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -123,7 +123,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
int32_t num_workers_; int32_t num_workers_;
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
int32_t op_connect_size_; int32_t op_connect_size_;
std::shared_ptr<Sampler> sampler_;
std::shared_ptr<SamplerRT> sampler_;
std::unique_ptr<DataSchema> schema_; std::unique_ptr<DataSchema> schema_;
CifarType cifar_type_; CifarType cifar_type_;
}; };
@@ -138,7 +138,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler);
std::shared_ptr<SamplerRT> sampler);
// Destructor. // Destructor.
~CifarOp() = default; ~CifarOp() = default;




+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc View File

@@ -94,7 +94,7 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim


ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0), num_rows_per_shard_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h View File

@@ -125,7 +125,7 @@ class ClueOp : public ParallelOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -141,13 +141,13 @@ class ClueOp : public ParallelOp {
std::vector<std::string> builder_clue_files_list_; std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_; bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_; std::map<std::string, std::string> builder_cols_to_keyword_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
}; };


// Constructor of ClueOp // Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);


// Default destructor // Default destructor
~ClueOp() = default; ~ClueOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc View File

@@ -60,7 +60,7 @@ Status CocoOp::Builder::Build(std::shared_ptr<CocoOp> *ptr) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }
builder_schema_ = std::make_unique<DataSchema>(); builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(builder_schema_->AddColumn( RETURN_IF_NOT_OK(builder_schema_->AddColumn(
@@ -123,7 +123,7 @@ Status CocoOp::Builder::SanityCheck() {


CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)), : ParallelOp(num_workers, queue_size, std::move(sampler)),
decode_(decode), decode_(decode),
row_cnt_(0), row_cnt_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h View File

@@ -119,7 +119,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// Setter method. // Setter method.
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -149,7 +149,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
int32_t builder_num_workers_; int32_t builder_num_workers_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
int32_t builder_rows_per_buffer_; int32_t builder_rows_per_buffer_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_; std::unique_ptr<DataSchema> builder_schema_;
}; };


@@ -166,7 +166,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read // @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read
CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);


// Destructor // Destructor
~CocoOp() = default; ~CocoOp() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc View File

@@ -77,7 +77,7 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
csv_files_list_(std::move(csv_files_list)), csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim), field_delim_(field_delim),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h View File

@@ -243,7 +243,7 @@ class CsvOp : public ParallelOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -261,7 +261,7 @@ class CsvOp : public ParallelOp {
char builder_field_delim_; char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_; std::vector<std::string> builder_column_name_list_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
}; };


// Constructor of CsvOp // Constructor of CsvOp
@@ -271,7 +271,7 @@ class CsvOp : public ParallelOp {
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
std::shared_ptr<Sampler> sampler);
std::shared_ptr<SamplerRT> sampler);


// Default destructor // Default destructor
~CsvOp() = default; ~CsvOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -38,7 +38,7 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data
const int64_t start_index = 0; const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }
builder_schema_ = std::make_unique<DataSchema>(); builder_schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
@@ -68,7 +68,7 @@ Status ImageFolderOp::Builder::SanityCheck() {
ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
bool recursive, bool do_decode, const std::set<std::string> &exts, bool recursive, bool do_decode, const std::set<std::string> &exts,
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema, const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_wkrs, queue_size, std::move(sampler)), : ParallelOp(num_wkrs, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir), folder_path_(file_dir),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h View File

@@ -113,7 +113,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -151,7 +151,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
int32_t builder_rows_per_buffer_; int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_; std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_; std::unique_ptr<DataSchema> builder_schema_;
std::map<std::string, int32_t> builder_labels_to_read_; std::map<std::string, int32_t> builder_labels_to_read_;
}; };
@@ -165,7 +165,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive,
bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map, bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler);


// Destructor. // Destructor.
~ImageFolderOp() = default; ~ImageFolderOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc View File

@@ -43,7 +43,7 @@ Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }
builder_schema_ = std::make_unique<DataSchema>(); builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
@@ -67,7 +67,7 @@ Status ManifestOp::Builder::SanityCheck() {


ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage)
std::shared_ptr<SamplerRT> sampler, std::string usage)
: ParallelOp(num_works, queue_size, std::move(sampler)), : ParallelOp(num_works, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
io_block_pushed_(0), io_block_pushed_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h View File

@@ -88,7 +88,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -119,7 +119,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
Status Build(std::shared_ptr<ManifestOp> *op); Status Build(std::shared_ptr<ManifestOp> *op);


private: private:
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
bool builder_decode_; bool builder_decode_;


std::string builder_file_; std::string builder_file_;
@@ -139,7 +139,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage);
std::shared_ptr<SamplerRT> sampler, std::string usage);
// Destructor. // Destructor.
~ManifestOp() = default; ~ManifestOp() = default;




+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc View File

@@ -45,7 +45,7 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }
builder_schema_ = std::make_unique<DataSchema>(); builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
@@ -75,7 +75,7 @@ Status MnistOp::Builder::SanityCheck() {
} }


MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)), : ParallelOp(num_workers, queue_size, std::move(sampler)),
usage_(usage), usage_(usage),
buf_cnt_(0), buf_cnt_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h View File

@@ -78,7 +78,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -113,7 +113,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
int32_t builder_num_workers_; int32_t builder_num_workers_;
int32_t builder_rows_per_buffer_; int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_; std::unique_ptr<DataSchema> builder_schema_;
}; };


@@ -126,7 +126,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);


// Destructor. // Destructor.
~MnistOp() = default; ~MnistOp() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc View File

@@ -65,7 +65,7 @@ Status RandomDataOp::Builder::SanityCheck() const {


// Constructor for RandomDataOp // Constructor for RandomDataOp
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
buffer_id_(0), buffer_id_(0),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h View File

@@ -120,7 +120,7 @@ class RandomDataOp : public ParallelOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -133,7 +133,7 @@ class RandomDataOp : public ParallelOp {
Status SanityCheck() const; Status SanityCheck() const;


std::unique_ptr<DataSchema> builder_data_schema_; std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_num_workers_; int32_t builder_num_workers_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_; int64_t builder_rows_per_buffer_;
@@ -152,7 +152,7 @@ class RandomDataOp : public ParallelOp {
* @return Builder - The modified builder by reference * @return Builder - The modified builder by reference
*/ */
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);


/** /**
* Destructor * Destructor


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -23,9 +23,9 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed, int64_t offset, bool even_dist)
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed, int64_t offset, bool even_dist)
: SamplerRT(num_samples, std::numeric_limits<int64_t>::max()),
cnt_(0), cnt_(0),
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed), seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
device_id_(dev_id), device_id_(dev_id),
@@ -35,7 +35,7 @@ DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int
offset_(offset), offset_(offset),
non_empty_(true) {} non_empty_(true) {}


Status DistributedSampler::InitSampler() {
Status DistributedSamplerRT::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) { if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -74,7 +74,7 @@ Status DistributedSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status DistributedSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (cnt_ > samples_per_buffer_) { if (cnt_ > samples_per_buffer_) {
RETURN_STATUS_UNEXPECTED( RETURN_STATUS_UNEXPECTED(
"Number of samples(cnt) that have already been filled in to buffer should be less than or " "Number of samples(cnt) that have already been filled in to buffer should be less than or "
@@ -143,7 +143,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
return Status::OK(); return Status::OK();
} }


Status DistributedSampler::ResetSampler() {
Status DistributedSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
cnt_ = 0; cnt_ = 0;


@@ -160,10 +160,10 @@ Status DistributedSampler::ResetSampler() {
return Status::OK(); return Status::OK();
} }


void DistributedSampler::Print(std::ostream &out, bool show_all) const {
void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: DistributedSampler"; out << "\nSampler: DistributedSampler";
if (show_all) { if (show_all) {
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_
<< "\nshuffle: " << shuffle_; << "\nshuffle: " << shuffle_;
} }


+ 5
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h View File

@@ -25,7 +25,7 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class DistributedSampler : public Sampler {
class DistributedSamplerRT : public SamplerRT {
public: public:
/// \brief Constructor /// \brief Constructor
/// \param[in] num_samples The total number of rows in the dataset /// \param[in] num_samples The total number of rows in the dataset
@@ -40,11 +40,12 @@ class DistributedSampler : public Sampler {
/// This option is not exposed in the python API. Current behavior is that the remainder will always /// This option is not exposed in the python API. Current behavior is that the remainder will always
/// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set, /// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set,
/// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario. /// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario.
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true);
DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1,
bool even_dist = true);


/// \brief default destructor /// \brief default destructor
~DistributedSampler() = default;
~DistributedSamplerRT() = default;


/// \param std::unique_ptr<DataBuffer> * pBuffer /// \param std::unique_ptr<DataBuffer> * pBuffer
/// \param int32_t workerId /// \param int32_t workerId


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc View File

@@ -20,14 +20,14 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer),
shuffle_(shuffle), shuffle_(shuffle),
seed_(GetSeed()), seed_(GetSeed()),
next_id_(0), next_id_(0),
samples_per_class_(val) {} samples_per_class_(val) {}


Status PKSampler::InitSampler() {
Status PKSamplerRT::InitSampler() {
labels_.reserve(label_to_ids_.size()); labels_.reserve(label_to_ids_.size());
for (const auto &pair : label_to_ids_) { for (const auto &pair : label_to_ids_) {
if (pair.second.empty() == false) { if (pair.second.empty() == false) {
@@ -61,7 +61,7 @@ Status PKSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_ || num_samples_ == 0) { if (next_id_ > num_samples_ || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_)); RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_));
} else if (next_id_ == num_samples_) { } else if (next_id_ == num_samples_) {
@@ -96,7 +96,7 @@ Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK(); return Status::OK();
} }


Status PKSampler::ResetSampler() {
Status PKSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0; next_id_ = 0;
rnd_.seed(seed_++); rnd_.seed(seed_++);
@@ -108,18 +108,18 @@ Status PKSampler::ResetSampler() {
return Status::OK(); return Status::OK();
} }


Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
RETURN_UNEXPECTED_IF_NULL(op); RETURN_UNEXPECTED_IF_NULL(op);
RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_));
RETURN_IF_NOT_OK(InitSampler()); RETURN_IF_NOT_OK(InitSampler());
return Status::OK(); return Status::OK();
} }


void PKSampler::Print(std::ostream &out, bool show_all) const {
void PKSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: PKSampler"; out << "\nSampler: PKSampler";
if (show_all) { if (show_all) {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any // Then add our own info if any
} }
} }


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h View File

@@ -26,17 +26,17 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class PKSampler : public Sampler { // NOT YET FINISHED
class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
public: public:
// @param num_samples - the number of samples to draw. value of 0 means to take the full amount // @param num_samples - the number of samples to draw. value of 0 means to take the full amount
// @param int64_t val // @param int64_t val
// @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());


// default destructor // default destructor
~PKSampler() = default;
~PKSamplerRT() = default;


// @param std::unique_ptr<DataBuffer pBuffer // @param std::unique_ptr<DataBuffer pBuffer
// @param int32_t workerId // @param int32_t workerId


+ 7
- 7
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc View File

@@ -20,10 +20,10 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
PythonSamplerRT::PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}


Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) { if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else { } else {
@@ -64,7 +64,7 @@ Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK(); return Status::OK();
} }


Status PythonSampler::InitSampler() {
Status PythonSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_));
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
@@ -86,7 +86,7 @@ Status PythonSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status PythonSampler::ResetSampler() {
Status PythonSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch");
need_to_reset_ = false; need_to_reset_ = false;
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
@@ -106,11 +106,11 @@ Status PythonSampler::ResetSampler() {
return Status::OK(); return Status::OK();
} }


void PythonSampler::Print(std::ostream &out, bool show_all) const {
void PythonSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: PythonSampler"; out << "\nSampler: PythonSampler";
if (show_all) { if (show_all) {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any // Then add our own info if any
} }
} }


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h View File

@@ -23,18 +23,18 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class PythonSampler : public Sampler {
class PythonSamplerRT : public SamplerRT {
public: public:
// Constructor // Constructor
// @param num_samples - the number of samples to draw. Value of 0 means to sample all of the // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the
// data from the dataset. // data from the dataset.
// @param py_sampler_instance - the python instance of the sampler // @param py_sampler_instance - the python instance of the sampler
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());


// Destructor. // Destructor.
~PythonSampler() = default;
~PythonSamplerRT() = default;


// Initialize the sampler. // Initialize the sampler.
// @return Status // @return Status


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc View File

@@ -22,16 +22,16 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer),
seed_(GetSeed()), seed_(GetSeed()),
replacement_(replacement), replacement_(replacement),
next_id_(0), next_id_(0),
reshuffle_each_epoch_(reshuffle_each_epoch), reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {} dist(nullptr) {}


Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_) { if (next_id_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error");
} else if (next_id_ == num_samples_) { } else if (next_id_ == num_samples_) {
@@ -68,7 +68,7 @@ Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK(); return Status::OK();
} }


Status RandomSampler::InitSampler() {
Status RandomSamplerRT::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) { if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -94,7 +94,7 @@ Status RandomSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status RandomSampler::ResetSampler() {
Status RandomSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0; next_id_ = 0;


@@ -115,11 +115,11 @@ Status RandomSampler::ResetSampler() {
return Status::OK(); return Status::OK();
} }


void RandomSampler::Print(std::ostream &out, bool show_all) const {
void RandomSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: RandomSampler"; out << "\nSampler: RandomSampler";
if (show_all) { if (show_all) {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any // Then add our own info if any
} }
} }


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h View File

@@ -24,18 +24,18 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class RandomSampler : public Sampler {
class RandomSamplerRT : public SamplerRT {
public: public:
// Constructor // Constructor
// @param int64_t num_samples - number samples to draw // @param int64_t num_samples - number samples to draw
// @param bool replacement - put he id back / or not after a sample // @param bool replacement - put he id back / or not after a sample
// @param reshuffle_each_epoch - T/F to reshuffle after epoch // @param reshuffle_each_epoch - T/F to reshuffle after epoch
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());


// Destructor. // Destructor.
~RandomSampler() = default;
~RandomSamplerRT() = default;


// Op calls this to get next Buffer that contains all the sampleIds // Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp


+ 14
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc View File

@@ -32,13 +32,13 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
return Status::OK(); return Status::OK();
} }


Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer)
SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer)
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}


Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<SamplerRT> child_sampler;
if (HasChildSampler()) { if (HasChildSampler()) {
child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]);
child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]);
if (!child_sampler) { if (!child_sampler) {
std::string err_msg("Cannot handshake, child is not a sampler object."); std::string err_msg("Cannot handshake, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
@@ -64,7 +64,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
return Status::OK(); return Status::OK();
} }


Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
if (num_elements == 0) { if (num_elements == 0) {
RETURN_STATUS_UNEXPECTED("Invalid data, num of elements cannot be 0."); RETURN_STATUS_UNEXPECTED("Invalid data, num of elements cannot be 0.");
} }
@@ -77,7 +77,7 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
return Status::OK(); return Status::OK();
} }


void Sampler::Print(std::ostream &out, bool show_all) const {
void SamplerRT::Print(std::ostream &out, bool show_all) const {
// Sampler printing is usually only called in the show_all mode. // Sampler printing is usually only called in the show_all mode.
// Derived classes will display the name, then call back to this base // Derived classes will display the name, then call back to this base
// for common info. // for common info.
@@ -88,7 +88,7 @@ void Sampler::Print(std::ostream &out, bool show_all) const {
} }


#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
Status Sampler::GetAllIdsThenReset(py::array *data) {
Status SamplerRT::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids; std::shared_ptr<Tensor> sample_ids;
TensorRow sample_row; TensorRow sample_row;
@@ -123,27 +123,27 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
} }
#endif #endif


Status Sampler::SetNumSamples(int64_t num_samples) {
Status SamplerRT::SetNumSamples(int64_t num_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0."); CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0.");
num_samples_ = num_samples; num_samples_ = num_samples;
return Status::OK(); return Status::OK();
} }


int64_t Sampler::GetNumSamples() { return num_samples_; }
int64_t SamplerRT::GetNumSamples() { return num_samples_; }


Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0.");
num_rows_ = num_rows; num_rows_ = num_rows;
return Status::OK(); return Status::OK();
} }


Status Sampler::AddChild(std::shared_ptr<Sampler> child) {
Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) {
if (child == nullptr) { if (child == nullptr) {
return Status::OK(); return Status::OK();
} }


// Only samplers can be added, not any other DatasetOp. // Only samplers can be added, not any other DatasetOp.
std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child);
std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child);
if (!sampler) { if (!sampler) {
std::string err_msg("Cannot add child, child is not a sampler object."); std::string err_msg("Cannot add child, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
@@ -160,9 +160,9 @@ Status Sampler::AddChild(std::shared_ptr<Sampler> child) {
return Status::OK(); return Status::OK();
} }


bool Sampler::HasChildSampler() { return !child_.empty(); }
bool SamplerRT::HasChildSampler() { return !child_.empty(); }


Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
if (child_ids_ == nullptr) { if (child_ids_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!");
} }


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h View File

@@ -51,21 +51,21 @@ class RandomAccessOp {
protected: protected:
// The amount of rows in the dataset itself. This is the before-sampling value, the // The amount of rows in the dataset itself. This is the before-sampling value, the
// total count of rows. A sampler may choose to sample less than this amount. // total count of rows. A sampler may choose to sample less than this amount.
int64_t num_rows_;
int64_t num_rows_ = -1;
}; };


class Sampler {
class SamplerRT {
public: public:
// Constructor // Constructor
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
// indicates that the sampler should produce the complete set of ids. // indicates that the sampler should produce the complete set of ids.
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);
explicit SamplerRT(int64_t num_samples, int64_t samples_per_buffer);


Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {}
SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {}


// default destructor // default destructor
~Sampler() = default;
~SamplerRT() = default;


// Get a list of sample ids. // Get a list of sample ids.
// @note It is Sampler responsibility to make sure that the id is not out of bound. // @note It is Sampler responsibility to make sure that the id is not out of bound.
@@ -111,7 +111,7 @@ class Sampler {
// Adds a sampler to become our child. // Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child. // @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned. // @return - The error code returned.
Status AddChild(std::shared_ptr<Sampler> child);
Status AddChild(std::shared_ptr<SamplerRT> child);


// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds // @param std::shared_ptr<Tensor>* sampleIds
@@ -129,7 +129,7 @@ class Sampler {
// @param out - reference to the output stream being overloaded // @param out - reference to the output stream being overloaded
// @param sampler - reference to teh sampler to print // @param sampler - reference to teh sampler to print
// @return - the output stream must be returned // @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) {
sampler.Print(out, false); sampler.Print(out, false);
return out; return out;
} }
@@ -158,7 +158,7 @@ class Sampler {


int64_t samples_per_buffer_; int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_; std::unique_ptr<ColDescriptor> col_desc_;
std::vector<std::shared_ptr<Sampler>> child_; // Child nodes
std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes
std::unique_ptr<DataBuffer> child_ids_; std::unique_ptr<DataBuffer> child_ids_;
}; };
} // namespace dataset } // namespace dataset


+ 7
- 7
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc View File

@@ -20,10 +20,10 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}


Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (id_count_ > num_samples_) { if (id_count_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error");
} else if (id_count_ == num_samples_) { } else if (id_count_ == num_samples_) {
@@ -62,7 +62,7 @@ Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer)
return Status::OK(); return Status::OK();
} }


Status SequentialSampler::InitSampler() {
Status SequentialSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
"Invalid parameter, start_index must be greater than or equal to 0, but got " + "Invalid parameter, start_index must be greater than or equal to 0, but got " +
std::to_string(start_index_) + ".\n"); std::to_string(start_index_) + ".\n");
@@ -85,7 +85,7 @@ Status SequentialSampler::InitSampler() {
return Status::OK(); return Status::OK();
} }


Status SequentialSampler::ResetSampler() {
Status SequentialSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late");
current_id_ = start_index_; current_id_ = start_index_;
id_count_ = 0; id_count_ = 0;
@@ -97,11 +97,11 @@ Status SequentialSampler::ResetSampler() {
return Status::OK(); return Status::OK();
} }


void SequentialSampler::Print(std::ostream &out, bool show_all) const {
void SequentialSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: SequentialSampler"; out << "\nSampler: SequentialSampler";
if (show_all) { if (show_all) {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info // Then add our own info
out << "\nStart index: " << start_index_; out << "\nStart index: " << start_index_;
} }


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h View File

@@ -23,18 +23,18 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class SequentialSampler : public Sampler {
class SequentialSamplerRT : public SamplerRT {
public: public:
// Constructor // Constructor
// @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the
// full amount of ids from the dataset // full amount of ids from the dataset
// @param start_index - The starting index value // @param start_index - The starting index value
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit SequentialSampler(int64_t num_samples, int64_t start_index,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit SequentialSamplerRT(int64_t num_samples, int64_t start_index,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());


// Destructor. // Destructor.
~SequentialSampler() = default;
~SequentialSamplerRT() = default;


// init sampler, called by python // init sampler, called by python
Status InitSampler() override; Status InitSampler() override;


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc View File

@@ -27,12 +27,12 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Constructor. // Constructor.
SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices,
int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}


// Initialized this Sampler. // Initialized this Sampler.
Status SubsetRandomSampler::InitSampler() {
Status SubsetRandomSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED( CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");


@@ -56,7 +56,7 @@ Status SubsetRandomSampler::InitSampler() {
} }


// Reset the internal variable to the initial state. // Reset the internal variable to the initial state.
Status SubsetRandomSampler::ResetSampler() {
Status SubsetRandomSamplerRT::ResetSampler() {
// Reset the internal counters. // Reset the internal counters.
sample_id_ = 0; sample_id_ = 0;
buffer_id_ = 0; buffer_id_ = 0;
@@ -73,7 +73,7 @@ Status SubsetRandomSampler::ResetSampler() {
} }


// Get the sample ids. // Get the sample ids.
Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status SubsetRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
// All samples have been drawn // All samples have been drawn
if (sample_id_ == num_samples_) { if (sample_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
@@ -120,11 +120,11 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
return Status::OK(); return Status::OK();
} }


void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const {
void SubsetRandomSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: SubsetRandomSampler"; out << "\nSampler: SubsetRandomSampler";
if (show_all) { if (show_all) {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any // Then add our own info if any
} }
} }


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h View File

@@ -25,18 +25,18 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Randomly samples elements from a given list of indices, without replacement. // Randomly samples elements from a given list of indices, without replacement.
class SubsetRandomSampler : public Sampler {
class SubsetRandomSamplerRT : public SamplerRT {
public: public:
// Constructor. // Constructor.
// @param num_samples The number of samples to draw. 0 for the full amount. // @param num_samples The number of samples to draw. 0 for the full amount.
// @param indices List of indices from where we will randomly draw samples. // @param indices List of indices from where we will randomly draw samples.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
explicit SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices,
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());


// Destructor. // Destructor.
~SubsetRandomSampler() = default;
~SubsetRandomSamplerRT() = default;


// Initialize the sampler. // Initialize the sampler.
// @return Status // @return Status


+ 9
- 9
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc View File

@@ -27,16 +27,16 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Constructor. // Constructor.
WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights,
bool replacement, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer),
weights_(weights), weights_(weights),
replacement_(replacement), replacement_(replacement),
sample_id_(0), sample_id_(0),
buffer_id_(0) {} buffer_id_(0) {}


// Initialized this Sampler. // Initialized this Sampler.
Status WeightedRandomSampler::InitSampler() {
Status WeightedRandomSamplerRT::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data. // Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) { if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -78,7 +78,7 @@ Status WeightedRandomSampler::InitSampler() {
} }


// Initialized the computation for generating weighted random numbers without replacement using onepass method. // Initialized the computation for generating weighted random numbers without replacement using onepass method.
void WeightedRandomSampler::InitOnePassSampling() {
void WeightedRandomSamplerRT::InitOnePassSampling() {
exp_dist_->reset(); exp_dist_->reset();
onepass_ids_.clear(); onepass_ids_.clear();
std::vector<std::pair<double, int64_t>> val_idx; std::vector<std::pair<double, int64_t>> val_idx;
@@ -94,7 +94,7 @@ void WeightedRandomSampler::InitOnePassSampling() {
} }


// Reset the internal variable to the initial state and reshuffle the indices. // Reset the internal variable to the initial state and reshuffle the indices.
Status WeightedRandomSampler::ResetSampler() {
Status WeightedRandomSamplerRT::ResetSampler() {
sample_id_ = 0; sample_id_ = 0;
buffer_id_ = 0; buffer_id_ = 0;
rand_gen_.seed(GetSeed()); rand_gen_.seed(GetSeed());
@@ -112,7 +112,7 @@ Status WeightedRandomSampler::ResetSampler() {
} }


// Get the sample ids. // Get the sample ids.
Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (weights_.size() > static_cast<size_t>(num_rows_)) { if (weights_.size() > static_cast<size_t>(num_rows_)) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, size of sample weights must be less than or equal to num of data, " "Invalid parameter, size of sample weights must be less than or equal to num of data, "
@@ -180,11 +180,11 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf
return Status::OK(); return Status::OK();
} }


void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const {
void WeightedRandomSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: WeightedRandomSampler"; out << "\nSampler: WeightedRandomSampler";
if (show_all) { if (show_all) {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any // Then add our own info if any
} }
} }


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h View File

@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). // Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights).
class WeightedRandomSampler : public Sampler {
class WeightedRandomSamplerRT : public SamplerRT {
public: public:
// Constructor. // Constructor.
// @param num_samples Number of samples to be drawn. // @param num_samples Number of samples to be drawn.
@@ -34,11 +34,11 @@ class WeightedRandomSampler : public Sampler {
// @param replacement Determine if samples are drawn with/without replacement. // @param replacement Determine if samples are drawn with/without replacement.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());


// Destructor. // Destructor.
~WeightedRandomSampler() = default;
~WeightedRandomSamplerRT() = default;


// Initialize the sampler. // Initialize the sampler.
// @param op (Not used in this sampler) // @param op (Not used in this sampler)


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc View File

@@ -84,7 +84,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
device_id_(device_id), device_id_(device_id),
num_devices_(num_device), num_devices_(num_device),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h View File

@@ -115,7 +115,7 @@ class TextFileOp : public ParallelOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -131,7 +131,7 @@ class TextFileOp : public ParallelOp {
std::vector<std::string> builder_text_files_list_; std::vector<std::string> builder_text_files_list_;
bool builder_shuffle_files_; bool builder_shuffle_files_;
std::unique_ptr<DataSchema> builder_schema_; std::unique_ptr<DataSchema> builder_schema_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
}; };


// Constructor of TextFileOp // Constructor of TextFileOp
@@ -148,7 +148,7 @@ class TextFileOp : public ParallelOp {
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);


// Default destructor // Default destructor
~TextFileOp() = default; ~TextFileOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc View File

@@ -58,7 +58,7 @@ TFReaderOp::Builder::Builder()
builder_data_schema_ = std::make_unique<DataSchema>(); builder_data_schema_ = std::make_unique<DataSchema>();
} }


bool ValidateFirstRowCrc(const std::string &filename) {
bool TFReaderOp::ValidateFirstRowCrc(const std::string &filename) {
std::ifstream reader; std::ifstream reader;
reader.open(filename); reader.open(filename);
if (!reader) { if (!reader) {
@@ -134,7 +134,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
int64_t total_num_rows, std::vector<std::string> dataset_files_list, int64_t total_num_rows, std::vector<std::string> dataset_files_list,
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler)
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
device_id_(device_id), device_id_(device_id),
num_devices_(num_device), num_devices_(num_device),


+ 5
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h View File

@@ -156,14 +156,14 @@ class TFReaderOp : public ParallelOp {
// Setter method // Setter method
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }


private: private:
std::unique_ptr<DataSchema> builder_data_schema_; std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_device_id_; int32_t builder_device_id_;
int32_t builder_num_devices_; int32_t builder_num_devices_;
int32_t builder_num_workers_; int32_t builder_num_workers_;
@@ -193,7 +193,7 @@ class TFReaderOp : public ParallelOp {
TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler);
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler);


// Default destructor // Default destructor
~TFReaderOp() = default; ~TFReaderOp() = default;
@@ -262,6 +262,8 @@ class TFReaderOp : public ParallelOp {
/// \return Status of the function /// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override; Status GetDatasetSize(int64_t *dataset_size) override;


static bool ValidateFirstRowCrc(const std::string &filename);

private: private:
// The entry point for when workers are launched. // The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.


+ 3
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc View File

@@ -62,7 +62,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
if (builder_sampler_ == nullptr) { if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; const int64_t num_samples = 0;
const int64_t start_index = 0; const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
} }
builder_schema_ = std::make_unique<DataSchema>(); builder_schema_ = std::make_unique<DataSchema>();
if (builder_task_type_ == TaskType::Segmentation) { if (builder_task_type_ == TaskType::Segmentation) {
@@ -102,7 +102,8 @@ Status VOCOp::Builder::SanityCheck() {


VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)), : ParallelOp(num_workers, queue_size, std::move(sampler)),
decode_(decode), decode_(decode),
row_cnt_(0), row_cnt_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h View File

@@ -118,7 +118,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// Setter method. // Setter method.
// @param std::shared_ptr<Sampler> sampler // @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler); builder_sampler_ = std::move(sampler);
return *this; return *this;
} }
@@ -148,7 +148,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
int32_t builder_num_workers_; int32_t builder_num_workers_;
int32_t builder_op_connector_size_; int32_t builder_op_connector_size_;
int32_t builder_rows_per_buffer_; int32_t builder_rows_per_buffer_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_; std::unique_ptr<DataSchema> builder_schema_;
std::map<std::string, int32_t> builder_labels_to_read_; std::map<std::string, int32_t> builder_labels_to_read_;
}; };
@@ -166,7 +166,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read // @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read
VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);


// Destructor // Destructor
~VOCOp() = default; ~VOCOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h View File

@@ -21,7 +21,7 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/dataset_op.h"


namespace mindspore::dataset::api {
namespace mindspore::dataset {


class DatasetCache { class DatasetCache {
public: public:
@@ -29,6 +29,6 @@ class DatasetCache {
virtual Status ValidateParams() = 0; virtual Status ValidateParams() = 0;
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0; virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
}; };
} // namespace mindspore::dataset::api
} // namespace mindspore::dataset


#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc View File

@@ -18,7 +18,7 @@
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h"


namespace mindspore::dataset::api {
namespace mindspore::dataset {


/// Method to initialize the DatasetCache by creating an instance of a CacheClient /// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code /// \return Status Error code
@@ -41,4 +41,4 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
return Status::OK(); return Status::OK();
} }


} // namespace mindspore::dataset::api
} // namespace mindspore::dataset

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h View File

@@ -24,7 +24,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h" #include "minddata/dataset/engine/ir/cache/dataset_cache.h"


namespace mindspore::dataset::api {
namespace mindspore::dataset {


/// DatasetCache is the IR of CacheClient /// DatasetCache is the IR of CacheClient
class DatasetCacheImpl : public DatasetCache { class DatasetCacheImpl : public DatasetCache {
@@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache {
std::optional<int32_t> num_connections_; std::optional<int32_t> num_connections_;
std::optional<int32_t> prefetch_sz_; std::optional<int32_t> prefetch_sz_;
}; };
} // namespace mindspore::dataset::api
} // namespace mindspore::dataset


#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc View File

@@ -26,7 +26,6 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
// constructor #1, called by Pybind // constructor #1, called by Pybind
@@ -96,6 +95,5 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
return node_ops; return node_ops;
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h View File

@@ -27,7 +27,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


class BatchNode : public DatasetNode { class BatchNode : public DatasetNode {
public: public:
@@ -66,7 +65,6 @@ class BatchNode : public DatasetNode {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_; std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc View File

@@ -27,7 +27,7 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {
BucketBatchByLengthNode::BucketBatchByLengthNode( BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
@@ -121,6 +121,5 @@ Status BucketBatchByLengthNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h View File

@@ -27,7 +27,7 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {
class BucketBatchByLengthNode : public DatasetNode { class BucketBatchByLengthNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
@@ -58,7 +58,6 @@ class BucketBatchByLengthNode : public DatasetNode {
bool drop_remainder_; bool drop_remainder_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc View File

@@ -26,7 +26,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child,
std::shared_ptr<SentencePieceVocab> vocab, std::shared_ptr<SentencePieceVocab> vocab,
@@ -77,6 +76,6 @@ Status BuildSentenceVocabNode::ValidateParams() {


return Status::OK(); return Status::OK();
} }
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h View File

@@ -27,7 +27,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


class BuildSentenceVocabNode : public DatasetNode { class BuildSentenceVocabNode : public DatasetNode {
public: public:
@@ -56,7 +55,6 @@ class BuildSentenceVocabNode : public DatasetNode {
std::unordered_map<std::string, std::string> params_; std::unordered_map<std::string, std::string> params_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_ #endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc View File

@@ -26,7 +26,6 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab, BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
@@ -78,6 +77,6 @@ Status BuildVocabNode::ValidateParams() {


return Status::OK(); return Status::OK();
} }
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h View File

@@ -26,7 +26,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


class BuildVocabNode : public DatasetNode { class BuildVocabNode : public DatasetNode {
public: public:
@@ -55,7 +54,6 @@ class BuildVocabNode : public DatasetNode {
bool special_first_; bool special_first_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -25,7 +25,7 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {
// Function to build ConcatOp // Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; }


@@ -53,6 +53,5 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
return node_ops; return node_ops;
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h View File

@@ -25,7 +25,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


class ConcatNode : public DatasetNode { class ConcatNode : public DatasetNode {
public: public:
@@ -44,7 +43,6 @@ class ConcatNode : public DatasetNode {
Status ValidateParams() override; Status ValidateParams() override;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_

+ 177
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -16,11 +16,187 @@


#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"


#include <algorithm>
#include <memory> #include <memory>
#include <set>

#include "minddata/dataset/util/random.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {

// Helper function to compute a default shuffle size
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int64_t *shuffle_size) {
const int64_t average_files_multiplier = 4;
const int64_t shuffle_max = 10000;
int64_t avg_rows_per_file = 0;

// Adjust the num rows per shard if sharding was given
if (num_devices > 0) {
if (num_rows % num_devices == 0) {
num_rows = num_rows / num_devices;
} else {
num_rows = (num_rows / num_devices) + 1;
}
}

// Cap based on total rows directive. Some ops do not have this and give value of 0.
if (total_rows > 0) {
num_rows = std::min(num_rows, total_rows);
}

// get the average per file
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0.");
avg_rows_per_file = num_rows / num_files;

*shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
return Status::OK();
}

// Helper function to inject a shuffle operator over top of current operator being built
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) {
std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr;
int64_t shuffle_size = 0;
RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
// Add the shuffle op
*shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer);
return Status::OK();
}

// Helper function to validate dataset directory parameter
Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
if (dataset_dir.empty()) {
std::string err_msg = dataset_name + ": dataset_dir is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

Path dir(dataset_dir);
if (!dir.IsDirectory()) {
std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (access(dataset_dir.c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset files parameter
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
if (dataset_files.empty()) {
std::string err_msg = dataset_name + ": dataset_files is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (access(dataset_file.toString().c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}

return Status::OK();
}

// Helper function to validate dataset num_shards and shard_id parameters
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
if (num_shards <= 0) {
std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (shard_id < 0 || shard_id >= num_shards) {
// num_shards;
std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
", num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
if (sampler == nullptr) {
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings) {
if (valid_strings.find(str) == valid_strings.end()) {
std::string mode;
mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode,
[](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

// Helper function to validate dataset input/output column parameter
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
const std::vector<std::string> &columns) {
if (columns.empty()) {
std::string err_msg = dataset_name + ":" + column_param + " should not be empty string";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (uint32_t i = 0; i < columns.size(); ++i) {
if (columns[i].empty()) {
std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
std::set<std::string> columns_set(columns.begin(), columns.end());
if (columns_set.size() != columns.size()) {
std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
if (shuffle) {
if (num_shards > 1) {
// If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle enabled, sharding disabled, use random sampler
return RandomSampler(num_samples >= 0, num_samples);
}
if (num_shards > 1) {
// If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(0, num_samples);
}


Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
if (cache_ != nullptr) { if (cache_ != nullptr) {
@@ -60,6 +236,5 @@ DatasetNode::DatasetNode() {
worker_connector_size_ = cfg->worker_connector_size(); worker_connector_size_ = cfg->worker_connector_size();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -28,7 +28,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


class Dataset; class Dataset;
class SamplerObj; class SamplerObj;
@@ -120,7 +119,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
int32_t worker_connector_size_; int32_t worker_connector_size_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc View File

@@ -26,7 +26,6 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns, std::vector<std::string> input_columns, std::vector<std::string> output_columns,
@@ -86,6 +85,5 @@ Status MapNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h View File

@@ -25,7 +25,7 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {
class MapNode : public DatasetNode { class MapNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
@@ -51,7 +51,6 @@ class MapNode : public DatasetNode {
std::vector<std::string> project_columns_; std::vector<std::string> project_columns_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc View File

@@ -25,7 +25,6 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Function to build ProjectOp // Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
@@ -53,6 +52,5 @@ std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
return node_ops; return node_ops;
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h View File

@@ -26,8 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


namespace api {

class ProjectNode : public DatasetNode { class ProjectNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
@@ -48,7 +46,6 @@ class ProjectNode : public DatasetNode {
std::vector<std::string> columns_; std::vector<std::string> columns_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc View File

@@ -25,7 +25,7 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {
// Function to build RenameOp // Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns) const std::vector<std::string> &output_columns)
@@ -54,6 +54,6 @@ std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_)); node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops; return node_ops;
} }
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h View File

@@ -26,8 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


namespace api {

class RenameNode : public DatasetNode { class RenameNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
@@ -50,7 +48,6 @@ class RenameNode : public DatasetNode {
std::vector<std::string> output_columns_; std::vector<std::string> output_columns_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -25,7 +25,6 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child); this->children.push_back(child);
@@ -49,6 +48,6 @@ Status RepeatNode::ValidateParams() {


return Status::OK(); return Status::OK();
} }
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -28,8 +28,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


namespace api {

class RepeatNode : public DatasetNode { class RepeatNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
@@ -50,7 +48,6 @@ class RepeatNode : public DatasetNode {
int32_t repeat_count_; int32_t repeat_count_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc View File

@@ -25,7 +25,6 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Constructor for ShuffleNode // Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch)
@@ -54,6 +53,5 @@ Status ShuffleNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h View File

@@ -28,8 +28,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


namespace api {

class ShuffleNode : public DatasetNode { class ShuffleNode : public DatasetNode {
public: public:
ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch); ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch);
@@ -46,7 +44,6 @@ class ShuffleNode : public DatasetNode {
bool reset_every_epoch_; bool reset_every_epoch_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc View File

@@ -25,7 +25,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


// Constructor for SkipNode // Constructor for SkipNode
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) {
@@ -52,6 +51,5 @@ Status SkipNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h View File

@@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


namespace api {
class SkipNode : public DatasetNode { class SkipNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
@@ -46,7 +45,7 @@ class SkipNode : public DatasetNode {
private: private:
int32_t skip_count_; int32_t skip_count_;
}; };
} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc View File

@@ -27,7 +27,7 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {
// Constructor for AlbumNode // Constructor for AlbumNode
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode, const std::vector<std::string> &column_names, bool decode,
@@ -78,6 +78,5 @@ Status AlbumNode::GetShardId(int32_t *shard_id) {
return Status::OK(); return Status::OK();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h View File

@@ -25,7 +25,6 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {


class AlbumNode : public DatasetNode { class AlbumNode : public DatasetNode {
public: public:
@@ -57,7 +56,6 @@ class AlbumNode : public DatasetNode {
std::shared_ptr<SamplerObj> sampler_; std::shared_ptr<SamplerObj> sampler_;
}; };


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -26,7 +26,7 @@
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace api {
// Constructor for CelebANode // Constructor for CelebANode
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode, const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
@@ -76,6 +76,5 @@ Status CelebANode::GetShardId(int32_t *shard_id) {
return Status::OK(); return Status::OK();
} }


} // namespace api
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save