From 24f35ae3229f492379781dee9d7e6f1b3bcaf6e5 Mon Sep 17 00:00:00 2001 From: alex-yuyue Date: Mon, 19 Oct 2020 23:09:17 -0400 Subject: [PATCH] Change IR to take dataset Obj as input and change name *Dataset to *Node Signed-off-by: alex-yuyue --- .../ccsrc/minddata/dataset/api/datasets.cc | 661 +++++++++--------- .../ccsrc/minddata/dataset/include/datasets.h | 541 +++++++------- 2 files changed, 598 insertions(+), 604 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 15cf38430c..09247f4d6b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -128,165 +128,165 @@ std::shared_ptr Schema(const std::string &schema_file) { // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS // (In alphabetical order) -// Function to create a AlbumDataset. -std::shared_ptr Album(const std::string &dataset_dir, const std::string &data_schema, - const std::vector &column_names, bool decode, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, data_schema, column_names, decode, sampler); +// Function to create a AlbumNode. +std::shared_ptr Album(const std::string &dataset_dir, const std::string &data_schema, + const std::vector &column_names, bool decode, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, data_schema, column_names, decode, sampler); return ds->ValidateParams() ? ds : nullptr; } -// Function to create a CelebADataset. -std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler, bool decode, - const std::set &extensions) { - auto ds = std::make_shared(dataset_dir, usage, sampler, decode, extensions); +// Function to create a CelebANode. +std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &usage, + const std::shared_ptr &sampler, bool decode, + const std::set &extensions) { + auto ds = std::make_shared(dataset_dir, usage, sampler, decode, extensions); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a Cifar10Dataset. -std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, usage, sampler); +// Function to create a Cifar10Node. +std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, usage, sampler); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a Cifar100Dataset. -std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, usage, sampler); +// Function to create a Cifar100Node. +std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, usage, sampler); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a CLUEDataset. -std::shared_ptr CLUE(const std::vector &clue_files, const std::string &task, - const std::string &usage, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id) { - auto ds = std::make_shared(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id); +// Function to create a CLUENode. +std::shared_ptr CLUE(const std::vector &clue_files, const std::string &task, + const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, + int32_t shard_id) { + auto ds = std::make_shared(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a CocoDataset. -std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, - const std::string &task, const bool &decode, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, annotation_file, task, decode, sampler); +// Function to create a CocoNode. +std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, + const std::string &task, const bool &decode, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, annotation_file, task, decode, sampler); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a CSVDataset. -std::shared_ptr CSV(const std::vector &dataset_files, char field_delim, - const std::vector> &column_defaults, - const std::vector &column_names, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id) { - auto ds = std::make_shared(dataset_files, field_delim, column_defaults, column_names, num_samples, - shuffle, num_shards, shard_id); +// Function to create a CSVNode. +std::shared_ptr CSV(const std::vector &dataset_files, char field_delim, + const std::vector> &column_defaults, + const std::vector &column_names, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id) { + auto ds = std::make_shared(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle, + num_shards, shard_id); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a ImageFolderDataset. -std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode, - const std::shared_ptr &sampler, - const std::set &extensions, - const std::map &class_indexing) { +// Function to create a ImageFolderNode. +std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode, + const std::shared_ptr &sampler, + const std::set &extensions, + const std::map &class_indexing) { // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. bool recursive = false; - // Create logical representation of ImageFolderDataset. - auto ds = std::make_shared(dataset_dir, decode, sampler, recursive, extensions, class_indexing); + // Create logical representation of ImageFolderNode. + auto ds = std::make_shared(dataset_dir, decode, sampler, recursive, extensions, class_indexing); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } #ifndef ENABLE_ANDROID -// Function to create a ManifestDataset. -std::shared_ptr Manifest(const std::string &dataset_file, const std::string &usage, - const std::shared_ptr &sampler, - const std::map &class_indexing, bool decode) { - auto ds = std::make_shared(dataset_file, usage, sampler, class_indexing, decode); +// Function to create a ManifestNode. +std::shared_ptr Manifest(const std::string &dataset_file, const std::string &usage, + const std::shared_ptr &sampler, + const std::map &class_indexing, bool decode) { + auto ds = std::make_shared(dataset_file, usage, sampler, class_indexing, decode); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } #endif -// Function to create a MindDataDataset. -std::shared_ptr MindData(const std::string &dataset_file, const std::vector &columns_list, - const std::shared_ptr &sampler, nlohmann::json padded_sample, - int64_t num_padded) { - auto ds = std::make_shared(dataset_file, columns_list, sampler, padded_sample, num_padded); +// Function to create a MindDataNode. +std::shared_ptr MindData(const std::string &dataset_file, const std::vector &columns_list, + const std::shared_ptr &sampler, nlohmann::json padded_sample, + int64_t num_padded) { + auto ds = std::make_shared(dataset_file, columns_list, sampler, padded_sample, num_padded); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a MindDataDataset. -std::shared_ptr MindData(const std::vector &dataset_files, - const std::vector &columns_list, - const std::shared_ptr &sampler, nlohmann::json padded_sample, - int64_t num_padded) { - auto ds = std::make_shared(dataset_files, columns_list, sampler, padded_sample, num_padded); +// Function to create a MindDataNode. +std::shared_ptr MindData(const std::vector &dataset_files, + const std::vector &columns_list, + const std::shared_ptr &sampler, nlohmann::json padded_sample, + int64_t num_padded) { + auto ds = std::make_shared(dataset_files, columns_list, sampler, padded_sample, num_padded); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a MnistDataset. -std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, usage, sampler); +// Function to create a MnistNode. +std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, usage, sampler); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } // Function to overload "+" operator to concat two datasets -std::shared_ptr operator+(const std::shared_ptr &datasets1, - const std::shared_ptr &datasets2) { - std::shared_ptr ds = std::make_shared(std::vector({datasets2, datasets1})); +std::shared_ptr operator+(const std::shared_ptr &datasets1, + const std::shared_ptr &datasets2) { + std::shared_ptr ds = std::make_shared(std::vector({datasets2, datasets1})); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } -// Function to create a TextFileDataset. -std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { - auto ds = std::make_shared(dataset_files, num_samples, shuffle, num_shards, shard_id); +// Function to create a TextFileNode. +std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { + auto ds = std::make_shared(dataset_files, num_samples, shuffle, num_shards, shard_id); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } #ifndef ENABLE_ANDROID -// Function to create a VOCDataset. -std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, - const std::map &class_indexing, bool decode, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, task, usage, class_indexing, decode, sampler); +// Function to create a VOCNode. +std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, + const std::map &class_indexing, bool decode, + const std::shared_ptr &sampler) { + auto ds = std::make_shared(dataset_dir, task, usage, class_indexing, decode, sampler); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; } #endif -// Function to create a ZipDataset. -std::shared_ptr Zip(const std::vector> &datasets) { - auto ds = std::make_shared(datasets); +// Function to create a ZipNode. +std::shared_ptr Zip(const std::vector> &datasets) { + auto ds = std::make_shared(datasets); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -296,39 +296,35 @@ std::shared_ptr Zip(const std::vector> &dat // (In alphabetical order) // Function to create a Batch dataset -std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remainder) { +std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remainder) { // Default values std::vector cols_to_map = {}; std::map>> pad_map; bool pad = false; - auto ds = std::make_shared(batch_size, drop_remainder, pad, cols_to_map, pad_map); + auto ds = std::make_shared(shared_from_this(), batch_size, drop_remainder, pad, cols_to_map, pad_map); if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } #ifndef ENABLE_ANDROID // Function to create a BucketBatchByLength dataset -std::shared_ptr Dataset::BucketBatchByLength( +std::shared_ptr Dataset::BucketBatchByLength( const std::vector &column_names, const std::vector &bucket_boundaries, const std::vector &bucket_batch_sizes, std::function element_length_function, const std::map>> &pad_info, bool pad_to_bucket_boundary, bool drop_remainder) { - auto ds = std::make_shared(column_names, bucket_boundaries, bucket_batch_sizes, - element_length_function, pad_info, pad_to_bucket_boundary, - drop_remainder); + auto ds = std::make_shared(shared_from_this(), column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, pad_info, + pad_to_bucket_boundary, drop_remainder); if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } @@ -337,14 +333,13 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum const std::pair &freq_range, int64_t top_k, const std::vector &special_tokens, bool special_first) { auto vocab = std::make_shared(); - auto ds = std::make_shared(vocab, columns, freq_range, top_k, special_tokens, special_first); + auto ds = std::make_shared(shared_from_this(), vocab, columns, freq_range, top_k, special_tokens, + special_first); if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - // Run tree here to starting building vocab std::shared_ptr iter = ds->CreateIterator(); if (iter == nullptr) { @@ -363,53 +358,46 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum #endif // Function to create a Concat dataset -std::shared_ptr Dataset::Concat(const std::vector> &datasets) { - auto ds = std::make_shared(datasets); +std::shared_ptr Dataset::Concat(const std::vector> &datasets) { + auto ds = std::make_shared(datasets); ds->children.push_back(shared_from_this()); return ds->ValidateParams() ? ds : nullptr; } // Function to create a Map dataset. -std::shared_ptr Dataset::Map(std::vector> operations, - std::vector input_columns, - std::vector output_columns, - const std::vector &project_columns) { - auto ds = std::make_shared(operations, input_columns, output_columns, project_columns); +std::shared_ptr Dataset::Map(std::vector> operations, + std::vector input_columns, std::vector output_columns, + const std::vector &project_columns) { + auto ds = std::make_shared(shared_from_this(), operations, input_columns, output_columns, project_columns); if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } -// Function to create a ProjectDataset. -std::shared_ptr Dataset::Project(const std::vector &columns) { - auto ds = std::make_shared(columns); +// Function to create a ProjectNode. +std::shared_ptr Dataset::Project(const std::vector &columns) { + auto ds = std::make_shared(shared_from_this(), columns); // Call derived class validation method. if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } -// Function to create a RenameDataset. -std::shared_ptr Dataset::Rename(const std::vector &input_columns, - const std::vector &output_columns) { - auto ds = std::make_shared(input_columns, output_columns); +// Function to create a RenameNode. +std::shared_ptr Dataset::Rename(const std::vector &input_columns, + const std::vector &output_columns) { + auto ds = std::make_shared(shared_from_this(), input_columns, output_columns); // Call derived class validation method. if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } @@ -420,46 +408,40 @@ std::shared_ptr Dataset::Repeat(int32_t count) { return shared_from_this(); } - auto ds = std::make_shared(count); + auto ds = std::make_shared(shared_from_this(), count); if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } // Function to create a ShuffleOp -std::shared_ptr Dataset::Shuffle(int32_t buffer_size) { +std::shared_ptr Dataset::Shuffle(int32_t buffer_size) { // Pass in reshuffle_each_epoch with true - auto ds = std::make_shared(buffer_size, true); + auto ds = std::make_shared(shared_from_this(), buffer_size, true); if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } -// Function to create a SkipDataset. -std::shared_ptr Dataset::Skip(int32_t count) { - auto ds = std::make_shared(count); +// Function to create a SkipNode. +std::shared_ptr Dataset::Skip(int32_t count) { + auto ds = std::make_shared(shared_from_this(), count); // Call derived class validation method. if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } -// Function to create a TakeDataset. +// Function to create a TakeNode. std::shared_ptr Dataset::Take(int32_t count) { // If count is greater than the number of element in dataset or equal to -1, // all the element in dataset will be taken @@ -467,22 +449,20 @@ std::shared_ptr Dataset::Take(int32_t count) { return shared_from_this(); } - auto ds = std::make_shared(count); + auto ds = std::make_shared(shared_from_this(), count); // Call derived class validation method. if (!ds->ValidateParams()) { return nullptr; } - ds->children.push_back(shared_from_this()); - return ds; } // Function to create a Zip dataset -std::shared_ptr Dataset::Zip(const std::vector> &datasets) { +std::shared_ptr Dataset::Zip(const std::vector> &datasets) { // Default values - auto ds = std::make_shared(datasets); + auto ds = std::make_shared(datasets); ds->children.push_back(shared_from_this()); return ds->ValidateParams() ? ds : nullptr; @@ -811,32 +791,32 @@ Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::st // DERIVED DATASET CLASSES LEAF-NODE DATASETS // (In alphabetical order) -// Constructor for AlbumDataset -AlbumDataset::AlbumDataset(const std::string &dataset_dir, const std::string &data_schema, - const std::vector &column_names, bool decode, - const std::shared_ptr &sampler) +// Constructor for AlbumNode +AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, + const std::vector &column_names, bool decode, + const std::shared_ptr &sampler) : dataset_dir_(dataset_dir), schema_path_(data_schema), column_names_(column_names), decode_(decode), sampler_(sampler) {} -Status AlbumDataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumDataset", dataset_dir_)); +Status AlbumNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumDataset", {schema_path_})); + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_})); - RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumNode", sampler_)); if (!column_names_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumDataset", "column_names", column_names_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumNode", "column_names", column_names_)); } return Status::OK(); } -// Function to build AlbumDataset -std::vector> AlbumDataset::Build() { +// Function to build AlbumNode +std::vector> AlbumNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -851,24 +831,24 @@ std::vector> AlbumDataset::Build() { return node_ops; } -// Constructor for CelebADataset -CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler, const bool &decode, - const std::set &extensions) +// Constructor for CelebANode +CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, + const std::shared_ptr &sampler, const bool &decode, + const std::set &extensions) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {} -Status CelebADataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebADataset", dataset_dir_)); +Status CelebANode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebADataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"all", "train", "valid", "test"})); return Status::OK(); } -// Function to build CelebADataset -std::vector> CelebADataset::Build() { +// Function to build CelebANode +std::vector> CelebANode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -884,15 +864,14 @@ std::vector> CelebADataset::Build() { return node_ops; } -// Constructor for Cifar10Dataset -Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string &usage, - std::shared_ptr sampler) +// Constructor for Cifar10Node +Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} -Status Cifar10Dataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_)); +Status Cifar10Node::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Dataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); @@ -900,7 +879,7 @@ Status Cifar10Dataset::ValidateParams() { } // Function to build CifarOp for Cifar10 -std::vector> Cifar10Dataset::Build() { +std::vector> Cifar10Node::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -917,15 +896,15 @@ std::vector> Cifar10Dataset::Build() { return node_ops; } -// Constructor for Cifar100Dataset -Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::string &usage, - std::shared_ptr sampler) +// Constructor for Cifar100Node +Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, + std::shared_ptr sampler) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} -Status Cifar100Dataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_)); +Status Cifar100Node::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Dataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); @@ -933,7 +912,7 @@ Status Cifar100Dataset::ValidateParams() { } // Function to build CifarOp for Cifar100 -std::vector> Cifar100Dataset::Build() { +std::vector> Cifar100Node::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -952,9 +931,9 @@ std::vector> Cifar100Dataset::Build() { return node_ops; } -// Constructor for CLUEDataset -CLUEDataset::CLUEDataset(const std::vector clue_files, std::string task, std::string usage, - int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) +// Constructor for CLUENode +CLUENode::CLUENode(const std::vector clue_files, std::string task, std::string usage, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) : dataset_files_(clue_files), task_(task), usage_(usage), @@ -963,8 +942,8 @@ CLUEDataset::CLUEDataset(const std::vector clue_files, std::string num_shards_(num_shards), shard_id_(shard_id) {} -Status CLUEDataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_)); +Status CLUENode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); std::vector task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}; std::vector usage_list = {"train", "test", "eval"}; @@ -982,18 +961,18 @@ Status CLUEDataset::ValidateParams() { } if (num_samples_ < 0) { - std::string err_msg = "CLUEDataset: Invalid number of samples: " + num_samples_; + std::string err_msg = "CLUENode: Invalid number of samples: " + num_samples_; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)); + RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUENode", num_shards_, shard_id_)); return Status::OK(); } // Function to split string based on a character delimiter -std::vector CLUEDataset::split(const std::string &s, char delim) { +std::vector CLUENode::split(const std::string &s, char delim) { std::vector res; std::stringstream ss(s); std::string item; @@ -1004,8 +983,8 @@ std::vector CLUEDataset::split(const std::string &s, char delim) { return res; } -// Function to build CLUEDataset -std::vector> CLUEDataset::Build() { +// Function to build CLUENode +std::vector> CLUENode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; std::map key_map; @@ -1142,15 +1121,15 @@ std::vector> CLUEDataset::Build() { return node_ops; } -// Constructor for CocoDataset -CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, - const bool &decode, const std::shared_ptr &sampler) +// Constructor for CocoNode +CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, + const bool &decode, const std::shared_ptr &sampler) : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} -Status CocoDataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoDataset", dataset_dir_)); +Status CocoNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_)); Path annotation_file(annotation_file_); if (!annotation_file.Exists()) { @@ -1170,8 +1149,8 @@ Status CocoDataset::ValidateParams() { return Status::OK(); } -// Function to build CocoDataset -std::vector> CocoDataset::Build() { +// Function to build CocoNode +std::vector> CocoNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1221,7 +1200,7 @@ std::vector> CocoDataset::Build() { schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); break; default: - MS_LOG(ERROR) << "CocoDataset::Build : Invalid task type: " << task_type; + MS_LOG(ERROR) << "CocoNode::Build : Invalid task type: " << task_type; return {}; } std::shared_ptr op = @@ -1231,11 +1210,11 @@ std::vector> CocoDataset::Build() { return node_ops; } -// Constructor for CSVDataset -CSVDataset::CSVDataset(const std::vector &csv_files, char field_delim, - const std::vector> &column_defaults, - const std::vector &column_names, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id) +// Constructor for CSVNode +CSVNode::CSVNode(const std::vector &csv_files, char field_delim, + const std::vector> &column_defaults, + const std::vector &column_names, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id) : dataset_files_(csv_files), field_delim_(field_delim), column_defaults_(column_defaults), @@ -1245,38 +1224,38 @@ CSVDataset::CSVDataset(const std::vector &csv_files, char field_del num_shards_(num_shards), shard_id_(shard_id) {} -Status CSVDataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVDataset", dataset_files_)); +Status CSVNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') { - std::string err_msg = "CSVDataset: The field delimiter should not be \", \\r, \\n"; + std::string err_msg = "CSVNode: The field delimiter should not be \", \\r, \\n"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (num_samples_ < 0) { - std::string err_msg = "CSVDataset: Invalid number of samples: " + num_samples_; + std::string err_msg = "CSVNode: Invalid number of samples: " + num_samples_; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVDataset", num_shards_, shard_id_)); + RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVNode", num_shards_, shard_id_)); if (find(column_defaults_.begin(), column_defaults_.end(), nullptr) != column_defaults_.end()) { - std::string err_msg = "CSVDataset: column_default should not be null."; + std::string err_msg = "CSVNode: column_default should not be null."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (!column_names_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVDataset", "column_names", column_names_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVNode", "column_names", column_names_)); } return Status::OK(); } -// Function to build CSVDataset -std::vector> CSVDataset::Build() { +// Function to build CSVNode +std::vector> CSVNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1322,9 +1301,9 @@ std::vector> CSVDataset::Build() { return node_ops; } -ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, - bool recursive, std::set extensions, - std::map class_indexing) +ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr sampler, + bool recursive, std::set extensions, + std::map class_indexing) : dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler), @@ -1332,15 +1311,15 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std class_indexing_(class_indexing), exts_(extensions) {} -Status ImageFolderDataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_)); +Status ImageFolderNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderNode", sampler_)); return Status::OK(); } -std::vector> ImageFolderDataset::Build() { +std::vector> ImageFolderNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1359,12 +1338,12 @@ std::vector> ImageFolderDataset::Build() { } #ifndef ENABLE_ANDROID -ManifestDataset::ManifestDataset(const std::string &dataset_file, const std::string &usage, - const std::shared_ptr &sampler, - const std::map &class_indexing, bool decode) +ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage, + const std::shared_ptr &sampler, + const std::map &class_indexing, bool decode) : dataset_file_(dataset_file), usage_(usage), decode_(decode), class_index_(class_indexing), sampler_(sampler) {} -Status ManifestDataset::ValidateParams() { +Status ManifestNode::ValidateParams() { std::vector forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; for (char c : dataset_file_) { auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); @@ -1382,7 +1361,7 @@ Status ManifestDataset::ValidateParams() { RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_)); std::vector usage_list = {"train", "eval", "inference"}; if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) { @@ -1394,7 +1373,7 @@ Status ManifestDataset::ValidateParams() { return Status::OK(); } -std::vector> ManifestDataset::Build() { +std::vector> ManifestNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1416,10 +1395,8 @@ std::vector> ManifestDataset::Build() { #endif #ifndef ENABLE_ANDROID -MindDataDataset::MindDataDataset(const std::vector &dataset_files, - const std::vector &columns_list, - const std::shared_ptr &sampler, nlohmann::json padded_sample, - int64_t num_padded) +MindDataNode::MindDataNode(const std::vector &dataset_files, const std::vector &columns_list, + const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded) : dataset_file_(std::string()), dataset_files_(dataset_files), search_for_pattern_(false), @@ -1429,9 +1406,8 @@ MindDataDataset::MindDataDataset(const std::vector &dataset_files, sample_bytes_({}), num_padded_(num_padded) {} -MindDataDataset::MindDataDataset(const std::string &dataset_file, const std::vector &columns_list, - const std::shared_ptr &sampler, nlohmann::json padded_sample, - int64_t num_padded) +MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector &columns_list, + const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded) : dataset_file_(dataset_file), dataset_files_({}), search_for_pattern_(true), @@ -1441,10 +1417,10 @@ MindDataDataset::MindDataDataset(const std::string &dataset_file, const std::vec sample_bytes_({}), num_padded_(num_padded) {} -Status MindDataDataset::ValidateParams() { +Status MindDataNode::ValidateParams() { if (!search_for_pattern_ && dataset_files_.size() > 4096) { std::string err_msg = - "MindDataDataset: length of dataset_file must be less than or equal to 4096, dataset_file length: " + + "MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " + std::to_string(dataset_file_.size()); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); @@ -1452,30 +1428,29 @@ Status MindDataDataset::ValidateParams() { std::vector dataset_file_vec = search_for_pattern_ ? std::vector{dataset_file_} : dataset_files_; - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataDataset", dataset_file_vec)); + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataNode", dataset_file_vec)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataNode", sampler_)); if (!columns_list_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataDataset", "columns_list", columns_list_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataNode", "columns_list", columns_list_)); } if (padded_sample_ != nullptr) { if (num_padded_ < 0) { std::string err_msg = - "MindDataDataset: num_padded must be greater than or equal to zero, num_padded: " + std::to_string(num_padded_); + "MindDataNode: num_padded must be greater than or equal to zero, num_padded: " + std::to_string(num_padded_); MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (columns_list_.empty()) { - std::string err_msg = "MindDataDataset: padded_sample is specified and requires columns_list as well"; + std::string err_msg = "MindDataNode: padded_sample is specified and requires columns_list as well"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } for (std::string &column : columns_list_) { if (padded_sample_.find(column) == padded_sample_.end()) { - std::string err_msg = - "MindDataDataset: " + column + " in columns_list does not match any column in padded_sample"; + std::string err_msg = "MindDataNode: " + column + " in columns_list does not match any column in padded_sample"; MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -1483,7 +1458,7 @@ Status MindDataDataset::ValidateParams() { } if (num_padded_ > 0) { if (padded_sample_ == nullptr) { - std::string err_msg = "MindDataDataset: num_padded is specified but padded_sample is not"; + std::string err_msg = "MindDataNode: num_padded is specified but padded_sample is not"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -1493,13 +1468,13 @@ Status MindDataDataset::ValidateParams() { } // Helper function to create runtime sampler for minddata dataset -Status MindDataDataset::BuildMindDatasetSamplerChain( - const std::shared_ptr &sampler, std::vector> *operators_, - int64_t num_padded) { +Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr &sampler, + std::vector> *operators_, + int64_t num_padded) { std::shared_ptr op = sampler->BuildForMindDataset(); if (op == nullptr) { std::string err_msg = - "MindDataDataset: Unsupported sampler is supplied for MindDataset. Supported sampler list: " + "MindDataNode: Unsupported sampler is supplied for MindDataset. Supported sampler list: " "SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); @@ -1523,11 +1498,9 @@ Status MindDataDataset::BuildMindDatasetSamplerChain( } // Helper function to set sample_bytes from py::byte type -void MindDataDataset::SetSampleBytes(std::map *sample_bytes) { - sample_bytes_ = *sample_bytes; -} +void MindDataNode::SetSampleBytes(std::map *sample_bytes) { sample_bytes_ = *sample_bytes; } -std::vector> MindDataDataset::Build() { +std::vector> MindDataNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1555,20 +1528,20 @@ std::vector> MindDataDataset::Build() { } #endif -MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr sampler) +MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler) : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} -Status MnistDataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistDataset", dataset_dir_)); +Status MnistNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"})); return Status::OK(); } -std::vector> MnistDataset::Build() { +std::vector> MnistNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1584,30 +1557,30 @@ std::vector> MnistDataset::Build() { return node_ops; } -// ValideParams for RandomDataset -Status RandomDataset::ValidateParams() { +// ValideParams for RandomNode +Status RandomNode::ValidateParams() { if (total_rows_ < 0) { - std::string err_msg = "RandomDataset: total_rows must be greater than or equal 0, now get " + total_rows_; + std::string err_msg = "RandomNode: total_rows must be greater than or equal 0, now get " + total_rows_; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomNode", sampler_)); if (!columns_list_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomDataset", "columns_list", columns_list_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_)); } return Status::OK(); } -int32_t RandomDataset::GenRandomInt(int32_t min, int32_t max) { +int32_t RandomNode::GenRandomInt(int32_t min, int32_t max) { std::uniform_int_distribution uniDist(min, max); return uniDist(rand_gen_); } -// Build for RandomDataset -std::vector> RandomDataset::Build() { +// Build for RandomNode +std::vector> RandomNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1652,31 +1625,31 @@ std::vector> RandomDataset::Build() { return node_ops; } -// Constructor for TextFileDataset -TextFileDataset::TextFileDataset(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id) +// Constructor for TextFileNode +TextFileNode::TextFileNode(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id) : dataset_files_(dataset_files), num_samples_(num_samples), shuffle_(shuffle), num_shards_(num_shards), shard_id_(shard_id) {} -Status TextFileDataset::ValidateParams() { - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_)); +Status TextFileNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); if (num_samples_ < 0) { - std::string err_msg = "TextFileDataset: Invalid number of samples: " + num_samples_; + std::string err_msg = "TextFileNode: Invalid number of samples: " + num_samples_; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileDataset", num_shards_, shard_id_)); + RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileNode", num_shards_, shard_id_)); return Status::OK(); } -// Function to build TextFileDataset -std::vector> TextFileDataset::Build() { +// Function to build TextFileNode +std::vector> TextFileNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1717,11 +1690,11 @@ std::vector> TextFileDataset::Build() { } #ifndef ENABLE_ANDROID -// Validator for TFRecordDataset -Status TFRecordDataset::ValidateParams() { return Status::OK(); } +// Validator for TFRecordNode +Status TFRecordNode::ValidateParams() { return Status::OK(); } -// Function to build TFRecordDataset -std::vector> TFRecordDataset::Build() { +// Function to build TFRecordNode +std::vector> TFRecordNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1767,10 +1740,9 @@ std::vector> TFRecordDataset::Build() { return node_ops; } -// Constructor for VOCDataset -VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage, - const std::map &class_indexing, bool decode, - std::shared_ptr sampler) +// Constructor for VOCNode +VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, + const std::map &class_indexing, bool decode, std::shared_ptr sampler) : dataset_dir_(dataset_dir), task_(task), usage_(usage), @@ -1778,7 +1750,7 @@ VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, decode_(decode), sampler_(sampler) {} -Status VOCDataset::ValidateParams() { +Status VOCNode::ValidateParams() { Path dir(dataset_dir_); if (!dir.IsDirectory()) { std::string err_msg = "Invalid dataset path or no dataset path is specified."; @@ -1786,7 +1758,7 @@ Status VOCDataset::ValidateParams() { RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_)); if (task_ == "Segmentation") { if (!class_index_.empty()) { @@ -1816,8 +1788,8 @@ Status VOCDataset::ValidateParams() { return Status::OK(); } -// Function to build VOCDataset -std::vector> VOCDataset::Build() { +// Function to build VOCNode +std::vector> VOCNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1855,15 +1827,18 @@ std::vector> VOCDataset::Build() { // DERIVED DATASET CLASSES LEAF-NODE DATASETS // (In alphabetical order) -BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, - std::map>> pad_map) +BatchNode::BatchNode(std::shared_ptr child, int32_t batch_size, bool drop_remainder, bool pad, + std::vector cols_to_map, + std::map>> pad_map) : batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(pad), cols_to_map_(cols_to_map), - pad_map_(pad_map) {} + pad_map_(pad_map) { + this->children.push_back(child); +} -std::vector> BatchDataset::Build() { +std::vector> BatchNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1881,7 +1856,7 @@ std::vector> BatchDataset::Build() { return node_ops; } -Status BatchDataset::ValidateParams() { +Status BatchNode::ValidateParams() { if (batch_size_ <= 0) { std::string err_msg = "Batch: batch_size should be positive integer, but got: " + batch_size_; MS_LOG(ERROR) << err_msg; @@ -1896,9 +1871,10 @@ Status BatchDataset::ValidateParams() { } #ifndef ENABLE_ANDROID -BucketBatchByLengthDataset::BucketBatchByLengthDataset( - const std::vector &column_names, const std::vector &bucket_boundaries, - const std::vector &bucket_batch_sizes, std::function element_length_function, +BucketBatchByLengthNode::BucketBatchByLengthNode( + std::shared_ptr child, const std::vector &column_names, + const std::vector &bucket_boundaries, const std::vector &bucket_batch_sizes, + std::function element_length_function, const std::map>> &pad_info, bool pad_to_bucket_boundary, bool drop_remainder) : column_names_(column_names), @@ -1907,9 +1883,11 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset( element_length_function_(element_length_function), pad_info_(pad_info), pad_to_bucket_boundary_(pad_to_bucket_boundary), - drop_remainder_(drop_remainder) {} + drop_remainder_(drop_remainder) { + this->children.push_back(child); +} -std::vector> BucketBatchByLengthDataset::Build() { +std::vector> BucketBatchByLengthNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1925,7 +1903,7 @@ std::vector> BucketBatchByLengthDataset::Build() { return node_ops; } -Status BucketBatchByLengthDataset::ValidateParams() { +Status BucketBatchByLengthNode::ValidateParams() { if (element_length_function_ == nullptr && column_names_.size() != 1) { std::string err_msg = "BucketBatchByLength: element_length_function not specified, but not one column name: " + column_names_.size(); @@ -1976,18 +1954,20 @@ Status BucketBatchByLengthDataset::ValidateParams() { return Status::OK(); } -BuildVocabDataset::BuildVocabDataset(std::shared_ptr vocab, const std::vector &columns, - const std::pair &freq_range, int64_t top_k, - const std::vector &special_tokens, bool special_first) +BuildVocabNode::BuildVocabNode(std::shared_ptr child, std::shared_ptr vocab, + const std::vector &columns, const std::pair &freq_range, + int64_t top_k, const std::vector &special_tokens, bool special_first) : vocab_(vocab), columns_(columns), freq_range_(freq_range), top_k_(top_k), special_tokens_(special_tokens), - special_first_(special_first) {} + special_first_(special_first) { + this->children.push_back(child); +} -// Function to build BuildVocabDataset -std::vector> BuildVocabDataset::Build() { +// Function to build BuildVocabNode +std::vector> BuildVocabNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -1998,7 +1978,7 @@ std::vector> BuildVocabDataset::Build() { return node_ops; } -Status BuildVocabDataset::ValidateParams() { +Status BuildVocabNode::ValidateParams() { if (vocab_ == nullptr) { std::string err_msg = "BuildVocab: vocab is null."; MS_LOG(ERROR) << err_msg; @@ -2027,11 +2007,11 @@ Status BuildVocabDataset::ValidateParams() { #endif // Function to build ConcatOp -ConcatDataset::ConcatDataset(const std::vector> &datasets) : datasets_(datasets) { +ConcatNode::ConcatNode(const std::vector> &datasets) : datasets_(datasets) { this->children = datasets_; } -Status ConcatDataset::ValidateParams() { +Status ConcatNode::ValidateParams() { if (datasets_.empty()) { std::string err_msg = "Concat: concatenated datasets are not specified."; MS_LOG(ERROR) << err_msg; @@ -2045,7 +2025,7 @@ Status ConcatDataset::ValidateParams() { return Status::OK(); } -std::vector> ConcatDataset::Build() { +std::vector> ConcatNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2053,14 +2033,17 @@ std::vector> ConcatDataset::Build() { return node_ops; } -MapDataset::MapDataset(std::vector> operations, std::vector input_columns, - std::vector output_columns, const std::vector &project_columns) +MapNode::MapNode(std::shared_ptr child, std::vector> operations, + std::vector input_columns, std::vector output_columns, + const std::vector &project_columns) : operations_(operations), input_columns_(input_columns), output_columns_(output_columns), - project_columns_(project_columns) {} + project_columns_(project_columns) { + this->children.push_back(child); +} -std::vector> MapDataset::Build() { +std::vector> MapNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2084,7 +2067,7 @@ std::vector> MapDataset::Build() { return node_ops; } -Status MapDataset::ValidateParams() { +Status MapNode::ValidateParams() { if (operations_.empty()) { std::string err_msg = "Map: No operation is specified."; MS_LOG(ERROR) << err_msg; @@ -2092,36 +2075,38 @@ Status MapDataset::ValidateParams() { } if (!input_columns_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "input_columns", input_columns_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_)); } if (!output_columns_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "output_columns", output_columns_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "output_columns", output_columns_)); } if (!project_columns_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapDataset", "project_columns", project_columns_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "project_columns", project_columns_)); } return Status::OK(); } // Function to build ProjectOp -ProjectDataset::ProjectDataset(const std::vector &columns) : columns_(columns) {} +ProjectNode::ProjectNode(std::shared_ptr child, const std::vector &columns) : columns_(columns) { + this->children.push_back(child); +} -Status ProjectDataset::ValidateParams() { +Status ProjectNode::ValidateParams() { if (columns_.empty()) { - std::string err_msg = "ProjectDataset: No columns are specified."; + std::string err_msg = "ProjectNode: No columns are specified."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectDataset", "columns", columns_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_)); return Status::OK(); } -std::vector> ProjectDataset::Build() { +std::vector> ProjectNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2130,25 +2115,27 @@ std::vector> ProjectDataset::Build() { } // Function to build RenameOp -RenameDataset::RenameDataset(const std::vector &input_columns, - const std::vector &output_columns) - : input_columns_(input_columns), output_columns_(output_columns) {} +RenameNode::RenameNode(std::shared_ptr child, const std::vector &input_columns, + const std::vector &output_columns) + : input_columns_(input_columns), output_columns_(output_columns) { + this->children.push_back(child); +} -Status RenameDataset::ValidateParams() { +Status RenameNode::ValidateParams() { if (input_columns_.size() != output_columns_.size()) { - std::string err_msg = "RenameDataset: input and output columns must be the same size"; + std::string err_msg = "RenameNode: input and output columns must be the same size"; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameDataset", "input_columns", input_columns_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_)); - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameDataset", "output_columns", output_columns_)); + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_)); return Status::OK(); } -std::vector> RenameDataset::Build() { +std::vector> RenameNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2156,9 +2143,11 @@ std::vector> RenameDataset::Build() { return node_ops; } -RepeatDataset::RepeatDataset(int32_t count) : repeat_count_(count) {} +RepeatNode::RepeatNode(std::shared_ptr child, int32_t count) : repeat_count_(count) { + this->children.push_back(child); +} -std::vector> RepeatDataset::Build() { +std::vector> RepeatNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2166,7 +2155,7 @@ std::vector> RepeatDataset::Build() { return node_ops; } -Status RepeatDataset::ValidateParams() { +Status RepeatNode::ValidateParams() { if (repeat_count_ <= 0 && repeat_count_ != -1) { std::string err_msg = "Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + repeat_count_; @@ -2177,12 +2166,14 @@ Status RepeatDataset::ValidateParams() { return Status::OK(); } -// Constructor for ShuffleDataset -ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch) - : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {} +// Constructor for ShuffleNode +ShuffleNode::ShuffleNode(std::shared_ptr child, int32_t shuffle_size, bool reset_every_epoch) + : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { + this->children.push_back(child); +} // Function to build the ShuffleOp -std::vector> ShuffleDataset::Build() { +std::vector> ShuffleNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2191,10 +2182,10 @@ std::vector> ShuffleDataset::Build() { return node_ops; } -// Function to validate the parameters for ShuffleDataset -Status ShuffleDataset::ValidateParams() { +// Function to validate the parameters for ShuffleNode +Status ShuffleNode::ValidateParams() { if (shuffle_size_ <= 1) { - std::string err_msg = "ShuffleDataset: Invalid input, shuffle_size: " + shuffle_size_; + std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + shuffle_size_; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } @@ -2202,11 +2193,13 @@ Status ShuffleDataset::ValidateParams() { return Status::OK(); } -// Constructor for SkipDataset -SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {} +// Constructor for SkipNode +SkipNode::SkipNode(std::shared_ptr child, int32_t count) : skip_count_(count) { + this->children.push_back(child); +} // Function to build the SkipOp -std::vector> SkipDataset::Build() { +std::vector> SkipNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2214,8 +2207,8 @@ std::vector> SkipDataset::Build() { return node_ops; } -// Function to validate the parameters for SkipDataset -Status SkipDataset::ValidateParams() { +// Function to validate the parameters for SkipNode +Status SkipNode::ValidateParams() { if (skip_count_ <= -1) { std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + skip_count_; MS_LOG(ERROR) << err_msg; @@ -2224,11 +2217,13 @@ Status SkipDataset::ValidateParams() { return Status::OK(); } -// Constructor for TakeDataset -TakeDataset::TakeDataset(int32_t count) : take_count_(count) {} +// Constructor for TakeNode +TakeNode::TakeNode(std::shared_ptr child, int32_t count) : take_count_(count) { + this->children.push_back(child); +} // Function to build the TakeOp -std::vector> TakeDataset::Build() { +std::vector> TakeNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; @@ -2236,8 +2231,8 @@ std::vector> TakeDataset::Build() { return node_ops; } -// Function to validate the parameters for TakeDataset -Status TakeDataset::ValidateParams() { +// Function to validate the parameters for TakeNode +Status TakeNode::ValidateParams() { if (take_count_ <= 0 && take_count_ != -1) { std::string err_msg = "Take: take_count should be either -1 or positive integer, take_count: " + take_count_; MS_LOG(ERROR) << err_msg; @@ -2247,27 +2242,27 @@ Status TakeDataset::ValidateParams() { } // Function to build ZipOp -ZipDataset::ZipDataset(const std::vector> &datasets) : datasets_(datasets) { +ZipNode::ZipNode(const std::vector> &datasets) : datasets_(datasets) { for (auto dataset : datasets_) { this->children.push_back(dataset); } } -Status ZipDataset::ValidateParams() { +Status ZipNode::ValidateParams() { if (datasets_.empty()) { std::string err_msg = "Zip: datasets to zip are not specified."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { - std::string err_msg = "ZipDataset: zip datasets should not be null."; + std::string err_msg = "ZipNode: zip datasets should not be null."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } return Status::OK(); } -std::vector> ZipDataset::Build() { +std::vector> ZipNode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index d261888f29..229de34957 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -55,48 +55,48 @@ class TensorOperation; class SchemaObj; class SamplerObj; // Datasets classes (in alphabetical order) -class AlbumDataset; -class CelebADataset; -class Cifar10Dataset; -class Cifar100Dataset; -class CLUEDataset; -class CocoDataset; -class CSVDataset; +class AlbumNode; +class CelebANode; +class Cifar10Node; +class Cifar100Node; +class CLUENode; +class CocoNode; +class CSVNode; class CsvBase; -class ImageFolderDataset; +class ImageFolderNode; #ifndef ENABLE_ANDROID -class ManifestDataset; -class MindDataDataset; +class ManifestNode; +class MindDataNode; #endif -class MnistDataset; -class RandomDataset; -class TextFileDataset; +class MnistNode; +class RandomNode; +class TextFileNode; #ifndef ENABLE_ANDROID -class TFRecordDataset; -class VOCDataset; +class TFRecordNode; +class VOCNode; #endif // Dataset Op classes (in alphabetical order) -class BatchDataset; +class BatchNode; #ifndef ENABLE_ANDROID -class BucketBatchByLengthDataset; -class BuildVocabDataset; +class BucketBatchByLengthNode; +class BuildVocabNode; #endif -class ConcatDataset; -class MapDataset; -class ProjectDataset; -class RenameDataset; -class RepeatDataset; -class ShuffleDataset; -class SkipDataset; -class TakeDataset; -class ZipDataset; +class ConcatNode; +class MapNode; +class ProjectNode; +class RenameNode; +class RepeatNode; +class ShuffleNode; +class SkipNode; +class TakeNode; +class ZipNode; /// \brief Function to create a SchemaObj /// \param[in] schema_file Path of schema file /// \return Shared pointer to the current schema std::shared_ptr Schema(const std::string &schema_file = ""); -/// \brief Function to create an AlbumDataset +/// \brief Function to create an AlbumNode /// \notes The generated dataset is specified through setting a schema /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] data_schema Path to dataset schema file @@ -106,11 +106,11 @@ std::shared_ptr Schema(const std::string &schema_file = ""); /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset -std::shared_ptr Album(const std::string &dataset_dir, const std::string &data_schema, - const std::vector &column_names = {}, bool decode = false, - const std::shared_ptr &sampler = RandomSampler()); +std::shared_ptr Album(const std::string &dataset_dir, const std::string &data_schema, + const std::vector &column_names = {}, bool decode = false, + const std::shared_ptr &sampler = RandomSampler()); -/// \brief Function to create a CelebADataset +/// \brief Function to create a CelebANode /// \notes The generated dataset has two columns ['image', 'attr']. /// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type. /// \param[in] dataset_dir Path to the root directory that contains the dataset. @@ -120,9 +120,9 @@ std::shared_ptr Album(const std::string &dataset_dir, const std::s /// \param[in] decode Decode the images after reading (default=false). /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). /// \return Shared pointer to the current Dataset -std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &usage = "all", - const std::shared_ptr &sampler = RandomSampler(), bool decode = false, - const std::set &extensions = {}); +std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &usage = "all", + const std::shared_ptr &sampler = RandomSampler(), bool decode = false, + const std::set &extensions = {}); /// \brief Function to create a Cifar10 Dataset /// \notes The generated dataset has two columns ["image", "label"] @@ -131,8 +131,8 @@ std::shared_ptr CelebA(const std::string &dataset_dir, const std: /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset -std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage = "all", - const std::shared_ptr &sampler = RandomSampler()); +std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage = "all", + const std::shared_ptr &sampler = RandomSampler()); /// \brief Function to create a Cifar100 Dataset /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"] @@ -141,10 +141,10 @@ std::shared_ptr Cifar10(const std::string &dataset_dir, const st /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset -std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage = "all", - const std::shared_ptr &sampler = RandomSampler()); +std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage = "all", + const std::shared_ptr &sampler = RandomSampler()); -/// \brief Function to create a CLUEDataset +/// \brief Function to create a CLUENode /// \notes The generated dataset has a variable number of columns depending on the task and usage /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list /// will be sorted in a lexicographical order. @@ -160,13 +160,13 @@ std::shared_ptr Cifar100(const std::string &dataset_dir, const /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) -/// \return Shared pointer to the current CLUEDataset -std::shared_ptr CLUE(const std::vector &dataset_files, const std::string &task = "AFQMC", - const std::string &usage = "train", int64_t num_samples = 0, - ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, - int32_t shard_id = 0); +/// \return Shared pointer to the current CLUENode +std::shared_ptr CLUE(const std::vector &dataset_files, const std::string &task = "AFQMC", + const std::string &usage = "train", int64_t num_samples = 0, + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, + int32_t shard_id = 0); -/// \brief Function to create a CocoDataset +/// \brief Function to create a CocoNode /// \notes The generated dataset has multi-columns : /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32], /// ['iscrowd', dtype=uint32]]. @@ -182,11 +182,11 @@ std::shared_ptr CLUE(const std::vector &dataset_files, /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset -std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, - const std::string &task = "Detection", const bool &decode = false, - const std::shared_ptr &sampler = RandomSampler()); +std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, + const std::string &task = "Detection", const bool &decode = false, + const std::shared_ptr &sampler = RandomSampler()); -/// \brief Function to create a CSVDataset +/// \brief Function to create a CSVNode /// \notes The generated dataset has a variable number of columns /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list /// will be sorted in a lexicographical order. @@ -206,13 +206,12 @@ std::shared_ptr Coco(const std::string &dataset_dir, const std::str /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) /// \return Shared pointer to the current Dataset -std::shared_ptr CSV(const std::vector &dataset_files, char field_delim = ',', - const std::vector> &column_defaults = {}, - const std::vector &column_names = {}, int64_t num_samples = 0, - ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, - int32_t shard_id = 0); +std::shared_ptr CSV(const std::vector &dataset_files, char field_delim = ',', + const std::vector> &column_defaults = {}, + const std::vector &column_names = {}, int64_t num_samples = 0, + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); -/// \brief Function to create an ImageFolderDataset +/// \brief Function to create an ImageFolderNode /// \notes A source dataset that reads images from a tree of directories /// All images within one folder have the same label /// The generated dataset has two columns ["image", "label"] @@ -222,14 +221,14 @@ std::shared_ptr CSV(const std::vector &dataset_files, c /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \param[in] extensions File extensions to be read /// \param[in] class_indexing a class name to label map -/// \return Shared pointer to the current ImageFolderDataset -std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode = false, - const std::shared_ptr &sampler = RandomSampler(), - const std::set &extensions = {}, - const std::map &class_indexing = {}); +/// \return Shared pointer to the current ImageFolderNode +std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode = false, + const std::shared_ptr &sampler = RandomSampler(), + const std::set &extensions = {}, + const std::map &class_indexing = {}); #ifndef ENABLE_ANDROID -/// \brief Function to create a ManifestDataset +/// \brief Function to create a ManifestNode /// \notes The generated dataset has two columns ["image", "label"] /// \param[in] dataset_file The dataset file to be read /// \param[in] usage Need "train", "eval" or "inference" data (default="train") @@ -238,15 +237,14 @@ std::shared_ptr ImageFolder(const std::string &dataset_dir, /// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder /// names will be sorted alphabetically and each class will be given a unique index starting from 0). /// \param[in] decode Decode the images after reading (default=false). -/// \return Shared pointer to the current ManifestDataset -std::shared_ptr Manifest(const std::string &dataset_file, const std::string &usage = "train", - const std::shared_ptr &sampler = RandomSampler(), - const std::map &class_indexing = {}, - bool decode = false); +/// \return Shared pointer to the current ManifestNode +std::shared_ptr Manifest(const std::string &dataset_file, const std::string &usage = "train", + const std::shared_ptr &sampler = RandomSampler(), + const std::map &class_indexing = {}, bool decode = false); #endif #ifndef ENABLE_ANDROID -/// \brief Function to create a MindDataDataset +/// \brief Function to create a MindDataNode /// \param[in] dataset_file File name of one component of a mindrecord source. Other files with identical source /// in the same path will be found and loaded automatically. /// \param[in] columns_list List of columns to be read (default={}) @@ -255,13 +253,13 @@ std::shared_ptr Manifest(const std::string &dataset_file, const /// supported sampler list: SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler. /// \param[in] padded_sample Samples will be appended to dataset, where keys are the same as column_list. /// \param[in] num_padded Number of padding samples. Dataset size plus num_padded should be divisible by num_shards. -/// \return Shared pointer to the current MindDataDataset -std::shared_ptr MindData(const std::string &dataset_file, - const std::vector &columns_list = {}, - const std::shared_ptr &sampler = RandomSampler(), - nlohmann::json padded_sample = nullptr, int64_t num_padded = 0); +/// \return Shared pointer to the current MindDataNode +std::shared_ptr MindData(const std::string &dataset_file, + const std::vector &columns_list = {}, + const std::shared_ptr &sampler = RandomSampler(), + nlohmann::json padded_sample = nullptr, int64_t num_padded = 0); -/// \brief Function to create a MindDataDataset +/// \brief Function to create a MindDataNode /// \param[in] dataset_files List of dataset files to be read directly. /// \param[in] columns_list List of columns to be read (default={}) /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, @@ -269,32 +267,32 @@ std::shared_ptr MindData(const std::string &dataset_file, /// supported sampler list: SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler. /// \param[in] padded_sample Samples will be appended to dataset, where keys are the same as column_list. /// \param[in] num_padded Number of padding samples. Dataset size plus num_padded should be divisible by num_shards. -/// \return Shared pointer to the current MindDataDataset -std::shared_ptr MindData(const std::vector &dataset_files, - const std::vector &columns_list = {}, - const std::shared_ptr &sampler = RandomSampler(), - nlohmann::json padded_sample = nullptr, int64_t num_padded = 0); +/// \return Shared pointer to the current MindDataNode +std::shared_ptr MindData(const std::vector &dataset_files, + const std::vector &columns_list = {}, + const std::shared_ptr &sampler = RandomSampler(), + nlohmann::json padded_sample = nullptr, int64_t num_padded = 0); #endif -/// \brief Function to create a MnistDataset +/// \brief Function to create a MnistNode /// \notes The generated dataset has two columns ["image", "label"] /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) -/// \return Shared pointer to the current MnistDataset -std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage = "all", - const std::shared_ptr &sampler = RandomSampler()); +/// \return Shared pointer to the current MnistNode +std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage = "all", + const std::shared_ptr &sampler = RandomSampler()); -/// \brief Function to create a ConcatDataset +/// \brief Function to create a ConcatNode /// \notes Reload "+" operator to concat two datasets /// \param[in] datasets1 Shared pointer to the first dataset to be concatenated /// \param[in] datasets2 Shared pointer to the second dataset to be concatenated -/// \return Shared pointer to the current ConcatDataset -std::shared_ptr operator+(const std::shared_ptr &datasets1, - const std::shared_ptr &datasets2); +/// \return Shared pointer to the current ConcatNode +std::shared_ptr operator+(const std::shared_ptr &datasets1, + const std::shared_ptr &datasets2); -/// \brief Function to create a RandomDataset +/// \brief Function to create a RandomNode /// \param[in] total_rows Number of rows for the dataset to generate (default=0, number of rows is random) /// \param[in] schema SchemaObj to set column type, data type and data shape /// \param[in] columns_list List of columns to be read (default={}, read all columns) @@ -302,43 +300,42 @@ std::shared_ptr operator+(const std::shared_ptr &dataset /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset template > -std::shared_ptr RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, - const std::vector &columns_list = {}, - const std::shared_ptr &sampler = RandomSampler()) { +std::shared_ptr RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, + const std::vector &columns_list = {}, + const std::shared_ptr &sampler = RandomSampler()) { if (total_rows < 0) { - MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows; + MS_LOG(ERROR) << "RandomNode: total_rows must be greater than or equal 0, now get " << total_rows; return nullptr; } if (sampler == nullptr) { - MS_LOG(ERROR) << "RandomDataset: Sampler is not constructed correctly, sampler: nullptr"; + MS_LOG(ERROR) << "RandomNode: Sampler is not constructed correctly, sampler: nullptr"; return nullptr; } if (!columns_list.empty()) { for (uint32_t i = 0; i < columns_list.size(); ++i) { if (columns_list[i].empty()) { - MS_LOG(ERROR) << "RandomDataset:columns_list" + MS_LOG(ERROR) << "RandomNode:columns_list" << "[" << i << "] should not be empty"; return nullptr; } } std::set columns_set(columns_list.begin(), columns_list.end()); if (columns_set.size() != columns_list.size()) { - MS_LOG(ERROR) << "RandomDataset:columns_list: Every column name should not be same with others"; + MS_LOG(ERROR) << "RandomNode:columns_list: Every column name should not be same with others"; return nullptr; } } - std::shared_ptr ds; + std::shared_ptr ds; if constexpr (std::is_same::value || std::is_same>::value) { std::shared_ptr schema_obj = schema; - ds = - std::make_shared(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler)); + ds = std::make_shared(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler)); } else { - ds = std::make_shared(total_rows, std::move(schema), std::move(columns_list), std::move(sampler)); + ds = std::make_shared(total_rows, std::move(schema), std::move(columns_list), std::move(sampler)); } return ds; } -/// \brief Function to create a TextFileDataset +/// \brief Function to create a TextFileNode /// \notes The generated dataset has one column ['text'] /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list /// will be sorted in a lexicographical order. @@ -352,13 +349,13 @@ std::shared_ptr RandomData(const int32_t &total_rows = 0, const T /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) -/// \return Shared pointer to the current TextFileDataset -std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples = 0, - ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, - int32_t shard_id = 0); +/// \return Shared pointer to the current TextFileNode +std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples = 0, + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, + int32_t shard_id = 0); #ifndef ENABLE_ANDROID -/// \brief Function to create a TFRecordDataset +/// \brief Function to create a TFRecordNode /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list /// will be sorted in a lexicographical order. /// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the @@ -379,60 +376,60 @@ std::shared_ptr TextFile(const std::vector &datase /// when num_shards is also specified. (Default = 0) /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of /// each shard may be not equal) -/// \return Shared pointer to the current TFRecordDataset +/// \return Shared pointer to the current TFRecordNode template > -std::shared_ptr TFRecord(const std::vector &dataset_files, const T &schema = nullptr, - const std::vector &columns_list = {}, int64_t num_samples = 0, - ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, - int32_t shard_id = 0, bool shard_equal_rows = false) { +std::shared_ptr TFRecord(const std::vector &dataset_files, const T &schema = nullptr, + const std::vector &columns_list = {}, int64_t num_samples = 0, + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, + int32_t shard_id = 0, bool shard_equal_rows = false) { if (dataset_files.empty()) { - MS_LOG(ERROR) << "TFRecordDataset: dataset_files is not specified."; + MS_LOG(ERROR) << "TFRecordNode: dataset_files is not specified."; return nullptr; } for (auto f : dataset_files) { Path dataset_file(f); if (!dataset_file.Exists()) { - MS_LOG(ERROR) << "TFRecordDataset: dataset file: [" << f << "] is invalid or does not exist."; + MS_LOG(ERROR) << "TFRecordNode: dataset file: [" << f << "] is invalid or does not exist."; return nullptr; } } if (num_samples < 0) { - MS_LOG(ERROR) << "TFRecordDataset: Invalid number of samples: " << num_samples; + MS_LOG(ERROR) << "TFRecordNode: Invalid number of samples: " << num_samples; return nullptr; } if (num_shards <= 0) { - MS_LOG(ERROR) << "TFRecordDataset: Invalid num_shards: " << num_shards; + MS_LOG(ERROR) << "TFRecordNode: Invalid num_shards: " << num_shards; return nullptr; } if (shard_id < 0 || shard_id >= num_shards) { - MS_LOG(ERROR) << "TFRecordDataset: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; + MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; return nullptr; } - std::shared_ptr ds = nullptr; + std::shared_ptr ds = nullptr; if constexpr (std::is_same::value || std::is_same>::value) { std::shared_ptr schema_obj = schema; - ds = std::make_shared(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards, - shard_id, shard_equal_rows); + ds = std::make_shared(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards, + shard_id, shard_equal_rows); } else { std::string schema_path = schema; if (!schema_path.empty()) { Path schema_file(schema_path); if (!schema_file.Exists()) { - MS_LOG(ERROR) << "TFRecordDataset: schema path [" << schema_path << "] is invalid or does not exist."; + MS_LOG(ERROR) << "TFRecordNode: schema path [" << schema_path << "] is invalid or does not exist."; return nullptr; } } - ds = std::make_shared(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards, - shard_id, shard_equal_rows); + ds = std::make_shared(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards, + shard_id, shard_equal_rows); } return ds; } -/// \brief Function to create a VOCDataset +/// \brief Function to create a VOCNode /// \notes The generated dataset has multi-columns : /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], /// ['difficult', dtype=uint32], ['truncate', dtype=uint32]]. @@ -445,17 +442,17 @@ std::shared_ptr TFRecord(const std::vector &datase /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset -std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", - const std::string &usage = "train", - const std::map &class_indexing = {}, bool decode = false, - const std::shared_ptr &sampler = RandomSampler()); +std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", + const std::string &usage = "train", + const std::map &class_indexing = {}, bool decode = false, + const std::shared_ptr &sampler = RandomSampler()); #endif -/// \brief Function to create a ZipDataset +/// \brief Function to create a ZipNode /// \notes Applies zip to the dataset /// \param[in] datasets List of shared pointers to the datasets that we want to zip /// \return Shared pointer to the current Dataset -std::shared_ptr Zip(const std::vector> &datasets); +std::shared_ptr Zip(const std::vector> &datasets); /// \class Dataset datasets.h /// \brief A base class to represent a dataset in the data pipeline. @@ -503,18 +500,18 @@ class Dataset : public std::enable_shared_from_this { /// \return Shared pointer to the Iterator std::shared_ptr CreateIterator(std::vector columns = {}); - /// \brief Function to create a BatchDataset + /// \brief Function to create a BatchNode /// \notes Combines batch_size number of consecutive rows into batches /// \param[in] batch_size Path to the root directory that contains the dataset /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete /// batch. If true, and if there are less than batch_size rows /// available to make the last batch, then those rows will /// be dropped and not propagated to the next node - /// \return Shared pointer to the current BatchDataset - std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); + /// \return Shared pointer to the current BatchNode + std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); #ifndef ENABLE_ANDROID - /// \brief Function to create a BucketBatchByLengthDataset + /// \brief Function to create a BucketBatchByLengthNode /// \notes Combines batch_size number of consecutive rows into batches /// \param[in] column_names Columns passed to element_length_function /// \param[in] bucket_boundaries A list consisting of the upper boundaries of the buckets. @@ -536,8 +533,8 @@ class Dataset : public std::enable_shared_from_this { /// minus 1. If there are any elements that fall into the last bucket, an error will occur (default=false). /// \param[in] drop_remainder If true, will drop the last batch for each bucket if it is not a full batch /// (default=false). - /// \return Shared pointer to the current BucketBatchByLengthDataset - std::shared_ptr BucketBatchByLength( + /// \return Shared pointer to the current BucketBatchByLengthNode + std::shared_ptr BucketBatchByLength( const std::vector &column_names, const std::vector &bucket_boundaries, const std::vector &bucket_batch_sizes, std::function element_length_function = nullptr, @@ -563,13 +560,13 @@ class Dataset : public std::enable_shared_from_this { bool special_first = true); #endif - /// \brief Function to create a ConcatDataset + /// \brief Function to create a ConcatNode /// \notes Concat the datasets in the input /// \param[in] datasets List of shared pointers to the dataset that should be concatenated together - /// \return Shared pointer to the current ConcatDataset - std::shared_ptr Concat(const std::vector> &datasets); + /// \return Shared pointer to the current ConcatNode + std::shared_ptr Concat(const std::vector> &datasets); - /// \brief Function to create a MapDataset + /// \brief Function to create a MapNode /// \notes Applies each operation in operations to this dataset /// \param[in] operations Vector of operations to be applied on the dataset. Operations are /// applied in the order they appear in this list @@ -583,47 +580,47 @@ class Dataset : public std::enable_shared_from_this { /// last operation. The default output_columns will have the same /// name as the input columns, i.e., the columns will be replaced /// \param[in] project_columns A list of column names to project - /// \return Shared pointer to the current MapDataset - std::shared_ptr Map(std::vector> operations, - std::vector input_columns = {}, - std::vector output_columns = {}, - const std::vector &project_columns = {}); + /// \return Shared pointer to the current MapNode + std::shared_ptr Map(std::vector> operations, + std::vector input_columns = {}, + std::vector output_columns = {}, + const std::vector &project_columns = {}); /// \brief Function to create a Project Dataset /// \notes Applies project to the dataset /// \param[in] columns The name of columns to project /// \return Shared pointer to the current Dataset - std::shared_ptr Project(const std::vector &columns); + std::shared_ptr Project(const std::vector &columns); /// \brief Function to create a Rename Dataset /// \notes Renames the columns in the input dataset /// \param[in] input_columns List of the input columns to rename /// \param[in] output_columns List of the output columns /// \return Shared pointer to the current Dataset - std::shared_ptr Rename(const std::vector &input_columns, - const std::vector &output_columns); + std::shared_ptr Rename(const std::vector &input_columns, + const std::vector &output_columns); - /// \brief Function to create a RepeatDataset + /// \brief Function to create a RepeatNode /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 /// \param[in] count Number of times the dataset should be repeated /// \return Shared pointer to the current Dataset - /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset` + /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatNode` /// due to a limitation in the current implementation std::shared_ptr Repeat(int32_t count = -1); /// \brief Function to create a Shuffle Dataset /// \notes Randomly shuffles the rows of this dataset /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling - /// \return Shared pointer to the current ShuffleDataset - std::shared_ptr Shuffle(int32_t buffer_size); + /// \return Shared pointer to the current ShuffleNode + std::shared_ptr Shuffle(int32_t buffer_size); - /// \brief Function to create a SkipDataset + /// \brief Function to create a SkipNode /// \notes Skips count elements in this dataset. /// \param[in] count Number of elements the dataset to be skipped. - /// \return Shared pointer to the current SkipDataset - std::shared_ptr Skip(int32_t count); + /// \return Shared pointer to the current SkipNode + std::shared_ptr Skip(int32_t count); - /// \brief Function to create a TakeDataset + /// \brief Function to create a TakeNode /// \notes Takes count elements in this dataset. /// \param[in] count Number of elements the dataset to be taken. /// \return Shared pointer to the current Dataset @@ -633,7 +630,7 @@ class Dataset : public std::enable_shared_from_this { /// \notes Applies zip to the dataset /// \param[in] datasets A list of shared pointers to the datasets that we want to zip /// \return Shared pointer to the current Dataset - std::shared_ptr Zip(const std::vector> &datasets); + std::shared_ptr Zip(const std::vector> &datasets); protected: std::vector> children; @@ -710,14 +707,14 @@ class SchemaObj { // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS // (In alphabetical order) -class AlbumDataset : public Dataset { +class AlbumNode : public Dataset { public: /// \brief Constructor - AlbumDataset(const std::string &dataset_dir, const std::string &data_schema, - const std::vector &column_names, bool decode, const std::shared_ptr &sampler); + AlbumNode(const std::string &dataset_dir, const std::string &data_schema, + const std::vector &column_names, bool decode, const std::shared_ptr &sampler); /// \brief Destructor - ~AlbumDataset() = default; + ~AlbumNode() = default; /// \brief a base class override function to create a runtime dataset op object from this class /// \return shared pointer to the newly created DatasetOp @@ -735,14 +732,14 @@ class AlbumDataset : public Dataset { std::shared_ptr sampler_; }; -class CelebADataset : public Dataset { +class CelebANode : public Dataset { public: /// \brief Constructor - CelebADataset(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, - const bool &decode, const std::set &extensions); + CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, + const bool &decode, const std::set &extensions); /// \brief Destructor - ~CelebADataset() = default; + ~CelebANode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps @@ -762,13 +759,13 @@ class CelebADataset : public Dataset { // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS // (In alphabetical order) -class Cifar10Dataset : public Dataset { +class Cifar10Node : public Dataset { public: /// \brief Constructor - Cifar10Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler); + Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler); /// \brief Destructor - ~Cifar10Dataset() = default; + ~Cifar10Node() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -784,13 +781,13 @@ class Cifar10Dataset : public Dataset { std::shared_ptr sampler_; }; -class Cifar100Dataset : public Dataset { +class Cifar100Node : public Dataset { public: /// \brief Constructor - Cifar100Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler); + Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler); /// \brief Destructor - ~Cifar100Dataset() = default; + ~Cifar100Node() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -806,16 +803,16 @@ class Cifar100Dataset : public Dataset { std::shared_ptr sampler_; }; -/// \class CLUEDataset +/// \class CLUENode /// \brief A Dataset derived class to represent CLUE dataset -class CLUEDataset : public Dataset { +class CLUENode : public Dataset { public: /// \brief Constructor - CLUEDataset(const std::vector dataset_files, std::string task, std::string usage, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); + CLUENode(const std::vector dataset_files, std::string task, std::string usage, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); /// \brief Destructor - ~CLUEDataset() = default; + ~CLUENode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -839,14 +836,14 @@ class CLUEDataset : public Dataset { int32_t shard_id_; }; -class CocoDataset : public Dataset { +class CocoNode : public Dataset { public: /// \brief Constructor - CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, - const bool &decode, const std::shared_ptr &sampler); + CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, + const bool &decode, const std::shared_ptr &sampler); /// \brief Destructor - ~CocoDataset() = default; + ~CocoNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps @@ -886,15 +883,15 @@ class CsvRecord : public CsvBase { T value; }; -class CSVDataset : public Dataset { +class CSVNode : public Dataset { public: /// \brief Constructor - CSVDataset(const std::vector &dataset_files, char field_delim, - const std::vector> &column_defaults, const std::vector &column_names, - int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); + CSVNode(const std::vector &dataset_files, char field_delim, + const std::vector> &column_defaults, const std::vector &column_names, + int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); /// \brief Destructor - ~CSVDataset() = default; + ~CSVNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps @@ -915,16 +912,16 @@ class CSVDataset : public Dataset { int32_t shard_id_; }; -/// \class ImageFolderDataset +/// \class ImageFolderNode /// \brief A Dataset derived class to represent ImageFolder dataset -class ImageFolderDataset : public Dataset { +class ImageFolderNode : public Dataset { public: /// \brief Constructor - ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, - std::set extensions, std::map class_indexing); + ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, + std::set extensions, std::map class_indexing); /// \brief Destructor - ~ImageFolderDataset() = default; + ~ImageFolderNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -944,14 +941,14 @@ class ImageFolderDataset : public Dataset { }; #ifndef ENABLE_ANDROID -class ManifestDataset : public Dataset { +class ManifestNode : public Dataset { public: /// \brief Constructor - ManifestDataset(const std::string &dataset_file, const std::string &usage, const std::shared_ptr &sampler, - const std::map &class_indexing, bool decode); + ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr &sampler, + const std::map &class_indexing, bool decode); /// \brief Destructor - ~ManifestDataset() = default; + ~ManifestNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -971,18 +968,18 @@ class ManifestDataset : public Dataset { #endif #ifndef ENABLE_ANDROID -class MindDataDataset : public Dataset { +class MindDataNode : public Dataset { public: /// \brief Constructor - MindDataDataset(const std::vector &dataset_files, const std::vector &columns_list, - const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded); + MindDataNode(const std::vector &dataset_files, const std::vector &columns_list, + const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded); /// \brief Constructor - MindDataDataset(const std::string &dataset_file, const std::vector &columns_list, - const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded); + MindDataNode(const std::string &dataset_file, const std::vector &columns_list, + const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded); /// \brief Destructor - ~MindDataDataset() = default; + ~MindDataNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -999,7 +996,7 @@ class MindDataDataset : public Dataset { int64_t num_padded); /// \brief Set sample_bytes when padded_sample has py::byte value - /// \note Pybind will use this function to set sample_bytes into MindDataDataset + /// \note Pybind will use this function to set sample_bytes into MindDataNode void SetSampleBytes(std::map *sample_bytes); private: @@ -1014,13 +1011,13 @@ class MindDataDataset : public Dataset { }; #endif -class MnistDataset : public Dataset { +class MnistNode : public Dataset { public: /// \brief Constructor - MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr sampler); + MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler); /// \brief Destructor - ~MnistDataset() = default; + ~MnistNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1036,7 +1033,7 @@ class MnistDataset : public Dataset { std::shared_ptr sampler_; }; -class RandomDataset : public Dataset { +class RandomNode : public Dataset { public: // Some constants to provide limits to random generation. static constexpr int32_t kMaxNumColumns = 4; @@ -1044,8 +1041,8 @@ class RandomDataset : public Dataset { static constexpr int32_t kMaxDimValue = 32; /// \brief Constructor - RandomDataset(const int32_t &total_rows, std::shared_ptr schema, - const std::vector &columns_list, const std::shared_ptr &sampler) + RandomNode(const int32_t &total_rows, std::shared_ptr schema, const std::vector &columns_list, + const std::shared_ptr &sampler) : total_rows_(total_rows), schema_path_(""), schema_(std::move(schema)), @@ -1053,12 +1050,12 @@ class RandomDataset : public Dataset { sampler_(std::move(sampler)) {} /// \brief Constructor - RandomDataset(const int32_t &total_rows, std::string schema_path, const std::vector &columns_list, - const std::shared_ptr &sampler) + RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector &columns_list, + const std::shared_ptr &sampler) : total_rows_(total_rows), schema_path_(schema_path), columns_list_(columns_list), sampler_(std::move(sampler)) {} /// \brief Destructor - ~RandomDataset() = default; + ~RandomNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1083,16 +1080,16 @@ class RandomDataset : public Dataset { std::mt19937 rand_gen_; }; -/// \class TextFileDataset +/// \class TextFileNode /// \brief A Dataset derived class to represent TextFile dataset -class TextFileDataset : public Dataset { +class TextFileNode : public Dataset { public: /// \brief Constructor - TextFileDataset(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, - int32_t shard_id); + TextFileNode(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, + int32_t shard_id); /// \brief Destructor - ~TextFileDataset() = default; + ~TextFileNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1110,15 +1107,15 @@ class TextFileDataset : public Dataset { ShuffleMode shuffle_; }; -/// \class TFRecordDataset +/// \class TFRecordNode /// \brief A Dataset derived class to represent TFRecord dataset -class TFRecordDataset : public Dataset { +class TFRecordNode : public Dataset { public: /// \brief Constructor /// \note Parameter 'schema' is the path to the schema file - TFRecordDataset(const std::vector &dataset_files, std::string schema, - const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, bool shard_equal_rows) + TFRecordNode(const std::vector &dataset_files, std::string schema, + const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, bool shard_equal_rows) : dataset_files_(dataset_files), schema_path_(schema), columns_list_(columns_list), @@ -1130,9 +1127,9 @@ class TFRecordDataset : public Dataset { /// \brief Constructor /// \note Parameter 'schema' is shared pointer to Schema object - TFRecordDataset(const std::vector &dataset_files, std::shared_ptr schema, - const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, bool shard_equal_rows) + TFRecordNode(const std::vector &dataset_files, std::shared_ptr schema, + const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, bool shard_equal_rows) : dataset_files_(dataset_files), schema_obj_(schema), columns_list_(columns_list), @@ -1143,7 +1140,7 @@ class TFRecordDataset : public Dataset { shard_equal_rows_(shard_equal_rows) {} /// \brief Destructor - ~TFRecordDataset() = default; + ~TFRecordNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1166,14 +1163,14 @@ class TFRecordDataset : public Dataset { }; #ifndef ENABLE_ANDROID -class VOCDataset : public Dataset { +class VOCNode : public Dataset { public: /// \brief Constructor - VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage, - const std::map &class_indexing, bool decode, std::shared_ptr sampler); + VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, + const std::map &class_indexing, bool decode, std::shared_ptr sampler); /// \brief Destructor - ~VOCDataset() = default; + ~VOCNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps @@ -1202,14 +1199,15 @@ class VOCDataset : public Dataset { // DERIVED DATASET CLASSES FOR DATASET OPS // (In alphabetical order) -class BatchDataset : public Dataset { +class BatchNode : public Dataset { public: /// \brief Constructor - BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, - std::map>> pad_map); + BatchNode(std::shared_ptr child, int32_t batch_size, bool drop_remainder, bool pad, + std::vector cols_to_map, + std::map>> pad_map); /// \brief Destructor - ~BatchDataset() = default; + ~BatchNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1228,18 +1226,17 @@ class BatchDataset : public Dataset { }; #ifndef ENABLE_ANDROID -class BucketBatchByLengthDataset : public Dataset { +class BucketBatchByLengthNode : public Dataset { public: /// \brief Constructor - BucketBatchByLengthDataset( - const std::vector &column_names, const std::vector &bucket_boundaries, - const std::vector &bucket_batch_sizes, - std::function element_length_function = nullptr, - const std::map>> &pad_info = {}, - bool pad_to_bucket_boundary = false, bool drop_remainder = false); + BucketBatchByLengthNode(std::shared_ptr child, const std::vector &column_names, + const std::vector &bucket_boundaries, const std::vector &bucket_batch_sizes, + std::function element_length_function = nullptr, + const std::map>> &pad_info = {}, + bool pad_to_bucket_boundary = false, bool drop_remainder = false); /// \brief Destructor - ~BucketBatchByLengthDataset() = default; + ~BucketBatchByLengthNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1259,15 +1256,15 @@ class BucketBatchByLengthDataset : public Dataset { bool drop_remainder_; }; -class BuildVocabDataset : public Dataset { +class BuildVocabNode : public Dataset { public: /// \brief Constructor - BuildVocabDataset(std::shared_ptr vocab, const std::vector &columns, - const std::pair &freq_range, int64_t top_k, - const std::vector &special_tokens, bool special_first); + BuildVocabNode(std::shared_ptr child, std::shared_ptr vocab, const std::vector &columns, + const std::pair &freq_range, int64_t top_k, + const std::vector &special_tokens, bool special_first); /// \brief Destructor - ~BuildVocabDataset() = default; + ~BuildVocabNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1287,13 +1284,13 @@ class BuildVocabDataset : public Dataset { }; #endif -class ConcatDataset : public Dataset { +class ConcatNode : public Dataset { public: /// \brief Constructor - explicit ConcatDataset(const std::vector> &datasets); + explicit ConcatNode(const std::vector> &datasets); /// \brief Destructor - ~ConcatDataset() = default; + ~ConcatNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1307,14 +1304,15 @@ class ConcatDataset : public Dataset { std::vector> datasets_; }; -class MapDataset : public Dataset { +class MapNode : public Dataset { public: /// \brief Constructor - MapDataset(std::vector> operations, std::vector input_columns = {}, - std::vector output_columns = {}, const std::vector &columns = {}); + MapNode(std::shared_ptr child, std::vector> operations, + std::vector input_columns = {}, std::vector output_columns = {}, + const std::vector &columns = {}); /// \brief Destructor - ~MapDataset() = default; + ~MapNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1331,13 +1329,13 @@ class MapDataset : public Dataset { std::vector project_columns_; }; -class ProjectDataset : public Dataset { +class ProjectNode : public Dataset { public: /// \brief Constructor - explicit ProjectDataset(const std::vector &columns); + explicit ProjectNode(std::shared_ptr child, const std::vector &columns); /// \brief Destructor - ~ProjectDataset() = default; + ~ProjectNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1351,13 +1349,14 @@ class ProjectDataset : public Dataset { std::vector columns_; }; -class RenameDataset : public Dataset { +class RenameNode : public Dataset { public: /// \brief Constructor - explicit RenameDataset(const std::vector &input_columns, const std::vector &output_columns); + explicit RenameNode(std::shared_ptr child, const std::vector &input_columns, + const std::vector &output_columns); /// \brief Destructor - ~RenameDataset() = default; + ~RenameNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1372,13 +1371,13 @@ class RenameDataset : public Dataset { std::vector output_columns_; }; -class RepeatDataset : public Dataset { +class RepeatNode : public Dataset { public: /// \brief Constructor - explicit RepeatDataset(int32_t count); + explicit RepeatNode(std::shared_ptr child, int32_t count); /// \brief Destructor - ~RepeatDataset() = default; + ~RepeatNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1392,11 +1391,11 @@ class RepeatDataset : public Dataset { int32_t repeat_count_; }; -class ShuffleDataset : public Dataset { +class ShuffleNode : public Dataset { public: - ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch); + ShuffleNode(std::shared_ptr child, int32_t shuffle_size, bool reset_every_epoch); - ~ShuffleDataset() = default; + ~ShuffleNode() = default; std::vector> Build() override; @@ -1408,13 +1407,13 @@ class ShuffleDataset : public Dataset { bool reset_every_epoch_; }; -class SkipDataset : public Dataset { +class SkipNode : public Dataset { public: /// \brief Constructor - explicit SkipDataset(int32_t count); + explicit SkipNode(std::shared_ptr child, int32_t count); /// \brief Destructor - ~SkipDataset() = default; + ~SkipNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps @@ -1428,13 +1427,13 @@ class SkipDataset : public Dataset { int32_t skip_count_; }; -class TakeDataset : public Dataset { +class TakeNode : public Dataset { public: /// \brief Constructor - explicit TakeDataset(int32_t count); + explicit TakeNode(std::shared_ptr child, int32_t count); /// \brief Destructor - ~TakeDataset() = default; + ~TakeNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps @@ -1448,13 +1447,13 @@ class TakeDataset : public Dataset { int32_t take_count_; }; -class ZipDataset : public Dataset { +class ZipNode : public Dataset { public: /// \brief Constructor - explicit ZipDataset(const std::vector> &datasets); + explicit ZipNode(const std::vector> &datasets); /// \brief Destructor - ~ZipDataset() = default; + ~ZipNode() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps