Browse Source

!8096 Add DatasetNode as a base Class for IR nodes

Merge pull request !8096 from h.farahat/datasetNode
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
6e2241d64f
82 changed files with 1696 additions and 1118 deletions
  1. +348
    -239
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/iterator.cc
  3. +12
    -3
      mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt
  4. +46
    -0
      mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc
  5. +13
    -16
      mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h
  6. +16
    -4
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  7. +29
    -7
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  8. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt
  9. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc
  10. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h
  11. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc
  12. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h
  13. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc
  14. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h
  15. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc
  16. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h
  17. +3
    -5
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  18. +3
    -6
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h
  19. +65
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  20. +126
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  21. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc
  22. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h
  23. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc
  24. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h
  25. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc
  26. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h
  27. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  28. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  29. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc
  30. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h
  31. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc
  32. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h
  33. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc
  34. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h
  35. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc
  36. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h
  37. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc
  38. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h
  39. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc
  40. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h
  41. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc
  42. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h
  43. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc
  44. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h
  45. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc
  46. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h
  47. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h
  48. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  49. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h
  50. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc
  51. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h
  52. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h
  53. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc
  54. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h
  55. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h
  56. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc
  57. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h
  58. +47
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc
  59. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h
  60. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc
  61. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h
  62. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc
  63. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h
  64. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc
  65. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h
  66. +14
    -11
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc
  67. +4
    -5
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h
  68. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc
  69. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h
  70. +27
    -0
      mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc
  71. +48
    -0
      mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h
  72. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc
  73. +5
    -3
      mindspore/ccsrc/minddata/dataset/engine/runtime_context.h
  74. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc
  75. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h
  76. +704
    -643
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  77. +1
    -1
      mindspore/ccsrc/minddata/dataset/include/iterator.h
  78. +59
    -51
      mindspore/lite/minddata/CMakeLists.txt
  79. +1
    -1
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
  80. +2
    -2
      tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc
  81. +7
    -7
      tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc
  82. +3
    -3
      tests/ut/cpp/dataset/tree_adapter_test.cc

+ 348
- 239
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -86,6 +86,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"

// IR leaf nodes disabled for android
@@ -140,26 +141,11 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
return false;
}

// Get a uuid for queue name
std::string queue_name = Services::GetUniqueID();

// TODO(CRC):
// Get device type from ms context
std::string device_type = "CPU";

// Get device ID from children
int32_t device_id = 0;
rc = TransferNode::get_distribution(shared_from_this(), &device_id);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc;
return false;
}

// Add TransferNode IR on top of dataset d
auto ds = std::make_shared<TransferNode>(shared_from_this(), queue_name, device_id, device_type, send_epoch_end);
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end);

// Get ToDevice consumer
auto consumer = std::make_unique<ToDevice>(device_type, send_epoch_end, -1);
auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1);
ToDevice *consumer_ = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
@@ -199,7 +185,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
return false;
}
SaveToDisk *consumer_ = consumer.get();
rc = consumer->Init(ds);
rc = consumer->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "CreateSaver failed." << rc;
return false;
@@ -225,19 +211,10 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
#endif

// Constructor
Dataset::Dataset() {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
tree_getters_ = std::make_shared<TreeGetters>();
}
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }

int64_t Dataset::GetDatasetSize() {
int64_t dataset_size;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
@@ -246,7 +223,7 @@ int64_t Dataset::GetDatasetSize() {
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
@@ -267,7 +244,7 @@ std::vector<DataType> Dataset::GetOutputTypes() {
return types;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(shared_from_this());
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
types.clear();
@@ -294,7 +271,7 @@ std::vector<TensorShape> Dataset::GetOutputShapes() {
return shapes;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(shared_from_this());
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
shapes.clear();
@@ -321,7 +298,7 @@ int64_t Dataset::GetNumClasses() {
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
return -1;
@@ -331,9 +308,6 @@ int64_t Dataset::GetNumClasses() {
return rc.IsError() ? -1 : num_classes;
}

// Constructor to initialize the cache
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }

/// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file
/// \return Shared pointer to the current schema
@@ -346,161 +320,155 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) {
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// (In alphabetical order)

// Function to create a AlbumNode.
std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
// Function to create a AlbumDataset.
std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<AlbumDataset>(dataset_dir, data_schema, column_names, decode, sampler, cache);

return ds;
}

// Function to create a CelebANode.
std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, bool decode,
const std::set<std::string> &extensions,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions, cache);
// Function to create a CelebADataset.
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, bool decode,
const std::set<std::string> &extensions,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CelebADataset>(dataset_dir, usage, sampler, decode, extensions, cache);

return ds;
}

// Function to create a Cifar10Node.
std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
// Function to create a Cifar10Dataset.
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, usage, sampler, cache);

return ds;
}

// Function to create a Cifar100Node.
std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
// Function to create a Cifar100Dataset.
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, usage, sampler, cache);

return ds;
}

// Function to create a CLUENode.
std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CLUENode>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
// Function to create a CLUEDataset.
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);

return ds;
}

// Function to create a CocoNode.
std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode, const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache);
// Function to create a CocoDataset.
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler, cache);

return ds;
}

// Function to create a CSVNode.
std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle,
num_shards, shard_id, cache);
// Function to create a CSVDataset.
std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CSVDataset>(dataset_files, field_delim, column_defaults, column_names, num_samples,
shuffle, num_shards, shard_id, cache);

return ds;
}

// Function to create a ImageFolderNode.
std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::set<std::string> &extensions,
const std::map<std::string, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache) {
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
bool recursive = false;

// Create logical representation of ImageFolderNode.
auto ds =
std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache);
// Function to create a ImageFolderDataset.
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::set<std::string> &extensions,
const std::map<std::string, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<ImageFolderDataset>(dataset_dir, decode, sampler, extensions, class_indexing, cache);

return ds;
}

#ifndef ENABLE_ANDROID
// Function to create a ManifestNode.
std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
// Function to create a ManifestDataset.
std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<ManifestDataset>(dataset_file, usage, sampler, class_indexing, decode, cache);

return ds;
}

// Function to create a MindDataNode.
std::shared_ptr<MindDataNode> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataNode>(dataset_file, columns_list, sampler, padded_sample, num_padded);
// Function to create a MindDataDataset.
std::shared_ptr<MindDataDataset> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataDataset>(dataset_file, columns_list, sampler, padded_sample, num_padded);

return ds;
}

// Function to create a MindDataNode.
std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataNode>(dataset_files, columns_list, sampler, padded_sample, num_padded);
// Function to create a MindDataDataset.
std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataDataset>(dataset_files, columns_list, sampler, padded_sample, num_padded);

return ds;
}
#endif

// Function to create a MnistNode.
std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
// Function to create a MnistDataset.
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<MnistDataset>(dataset_dir, usage, sampler, cache);

return ds;
}

// Function to overload "+" operator to concat two datasets
std::shared_ptr<ConcatNode> operator+(const std::shared_ptr<Dataset> &datasets1,
const std::shared_ptr<Dataset> &datasets2) {
std::shared_ptr<ConcatNode> ds = std::make_shared<ConcatNode>(std::vector({datasets2, datasets1}));

return ds;
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
const std::shared_ptr<Dataset> &datasets2) {
return std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1}));
}

// Function to create a TextFileNode.
std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
// Function to create a TextFileDataset.
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);

return ds;
}

#ifndef ENABLE_ANDROID
// Function to create a VOCNode.
std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache);
// Function to create a VOCDataset.
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, usage, class_indexing, decode, sampler, cache);

return ds;
}
#endif

// Function to create a ZipNode.
std::shared_ptr<ZipNode> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
auto ds = std::make_shared<ZipNode>(datasets);

std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
auto ds = std::make_shared<ZipDataset>(datasets);
return ds;
}

@@ -508,170 +476,112 @@ std::shared_ptr<ZipNode> Zip(const std::vector<std::shared_ptr<Dataset>> &datase
// (In alphabetical order)

// Function to create a Batch dataset
std::shared_ptr<BatchNode> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) {
// Default values
std::vector<std::string> cols_to_map = {};
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map;
bool pad = false;
auto ds = std::make_shared<BatchNode>(shared_from_this(), batch_size, drop_remainder, pad, cols_to_map, pad_map);

return ds;
auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder, pad, cols_to_map, pad_map);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

#ifndef ENABLE_ANDROID
// Function to create a BucketBatchByLength dataset
std::shared_ptr<BucketBatchByLengthNode> Dataset::BucketBatchByLength(
const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
const std::vector<int32_t> &bucket_batch_sizes, std::function<TensorRow(TensorRow)> element_length_function,
BucketBatchByLengthDataset::BucketBatchByLengthDataset(
std::shared_ptr<Dataset> input, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) {
auto ds = std::make_shared<BucketBatchByLengthNode>(shared_from_this(), column_names, bucket_boundaries,
auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function, pad_info,
pad_to_bucket_boundary, drop_remainder);

return ds;
}

// Function to create a SentencePieceVocab from dataset
std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto vocab = std::make_shared<SentencePieceVocab>();
auto ds = std::make_shared<BuildSentenceVocabNode>(shared_from_this(), vocab, col_names, vocab_size,
character_coverage, model_type, params);

// Run tree here to start building vocab
std::shared_ptr<Iterator> iter = ds->CreateIterator();
if (iter == nullptr) {
MS_LOG(ERROR) << "Fail to run iterator in BuildSentencePieceVocab.";
return nullptr;
}

// Finish building vocab by triggering GetNextRow
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
if (!iter->GetNextRow(&row)) {
return nullptr;
}

return vocab;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a Vocab from dataset
std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first) {
auto vocab = std::make_shared<Vocab>();
auto ds = std::make_shared<BuildVocabNode>(shared_from_this(), vocab, columns, freq_range, top_k, special_tokens,
special_first);

// Run tree here to starting building vocab
std::shared_ptr<Iterator> iter = ds->CreateIterator();
if (iter == nullptr) {
MS_LOG(ERROR) << "Fail to run iterator in BuildVocab.";
return nullptr;
}

// Finish building vocab by triggering GetNextRow
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
if (!iter->GetNextRow(&row)) {
return nullptr;
}

return vocab;
}
#endif

// Function to create a Concat dataset
std::shared_ptr<ConcatNode> Dataset::Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) {
auto ds = std::make_shared<ConcatNode>(datasets);
ds->children.push_back(shared_from_this());
ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
(void)std::transform(
datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); });

return ds;
auto ds = std::make_shared<ConcatNode>(all_datasets);

ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a Map dataset.
std::shared_ptr<MapNode> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns,
const std::shared_ptr<DatasetCache> &cache) {
MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache) {
auto ds =
std::make_shared<MapNode>(shared_from_this(), operations, input_columns, output_columns, project_columns, cache);
std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns, cache);

return ds;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a ProjectNode.
std::shared_ptr<ProjectNode> Dataset::Project(const std::vector<std::string> &columns) {
auto ds = std::make_shared<ProjectNode>(shared_from_this(), columns);
ProjectDataset::ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &columns) {
auto ds = std::make_shared<ProjectNode>(input->IRNode(), columns);

return ds;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a RenameNode.
std::shared_ptr<RenameNode> Dataset::Rename(const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns) {
auto ds = std::make_shared<RenameNode>(shared_from_this(), input_columns, output_columns);
RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns) {
auto ds = std::make_shared<RenameNode>(input->IRNode(), input_columns, output_columns);

return ds;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create Repeat dataset.
std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
// Workaround for repeat == 1, do not inject repeat.
if (count == 1) {
return shared_from_this();
ir_node_ = input->IRNode();
return;
}

auto ds = std::make_shared<RepeatNode>(shared_from_this(), count);
auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);

return ds;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a ShuffleOp
std::shared_ptr<ShuffleNode> Dataset::Shuffle(int32_t buffer_size) {
ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size) {
// Pass in reshuffle_each_epoch with true
auto ds = std::make_shared<ShuffleNode>(shared_from_this(), buffer_size, true);
auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);

return ds;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a SkipNode.
std::shared_ptr<SkipNode> Dataset::Skip(int32_t count) {
auto ds = std::make_shared<SkipNode>(shared_from_this(), count);
SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
auto ds = std::make_shared<SkipNode>(input->IRNode(), count);

return ds;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a TakeNode.
std::shared_ptr<Dataset> Dataset::Take(int32_t count) {
TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
// If count is greater than the number of element in dataset or equal to -1,
// all the element in dataset will be taken
if (count == -1) {
return shared_from_this();
ir_node_ = input->IRNode();
return;
}

auto ds = std::make_shared<TakeNode>(shared_from_this(), count);
auto ds = std::make_shared<TakeNode>(input->IRNode(), count);

return ds;
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

// Function to create a Zip dataset
std::shared_ptr<ZipNode> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
// Default values
auto ds = std::make_shared<ZipNode>(datasets);
ds->children.push_back(shared_from_this());
ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
std::vector<std::shared_ptr<DatasetNode>> all_datasets;
(void)std::transform(
datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
[](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); });

return ds;
}
auto ds = std::make_shared<ZipNode>(all_datasets);

Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
if (cache_ != nullptr) {
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
node_ops->push_back(cache_op);
}
return Status::OK();
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

int64_t Dataset::GetBatchSize() {
@@ -685,7 +595,7 @@ int64_t Dataset::GetBatchSize() {
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
@@ -706,7 +616,7 @@ int64_t Dataset::GetRepeatCount() {
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
@@ -715,7 +625,77 @@ int64_t Dataset::GetRepeatCount() {
rc = tree_getters_->GetRepeatCount(&repeat_count);
return rc.IsError() ? 0 : repeat_count;
}
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
return nullptr;
}
return shared_from_this();
}
#ifndef ENABLE_ANDROID
std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
auto vocab = std::make_shared<SentencePieceVocab>();
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params);

std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
return nullptr;
}

auto consumer = std::make_unique<BuildVocabConsumer>();
BuildVocabConsumer *bv_consumer = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc;
return nullptr;
}
runtime_context->AssignConsumer(std::move(consumer));

// Run tree here to starting building vocab
rc = bv_consumer->Start();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc;
return nullptr;
}
return vocab;
}

std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first) {
auto vocab = std::make_shared<Vocab>();
auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);

std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
return nullptr;
}

auto consumer = std::make_unique<BuildVocabConsumer>();
BuildVocabConsumer *bv_consumer = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc;
return nullptr;
}
runtime_context->AssignConsumer(std::move(consumer));

// Run tree here to starting building vocab
rc = bv_consumer->Start();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc;
return nullptr;
}
return vocab;
}
#endif
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}

// SchemaObj init function
@@ -1046,6 +1026,136 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me
}
#endif

AlbumDataset::AlbumDataset(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, bool decode,
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
CLUEDataset::CLUEDataset(const std::vector<std::string> &dataset_files, const std::string &task,
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
CSVDataset::CSVDataset(const std::vector<std::string> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle,
num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
ImageFolderDataset::ImageFolderDataset(const std::string &dataset_dir, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::set<std::string> &extensions,
const std::map<std::string, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache) {
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
bool recursive = false;

// Create logical representation of ImageFolderDataset.
auto ds =
std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

#ifndef ENABLE_ANDROID
ManifestDataset::ManifestDataset(const std::string &dataset_file, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
MindDataDataset::MindDataDataset(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataNode>(dataset_file, columns_list, sampler, padded_sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
MindDataDataset::MindDataDataset(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataNode>(dataset_files, columns_list, sampler, padded_sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#endif
MnistDataset::MnistDataset(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TextFileDataset::TextFileDataset(const std::vector<std::string> &dataset_files, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#ifndef ENABLE_ANDROID
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#endif
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) {
auto ds =
std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::string schema_path,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) {
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema_path), std::move(columns_list),
std::move(sampler), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#ifndef ENABLE_ANDROID
TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
std::shared_ptr<DatasetCache> cache) {
auto ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
std::shared_ptr<DatasetCache> cache) {
auto ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#endif
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
if (shuffle) {
if (num_shards > 1) {
@@ -1062,7 +1172,6 @@ std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int
// If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(0, num_samples);
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/iterator.cc View File

@@ -53,7 +53,7 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
RETURN_IF_NOT_OK(runtime_context->Init());
auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get();
RETURN_IF_NOT_OK(consumer->Init(ds));
RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
runtime_context->AssignConsumer(std::move(consumer));
return Status::OK();
}


+ 12
- 3
mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt View File

@@ -11,7 +11,7 @@ endif ()

file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine OBJECT
set(SRC_FILES_LIST
execution_tree.cc
data_buffer.cc
data_schema.cc
@@ -20,10 +20,19 @@ add_library(engine OBJECT
runtime_context.cc
consumers/tree_consumer.cc
)
if (ENABLE_PYTHON)
set(SRC_FILES_LIST
${SRC_FILES_LIST}
python_runtime_context.cc
consumers/python_tree_consumer.cc
)
endif ()

add_library(engine OBJECT ${SRC_FILES_LIST})

if (ENABLE_PYTHON)
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
endif()
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
endif ()

add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop)



+ 46
- 0
mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc View File

@@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"

namespace mindspore::dataset {

Status PythonIteratorConsumer::GetNextAsList(py::list *out) {
std::vector<TensorPtr> row;
{
py::gil_scoped_release gil_release;
RETURN_IF_NOT_OK(GetNextAsVector(&row));
}
for (auto el : row) {
(*out).append(el);
}
return Status::OK();
}
Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) {
std::unordered_map<std::string, TensorPtr> row;
{
py::gil_scoped_release gil_release;
RETURN_IF_NOT_OK(GetNextAsMap(&row));
}
for (auto el : row) {
(*out)[common::SafeCStr(el.first)] = el.second;
}
return Status::OK();
}
} // namespace mindspore::dataset

+ 13
- 16
mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h View File

@@ -26,24 +26,21 @@
namespace mindspore::dataset {

/// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict
class PythonIterator : public IteratorConsumer {
/// Constructor

class PythonIteratorConsumer : public IteratorConsumer {
public:
/// Constructor which will call the base class default constructor.
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit PythonIterator(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {}
explicit PythonIteratorConsumer(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {}
/// Returns the next row in a vector format
/// \param[out] out std::vector of Tensors
/// \return Status error code
Status GetNextAsList(py::list *out);

/// Get the next row as a python dict
/// \param[out] output python dict
/// \return Status error code
Status GetNextAsMap(py::dict *output) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// Get the next row as a python dict
/// \param[out] output python dict
/// \return Status error code
Status GetNextAsList(py::list *output) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// Returns the next row in as a map
/// \param[out] out std::map of string to Tensor
/// \return Status error code
Status GetNextAsDict(py::dict *out);
};

} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_

+ 16
- 4
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -34,10 +34,11 @@ namespace mindspore::dataset {
// TreeConsumer
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }

Status TreeConsumer::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }

// IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<api::Dataset> d) {
Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}

@@ -73,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
}

// ToDevice
Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}

@@ -384,7 +385,7 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal
tree_adapter_ = std::make_unique<TreeAdapter>();
}

Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) {
Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
if (!s.IsError()) {
init_flag_ = true;
@@ -463,4 +464,15 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK();
}
Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), 1);
}
Status BuildVocabConsumer::Start() {
// Getting one row would trigger building the vocab
TensorRow row;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
// The returned row would EOE which is an empty row
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE.");
return Status::OK();
}
} // namespace mindspore::dataset

+ 29
- 7
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -22,14 +22,16 @@
#include <unordered_map>
#include <utility>
#include <vector>

#include "minddata/dataset/engine/tree_adapter.h"
#include "minddata/dataset/text/vocab.h"

namespace mindspore::dataset {
// Forward declare
class TreeAdapter;

namespace api {
class Dataset;
class DatasetNode;
}

/// A base class for tree consumers which would fetch rows from the tree pipeline
@@ -40,7 +42,9 @@ class TreeConsumer {
/// Initializes the consumer, this involves constructing and preparing the tree.
/// \param d The dataset node that represent the root of the IR tree.
/// \return Status error code.
virtual Status Init(std::shared_ptr<api::Dataset> d);
virtual Status Init(std::shared_ptr<api::DatasetNode> d);

Status Terminate();

protected:
/// The class owns the tree_adapter that handles execution tree operations.
@@ -57,7 +61,7 @@ class IteratorConsumer : public TreeConsumer {
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}

Status Init(std::shared_ptr<api::Dataset> d) override;
Status Init(std::shared_ptr<api::DatasetNode> d) override;

/// Returns the next row in a vector format
/// \param[out] out std::vector of Tensors
@@ -126,10 +130,10 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer {
public:
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}

Status Init(std::shared_ptr<api::Dataset> d) override;
Status Init(std::shared_ptr<api::DatasetNode> d) override;

/// Send the data to device
/// \return Status error code
@@ -158,7 +162,7 @@ class ToDevice : public TreeConsumer {
class TreeGetters : public TreeConsumer {
public:
TreeGetters();
Status Init(std::shared_ptr<api::Dataset> d) override;
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
@@ -176,5 +180,23 @@ class TreeGetters : public TreeConsumer {
bool row_flag_; // indicate whether the first row has been stored in row_
};

class BuildVocabConsumer : public TreeConsumer {
public:
/// BuildVocabConsumer Constructor which will call the base class default constructor.
BuildVocabConsumer() = default;

Status Init(std::shared_ptr<api::DatasetNode> d) override;

/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
/// would be written to disk)
/// \return Status error code
Status Start();

protected:
/// Method to return the name of the consumer
/// \return string
std::string Name() override { return "BuildVocab"; }
};

} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_

+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt View File

@@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_subdirectory(source)

set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
dataset_node.cc
batch_node.cc
bucket_batch_by_length_node.cc
build_sentence_piece_vocab_node.cc


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc View File

@@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset {
namespace api {

BatchNode::BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad,
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
std::vector<std::string> cols_to_map,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
: batch_size_(batch_size),


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h View File

@@ -23,16 +23,16 @@
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class BatchNode : public Dataset {
class BatchNode : public DatasetNode {
public:
/// \brief Constructor
BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad,
BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
std::vector<std::string> cols_to_map,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);



+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc View File

@@ -29,7 +29,7 @@ namespace mindspore {
namespace dataset {
namespace api {
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h View File

@@ -23,15 +23,15 @@
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
class BucketBatchByLengthNode : public Dataset {
class BucketBatchByLengthNode : public DatasetNode {
public:
/// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc View File

@@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset {
namespace api {

BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<Dataset> child,
BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child,
std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size,
float character_coverage, SentencePieceModel model_type,


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h View File

@@ -29,10 +29,10 @@ namespace mindspore {
namespace dataset {
namespace api {

class BuildSentenceVocabNode : public Dataset {
class BuildSentenceVocabNode : public DatasetNode {
public:
/// \brief Constructor
BuildSentenceVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<SentencePieceVocab> vocab,
BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params);



+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc View File

@@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset {
namespace api {

BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
: vocab_(vocab),


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h View File

@@ -22,17 +22,17 @@
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class BuildVocabNode : public Dataset {
class BuildVocabNode : public DatasetNode {
public:
/// \brief Constructor
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first);

/// \brief Destructor


+ 3
- 5
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -27,18 +27,16 @@ namespace mindspore {
namespace dataset {
namespace api {
// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
this->children = datasets_;
}
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; }

Status ConcatNode::ValidateParams() {
if (datasets_.empty()) {
if (children.size() < 2) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
if (find(children.begin(), children.end(), nullptr) != children.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);


+ 3
- 6
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h View File

@@ -21,16 +21,16 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class ConcatNode : public Dataset {
class ConcatNode : public DatasetNode {
public:
/// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets);

/// \brief Destructor
~ConcatNode() = default;
@@ -42,9 +42,6 @@ class ConcatNode : public Dataset {
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};

} // namespace api


+ 65
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

#include <memory>

namespace mindspore {
namespace dataset {
namespace api {

Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
if (cache_ != nullptr) {
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
node_ops->push_back(cache_op);
}
return Status::OK();
}
// Constructor to initialize the cache
DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }

std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
#if !defined(_WIN32) && !defined(_WIN64)
#ifndef ENABLE_ANDROID
int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF);
if (cpu_count < 0 || cpu_count > INT32_MAX) {
MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count;
return nullptr;
}
if (num_workers < 1 || num_workers > cpu_count) {
MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count;
return nullptr;
}
#endif
#endif
num_workers_ = num_workers;
return shared_from_this();
}
DatasetNode::DatasetNode() {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 126
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -0,0 +1,126 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {
namespace api {

class Dataset;
class SamplerObj;

#define RETURN_EMPTY_IF_ERROR(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return {}; \
} \
} while (false)

Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op);

// Helper function to validate dataset files parameter
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files);

// Helper function to validate dataset num_shards and shard_id parameters
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id);

// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler);

Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings);

// Helper function to validate dataset input/output column parameterCD -
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
const std::vector<std::string> &columns);

// Helper function to validate dataset directory parameter
Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir);

/// \brief Function to create a sampler for non-mappable dataset (to be used by cache op later).
/// \notes Non-mappable dataset does not directly support a sampler. It has provided sampling arguments (shuffle,
/// num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in the pipeline contains
/// a cache. If there is no cache above it, then the sampler is not used.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle If true, the indices are shuffled.
/// \param[in] num_shards Number of shards to divide the dataset into.
/// \param[in] shard_id Shard ID of the current shard within num_shards.
/// \return Shared pointer to the current Sampler.
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id);

class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
public:
/// \brief Constructor
DatasetNode();

/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache);

/// \brief Destructor
~DatasetNode() = default;

/// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
/// \return The list of shared pointers to the newly created DatasetOps
virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;

/// \brief Pure virtual function for derived class to implement parameters validation
/// \return Status Status::OK() if all the parameters are valid
virtual Status ValidateParams() = 0;

const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; }

/// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}

/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);

protected:
std::vector<std::shared_ptr<DatasetNode>> children;
std::shared_ptr<DatasetNode> parent;
std::shared_ptr<DatasetCache> cache_;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);

int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t connector_que_size_;
int32_t worker_connector_size_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc View File

@@ -28,14 +28,14 @@ namespace mindspore {
namespace dataset {
namespace api {

MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache)
: operations_(operations),
input_columns_(input_columns),
output_columns_(output_columns),
project_columns_(project_columns),
Dataset(std::move(cache)) {
DatasetNode(std::move(cache)) {
this->children.push_back(child);
}



+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h View File

@@ -21,15 +21,15 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
class MapNode : public Dataset {
class MapNode : public DatasetNode {
public:
/// \brief Constructor
MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr);



+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc View File

@@ -28,7 +28,8 @@ namespace dataset {
namespace api {

// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
: columns_(columns) {
this->children.push_back(child);
}



+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h View File

@@ -21,17 +21,17 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

namespace api {

class ProjectNode : public Dataset {
class ProjectNode : public DatasetNode {
public:
/// \brief Constructor
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);
explicit ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns);

/// \brief Destructor
~ProjectNode() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc View File

@@ -27,7 +27,7 @@ namespace mindspore {
namespace dataset {
namespace api {
// Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h View File

@@ -21,17 +21,17 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

namespace api {

class RenameNode : public Dataset {
class RenameNode : public DatasetNode {
public:
/// \brief Constructor
explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
explicit RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);

/// \brief Destructor


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -27,7 +27,7 @@ namespace mindspore {
namespace dataset {
namespace api {

RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) {
RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
}



+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -23,17 +23,17 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

namespace api {

class RepeatNode : public Dataset {
class RepeatNode : public DatasetNode {
public:
/// \brief Constructor
explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count);
explicit RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count);

/// \brief Destructor
~RepeatNode() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc View File

@@ -28,7 +28,7 @@ namespace dataset {
namespace api {

// Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch)
ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
this->children.push_back(child);
}


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h View File

@@ -23,16 +23,16 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

namespace api {

class ShuffleNode : public Dataset {
class ShuffleNode : public DatasetNode {
public:
ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch);
ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch);

~ShuffleNode() = default;



+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc View File

@@ -28,7 +28,7 @@ namespace dataset {
namespace api {

// Constructor for SkipNode
SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) {
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) {
this->children.push_back(child);
}



+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h View File

@@ -21,16 +21,16 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

namespace api {
class SkipNode : public Dataset {
class SkipNode : public DatasetNode {
public:
/// \brief Constructor
explicit SkipNode(std::shared_ptr<Dataset> child, int32_t count);
explicit SkipNode(std::shared_ptr<DatasetNode> child, int32_t count);

/// \brief Destructor
~SkipNode() = default;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc View File

@@ -32,7 +32,7 @@ namespace api {
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_dir_(dataset_dir),
schema_path_(data_schema),
column_names_(column_names),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h View File

@@ -21,13 +21,13 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class AlbumNode : public Dataset {
class AlbumNode : public DatasetNode {
public:
/// \brief Constructor
AlbumNode(const std::string &dataset_dir, const std::string &data_schema,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -31,7 +31,7 @@ namespace api {
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
sampler_(sampler),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h View File

@@ -23,13 +23,13 @@
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class CelebANode : public Dataset {
class CelebANode : public DatasetNode {
public:
/// \brief Constructor
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc View File

@@ -31,7 +31,7 @@ namespace api {
// Constructor for Cifar100Node
Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

Status Cifar100Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h View File

@@ -21,13 +21,13 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class Cifar100Node : public Dataset {
class Cifar100Node : public DatasetNode {
public:
/// \brief Constructor
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc View File

@@ -31,7 +31,7 @@ namespace api {
// Constructor for Cifar10Node
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

Status Cifar10Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h View File

@@ -21,13 +21,13 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class Cifar10Node : public Dataset {
class Cifar10Node : public DatasetNode {
public:
/// \brief Constructor
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc View File

@@ -33,7 +33,7 @@ namespace api {
// Constructor for CLUENode
CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_files_(clue_files),
task_(task),
usage_(usage),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h View File

@@ -21,14 +21,14 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
/// \class CLUENode
/// \brief A Dataset derived class to represent CLUE dataset
class CLUENode : public Dataset {
class CLUENode : public DatasetNode {
public:
/// \brief Constructor
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc View File

@@ -30,7 +30,7 @@ namespace api {
// Constructor for CocoNode
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_dir_(dataset_dir),
annotation_file_(annotation_file),
task_(task),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h View File

@@ -21,12 +21,12 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
class CocoNode : public Dataset {
class CocoNode : public DatasetNode {
public:
/// \brief Constructor
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc View File

@@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_files_(csv_files),
field_delim_(field_delim),
column_defaults_(column_defaults),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h View File

@@ -21,7 +21,7 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
@@ -47,7 +47,7 @@ class CsvRecord : public CsvBase {
T value;
};

class CSVNode : public Dataset {
class CSVNode : public DatasetNode {
public:
/// \brief Constructor
CSVNode(const std::vector<std::string> &dataset_files, char field_delim,


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h View File

@@ -21,7 +21,7 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
@@ -31,7 +31,7 @@ namespace api {

/// \class GeneratorNode
/// \brief A Dataset derived class to represent GeneratorNode dataset
class GeneratorNode : public Dataset {
class GeneratorNode : public DatasetNode {
public:
/// \brief Constructor
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -40,7 +40,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
recursive_(recursive),
class_indexing_(class_indexing),
exts_(extensions),
Dataset(std::move(cache)) {}
DatasetNode(std::move(cache)) {}

Status ImageFolderNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h View File

@@ -24,7 +24,7 @@
#include <vector>

#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
@@ -33,7 +33,7 @@ namespace api {

/// \class ImageFolderNode
/// \brief A Dataset derived class to represent ImageFolder dataset
class ImageFolderNode : public Dataset {
class ImageFolderNode : public DatasetNode {
public:
/// \brief Constructor
ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc View File

@@ -32,7 +32,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_file_(dataset_file),
usage_(usage),
decode_(decode),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h View File

@@ -22,12 +22,12 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
class ManifestNode : public Dataset {
class ManifestNode : public DatasetNode {
public:
/// \brief Constructor
ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h View File

@@ -22,12 +22,12 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
class MindDataNode : public Dataset {
class MindDataNode : public DatasetNode {
public:
/// \brief Constructor
MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc View File

@@ -30,7 +30,7 @@ namespace api {

MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

Status MnistNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h View File

@@ -21,13 +21,13 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class MnistNode : public Dataset {
class MnistNode : public DatasetNode {
public:
/// \brief Constructor
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h View File

@@ -22,13 +22,13 @@
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class RandomNode : public Dataset {
class RandomNode : public DatasetNode {
public:
// Some constants to provide limits to random generation.
static constexpr int32_t kMaxNumColumns = 4;
@@ -38,7 +38,7 @@ class RandomNode : public Dataset {
/// \brief Constructor
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
total_rows_(total_rows),
schema_path_(""),
schema_(std::move(schema)),
@@ -48,7 +48,7 @@ class RandomNode : public Dataset {
/// \brief Constructor
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
total_rows_(total_rows),
schema_path_(schema_path),
columns_list_(columns_list),


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc View File

@@ -31,7 +31,7 @@ namespace api {
// Constructor for TextFileNode
TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_files_(dataset_files),
num_samples_(num_samples),
shuffle_(shuffle),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h View File

@@ -21,14 +21,14 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
/// \class TextFileNode
/// \brief A Dataset derived class to represent TextFile dataset
class TextFileNode : public Dataset {
class TextFileNode : public DatasetNode {
public:
/// \brief Constructor
TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,


+ 47
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc View File

@@ -55,6 +55,53 @@ bool ValidateFirstRowCrc(const std::string &filename) {

// Validator for TFRecordNode
Status TFRecordNode::ValidateParams() {
if (dataset_files_.empty()) {
std::string err_msg = "TFRecordNode: dataset_files is not specified.";
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

for (const auto &f : dataset_files_) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
std::string err_msg = "TFRecordNode: dataset file: [" + f + "] is invalid or does not exist.";
MS_LOG(ERROR) << err_msg;

return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
}

if (num_samples_ < 0) {
std::string err_msg = "TFRecordNode: Invalid number of samples: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;

return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

if (num_shards_ <= 0) {
std::string err_msg = "TFRecordNode: Invalid num_shards: " + std::to_string(num_shards_);
MS_LOG(ERROR) << err_msg;

return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

if (shard_id_ < 0 || shard_id_ >= num_shards_) {
std::string err_msg = "TFRecordNode: Invalid input, shard_id: " + std::to_string(shard_id_) +
", num_shards: " + std::to_string(num_shards_);
MS_LOG(ERROR) << err_msg;

return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

if (cache_ == nullptr && !shard_equal_rows_ && dataset_files_.size() < num_shards_) {
// This check only makes sense in a non-cache path. We should make sure there is at least one file per
// shard in file-based sharding
std::string err_msg =
"TFRecordNode: Invalid number of dataset files, should at least be " + std::to_string(num_shards_);
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

std::vector<std::string> invalid_files(dataset_files_.size());
auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(),
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); });


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h View File

@@ -22,21 +22,21 @@
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
/// \class TFRecordNode
/// \brief A Dataset derived class to represent TFRecord dataset
class TFRecordNode : public Dataset {
class TFRecordNode : public DatasetNode {
public:
/// \brief Constructor
/// \note Parameter 'schema' is the path to the schema file
TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_files_(dataset_files),
schema_path_(schema),
columns_list_(columns_list),
@@ -51,7 +51,7 @@ class TFRecordNode : public Dataset {
TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_files_(dataset_files),
schema_obj_(schema),
columns_list_(columns_list),


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc View File

@@ -32,7 +32,7 @@ namespace api {
VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
: DatasetNode(std::move(cache)),
dataset_dir_(dataset_dir),
task_(task),
usage_(usage),


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h View File

@@ -22,12 +22,12 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
class VOCNode : public Dataset {
class VOCNode : public DatasetNode {
public:
/// \brief Constructor
VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc View File

@@ -27,7 +27,7 @@ namespace mindspore {
namespace dataset {
namespace api {
// Constructor for SyncWaitNode
SyncWaitNode::SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch,
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
py::function callback)
: condition_name_(condition_name), num_batch_(num_batch), callback_(callback) {
this->children.push_back(child);


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h View File

@@ -21,7 +21,7 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
@@ -30,10 +30,10 @@ namespace api {

/// \class SyncWaitNode
/// \brief A Dataset derived class to represent SyncWaitNode dataset
class SyncWaitNode : public Dataset {
class SyncWaitNode : public DatasetNode {
public:
/// \brief Constructor
explicit SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch,
explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
py::function callback);

/// \brief Destructor


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc View File

@@ -27,7 +27,7 @@ namespace mindspore {
namespace dataset {
namespace api {
// Constructor for TakeNode
TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) {
TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) {
this->children.push_back(child);
}



+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h View File

@@ -21,17 +21,17 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

namespace api {

class TakeNode : public Dataset {
class TakeNode : public DatasetNode {
public:
/// \brief Constructor
explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count);
explicit TakeNode(std::shared_ptr<DatasetNode> child, int32_t count);

/// \brief Destructor
~TakeNode() = default;


+ 14
- 11
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc View File

@@ -28,14 +28,8 @@ namespace dataset {
namespace api {

// Constructor for TransferNode
TransferNode::TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
const std::string &device_type, bool send_epoch_end)
: queue_name_(queue_name),
device_id_(device_id),
device_type_(device_type),
prefetch_size_(16),
send_epoch_end_(send_epoch_end),
total_batch_(0) {
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end)
: prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) {
this->children.push_back(child);
}

@@ -48,6 +42,15 @@ Status TransferNode::ValidateParams() {

// Function to build TransferNode
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
// Get a uuid for queue name
queue_name_ = Services::GetUniqueID();
// TODO(CRC):
// Get device type from ms context
device_type_ = "CPU";
// Get device ID from children
device_id_ = 0;
RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_));

// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

@@ -67,13 +70,13 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
}

// Function to get the device_id
Status TransferNode::get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id) {
Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) {
// Get device id according to the type of dataset
Status rc = ds->GetShardId(device_id);
if (rc != Status::OK()) {
// Get device id from the child node
if (ds->children.size()) {
ds = ds->children[0];
if (ds->Children().size()) {
ds = ds->Children()[0];
return TransferNode::get_distribution(ds, device_id);
} else {
std::string err_msg = "Unknown dataset type.";


+ 4
- 5
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h View File

@@ -21,18 +21,17 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {

namespace api {

class TransferNode : public Dataset {
class TransferNode : public DatasetNode {
public:
/// \brief Constructor
TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
const std::string &device_type, bool send_epoch_end);
TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end);

/// \brief Destructor
~TransferNode() = default;
@@ -45,7 +44,7 @@ class TransferNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

static Status get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id);
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);

private:
std::string queue_name_;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc View File

@@ -27,7 +27,7 @@ namespace mindspore {
namespace dataset {
namespace api {

ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) : datasets_(datasets) {
for (auto dataset : datasets_) {
this->children.push_back(dataset);
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h View File

@@ -21,16 +21,16 @@
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {

class ZipNode : public Dataset {
class ZipNode : public DatasetNode {
public:
/// \brief Constructor
explicit ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
explicit ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets);

/// \brief Destructor
~ZipNode() = default;
@@ -44,7 +44,7 @@ class ZipNode : public Dataset {
Status ValidateParams() override;

private:
std::vector<std::shared_ptr<Dataset>> datasets_;
std::vector<std::shared_ptr<DatasetNode>> datasets_;
};

} // namespace api


+ 27
- 0
mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc View File

@@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/python_runtime_context.h"
#include "pybind11/pybind11.h"

namespace mindspore::dataset {

Status PythonRuntimeContext::Terminate() {
// Release GIL before joining all threads
py::gil_scoped_release gil_release;
return tree_consumer_->Terminate();
}
} // namespace mindspore::dataset

+ 48
- 0
mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PYTHON_RUNTIME_CONTEXT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PYTHON_RUNTIME_CONTEXT_H_

#include <memory>
#include <utility>
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"
#include "minddata/dataset/engine/runtime_context.h"

namespace mindspore::dataset {
class RuntimeContext;

/// Class the represents single runtime instance which can consume data from a data pipeline
class PythonRuntimeContext : public RuntimeContext {
public:
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
Status Terminate() override;

~PythonRuntimeContext() {
Terminate();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}

PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); }
};

} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PYTHON_RUNTIME_CONTEXT_H_

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc View File

@@ -19,7 +19,7 @@
#include <utility>
namespace mindspore::dataset {

void RuntimeContext::AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer) {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer);
}
} // namespace mindspore::dataset

+ 5
- 3
mindspore/ccsrc/minddata/dataset/engine/runtime_context.h View File

@@ -40,14 +40,16 @@ class RuntimeContext {

/// Set the tree consumer
/// \param tree_consumer to be assigned
void AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer);
void AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer);

/// Get the tree consumer
/// \return Raw pointer to the tree consumer.
TreeConsumer *GetConsumer() { return tree_consumer_.get(); }

private:
std::unique_ptr<TreeConsumer> tree_consumer_;
~RuntimeContext() { Terminate(); }

protected:
std::shared_ptr<TreeConsumer> tree_consumer_;
};

} // namespace mindspore::dataset


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc View File

@@ -22,7 +22,7 @@
namespace mindspore {
namespace dataset {

Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::Dataset> root_ir, int32_t num_epoch) {
Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::DatasetNode> root_ir, int32_t num_epoch) {
// Check whether this function has been called before. If so, return failure
CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built.");
RETURN_UNEXPECTED_IF_NULL(root_ir);
@@ -65,7 +65,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
return Status::OK();
}

Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op) {
Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
// validate the op can be built first before building the DatasetOp
RETURN_IF_NOT_OK(ir->ValidateParams());
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
@@ -80,7 +80,7 @@ Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_p
}

// Build the children of ir, once they return, add the return value to *op
for (std::shared_ptr<api::Dataset> child_ir : ir->children) {
for (const auto &child_ir : ir->Children()) {
std::shared_ptr<DatasetOp> child_op;
RETURN_IF_NOT_OK(DFSBuildTree(child_ir, &child_op));
RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h View File

@@ -24,12 +24,12 @@
#include <vector>

#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

namespace mindspore {
namespace dataset {
namespace api {
class Dataset;
class DatasetNode;
}
class TreeAdapter {
public:
@@ -40,7 +40,7 @@ class TreeAdapter {
// This will construct an ExeTree from a Dataset root and Prepare() the ExeTree
// This function is only meant to be called once and needs to be called before GetNext
// ExeTree will be launched when the first GetNext is called
Status BuildAndPrepare(std::shared_ptr<api::Dataset> root, int32_t num_epoch = -1);
Status BuildAndPrepare(std::shared_ptr<api::DatasetNode> root, int32_t num_epoch = -1);

// This is the main method TreeConsumer uses to interact with TreeAdapter
// 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared)
@@ -62,7 +62,7 @@ class TreeAdapter {
private:
// This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In
// such case, the first node is returned. Op is added as child when the current function returns.
Status DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op);
Status DFSBuildTree(std::shared_ptr<api::DatasetNode> ir, std::shared_ptr<DatasetOp> *op);

std::unique_ptr<DataBuffer> cur_db_;
std::unordered_map<std::string, int32_t> column_name_map_;


+ 704
- 643
mindspore/ccsrc/minddata/dataset/include/datasets.h
File diff suppressed because it is too large
View File


+ 1
- 1
mindspore/ccsrc/minddata/dataset/include/iterator.h View File

@@ -49,7 +49,7 @@ class Iterator {
Iterator() : consumer_(nullptr) {}

/// \brief Destructor
~Iterator() = default;
~Iterator() { Stop(); }

/// \brief Method for building and launching the pipeline.
/// \param[in] ops - a vector of DatasetOp in the data pipeline.


+ 59
- 51
mindspore/lite/minddata/CMakeLists.txt View File

@@ -82,30 +82,30 @@ AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/image/lite_cv MINDDATA_KERNELS_IMA

if (BUILD_MINDDATA STREQUAL "full")
include_directories("${CMAKE_SOURCE_DIR}/../ccsrc/minddata/dataset/kernels/image")
list(REMOVE_ITEM MINDDATA_API_SRC_FILES
"${MINDDATA_DIR}/api/text.cc"
)
list(REMOVE_ITEM MINDDATA_API_SRC_FILES
"${MINDDATA_DIR}/api/text.cc"
)

list(REMOVE_ITEM MINDDATA_CALLBACK_SRC_FILES
"${MINDDATA_DIR}/callback/py_ds_callback.cc"
)
list(REMOVE_ITEM MINDDATA_CALLBACK_SRC_FILES
"${MINDDATA_DIR}/callback/py_ds_callback.cc"
)

list(REMOVE_ITEM MINDDATA_CORE_SRC_FILES
"${MINDDATA_DIR}/core/cv_tensor.cc"
)
"${MINDDATA_DIR}/core/cv_tensor.cc"
)

list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc")
list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SRC_FILES
"${MINDDATA_DIR}/engine/datasetops/build_sentence_piece_vocab_op.cc"
"${MINDDATA_DIR}/engine/datasetops/filter_op.cc"
"${MINDDATA_DIR}/engine/datasetops/barrier_op.cc"
"${MINDDATA_DIR}/engine/datasetops/bucket_batch_by_length_op.cc"
"${MINDDATA_DIR}/engine/datasetops/build_vocab_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_merge_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_base_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_lookup_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_op.cc"
)
"${MINDDATA_DIR}/engine/datasetops/build_sentence_piece_vocab_op.cc"
"${MINDDATA_DIR}/engine/datasetops/filter_op.cc"
"${MINDDATA_DIR}/engine/datasetops/barrier_op.cc"
"${MINDDATA_DIR}/engine/datasetops/bucket_batch_by_length_op.cc"
"${MINDDATA_DIR}/engine/datasetops/build_vocab_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_merge_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_base_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_lookup_op.cc"
"${MINDDATA_DIR}/engine/datasetops/cache_op.cc"
)

list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SOURCE_SRC_FILES
"${MINDDATA_DIR}/engine/datasetops/source/generator_op.cc"
@@ -161,47 +161,55 @@ if (BUILD_MINDDATA STREQUAL "full")
"${MINDDATA_DIR}/kernels/image/random_crop_and_resize_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_crop_decode_resize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_crop_and_resize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_crop_op.cc"
"${MINDDATA_DIR}/kernels/image/random_crop_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_horizontal_flip_op.cc"
"${MINDDATA_DIR}/kernels/image/random_horizontal_flip_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_posterize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_resize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_rotation_op.cc"
"${MINDDATA_DIR}/kernels/image/random_select_subpolicy_op.cc"
"${MINDDATA_DIR}/kernels/image/random_solarize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_vertical_flip_op.cc"
"${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc"
"${MINDDATA_DIR}/kernels/image/rescale_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc"
"${MINDDATA_DIR}/kernels/image/sharpness_op.cc"
"${MINDDATA_DIR}/kernels/image/solarize_op.cc"
"${MINDDATA_DIR}/kernels/image/swap_red_blue_op.cc"
"${MINDDATA_DIR}/kernels/image/uniform_aug_op.cc"
"${MINDDATA_DIR}/kernels/image/resize_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_resize_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_color_op.cc"
)
"${MINDDATA_DIR}/kernels/image/random_crop_op.cc"
"${MINDDATA_DIR}/kernels/image/random_crop_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_horizontal_flip_op.cc"
"${MINDDATA_DIR}/kernels/image/random_horizontal_flip_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_posterize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_resize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_rotation_op.cc"
"${MINDDATA_DIR}/kernels/image/random_select_subpolicy_op.cc"
"${MINDDATA_DIR}/kernels/image/random_solarize_op.cc"
"${MINDDATA_DIR}/kernels/image/random_vertical_flip_op.cc"
"${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc"
"${MINDDATA_DIR}/kernels/image/rescale_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc"
"${MINDDATA_DIR}/kernels/image/sharpness_op.cc"
"${MINDDATA_DIR}/kernels/image/solarize_op.cc"
"${MINDDATA_DIR}/kernels/image/swap_red_blue_op.cc"
"${MINDDATA_DIR}/kernels/image/uniform_aug_op.cc"
"${MINDDATA_DIR}/kernels/image/resize_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_resize_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_color_op.cc"
)

list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES
"${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/build_sentence_piece_vocab_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc"
)
"${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/build_sentence_piece_vocab_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc"
)
list(REMOVE_ITEM MINDDATA_ENGINE_CONSUMERS_SRC_FILES
"${MINDDATA_DIR}/engine/consumers/python_tree_consumer.cc"
)

list(REMOVE_ITEM MINDDATA_ENGINE_SRC_FILES
"${MINDDATA_DIR}/engine/python_runtime_context.cc"
)

list(REMOVE_ITEM MINDDATA_KERNELS_DATA_SRC_FILES
"${MINDDATA_DIR}/kernels/data/unique_op.cc"
)
"${MINDDATA_DIR}/kernels/data/unique_op.cc"
)
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")

if (BUILD_MINDDATA_EXAMPLE AND (PLATFORM_ARM32 OR PLATFORM_ARM64))
set(MINDDATA_EXAMPLE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/example/jni-example.cc)
endif()
set(MINDDATA_EXAMPLE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/example/jni-example.cc)
endif ()

add_library(minddata-lite SHARED
${MINDDATA_API_SRC_FILES}
${MINDDATA_API_SRC_FILES}
${MINDDATA_CALLBACK_SRC_FILES}
${MINDDATA_CORE_SRC_FILES}
${MINDDATA_ENGINE_SRC_FILES}


+ 1
- 1
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -1093,7 +1093,7 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetDefault) {
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 7));
EXPECT_NE(ds, nullptr);

// Create a Take operation on ds, dafault count = -1
// Create a Take operation on ds, default count = -1
ds = ds->Take();
EXPECT_NE(ds, nullptr);



+ 2
- 2
tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc View File

@@ -429,7 +429,7 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetWithNullSampler) {
schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1});
std::shared_ptr<Dataset> ds = RandomData(50, schema, {}, nullptr);
// Expect failure: sampler can not be nullptr
EXPECT_EQ(ds, nullptr);
EXPECT_EQ(ds->CreateIterator(), nullptr);
}

TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) {
@@ -441,5 +441,5 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) {
schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1});
std::shared_ptr<Dataset> ds = RandomData(50, schema, {"image", "image"});
// Expect failure: duplicate column names
EXPECT_EQ(ds, nullptr);
EXPECT_EQ(ds->CreateIterator(), nullptr);
}

+ 7
- 7
tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc View File

@@ -443,34 +443,34 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) {

// This case expected to fail because the list of dir_path cannot be empty.
std::shared_ptr<Dataset> ds1 = TFRecord({});
EXPECT_EQ(ds1, nullptr);
EXPECT_EQ(ds1->CreateIterator(), nullptr);

// This case expected to fail because the file in dir_path is not exist.
std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::shared_ptr<Dataset> ds2 = TFRecord({file_path, "noexist.data"});
EXPECT_EQ(ds2, nullptr);
EXPECT_EQ(ds2->CreateIterator(), nullptr);

// This case expected to fail because the file of schema is not exist.
std::shared_ptr<Dataset> ds4 = TFRecord({file_path, "notexist.json"});
EXPECT_EQ(ds4, nullptr);
EXPECT_EQ(ds4->CreateIterator(), nullptr);

// This case expected to fail because num_samples is negative.
std::shared_ptr<Dataset> ds5 = TFRecord({file_path}, "", {}, -1);
EXPECT_EQ(ds5, nullptr);
EXPECT_EQ(ds5->CreateIterator(), nullptr);

// This case expected to fail because num_shards is negative.
std::shared_ptr<Dataset> ds6 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 0);
EXPECT_EQ(ds6, nullptr);
EXPECT_EQ(ds6->CreateIterator(), nullptr);

// This case expected to fail because shard_id is out_of_bound.
std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3);
EXPECT_EQ(ds7, nullptr);
EXPECT_EQ(ds7->CreateIterator(), nullptr);

// This case expected to fail because the provided number of files < num_shards in file-based sharding.
std::string file_path1 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
std::string file_path2 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data";
std::shared_ptr<Dataset> ds8 = TFRecord({file_path1, file_path2}, "", {}, 0, ShuffleMode::kFalse, 3);
EXPECT_EQ(ds8, nullptr);
EXPECT_EQ(ds8->CreateIterator(), nullptr);
}

TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {


+ 3
- 3
tests/ut/cpp/dataset/tree_adapter_test.cc View File

@@ -56,7 +56,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {

mindspore::dataset::TreeAdapter tree_adapter;

Status rc = tree_adapter.BuildAndPrepare(ds, 1);
Status rc = tree_adapter.BuildAndPrepare(ds->IRNode(), 1);

EXPECT_TRUE(rc.IsOk());

@@ -91,7 +91,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {

mindspore::dataset::TreeAdapter tree_adapter;

Status rc = tree_adapter.BuildAndPrepare(ds, 2);
Status rc = tree_adapter.BuildAndPrepare(ds->IRNode(), 2);
EXPECT_TRUE(rc.IsOk());

const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap();
@@ -128,7 +128,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {

mindspore::dataset::TreeAdapter tree_adapter;

Status rc = tree_adapter.BuildAndPrepare(ds, 2);
Status rc = tree_adapter.BuildAndPrepare(ds->IRNode(), 2);

EXPECT_TRUE(rc.IsOk());



Loading…
Cancel
Save