/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ #include #include #include #include #include #include #include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/include/tensor.h" #include "minddata/dataset/include/iterator.h" #include "minddata/dataset/include/samplers.h" namespace mindspore { namespace dataset { // Forward declare class DatasetOp; class DataSchema; class Tensor; class TensorShape; namespace api { class TensorOperation; class SamplerObj; // Datasets classes (in alphabetical order) class CelebADataset; class Cifar10Dataset; class Cifar100Dataset; class CLUEDataset; class CocoDataset; class ImageFolderDataset; class MnistDataset; class TextFileDataset; class VOCDataset; // Dataset Op classes (in alphabetical order) class BatchDataset; class ConcatDataset; class MapDataset; class ProjectDataset; class RenameDataset; class RepeatDataset; class ShuffleDataset; class SkipDataset; class TakeDataset; class ZipDataset; /// \brief Function to create a CelebADataset /// \notes The generated dataset has two columns ['image', 'attr']. // The type of the image tensor is uint8. The attr tensor is uint32 and one hot type. /// \param[in] dataset_dir Path to the root directory that contains the dataset. /// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'. /// \param[in] decode Decode the images after reading (default=False). /// \param[in] extensions List of file extensions to be included in the dataset (default=None). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all", const std::shared_ptr &sampler = nullptr, bool decode = false, const std::set &extensions = {}); /// \brief Function to create a Cifar10 Dataset /// \notes The generated dataset has two columns ['image', 'label'] /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset std::shared_ptr Cifar10(const std::string &dataset_dir, const std::shared_ptr &sampler = nullptr); /// \brief Function to create a Cifar100 Dataset /// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label'] /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset std::shared_ptr Cifar100(const std::string &dataset_dir, const std::shared_ptr &sampler = nullptr); /// \brief Function to create a CLUEDataset /// \notes The generated dataset has a variable number of columns depending on the task and usage /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list /// will be sorted in a lexicographical order. /// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC"). /// \param[in] usage Be used to "train", "test" or "eval" data (default="train"). /// \param[in] num_samples The number of samples to be included in the dataset. /// (Default = 0 means all samples.) /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal) /// Can be any of: /// ShuffleMode.kFalse - No shuffling is performed. /// ShuffleMode.kFiles - Shuffle files only. /// ShuffleMode.kGlobal - Shuffle both the files and samples. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) /// \return Shared pointer to the current CLUEDataset std::shared_ptr CLUE(const std::vector &dataset_files, const std::string &task = "AFQMC", const std::string &usage = "train", int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); /// \brief Function to create a CocoDataset /// \notes The generated dataset has multi-columns : /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32], /// ['iscrowd', dtype=uint32]]. /// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]]. /// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32], /// ['num_keypoints', dtype=uint32]]. /// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32], /// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]]. /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] annotation_file Path to the annotation json /// \param[in] task Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint' /// \param[in] decode Decode the images after reading /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task = "Detection", const bool &decode = false, const std::shared_ptr &sampler = nullptr); /// \brief Function to create an ImageFolderDataset /// \notes A source dataset that reads images from a tree of directories /// All images within one folder have the same label /// The generated dataset has two columns ['image', 'label'] /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] decode A flag to decode in ImageFolder /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, /// A `RandomSampler` will be used to randomly iterate the entire dataset /// \param[in] extensions File extensions to be read /// \param[in] class_indexing a class name to label map /// \return Shared pointer to the current ImageFolderDataset std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode = false, const std::shared_ptr &sampler = nullptr, const std::set &extensions = {}, const std::map &class_indexing = {}); /// \brief Function to create a MnistDataset /// \notes The generated dataset has two columns ['image', 'label'] /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, /// A `RandomSampler` will be used to randomly iterate the entire dataset /// \return Shared pointer to the current MnistDataset std::shared_ptr Mnist(const std::string &dataset_dir, const std::shared_ptr &sampler = nullptr); /// \brief Function to create a ConcatDataset /// \notes Reload "+" operator to concat two datasets /// \param[in] datasets1 Shared pointer to the first dataset to be concatenated /// \param[in] datasets2 Shared pointer to the second dataset to be concatenated /// \return Shared pointer to the current ConcatDataset std::shared_ptr operator+(const std::shared_ptr &datasets1, const std::shared_ptr &datasets2); /// \brief Function to create a TextFileDataset /// \notes The generated dataset has one column ['text'] /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list /// will be sorted in a lexicographical order. /// \param[in] num_samples The number of samples to be included in the dataset. /// (Default = 0 means all samples.) /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal) /// Can be any of: /// ShuffleMode.kFalse - No shuffling is performed. /// ShuffleMode.kFiles - Shuffle files only. /// ShuffleMode.kGlobal - Shuffle both the files and samples. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) /// \return Shared pointer to the current TextFileDataset std::shared_ptr TextFile(const std::vector &dataset_files, int32_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); /// \brief Function to create a VOCDataset /// \notes The generated dataset has multi-columns : /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], /// ['difficult', dtype=uint32], ['truncate', dtype=uint32]]. /// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]]. /// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection" /// \param[in] mode Set the data list txt file to be readed /// \param[in] class_indexing A str-to-int mapping from label name to index /// \param[in] decode Decode the images after reading /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// will be used to randomly iterate the entire dataset /// \return Shared pointer to the current Dataset std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", const std::string &mode = "train", const std::map &class_indexing = {}, bool decode = false, const std::shared_ptr &sampler = nullptr); /// \brief Function to create a ZipDataset /// \notes Applies zip to the dataset /// \param[in] datasets List of shared pointers to the datasets that we want to zip /// \return Shared pointer to the current Dataset std::shared_ptr Zip(const std::vector> &datasets); /// \class Dataset datasets.h /// \brief A base class to represent a dataset in the data pipeline. class Dataset : public std::enable_shared_from_this { public: friend class Iterator; /// \brief Constructor Dataset(); /// \brief Destructor ~Dataset() = default; /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object /// \return The list of shared pointers to the newly created DatasetOps virtual std::vector> Build() = 0; /// \brief Pure virtual function for derived class to implement parameters validation /// \return bool True if all the params are valid virtual bool ValidateParams() = 0; /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object std::shared_ptr SetNumWorkers(int32_t num_workers) { num_workers_ = num_workers; return shared_from_this(); } /// \brief Function to create an Iterator over the Dataset pipeline /// \param[in] columns List of columns to be used to specify the order of columns /// \return Shared pointer to the Iterator std::shared_ptr CreateIterator(std::vector columns = {}); /// \brief Function to create a BatchDataset /// \notes Combines batch_size number of consecutive rows into batches /// \param[in] batch_size Path to the root directory that contains the dataset /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete /// batch. If true, and if there are less than batch_size rows /// available to make the last batch, then those rows will /// be dropped and not propagated to the next node /// \return Shared pointer to the current BatchDataset std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); /// \brief Function to create a ConcatDataset /// \notes Concat the datasets in the input /// \param[in] datasets List of shared pointers to the dataset that should be concatenated together /// \return Shared pointer to the current ConcatDataset std::shared_ptr Concat(const std::vector> &datasets); /// \brief Function to create a MapDataset /// \notes Applies each operation in operations to this dataset /// \param[in] operations Vector of operations to be applied on the dataset. Operations are /// applied in the order they appear in this list /// \param[in] input_columns Vector of the names of the columns that will be passed to the first /// operation as input. The size of this list must match the number of /// input columns expected by the first operator. The default input_columns /// is the first column /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation /// This parameter is mandatory if len(input_columns) != len(output_columns) /// The size of this list must match the number of output columns of the /// last operation. The default output_columns will have the same /// name as the input columns, i.e., the columns will be replaced /// \param[in] project_columns A list of column names to project /// \return Shared pointer to the current MapDataset std::shared_ptr Map(std::vector> operations, std::vector input_columns = {}, std::vector output_columns = {}, const std::vector &project_columns = {}); /// \brief Function to create a Project Dataset /// \notes Applies project to the dataset /// \param[in] columns The name of columns to project /// \return Shared pointer to the current Dataset std::shared_ptr Project(const std::vector &columns); /// \brief Function to create a Rename Dataset /// \notes Renames the columns in the input dataset /// \param[in] input_columns List of the input columns to rename /// \param[in] output_columns List of the output columns /// \return Shared pointer to the current Dataset std::shared_ptr Rename(const std::vector &input_columns, const std::vector &output_columns); /// \brief Function to create a RepeatDataset /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 /// \param[in] count Number of times the dataset should be repeated /// \return Shared pointer to the current Dataset /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset` /// due to a limitation in the current implementation std::shared_ptr Repeat(int32_t count = -1); /// \brief Function to create a Shuffle Dataset /// \notes Randomly shuffles the rows of this dataset /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling /// \return Shared pointer to the current ShuffleDataset std::shared_ptr Shuffle(int32_t buffer_size); /// \brief Function to create a SkipDataset /// \notes Skips count elements in this dataset. /// \param[in] count Number of elements the dataset to be skipped. /// \return Shared pointer to the current SkipDataset std::shared_ptr Skip(int32_t count); /// \brief Function to create a TakeDataset /// \notes Takes count elements in this dataset. /// \param[in] count Number of elements the dataset to be taken. /// \return Shared pointer to the current Dataset std::shared_ptr Take(int32_t count = -1); /// \brief Function to create a Zip Dataset /// \notes Applies zip to the dataset /// \param[in] datasets A list of shared pointers to the datasets that we want to zip /// \return Shared pointer to the current Dataset std::shared_ptr Zip(const std::vector> &datasets); protected: std::vector> children; std::shared_ptr parent; int32_t num_workers_; int32_t rows_per_buffer_; int32_t connector_que_size_; int32_t worker_connector_size_; }; /* ####################################### Derived Dataset classes ################################# */ // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS // (In alphabetical order) class CelebADataset : public Dataset { public: /// \brief Constructor CelebADataset(const std::string &dataset_dir, const std::string &dataset_type, const std::shared_ptr &sampler, const bool &decode, const std::set &extensions); /// \brief Destructor ~CelebADataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::string dataset_dir_; std::string dataset_type_; bool decode_; std::set extensions_; std::shared_ptr sampler_; }; // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS // (In alphabetical order) class Cifar10Dataset : public Dataset { public: /// \brief Constructor Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr sampler); /// \brief Destructor ~Cifar10Dataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::string dataset_dir_; std::shared_ptr sampler_; }; class Cifar100Dataset : public Dataset { public: /// \brief Constructor Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr sampler); /// \brief Destructor ~Cifar100Dataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::string dataset_dir_; std::shared_ptr sampler_; }; /// \class CLUEDataset /// \brief A Dataset derived class to represent CLUE dataset class CLUEDataset : public Dataset { public: /// \brief Constructor CLUEDataset(const std::vector dataset_files, std::string task, std::string usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); /// \brief Destructor ~CLUEDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: /// \brief Split string based on a character delimiter /// \return A string vector std::vector split(const std::string &s, char delim); std::vector dataset_files_; std::string task_; std::string usage_; int64_t num_samples_; ShuffleMode shuffle_; int32_t num_shards_; int32_t shard_id_; }; class CocoDataset : public Dataset { public: /// \brief Constructor CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, const bool &decode, const std::shared_ptr &sampler); /// \brief Destructor ~CocoDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::string dataset_dir_; std::string annotation_file_; std::string task_; bool decode_; std::shared_ptr sampler_; }; /// \class ImageFolderDataset /// \brief A Dataset derived class to represent ImageFolder dataset class ImageFolderDataset : public Dataset { public: /// \brief Constructor ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, std::set extensions, std::map class_indexing); /// \brief Destructor ~ImageFolderDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::string dataset_dir_; bool decode_; bool recursive_; std::shared_ptr sampler_; std::map class_indexing_; std::set exts_; }; class MnistDataset : public Dataset { public: /// \brief Constructor MnistDataset(std::string dataset_dir, std::shared_ptr sampler); /// \brief Destructor ~MnistDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::string dataset_dir_; std::shared_ptr sampler_; }; /// \class TextFileDataset /// \brief A Dataset derived class to represent TextFile dataset class TextFileDataset : public Dataset { public: /// \brief Constructor TextFileDataset(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); /// \brief Destructor ~TextFileDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::vector dataset_files_; int32_t num_samples_; int32_t num_shards_; int32_t shard_id_; ShuffleMode shuffle_; }; class VOCDataset : public Dataset { public: /// \brief Constructor VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, const std::map &class_indexing, bool decode, std::shared_ptr sampler); /// \brief Destructor ~VOCDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: const std::string kColumnImage = "image"; const std::string kColumnTarget = "target"; const std::string kColumnBbox = "bbox"; const std::string kColumnLabel = "label"; const std::string kColumnDifficult = "difficult"; const std::string kColumnTruncate = "truncate"; std::string dataset_dir_; std::string task_; std::string mode_; std::map class_index_; bool decode_; std::shared_ptr sampler_; }; // DERIVED DATASET CLASSES FOR DATASET OPS // (In alphabetical order) class BatchDataset : public Dataset { public: /// \brief Constructor BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, std::map>> pad_map); /// \brief Destructor ~BatchDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: int32_t batch_size_; bool drop_remainder_; bool pad_; std::vector cols_to_map_; std::map>> pad_map_; }; class ConcatDataset : public Dataset { public: /// \brief Constructor explicit ConcatDataset(const std::vector> &datasets); /// \brief Destructor ~ConcatDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::vector> datasets_; }; class MapDataset : public Dataset { public: /// \brief Constructor MapDataset(std::vector> operations, std::vector input_columns = {}, std::vector output_columns = {}, const std::vector &columns = {}); /// \brief Destructor ~MapDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::vector> operations_; std::vector input_columns_; std::vector output_columns_; std::vector project_columns_; }; class ProjectDataset : public Dataset { public: /// \brief Constructor explicit ProjectDataset(const std::vector &columns); /// \brief Destructor ~ProjectDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::vector columns_; }; class RenameDataset : public Dataset { public: /// \brief Constructor explicit RenameDataset(const std::vector &input_columns, const std::vector &output_columns); /// \brief Destructor ~RenameDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::vector input_columns_; std::vector output_columns_; }; class RepeatDataset : public Dataset { public: /// \brief Constructor explicit RepeatDataset(int32_t count); /// \brief Destructor ~RepeatDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: int32_t repeat_count_; }; class ShuffleDataset : public Dataset { public: ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch); ~ShuffleDataset() = default; std::vector> Build() override; bool ValidateParams() override; private: int32_t shuffle_size_; uint32_t shuffle_seed_; bool reset_every_epoch_; }; class SkipDataset : public Dataset { public: /// \brief Constructor explicit SkipDataset(int32_t count); /// \brief Destructor ~SkipDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: int32_t skip_count_; }; class TakeDataset : public Dataset { public: /// \brief Constructor explicit TakeDataset(int32_t count); /// \brief Destructor ~TakeDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return shared pointer to the list of newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: int32_t take_count_; }; class ZipDataset : public Dataset { public: /// \brief Constructor explicit ZipDataset(const std::vector> &datasets); /// \brief Destructor ~ZipDataset() = default; /// \brief a base class override function to create the required runtime dataset op objects for this class /// \return The list of shared pointers to the newly created DatasetOps std::vector> Build() override; /// \brief Parameters validation /// \return bool true if all the params are valid bool ValidateParams() override; private: std::vector> datasets_; }; } // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_