Browse Source

Remove api namespace

tags/v1.1.0
hesham 5 years ago
parent
commit
5169fb4c42
100 changed files with 490 additions and 565 deletions
  1. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/config.cc
  2. +14
    -207
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  3. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/execute.cc
  4. +4
    -6
      mindspore/ccsrc/minddata/dataset/api/iterator.cc
  5. +22
    -22
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc
  6. +24
    -24
      mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc
  7. +13
    -15
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  8. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/text.cc
  9. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/transforms.cc
  10. +0
    -2
      mindspore/ccsrc/minddata/dataset/api/vision.cc
  11. +6
    -6
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  12. +6
    -10
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  13. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
  14. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
  15. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc
  16. +5
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h
  17. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
  18. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
  19. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
  20. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h
  21. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc
  22. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h
  23. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
  24. +6
    -6
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
  25. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc
  26. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h
  27. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc
  28. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h
  29. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc
  30. +4
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h
  31. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc
  32. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h
  33. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
  34. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h
  35. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc
  36. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h
  37. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc
  38. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h
  39. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
  40. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
  41. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
  42. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h
  43. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc
  44. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h
  45. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc
  46. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h
  47. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc
  48. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h
  49. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  50. +5
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h
  51. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc
  52. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h
  53. +7
    -7
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc
  54. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h
  55. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc
  56. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h
  57. +14
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc
  58. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h
  59. +7
    -7
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
  60. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h
  61. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
  62. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
  63. +9
    -9
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
  64. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
  65. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc
  66. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h
  67. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc
  68. +5
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h
  69. +3
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
  70. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h
  71. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h
  72. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc
  73. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h
  74. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc
  75. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h
  76. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc
  77. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h
  78. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc
  79. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h
  80. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc
  81. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h
  82. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  83. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h
  84. +177
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  85. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  86. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc
  87. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h
  88. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc
  89. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h
  90. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc
  91. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h
  92. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  93. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  94. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc
  95. +0
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h
  96. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc
  97. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h
  98. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc
  99. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h
  100. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/config.cc View File

@@ -21,7 +21,6 @@

namespace mindspore {
namespace dataset {
namespace api {

// Config operations for setting and getting the configuration.
namespace config {
@@ -104,6 +103,5 @@ bool load(std::string file) {
}

} // namespace config
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 14
- 207
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -21,36 +21,14 @@
#include <utility>
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"

#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"

#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#endif
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"

// Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"

// IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
@@ -99,7 +77,6 @@

namespace mindspore {
namespace dataset {
namespace api {

// Function to create the iterator, which will build and launch the execution tree.
std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
@@ -317,7 +294,7 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) {
return schema->init() ? schema : nullptr;
}

// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// FUNCTIONS TO CREATE DATASETS FOR LEAF CLASSES
// (In alphabetical order)

// Function to create a AlbumDataset.
@@ -466,7 +443,7 @@ std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::strin
}
#endif

// Function to create a ZipNode.
// Function to create a ZipDatset.
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
auto ds = std::make_shared<ZipDataset>(datasets);
return ds;
@@ -639,7 +616,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
return nullptr;
}

@@ -647,15 +624,15 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
BuildVocabConsumer *bv_consumer = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc;
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
return nullptr;
}
runtime_context->AssignConsumer(std::move(consumer));

// Run tree here to starting building vocab
// Run tree here to starting building SentencePieceVocab
rc = bv_consumer->Start();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc;
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc;
return nullptr;
}
return vocab;
@@ -671,7 +648,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
return nullptr;
}

@@ -679,7 +656,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
BuildVocabConsumer *bv_consumer = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc;
MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
return nullptr;
}
runtime_context->AssignConsumer(std::move(consumer));
@@ -687,11 +664,14 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
// Run tree here to starting building vocab
rc = bv_consumer->Start();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc;
MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc;
return nullptr;
}
return vocab;
}
std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
}
#endif
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}

@@ -856,162 +836,6 @@ bool SchemaObj::from_json(nlohmann::json json_obj) {

// OTHER FUNCTIONS

// Helper function to compute a default shuffle size
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int64_t *shuffle_size) {
const int64_t average_files_multiplier = 4;
const int64_t shuffle_max = 10000;
int64_t avg_rows_per_file = 0;

// Adjust the num rows per shard if sharding was given
if (num_devices > 0) {
if (num_rows % num_devices == 0) {
num_rows = num_rows / num_devices;
} else {
num_rows = (num_rows / num_devices) + 1;
}
}

// Cap based on total rows directive. Some ops do not have this and give value of 0.
if (total_rows > 0) {
num_rows = std::min(num_rows, total_rows);
}

// get the average per file
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0.");
avg_rows_per_file = num_rows / num_files;

*shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
return Status::OK();
}

// Helper function to inject a shuffle operator over top of current operator being built
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) {
std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr;
int64_t shuffle_size = 0;
RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
// Add the shuffle op
*shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer);
return Status::OK();
}

// Helper function to validate dataset directory parameter
Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
if (dataset_dir.empty()) {
std::string err_msg = dataset_name + ": dataset_dir is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

Path dir(dataset_dir);
if (!dir.IsDirectory()) {
std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (access(dataset_dir.c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset files parameter
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
if (dataset_files.empty()) {
std::string err_msg = dataset_name + ": dataset_files is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (access(dataset_file.toString().c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}

return Status::OK();
}

// Helper function to validate dataset num_shards and shard_id parameters
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
if (num_shards <= 0) {
std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (shard_id < 0 || shard_id >= num_shards) {
// num_shards;
std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
", num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
if (sampler == nullptr) {
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings) {
if (valid_strings.find(str) == valid_strings.end()) {
std::string mode;
mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode,
[](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

// Helper function to validate dataset input/output column parameter
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
const std::vector<std::string> &columns) {
if (columns.empty()) {
std::string err_msg = dataset_name + ":" + column_param + " should not be empty string";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (uint32_t i = 0; i < columns.size(); ++i) {
if (columns[i].empty()) {
std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
std::set<std::string> columns_set(columns.begin(), columns.end());
if (columns_set.size() != columns.size()) {
std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

#ifndef ENABLE_ANDROID

std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill,
@@ -1153,22 +977,5 @@ TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files,
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#endif
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
if (shuffle) {
if (num_shards > 1) {
// If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle enabled, sharding disabled, use random sampler
return RandomSampler(num_samples >= 0, num_samples);
}
if (num_shards > 1) {
// If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(0, num_samples);
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/execute.cc View File

@@ -26,7 +26,6 @@

namespace mindspore {
namespace dataset {
namespace api {

Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {}

@@ -54,6 +53,5 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS
return std::make_shared<tensor::DETensor>(std::move(de_output));
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 6
mindspore/ccsrc/minddata/dataset/api/iterator.cc View File

@@ -20,7 +20,6 @@

namespace mindspore {
namespace dataset {
namespace api {

// Get the next row from the data pipeline.
bool Iterator::GetNextRow(TensorMap *row) {
@@ -45,19 +44,18 @@ bool Iterator::GetNextRow(TensorVec *row) {
}

// Shut down the data pipeline.
void Iterator::Stop() { runtime_context->Terminate(); }
void Iterator::Stop() { runtime_context_->Terminate(); }

// Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
runtime_context = std::make_unique<RuntimeContext>();
RETURN_IF_NOT_OK(runtime_context->Init());
runtime_context_ = std::make_unique<RuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get();
RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
runtime_context->AssignConsumer(std::move(consumer));
runtime_context_->AssignConsumer(std::move(consumer));
return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 22
- 22
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc View File

@@ -27,59 +27,59 @@
namespace mindspore {
namespace dataset {

PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) {
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
.def("set_num_rows",
[](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples",
[](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices",
[](Sampler &self) {
[](SamplerRT &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child", [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) {
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
THROW_IF_ERROR(self->AddChild(child));
});
}));

PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) {
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
}));

PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) {
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
.def(py::init<int64_t, int64_t, bool>());
}));

PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) {
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
.def(py::init<int64_t, py::object>());
}));

PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) {
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
.def(py::init<int64_t, bool, bool>());
}));

PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) {
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m,
"SequentialSampler")
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
*m, "SequentialSampler")
.def(py::init<int64_t, int64_t>());
}));

PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(
PYBIND_REGISTER(SubsetRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerRT, SamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
*m, "SubsetRandomSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));

PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>());
}));


+ 24
- 24
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc View File

@@ -1140,7 +1140,7 @@ Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp
if (!value.is_none()) {
if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
}
if (key == "children_flag_and_nums") {
@@ -1164,7 +1164,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
// Required arguments
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
if (!args["dataset_files"].is_none()) {
@@ -1210,7 +1210,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@@ -1234,7 +1234,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
} else if (cache_client) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}

@@ -1308,7 +1308,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "extensions") {
(void)builder->SetExtensions(ToStringSet(value));
@@ -1363,7 +1363,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "class_indexing") {
(void)builder->SetClassIndex(ToStringMap(value));
@@ -1416,7 +1416,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
@@ -1478,7 +1478,7 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp>
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
@@ -1529,7 +1529,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
@@ -1583,7 +1583,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
@@ -1618,7 +1618,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
// Required arguments
RandomDataOp::Builder builder;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;

if (args["total_rows"].is_none()) {
@@ -1646,7 +1646,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@@ -1670,7 +1670,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
} else if (cache_client) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder.SetSampler(std::move(sampler));
}

@@ -1715,7 +1715,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
@@ -1768,7 +1768,7 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
@@ -1806,7 +1806,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
// Required arguments
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
if (!args["dataset_files"].is_none()) {
@@ -1840,7 +1840,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@@ -1855,7 +1855,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}

@@ -1991,7 +1991,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;

std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
@@ -2036,7 +2036,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@@ -2051,7 +2051,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}

@@ -2116,7 +2116,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
if (!args["dataset_files"].is_none()) {
@@ -2173,7 +2173,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@@ -2188,7 +2188,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}



+ 13
- 15
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -35,7 +35,6 @@

namespace mindspore {
namespace dataset {
namespace api {

#define RETURN_NULL_IF_ERROR(_s) \
do { \
@@ -151,10 +150,10 @@ bool DistributedSamplerObj::ValidateParams() {
return true;
}

std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
return sampler;
}

@@ -184,9 +183,9 @@ bool PKSamplerObj::ValidateParams() {
return true;
}

std::shared_ptr<Sampler> PKSamplerObj::Build() {
std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_);
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);

return sampler;
}
@@ -218,10 +217,10 @@ bool RandomSamplerObj::ValidateParams() {
return true;
}

std::shared_ptr<Sampler> RandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
// runtime sampler object
bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch);
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);

return sampler;
}
@@ -255,9 +254,9 @@ bool SequentialSamplerObj::ValidateParams() {
return true;
}

std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_);
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);

return sampler;
}
@@ -284,9 +283,9 @@ bool SubsetRandomSamplerObj::ValidateParams() {
return true;
}

std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_);
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);

return sampler;
}
@@ -330,11 +329,10 @@ bool WeightedRandomSamplerObj::ValidateParams() {
return true;
}

std::shared_ptr<Sampler> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSampler>(num_samples_, weights_, replacement_);
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
return sampler;
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/text.cc View File

@@ -22,7 +22,6 @@

namespace mindspore {
namespace dataset {
namespace api {

// Transform operations for text.
namespace text {
@@ -130,6 +129,5 @@ std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
}

} // namespace text
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/transforms.cc View File

@@ -22,7 +22,6 @@

namespace mindspore {
namespace dataset {
namespace api {

TensorOperation::TensorOperation() {}

@@ -94,6 +93,5 @@ Status TypeCastOperation::ValidateParams() {
std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }

} // namespace transforms
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/api/vision.cc View File

@@ -65,7 +65,6 @@

namespace mindspore {
namespace dataset {
namespace api {

// Transform operations for computer vision.
namespace vision {
@@ -1702,6 +1701,5 @@ std::shared_ptr<TensorOp> UniformAugOperation::Build() {
#endif

} // namespace vision
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 6
- 6
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -34,11 +34,11 @@ namespace mindspore::dataset {
// TreeConsumer
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }

Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }

// IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}

@@ -74,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
}

// ToDevice
Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) {
Status ToDevice::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}

@@ -385,8 +385,8 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal
tree_adapter_ = std::make_unique<TreeAdapter>();
}

Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d), 1);
if (!s.IsError()) {
init_flag_ = true;
}
@@ -464,7 +464,7 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK();
}
Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), 1);
}
Status BuildVocabConsumer::Start() {


+ 6
- 10
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -29,10 +29,7 @@
namespace mindspore::dataset {
// Forward declare
class TreeAdapter;

namespace api {
class DatasetNode;
}

/// A base class for tree consumers which would fetch rows from the tree pipeline
class TreeConsumer {
@@ -42,7 +39,7 @@ class TreeConsumer {
/// Initializes the consumer, this involves constructing and preparing the tree.
/// \param d The dataset node that represent the root of the IR tree.
/// \return Status error code.
virtual Status Init(std::shared_ptr<api::DatasetNode> d);
virtual Status Init(std::shared_ptr<DatasetNode> d);

Status Terminate();

@@ -61,7 +58,7 @@ class IteratorConsumer : public TreeConsumer {
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}

Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;

/// Returns the next row in a vector format
/// \param[out] out std::vector of Tensors
@@ -133,7 +130,7 @@ class ToDevice : public TreeConsumer {
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}

Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;

/// Send the data to device
/// \return Status error code
@@ -162,7 +159,7 @@ class ToDevice : public TreeConsumer {
class TreeGetters : public TreeConsumer {
public:
TreeGetters();
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
@@ -185,10 +182,9 @@ class BuildVocabConsumer : public TreeConsumer {
/// BuildVocabConsumer Constructor which will call the base class default constructor.
BuildVocabConsumer() = default;

Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;

/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
/// would be written to disk)
/// Start consuming
/// \return Status error code
Status Start();



+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc View File

@@ -46,7 +46,7 @@ Status CacheBase::Reset() {
return Status::OK();
}
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
row_cnt_(0),
num_cache_miss_(0),


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h View File

@@ -46,7 +46,7 @@ class CacheBase : public ParallelOp {
/// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor
~CacheBase();



+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc View File

@@ -87,7 +87,7 @@ Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
leaf_op_wp_.Set();
return Status::OK();
}
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); }
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
std::vector<row_id_type> cache_miss;


+ 5
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h View File

@@ -28,7 +28,7 @@ namespace dataset {
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
/// \note For non-mappable dataset, please see CacheOp
/// \see CacheOp
class CacheLookupOp : public CacheBase, public Sampler {
class CacheLookupOp : public CacheBase, public SamplerRT {
public:
class Builder {
public:
@@ -62,7 +62,7 @@ class CacheLookupOp : public CacheBase, public Sampler {

/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
@@ -77,7 +77,7 @@ class CacheLookupOp : public CacheBase, public Sampler {
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;

// Check if the required parameters are set by the builder.
// \return Status The error code return
@@ -87,8 +87,8 @@ class CacheLookupOp : public CacheBase, public Sampler {
/// \note It takes the same argument as the base class.
/// \see CacheBase
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), SamplerRT(*(sampler.get())) {}
~CacheLookupOp() = default;
// As a parallel op, we override these two functions
Status operator()() override;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc View File

@@ -46,7 +46,7 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
}

CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler),
num_cleaners_(numCleaners),
cache_client_(std::move(cache_client)),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h View File

@@ -110,7 +110,7 @@ class CacheMergeOp : public ParallelOp {
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
@@ -133,7 +133,7 @@ class CacheMergeOp : public ParallelOp {
int32_t build_op_connector_size_;
int32_t build_num_cleaners_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;

/// Check if the required parameters are set by the builder.
/// \return Status The error code return
@@ -147,7 +147,7 @@ class CacheMergeOp : public ParallelOp {
/// \param cache_client CacheClient to commmunicate with the Cache server
/// \param sampler as a derived class of ParallelOp
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler);
~CacheMergeOp();
void Print(std::ostream &out, bool show_all) const override;
std::string Name() const override { return kCacheMergeOp; }


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc View File

@@ -68,7 +68,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {

// Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)),
num_guys_in_(0),
phase_(Phase::kBuildPhase) {}


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h View File

@@ -81,7 +81,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
@@ -96,7 +96,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;

/// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return
@@ -108,7 +108,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \param num_workers The number of worker threads.
/// \param op_connector_size The size of each queue in the connector.
CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler);

// Destructor
~CacheOp();


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc View File

@@ -36,7 +36,7 @@ ConcatOp::Builder::Builder() {
// The builder "build" method creates the final object.
Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<DistributedSampler>(0, 1, 0, false);
builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false);
}
*ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_,
children_start_end_index_);
@@ -44,7 +44,7 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
}

// Constructor of the ConcatOp.
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index)
: PipelineOp(op_connector_size),
@@ -80,7 +80,7 @@ Status ConcatOp::operator()() {
bool is_not_mappable = true;
int num_shard = 1;
int shard_index = 0;
std::shared_ptr<DistributedSampler> distribute_sampler = std::dynamic_pointer_cast<DistributedSampler>(sampler_);
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_);
if (distribute_sampler != nullptr) {
num_shard = distribute_sampler->GetDeviceNum();
shard_index = distribute_sampler->GetDeviceID();


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h View File

@@ -44,7 +44,7 @@ class ConcatOp : public PipelineOp {
// The builder "build" method creates the final object.
// @return shared_ptr to the new ConcatOp object
Status Build(std::shared_ptr<ConcatOp> *);
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -61,7 +61,7 @@ class ConcatOp : public PipelineOp {

private:
int32_t builder_op_connector_size_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
};
@@ -70,7 +70,7 @@ class ConcatOp : public PipelineOp {
// @note The builder class should be used to call it
// @param op_connector_size - connector size
explicit ConcatOp(int32_t op_connector_size);
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index);

@@ -123,7 +123,7 @@ class ConcatOp : public PipelineOp {
std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name
std::vector<DataType> data_type_;
std::vector<dsize_t> data_rank_;
std::shared_ptr<Sampler> sampler_;
std::shared_ptr<SamplerRT> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
};


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc View File

@@ -40,7 +40,7 @@
namespace mindspore {
namespace dataset {
// Constructor
DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: oc_queue_size_(op_connector_size),
sampler_(sampler),
operator_id_(kInvalidOperatorId),
@@ -409,7 +409,7 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
}

// Getter for the sampler, and it also removes the sampler from the op
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) {
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler) {
*sampler = sampler_; // It's okay if it sampler_ points to nullptr
sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler
return Status::OK();


+ 6
- 6
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h View File

@@ -62,7 +62,7 @@ class DataBuffer;

class NodePass;

class Sampler;
class SamplerRT;

/// \brief The base class DatasetOp is the main tree node. It is an abstract class, so
/// the actual implementation of the operators will be derived from here.
@@ -80,7 +80,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// Constructor
/// \param op_connector_size - The size for the output connector of this operator.
/// \param sampler - The sampler for the op
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler);
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler);

/// Destructor
virtual ~DatasetOp() { tree_ = nullptr; }
@@ -347,12 +347,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {

/// Getter for the sampler
/// \return Shared pointer to the sampler (may return nullptr)
std::shared_ptr<Sampler> sampler() { return sampler_; }
std::shared_ptr<SamplerRT> sampler() { return sampler_; }

/// \brief Getter for the sampler, and it also removes the sampler from the op
/// \param[out] sampler A pointer to the output sampler that was removed
/// \return Status error code
Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler);
Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler);

#ifndef ENABLE_ANDROID
// Computes a CRC value for the operator
@@ -368,7 +368,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
}

/// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one.
void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; }
void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; }

/// \brief Checks if this is a leaf node (0 children)
/// \return boolean returns true if it's a leaf
@@ -409,7 +409,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {

std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler
int32_t oc_queue_size_; // Capacity for each out_connector_
int32_t operator_id_; // Generated id for the node
ExecutionTree *tree_; // Back pointer to our tree.


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc View File

@@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
// Constructor
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: DatasetOp(op_connector_size, sampler),
num_workers_(num_workers),
num_producers_(num_workers),


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h View File

@@ -41,7 +41,7 @@ class ParallelOp : public DatasetOp {
// @param num_workers
// @param op_connector_size - size of the output connector for this operator
// @param sampler - The sampler for the op
ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr);
ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr);

// Destructor
~ParallelOp() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc View File

@@ -20,7 +20,7 @@
namespace mindspore {
namespace dataset {
// Constructor
PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: DatasetOp(op_connector_size, sampler) {}

// A print method typically used for debugging


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h View File

@@ -34,7 +34,7 @@ class PipelineOp : public DatasetOp {
// @param op_connector_size - size of the output connector
// @return Builder setter method returns reference to the builder.
// @param sampler - The sampler for the op
explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr);
explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr);

// Destructor
~PipelineOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc View File

@@ -42,7 +42,7 @@ Status AlbumOp::Builder::Build(std::shared_ptr<AlbumOp> *ptr) {
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}

builder_schema_ = std::make_unique<DataSchema>();
@@ -73,7 +73,7 @@ Status AlbumOp::Builder::SanityCheck() {

AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_wkrs, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir),


+ 4
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h View File

@@ -100,7 +100,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
/// \brief Setter method
/// \param[in] sampler
/// \return Builder setter method returns reference to the builder
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -147,7 +147,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
};

@@ -161,7 +161,8 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
/// \param[in] data_schema - schema of dataset
/// \param[in] sampler - sampler tells AlbumOp what to read
AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler);

/// \brief Destructor.
~AlbumOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc View File

@@ -46,7 +46,7 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}

builder_schema_ = std::make_unique<DataSchema>();
@@ -79,7 +79,7 @@ Status CelebAOp::Builder::SanityCheck() {

CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
bool decode, const std::string &usage, const std::set<std::string> &exts,
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
folder_path_(dir),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h View File

@@ -95,7 +95,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -131,7 +131,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
std::string builder_usage_;
};
@@ -144,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode,
const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
std::shared_ptr<Sampler> sampler);
std::shared_ptr<SamplerRT> sampler);

~CelebAOp() override = default;



+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc View File

@@ -50,7 +50,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
if (sampler_ == nullptr) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}
schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
@@ -88,7 +88,7 @@ Status CifarOp::Builder::SanityCheck() {

CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_works, queue_size, std::move(sampler)),
cifar_type_(type),
usage_(usage),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h View File

@@ -75,7 +75,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
sampler_ = std::move(sampler);
return *this;
}
@@ -123,7 +123,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t op_connect_size_;
std::shared_ptr<Sampler> sampler_;
std::shared_ptr<SamplerRT> sampler_;
std::unique_ptr<DataSchema> schema_;
CifarType cifar_type_;
};
@@ -138,7 +138,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler);
std::shared_ptr<SamplerRT> sampler);
// Destructor.
~CifarOp() = default;



+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc View File

@@ -94,7 +94,7 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim

ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h View File

@@ -125,7 +125,7 @@ class ClueOp : public ParallelOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -141,13 +141,13 @@ class ClueOp : public ParallelOp {
std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
};

// Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);

// Default destructor
~ClueOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc View File

@@ -60,7 +60,7 @@ Status CocoOp::Builder::Build(std::shared_ptr<CocoOp> *ptr) {
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
@@ -123,7 +123,7 @@ Status CocoOp::Builder::SanityCheck() {

CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)),
decode_(decode),
row_cnt_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h View File

@@ -119,7 +119,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// Setter method.
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -149,7 +149,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int32_t builder_rows_per_buffer_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
};

@@ -166,7 +166,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read
CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);

// Destructor
~CocoOp() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc View File

@@ -77,7 +77,7 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h View File

@@ -243,7 +243,7 @@ class CsvOp : public ParallelOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -261,7 +261,7 @@ class CsvOp : public ParallelOp {
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
};

// Constructor of CsvOp
@@ -271,7 +271,7 @@ class CsvOp : public ParallelOp {
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
std::shared_ptr<Sampler> sampler);
std::shared_ptr<SamplerRT> sampler);

// Default destructor
~CsvOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -38,7 +38,7 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
@@ -68,7 +68,7 @@ Status ImageFolderOp::Builder::SanityCheck() {
ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
bool recursive, bool do_decode, const std::set<std::string> &exts,
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_wkrs, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h View File

@@ -113,7 +113,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -151,7 +151,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
std::map<std::string, int32_t> builder_labels_to_read_;
};
@@ -165,7 +165,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive,
bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler);

// Destructor.
~ImageFolderOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc View File

@@ -43,7 +43,7 @@ Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) {
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
@@ -67,7 +67,7 @@ Status ManifestOp::Builder::SanityCheck() {

ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage)
std::shared_ptr<SamplerRT> sampler, std::string usage)
: ParallelOp(num_works, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
io_block_pushed_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h View File

@@ -88,7 +88,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -119,7 +119,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
Status Build(std::shared_ptr<ManifestOp> *op);

private:
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
bool builder_decode_;

std::string builder_file_;
@@ -139,7 +139,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage);
std::shared_ptr<SamplerRT> sampler, std::string usage);
// Destructor.
~ManifestOp() = default;



+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc View File

@@ -45,7 +45,7 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
@@ -75,7 +75,7 @@ Status MnistOp::Builder::SanityCheck() {
}

MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)),
usage_(usage),
buf_cnt_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h View File

@@ -78,7 +78,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -113,7 +113,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
int32_t builder_num_workers_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
};

@@ -126,7 +126,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);

// Destructor.
~MnistOp() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc View File

@@ -65,7 +65,7 @@ Status RandomDataOp::Builder::SanityCheck() const {

// Constructor for RandomDataOp
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
buffer_id_(0),
rows_per_buffer_(rows_per_buffer),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h View File

@@ -120,7 +120,7 @@ class RandomDataOp : public ParallelOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -133,7 +133,7 @@ class RandomDataOp : public ParallelOp {
Status SanityCheck() const;

std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
@@ -152,7 +152,7 @@ class RandomDataOp : public ParallelOp {
* @return Builder - The modified builder by reference
*/
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);

/**
* Destructor


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -23,9 +23,9 @@

namespace mindspore {
namespace dataset {
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed, int64_t offset, bool even_dist)
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed, int64_t offset, bool even_dist)
: SamplerRT(num_samples, std::numeric_limits<int64_t>::max()),
cnt_(0),
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
device_id_(dev_id),
@@ -35,7 +35,7 @@ DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int
offset_(offset),
non_empty_(true) {}

Status DistributedSampler::InitSampler() {
Status DistributedSamplerRT::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -74,7 +74,7 @@ Status DistributedSampler::InitSampler() {
return Status::OK();
}

Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status DistributedSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (cnt_ > samples_per_buffer_) {
RETURN_STATUS_UNEXPECTED(
"Number of samples(cnt) that have already been filled in to buffer should be less than or "
@@ -143,7 +143,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
return Status::OK();
}

Status DistributedSampler::ResetSampler() {
Status DistributedSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
cnt_ = 0;

@@ -160,10 +160,10 @@ Status DistributedSampler::ResetSampler() {
return Status::OK();
}

void DistributedSampler::Print(std::ostream &out, bool show_all) const {
void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: DistributedSampler";
if (show_all) {
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_
<< "\nshuffle: " << shuffle_;
}


+ 5
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h View File

@@ -25,7 +25,7 @@

namespace mindspore {
namespace dataset {
class DistributedSampler : public Sampler {
class DistributedSamplerRT : public SamplerRT {
public:
/// \brief Constructor
/// \param[in] num_samples The total number of rows in the dataset
@@ -40,11 +40,12 @@ class DistributedSampler : public Sampler {
/// This option is not exposed in the python API. Current behavior is that the remainder will always
/// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set,
/// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario.
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true);
DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1,
bool even_dist = true);

/// \brief default destructor
~DistributedSampler() = default;
~DistributedSamplerRT() = default;

/// \param std::unique_ptr<DataBuffer> * pBuffer
/// \param int32_t workerId


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc View File

@@ -20,14 +20,14 @@

namespace mindspore {
namespace dataset {
PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer),
shuffle_(shuffle),
seed_(GetSeed()),
next_id_(0),
samples_per_class_(val) {}

Status PKSampler::InitSampler() {
Status PKSamplerRT::InitSampler() {
labels_.reserve(label_to_ids_.size());
for (const auto &pair : label_to_ids_) {
if (pair.second.empty() == false) {
@@ -61,7 +61,7 @@ Status PKSampler::InitSampler() {
return Status::OK();
}

Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_ || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_));
} else if (next_id_ == num_samples_) {
@@ -96,7 +96,7 @@ Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK();
}

Status PKSampler::ResetSampler() {
Status PKSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
rnd_.seed(seed_++);
@@ -108,18 +108,18 @@ Status PKSampler::ResetSampler() {
return Status::OK();
}

Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
RETURN_UNEXPECTED_IF_NULL(op);
RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_));
RETURN_IF_NOT_OK(InitSampler());
return Status::OK();
}

void PKSampler::Print(std::ostream &out, bool show_all) const {
void PKSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: PKSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any
}
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h View File

@@ -26,17 +26,17 @@

namespace mindspore {
namespace dataset {
class PKSampler : public Sampler { // NOT YET FINISHED
class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
public:
// @param num_samples - the number of samples to draw. value of 0 means to take the full amount
// @param int64_t val
// @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());

// default destructor
~PKSampler() = default;
~PKSamplerRT() = default;

// @param std::unique_ptr<DataBuffer pBuffer
// @param int32_t workerId


+ 7
- 7
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc View File

@@ -20,10 +20,10 @@
namespace mindspore {
namespace dataset {

PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
PythonSamplerRT::PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}

Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
@@ -64,7 +64,7 @@ Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK();
}

Status PythonSampler::InitSampler() {
Status PythonSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_));
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
@@ -86,7 +86,7 @@ Status PythonSampler::InitSampler() {
return Status::OK();
}

Status PythonSampler::ResetSampler() {
Status PythonSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch");
need_to_reset_ = false;
py::gil_scoped_acquire gil_acquire;
@@ -106,11 +106,11 @@ Status PythonSampler::ResetSampler() {
return Status::OK();
}

void PythonSampler::Print(std::ostream &out, bool show_all) const {
void PythonSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: PythonSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any
}
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h View File

@@ -23,18 +23,18 @@

namespace mindspore {
namespace dataset {
class PythonSampler : public Sampler {
class PythonSamplerRT : public SamplerRT {
public:
// Constructor
// @param num_samples - the number of samples to draw. Value of 0 means to sample all of the
// data from the dataset.
// @param py_sampler_instance - the python instance of the sampler
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());

// Destructor.
~PythonSampler() = default;
~PythonSamplerRT() = default;

// Initialize the sampler.
// @return Status


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc View File

@@ -22,16 +22,16 @@

namespace mindspore {
namespace dataset {
RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer),
seed_(GetSeed()),
replacement_(replacement),
next_id_(0),
reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {}

Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error");
} else if (next_id_ == num_samples_) {
@@ -68,7 +68,7 @@ Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK();
}

Status RandomSampler::InitSampler() {
Status RandomSamplerRT::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -94,7 +94,7 @@ Status RandomSampler::InitSampler() {
return Status::OK();
}

Status RandomSampler::ResetSampler() {
Status RandomSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;

@@ -115,11 +115,11 @@ Status RandomSampler::ResetSampler() {
return Status::OK();
}

void RandomSampler::Print(std::ostream &out, bool show_all) const {
void RandomSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: RandomSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any
}
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h View File

@@ -24,18 +24,18 @@

namespace mindspore {
namespace dataset {
class RandomSampler : public Sampler {
class RandomSamplerRT : public SamplerRT {
public:
// Constructor
// @param int64_t num_samples - number samples to draw
// @param bool replacement - put he id back / or not after a sample
// @param reshuffle_each_epoch - T/F to reshuffle after epoch
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());

// Destructor.
~RandomSampler() = default;
~RandomSamplerRT() = default;

// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp


+ 14
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc View File

@@ -32,13 +32,13 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
return Status::OK();
}

Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer)
SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer)
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}

Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<SamplerRT> child_sampler;
if (HasChildSampler()) {
child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]);
child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]);
if (!child_sampler) {
std::string err_msg("Cannot handshake, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
@@ -64,7 +64,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
return Status::OK();
}

Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
if (num_elements == 0) {
RETURN_STATUS_UNEXPECTED("Invalid data, num of elements cannot be 0.");
}
@@ -77,7 +77,7 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
return Status::OK();
}

void Sampler::Print(std::ostream &out, bool show_all) const {
void SamplerRT::Print(std::ostream &out, bool show_all) const {
// Sampler printing is usually only called in the show_all mode.
// Derived classes will display the name, then call back to this base
// for common info.
@@ -88,7 +88,7 @@ void Sampler::Print(std::ostream &out, bool show_all) const {
}

#ifdef ENABLE_PYTHON
Status Sampler::GetAllIdsThenReset(py::array *data) {
Status SamplerRT::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids;
TensorRow sample_row;
@@ -123,27 +123,27 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
}
#endif

Status Sampler::SetNumSamples(int64_t num_samples) {
Status SamplerRT::SetNumSamples(int64_t num_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0.");
num_samples_ = num_samples;
return Status::OK();
}

int64_t Sampler::GetNumSamples() { return num_samples_; }
int64_t SamplerRT::GetNumSamples() { return num_samples_; }

Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0.");
num_rows_ = num_rows;
return Status::OK();
}

Status Sampler::AddChild(std::shared_ptr<Sampler> child) {
Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) {
if (child == nullptr) {
return Status::OK();
}

// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child);
std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child);
if (!sampler) {
std::string err_msg("Cannot add child, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
@@ -160,9 +160,9 @@ Status Sampler::AddChild(std::shared_ptr<Sampler> child) {
return Status::OK();
}

bool Sampler::HasChildSampler() { return !child_.empty(); }
bool SamplerRT::HasChildSampler() { return !child_.empty(); }

Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
if (child_ids_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!");
}


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h View File

@@ -51,21 +51,21 @@ class RandomAccessOp {
protected:
// The amount of rows in the dataset itself. This is the before-sampling value, the
// total count of rows. A sampler may choose to sample less than this amount.
int64_t num_rows_;
int64_t num_rows_ = -1;
};

class Sampler {
class SamplerRT {
public:
// Constructor
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
// indicates that the sampler should produce the complete set of ids.
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);
explicit SamplerRT(int64_t num_samples, int64_t samples_per_buffer);

Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {}
SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {}

// default destructor
~Sampler() = default;
~SamplerRT() = default;

// Get a list of sample ids.
// @note It is Sampler responsibility to make sure that the id is not out of bound.
@@ -111,7 +111,7 @@ class Sampler {
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
Status AddChild(std::shared_ptr<Sampler> child);
Status AddChild(std::shared_ptr<SamplerRT> child);

// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds
@@ -129,7 +129,7 @@ class Sampler {
// @param out - reference to the output stream being overloaded
// @param sampler - reference to teh sampler to print
// @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) {
sampler.Print(out, false);
return out;
}
@@ -158,7 +158,7 @@ class Sampler {

int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::vector<std::shared_ptr<Sampler>> child_; // Child nodes
std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes
std::unique_ptr<DataBuffer> child_ids_;
};
} // namespace dataset


+ 7
- 7
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc View File

@@ -20,10 +20,10 @@

namespace mindspore {
namespace dataset {
SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}

Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (id_count_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error");
} else if (id_count_ == num_samples_) {
@@ -62,7 +62,7 @@ Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer)
return Status::OK();
}

Status SequentialSampler::InitSampler() {
Status SequentialSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
"Invalid parameter, start_index must be greater than or equal to 0, but got " +
std::to_string(start_index_) + ".\n");
@@ -85,7 +85,7 @@ Status SequentialSampler::InitSampler() {
return Status::OK();
}

Status SequentialSampler::ResetSampler() {
Status SequentialSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late");
current_id_ = start_index_;
id_count_ = 0;
@@ -97,11 +97,11 @@ Status SequentialSampler::ResetSampler() {
return Status::OK();
}

void SequentialSampler::Print(std::ostream &out, bool show_all) const {
void SequentialSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: SequentialSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info
out << "\nStart index: " << start_index_;
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h View File

@@ -23,18 +23,18 @@

namespace mindspore {
namespace dataset {
class SequentialSampler : public Sampler {
class SequentialSamplerRT : public SamplerRT {
public:
// Constructor
// @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the
// full amount of ids from the dataset
// @param start_index - The starting index value
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit SequentialSampler(int64_t num_samples, int64_t start_index,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit SequentialSamplerRT(int64_t num_samples, int64_t start_index,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());

// Destructor.
~SequentialSampler() = default;
~SequentialSamplerRT() = default;

// init sampler, called by python
Status InitSampler() override;


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc View File

@@ -27,12 +27,12 @@
namespace mindspore {
namespace dataset {
// Constructor.
SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices,
int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}

// Initialized this Sampler.
Status SubsetRandomSampler::InitSampler() {
Status SubsetRandomSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");

@@ -56,7 +56,7 @@ Status SubsetRandomSampler::InitSampler() {
}

// Reset the internal variable to the initial state.
Status SubsetRandomSampler::ResetSampler() {
Status SubsetRandomSamplerRT::ResetSampler() {
// Reset the internal counters.
sample_id_ = 0;
buffer_id_ = 0;
@@ -73,7 +73,7 @@ Status SubsetRandomSampler::ResetSampler() {
}

// Get the sample ids.
Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status SubsetRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
// All samples have been drawn
if (sample_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
@@ -120,11 +120,11 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
return Status::OK();
}

void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const {
void SubsetRandomSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: SubsetRandomSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any
}
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h View File

@@ -25,18 +25,18 @@
namespace mindspore {
namespace dataset {
// Randomly samples elements from a given list of indices, without replacement.
class SubsetRandomSampler : public Sampler {
class SubsetRandomSamplerRT : public SamplerRT {
public:
// Constructor.
// @param num_samples The number of samples to draw. 0 for the full amount.
// @param indices List of indices from where we will randomly draw samples.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
explicit SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices,
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());

// Destructor.
~SubsetRandomSampler() = default;
~SubsetRandomSamplerRT() = default;

// Initialize the sampler.
// @return Status


+ 9
- 9
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc View File

@@ -27,16 +27,16 @@
namespace mindspore {
namespace dataset {
// Constructor.
WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights,
bool replacement, int64_t samples_per_buffer)
: SamplerRT(num_samples, samples_per_buffer),
weights_(weights),
replacement_(replacement),
sample_id_(0),
buffer_id_(0) {}

// Initialized this Sampler.
Status WeightedRandomSampler::InitSampler() {
Status WeightedRandomSamplerRT::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@@ -78,7 +78,7 @@ Status WeightedRandomSampler::InitSampler() {
}

// Initialized the computation for generating weighted random numbers without replacement using onepass method.
void WeightedRandomSampler::InitOnePassSampling() {
void WeightedRandomSamplerRT::InitOnePassSampling() {
exp_dist_->reset();
onepass_ids_.clear();
std::vector<std::pair<double, int64_t>> val_idx;
@@ -94,7 +94,7 @@ void WeightedRandomSampler::InitOnePassSampling() {
}

// Reset the internal variable to the initial state and reshuffle the indices.
Status WeightedRandomSampler::ResetSampler() {
Status WeightedRandomSamplerRT::ResetSampler() {
sample_id_ = 0;
buffer_id_ = 0;
rand_gen_.seed(GetSeed());
@@ -112,7 +112,7 @@ Status WeightedRandomSampler::ResetSampler() {
}

// Get the sample ids.
Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (weights_.size() > static_cast<size_t>(num_rows_)) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, size of sample weights must be less than or equal to num of data, "
@@ -180,11 +180,11 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf
return Status::OK();
}

void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const {
void WeightedRandomSamplerRT::Print(std::ostream &out, bool show_all) const {
out << "\nSampler: WeightedRandomSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
Sampler::Print(out, show_all);
SamplerRT::Print(out, show_all);
// Then add our own info if any
}
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h View File

@@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
// Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights).
class WeightedRandomSampler : public Sampler {
class WeightedRandomSamplerRT : public SamplerRT {
public:
// Constructor.
// @param num_samples Number of samples to be drawn.
@@ -34,11 +34,11 @@ class WeightedRandomSampler : public Sampler {
// @param replacement Determine if samples are drawn with/without replacement.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());

// Destructor.
~WeightedRandomSampler() = default;
~WeightedRandomSamplerRT() = default;

// Initialize the sampler.
// @param op (Not used in this sampler)


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc View File

@@ -84,7 +84,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id,
std::shared_ptr<Sampler> sampler)
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
device_id_(device_id),
num_devices_(num_device),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h View File

@@ -115,7 +115,7 @@ class TextFileOp : public ParallelOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -131,7 +131,7 @@ class TextFileOp : public ParallelOp {
std::vector<std::string> builder_text_files_list_;
bool builder_shuffle_files_;
std::unique_ptr<DataSchema> builder_schema_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
};

// Constructor of TextFileOp
@@ -148,7 +148,7 @@ class TextFileOp : public ParallelOp {
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);

// Default destructor
~TextFileOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc View File

@@ -58,7 +58,7 @@ TFReaderOp::Builder::Builder()
builder_data_schema_ = std::make_unique<DataSchema>();
}

bool ValidateFirstRowCrc(const std::string &filename) {
bool TFReaderOp::ValidateFirstRowCrc(const std::string &filename) {
std::ifstream reader;
reader.open(filename);
if (!reader) {
@@ -134,7 +134,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler)
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
device_id_(device_id),
num_devices_(num_device),


+ 5
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h View File

@@ -156,14 +156,14 @@ class TFReaderOp : public ParallelOp {
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}

private:
std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
@@ -193,7 +193,7 @@ class TFReaderOp : public ParallelOp {
TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler);
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler);

// Default destructor
~TFReaderOp() = default;
@@ -262,6 +262,8 @@ class TFReaderOp : public ParallelOp {
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

static bool ValidateFirstRowCrc(const std::string &filename);

private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.


+ 3
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc View File

@@ -62,7 +62,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
if (builder_task_type_ == TaskType::Segmentation) {
@@ -102,7 +102,8 @@ Status VOCOp::Builder::SanityCheck() {

VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)),
decode_(decode),
row_cnt_(0),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h View File

@@ -118,7 +118,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// Setter method.
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@@ -148,7 +148,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int32_t builder_rows_per_buffer_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
std::map<std::string, int32_t> builder_labels_to_read_;
};
@@ -166,7 +166,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read
VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);

// Destructor
~VOCOp() = default;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h View File

@@ -21,7 +21,7 @@
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"

namespace mindspore::dataset::api {
namespace mindspore::dataset {

class DatasetCache {
public:
@@ -29,6 +29,6 @@ class DatasetCache {
virtual Status ValidateParams() = 0;
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
};
} // namespace mindspore::dataset::api
} // namespace mindspore::dataset

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc View File

@@ -18,7 +18,7 @@
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"

namespace mindspore::dataset::api {
namespace mindspore::dataset {

/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
@@ -41,4 +41,4 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
return Status::OK();
}

} // namespace mindspore::dataset::api
} // namespace mindspore::dataset

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h View File

@@ -24,7 +24,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"

namespace mindspore::dataset::api {
namespace mindspore::dataset {

/// DatasetCache is the IR of CacheClient
class DatasetCacheImpl : public DatasetCache {
@@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache {
std::optional<int32_t> num_connections_;
std::optional<int32_t> prefetch_sz_;
};
} // namespace mindspore::dataset::api
} // namespace mindspore::dataset

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc View File

@@ -26,7 +26,6 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

#ifdef ENABLE_PYTHON
// constructor #1, called by Pybind
@@ -96,6 +95,5 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
return node_ops;
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h View File

@@ -27,7 +27,6 @@

namespace mindspore {
namespace dataset {
namespace api {

class BatchNode : public DatasetNode {
public:
@@ -66,7 +65,6 @@ class BatchNode : public DatasetNode {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc View File

@@ -27,7 +27,7 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
@@ -121,6 +121,5 @@ Status BucketBatchByLengthNode::ValidateParams() {
return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h View File

@@ -27,7 +27,7 @@

namespace mindspore {
namespace dataset {
namespace api {
class BucketBatchByLengthNode : public DatasetNode {
public:
/// \brief Constructor
@@ -58,7 +58,6 @@ class BucketBatchByLengthNode : public DatasetNode {
bool drop_remainder_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc View File

@@ -26,7 +26,6 @@

namespace mindspore {
namespace dataset {
namespace api {

BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child,
std::shared_ptr<SentencePieceVocab> vocab,
@@ -77,6 +76,6 @@ Status BuildSentenceVocabNode::ValidateParams() {

return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h View File

@@ -27,7 +27,6 @@

namespace mindspore {
namespace dataset {
namespace api {

class BuildSentenceVocabNode : public DatasetNode {
public:
@@ -56,7 +55,6 @@ class BuildSentenceVocabNode : public DatasetNode {
std::unordered_map<std::string, std::string> params_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc View File

@@ -26,7 +26,6 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
@@ -78,6 +77,6 @@ Status BuildVocabNode::ValidateParams() {

return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h View File

@@ -26,7 +26,6 @@

namespace mindspore {
namespace dataset {
namespace api {

class BuildVocabNode : public DatasetNode {
public:
@@ -55,7 +54,6 @@ class BuildVocabNode : public DatasetNode {
bool special_first_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -25,7 +25,7 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; }

@@ -53,6 +53,5 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
return node_ops;
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h View File

@@ -25,7 +25,6 @@

namespace mindspore {
namespace dataset {
namespace api {

class ConcatNode : public DatasetNode {
public:
@@ -44,7 +43,6 @@ class ConcatNode : public DatasetNode {
Status ValidateParams() override;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_

+ 177
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -16,11 +16,187 @@

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

#include <algorithm>
#include <memory>
#include <set>

#include "minddata/dataset/util/random.h"

namespace mindspore {
namespace dataset {
namespace api {

// Helper function to compute a default shuffle size
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int64_t *shuffle_size) {
const int64_t average_files_multiplier = 4;
const int64_t shuffle_max = 10000;
int64_t avg_rows_per_file = 0;

// Adjust the num rows per shard if sharding was given
if (num_devices > 0) {
if (num_rows % num_devices == 0) {
num_rows = num_rows / num_devices;
} else {
num_rows = (num_rows / num_devices) + 1;
}
}

// Cap based on total rows directive. Some ops do not have this and give value of 0.
if (total_rows > 0) {
num_rows = std::min(num_rows, total_rows);
}

// get the average per file
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0.");
avg_rows_per_file = num_rows / num_files;

*shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
return Status::OK();
}

// Helper function to inject a shuffle operator over top of current operator being built
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) {
std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr;
int64_t shuffle_size = 0;
RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
// Add the shuffle op
*shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer);
return Status::OK();
}

// Helper function to validate dataset directory parameter
Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
if (dataset_dir.empty()) {
std::string err_msg = dataset_name + ": dataset_dir is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

Path dir(dataset_dir);
if (!dir.IsDirectory()) {
std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (access(dataset_dir.c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset files parameter
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
if (dataset_files.empty()) {
std::string err_msg = dataset_name + ": dataset_files is not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (access(dataset_file.toString().c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}

return Status::OK();
}

// Helper function to validate dataset num_shards and shard_id parameters
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
if (num_shards <= 0) {
std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (shard_id < 0 || shard_id >= num_shards) {
// num_shards;
std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) +
", num_shards: " + std::to_string(num_shards);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
if (sampler == nullptr) {
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings) {
if (valid_strings.find(str) == valid_strings.end()) {
std::string mode;
mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode,
[](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

// Helper function to validate dataset input/output column parameter
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
const std::vector<std::string> &columns) {
if (columns.empty()) {
std::string err_msg = dataset_name + ":" + column_param + " should not be empty string";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (uint32_t i = 0; i < columns.size(); ++i) {
if (columns[i].empty()) {
std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
std::set<std::string> columns_set(columns.begin(), columns.end());
if (columns_set.size() != columns.size()) {
std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
if (shuffle) {
if (num_shards > 1) {
// If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle enabled, sharding disabled, use random sampler
return RandomSampler(num_samples >= 0, num_samples);
}
if (num_shards > 1) {
// If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
}
// If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(0, num_samples);
}

Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
if (cache_ != nullptr) {
@@ -60,6 +236,5 @@ DatasetNode::DatasetNode() {
worker_connector_size_ = cfg->worker_connector_size();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -28,7 +28,6 @@

namespace mindspore {
namespace dataset {
namespace api {

class Dataset;
class SamplerObj;
@@ -120,7 +119,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
int32_t worker_connector_size_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc View File

@@ -26,7 +26,6 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
@@ -86,6 +85,5 @@ Status MapNode::ValidateParams() {
return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h View File

@@ -25,7 +25,7 @@

namespace mindspore {
namespace dataset {
namespace api {
class MapNode : public DatasetNode {
public:
/// \brief Constructor
@@ -51,7 +51,6 @@ class MapNode : public DatasetNode {
std::vector<std::string> project_columns_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc View File

@@ -25,7 +25,6 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
@@ -53,6 +52,5 @@ std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
return node_ops;
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h View File

@@ -26,8 +26,6 @@
namespace mindspore {
namespace dataset {

namespace api {

class ProjectNode : public DatasetNode {
public:
/// \brief Constructor
@@ -48,7 +46,6 @@ class ProjectNode : public DatasetNode {
std::vector<std::string> columns_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc View File

@@ -25,7 +25,7 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
@@ -54,6 +54,6 @@ std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h View File

@@ -26,8 +26,6 @@
namespace mindspore {
namespace dataset {

namespace api {

class RenameNode : public DatasetNode {
public:
/// \brief Constructor
@@ -50,7 +48,6 @@ class RenameNode : public DatasetNode {
std::vector<std::string> output_columns_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -25,7 +25,6 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
@@ -49,6 +48,6 @@ Status RepeatNode::ValidateParams() {

return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -28,8 +28,6 @@
namespace mindspore {
namespace dataset {

namespace api {

class RepeatNode : public DatasetNode {
public:
/// \brief Constructor
@@ -50,7 +48,6 @@ class RepeatNode : public DatasetNode {
int32_t repeat_count_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc View File

@@ -25,7 +25,6 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

// Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch)
@@ -54,6 +53,5 @@ Status ShuffleNode::ValidateParams() {
return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h View File

@@ -28,8 +28,6 @@
namespace mindspore {
namespace dataset {

namespace api {

class ShuffleNode : public DatasetNode {
public:
ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch);
@@ -46,7 +44,6 @@ class ShuffleNode : public DatasetNode {
bool reset_every_epoch_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc View File

@@ -25,7 +25,6 @@

namespace mindspore {
namespace dataset {
namespace api {

// Constructor for SkipNode
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) {
@@ -52,6 +51,5 @@ Status SkipNode::ValidateParams() {
return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h View File

@@ -26,7 +26,6 @@
namespace mindspore {
namespace dataset {

namespace api {
class SkipNode : public DatasetNode {
public:
/// \brief Constructor
@@ -46,7 +45,7 @@ class SkipNode : public DatasetNode {
private:
int32_t skip_count_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc View File

@@ -27,7 +27,7 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Constructor for AlbumNode
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
@@ -78,6 +78,5 @@ Status AlbumNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h View File

@@ -25,7 +25,6 @@

namespace mindspore {
namespace dataset {
namespace api {

class AlbumNode : public DatasetNode {
public:
@@ -57,7 +56,6 @@ class AlbumNode : public DatasetNode {
std::shared_ptr<SamplerObj> sampler_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_

+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -26,7 +26,7 @@
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Constructor for CelebANode
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
@@ -76,6 +76,5 @@ Status CelebANode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save