You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.h 49 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  18. #include <unistd.h>
  19. #include <vector>
  20. #include <memory>
  21. #include <set>
  22. #include <map>
  23. #include <utility>
  24. #include <string>
  25. #include "minddata/dataset/core/constants.h"
  26. #include "minddata/dataset/engine/data_schema.h"
  27. #include "minddata/dataset/include/tensor.h"
  28. #include "minddata/dataset/include/iterator.h"
  29. #include "minddata/dataset/include/samplers.h"
  30. #include "minddata/dataset/include/type_id.h"
  31. #include "minddata/dataset/text/vocab.h"
  32. namespace mindspore {
  33. namespace dataset {
  34. // Forward declare
  35. class DatasetOp;
  36. class DataSchema;
  37. class Tensor;
  38. class TensorShape;
  39. class Vocab;
  40. namespace api {
  41. class TensorOperation;
  42. class SchemaObj;
  43. class SamplerObj;
  44. // Datasets classes (in alphabetical order)
  45. class AlbumDataset;
  46. class CelebADataset;
  47. class Cifar10Dataset;
  48. class Cifar100Dataset;
  49. class CLUEDataset;
  50. class CocoDataset;
  51. class CSVDataset;
  52. class CsvBase;
  53. class ImageFolderDataset;
  54. class ManifestDataset;
  55. class MnistDataset;
  56. class RandomDataset;
  57. class TextFileDataset;
  58. class VOCDataset;
  59. // Dataset Op classes (in alphabetical order)
  60. class BatchDataset;
  61. class BuildVocabDataset;
  62. class ConcatDataset;
  63. class MapDataset;
  64. class ProjectDataset;
  65. class RenameDataset;
  66. class RepeatDataset;
  67. class ShuffleDataset;
  68. class SkipDataset;
  69. class TakeDataset;
  70. class ZipDataset;
  71. /// \brief Function to create a SchemaObj
  72. /// \param[in] schema_file Path of schema file
  73. /// \return Shared pointer to the current schema
  74. std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "");
  75. /// \brief Function to create an AlbumDataset
  76. /// \notes The generated dataset is specified through setting a schema
  77. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  78. /// \param[in] data_schema Path to dataset schema file
  79. /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
  80. /// (default = {})
  81. /// \param[in] decode the option to decode the images in dataset (default = false)
  82. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
  83. /// A `RandomSampler` will be used to randomly iterate the entire dataset (default = nullptr)
  84. /// \return Shared pointer to the current Dataset
  85. std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  86. const std::vector<std::string> &column_names = {}, bool decode = false,
  87. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  88. /// \brief Function to create a CelebADataset
  89. /// \notes The generated dataset has two columns ['image', 'attr'].
  90. // The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
  91. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  92. /// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'.
  93. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  94. /// will be used to randomly iterate the entire dataset
  95. /// \param[in] decode Decode the images after reading (default=false).
  96. /// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
  97. /// \return Shared pointer to the current Dataset
  98. std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
  99. const std::shared_ptr<SamplerObj> &sampler = nullptr, bool decode = false,
  100. const std::set<std::string> &extensions = {});
  101. /// \brief Function to create a Cifar10 Dataset
  102. /// \notes The generated dataset has two columns ['image', 'label']
  103. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  104. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  105. /// will be used to randomly iterate the entire dataset
  106. /// \return Shared pointer to the current Dataset
  107. std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
  108. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  109. /// \brief Function to create a Cifar100 Dataset
  110. /// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
  111. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  112. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  113. /// will be used to randomly iterate the entire dataset
  114. /// \return Shared pointer to the current Dataset
  115. std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
  116. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  117. /// \brief Function to create a CLUEDataset
  118. /// \notes The generated dataset has a variable number of columns depending on the task and usage
  119. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  120. /// will be sorted in a lexicographical order.
  121. /// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC").
  122. /// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
  123. /// \param[in] num_samples The number of samples to be included in the dataset.
  124. /// (Default = 0 means all samples.)
  125. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
  126. /// Can be any of:
  127. /// ShuffleMode.kFalse - No shuffling is performed.
  128. /// ShuffleMode.kFiles - Shuffle files only.
  129. /// ShuffleMode.kGlobal - Shuffle both the files and samples.
  130. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  131. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  132. /// specified only when num_shards is also specified. (Default = 0)
  133. /// \return Shared pointer to the current CLUEDataset
  134. std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
  135. const std::string &usage = "train", int64_t num_samples = 0,
  136. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
  137. int32_t shard_id = 0);
  138. /// \brief Function to create a CocoDataset
  139. /// \notes The generated dataset has multi-columns :
  140. /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  141. /// ['iscrowd', dtype=uint32]].
  142. /// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]].
  143. /// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
  144. /// ['num_keypoints', dtype=uint32]].
  145. /// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  146. /// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]].
  147. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  148. /// \param[in] annotation_file Path to the annotation json
  149. /// \param[in] task Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint'
  150. /// \param[in] decode Decode the images after reading
  151. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  152. /// will be used to randomly iterate the entire dataset
  153. /// \return Shared pointer to the current Dataset
  154. std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
  155. const std::string &task = "Detection", const bool &decode = false,
  156. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  157. /// \brief Function to create a CSVDataset
  158. /// \notes The generated dataset has a variable number of columns
  159. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  160. /// will be sorted in a lexicographical order.
  161. /// \param[in] field_delim A char that indicates the delimiter to separate fields (default=',').
  162. /// \param[in] column_defaults List of default values for the CSV field (default={}). Each item in the list is
  163. /// either a valid type (float, int, or string). If this is not provided, treats all columns as string type.
  164. /// \param[in] column_names List of column names of the dataset (default={}). If this is not provided, infers the
  165. /// column_names from the first row of CSV file.
  166. /// \param[in] num_samples The number of samples to be included in the dataset.
  167. /// (Default = -1 means all samples.)
  168. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
  169. /// Can be any of:
  170. /// ShuffleMode::kFalse - No shuffling is performed.
  171. /// ShuffleMode::kFiles - Shuffle files only.
  172. /// ShuffleMode::kGlobal - Shuffle both the files and samples.
  173. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  174. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  175. /// specified only when num_shards is also specified. (Default = 0)
  176. /// \return Shared pointer to the current Dataset
  177. std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',',
  178. const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {},
  179. const std::vector<std::string> &column_names = {}, int64_t num_samples = -1,
  180. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
  181. int32_t shard_id = 0);
  182. /// \brief Function to create an ImageFolderDataset
  183. /// \notes A source dataset that reads images from a tree of directories
  184. /// All images within one folder have the same label
  185. /// The generated dataset has two columns ['image', 'label']
  186. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  187. /// \param[in] decode A flag to decode in ImageFolder
  188. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
  189. /// A `RandomSampler` will be used to randomly iterate the entire dataset
  190. /// \param[in] extensions File extensions to be read
  191. /// \param[in] class_indexing a class name to label map
  192. /// \return Shared pointer to the current ImageFolderDataset
  193. std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false,
  194. const std::shared_ptr<SamplerObj> &sampler = nullptr,
  195. const std::set<std::string> &extensions = {},
  196. const std::map<std::string, int32_t> &class_indexing = {});
  197. /// \brief Function to create a ManifestDataset
  198. /// \notes The generated dataset has two columns ['image', 'label']
  199. /// \param[in] dataset_file The dataset file to be read
  200. /// \param[in] usage Need "train", "eval" or "inference" data (default="train")
  201. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
  202. /// A `RandomSampler` will be used to randomly iterate the entire dataset
  203. /// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
  204. /// names will be sorted alphabetically and each class will be given a unique index starting from 0).
  205. /// \param[in] decode Decode the images after reading (default=false).
  206. /// \return Shared pointer to the current ManifestDataset
  207. std::shared_ptr<ManifestDataset> Manifest(std::string dataset_file, std::string usage = "train",
  208. std::shared_ptr<SamplerObj> sampler = nullptr,
  209. const std::map<std::string, int32_t> &class_indexing = {},
  210. bool decode = false);
  211. /// \brief Function to create a MnistDataset
  212. /// \notes The generated dataset has two columns ['image', 'label']
  213. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  214. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
  215. /// A `RandomSampler` will be used to randomly iterate the entire dataset
  216. /// \return Shared pointer to the current MnistDataset
  217. std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
  218. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  219. /// \brief Function to create a ConcatDataset
  220. /// \notes Reload "+" operator to concat two datasets
  221. /// \param[in] datasets1 Shared pointer to the first dataset to be concatenated
  222. /// \param[in] datasets2 Shared pointer to the second dataset to be concatenated
  223. /// \return Shared pointer to the current ConcatDataset
  224. std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
  225. const std::shared_ptr<Dataset> &datasets2);
  226. /// \brief Function to create a RandomDataset
  227. /// \param[in] total_rows Number of rows for the dataset to generate (default=0, number of rows is random)
  228. /// \param[in] schema SchemaObj to set column type, data type and data shape
  229. /// \param[in] columns_list List of columns to be read (default={}, read all columns)
  230. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  231. /// will be used to randomly iterate the entire dataset
  232. /// \return Shared pointer to the current Dataset
  233. template <typename T = std::shared_ptr<SchemaObj>>
  234. std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schema = nullptr,
  235. const std::vector<std::string> &columns_list = {},
  236. std::shared_ptr<SamplerObj> sampler = nullptr) {
  237. auto ds = std::make_shared<RandomDataset>(total_rows, schema, std::move(columns_list), std::move(sampler));
  238. return ds->ValidateParams() ? ds : nullptr;
  239. }
  240. /// \brief Function to create a TextFileDataset
  241. /// \notes The generated dataset has one column ['text']
  242. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  243. /// will be sorted in a lexicographical order.
  244. /// \param[in] num_samples The number of samples to be included in the dataset.
  245. /// (Default = 0 means all samples.)
  246. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
  247. /// Can be any of:
  248. /// ShuffleMode.kFalse - No shuffling is performed.
  249. /// ShuffleMode.kFiles - Shuffle files only.
  250. /// ShuffleMode.kGlobal - Shuffle both the files and samples.
  251. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  252. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  253. /// specified only when num_shards is also specified. (Default = 0)
  254. /// \return Shared pointer to the current TextFileDataset
  255. std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0,
  256. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
  257. int32_t shard_id = 0);
  258. /// \brief Function to create a VOCDataset
  259. /// \notes The generated dataset has multi-columns :
  260. /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
  261. /// ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
  262. /// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
  263. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  264. /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
  265. /// \param[in] mode Set the data list txt file to be readed
  266. /// \param[in] class_indexing A str-to-int mapping from label name to index
  267. /// \param[in] decode Decode the images after reading
  268. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  269. /// will be used to randomly iterate the entire dataset
  270. /// \return Shared pointer to the current Dataset
  271. std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
  272. const std::string &mode = "train",
  273. const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
  274. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  275. /// \brief Function to create a ZipDataset
  276. /// \notes Applies zip to the dataset
  277. /// \param[in] datasets List of shared pointers to the datasets that we want to zip
  278. /// \return Shared pointer to the current Dataset
  279. std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
  280. /// \class Dataset datasets.h
  281. /// \brief A base class to represent a dataset in the data pipeline.
  282. class Dataset : public std::enable_shared_from_this<Dataset> {
  283. public:
  284. friend class Iterator;
  285. /// \brief Constructor
  286. Dataset();
  287. /// \brief Destructor
  288. ~Dataset() = default;
  289. /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object
  290. /// \return The list of shared pointers to the newly created DatasetOps
  291. virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
  292. /// \brief Pure virtual function for derived class to implement parameters validation
  293. /// \return bool true if all the parameters are valid
  294. virtual bool ValidateParams() = 0;
  295. /// \brief Setter function for runtime number of workers
  296. /// \param[in] num_workers The number of threads in this operator
  297. /// \return Shared pointer to the original object
  298. std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers) {
  299. #if !defined(_WIN32) && !defined(_WIN64)
  300. #ifndef ENABLE_ANDROID
  301. int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF);
  302. if (cpu_count < 0 || cpu_count > INT32_MAX) {
  303. MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count;
  304. return nullptr;
  305. }
  306. if (num_workers < 1 || num_workers > cpu_count) {
  307. MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count;
  308. return nullptr;
  309. }
  310. #endif
  311. #endif
  312. num_workers_ = num_workers;
  313. return shared_from_this();
  314. }
  315. /// \brief Function to create an Iterator over the Dataset pipeline
  316. /// \param[in] columns List of columns to be used to specify the order of columns
  317. /// \return Shared pointer to the Iterator
  318. std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
  319. /// \brief Function to create a BatchDataset
  320. /// \notes Combines batch_size number of consecutive rows into batches
  321. /// \param[in] batch_size Path to the root directory that contains the dataset
  322. /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
  323. /// batch. If true, and if there are less than batch_size rows
  324. /// available to make the last batch, then those rows will
  325. /// be dropped and not propagated to the next node
  326. /// \return Shared pointer to the current BatchDataset
  327. std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
  328. /// \brief Function to create a Vocab from source dataset
  329. /// \notes Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
  330. /// which contains top_k most frequent words (if top_k is specified)
  331. /// \param[in] columns Column names to get words from. It can be a vector of column names
  332. /// \param[in] freq_range A tuple of integers (min_frequency, max_frequency). Words within the frequency
  333. /// range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
  334. /// can be set to default, which corresponds to 0/total_words separately
  335. /// \param[in] top_k Number of words to be built into vocab. top_k most frequent words are
  336. /// taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
  337. /// \param[in] special_tokens A list of strings, each one is a special token
  338. /// \param[in] special_first Whether special_tokens will be prepended/appended to vocab, If special_tokens
  339. /// is specified and special_first is set to default, special_tokens will be prepended
  340. /// \return Shared pointer to the current Vocab
  341. std::shared_ptr<Vocab> BuildVocab(const std::vector<std::string> &columns = {},
  342. const std::pair<int64_t, int64_t> &freq_range = {0, kDeMaxFreq},
  343. int64_t top_k = kDeMaxTopk, const std::vector<std::string> &special_tokens = {},
  344. bool special_first = true);
  345. /// \brief Function to create a ConcatDataset
  346. /// \notes Concat the datasets in the input
  347. /// \param[in] datasets List of shared pointers to the dataset that should be concatenated together
  348. /// \return Shared pointer to the current ConcatDataset
  349. std::shared_ptr<ConcatDataset> Concat(const std::vector<std::shared_ptr<Dataset>> &datasets);
  350. /// \brief Function to create a MapDataset
  351. /// \notes Applies each operation in operations to this dataset
  352. /// \param[in] operations Vector of operations to be applied on the dataset. Operations are
  353. /// applied in the order they appear in this list
  354. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  355. /// operation as input. The size of this list must match the number of
  356. /// input columns expected by the first operator. The default input_columns
  357. /// is the first column
  358. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  359. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  360. /// The size of this list must match the number of output columns of the
  361. /// last operation. The default output_columns will have the same
  362. /// name as the input columns, i.e., the columns will be replaced
  363. /// \param[in] project_columns A list of column names to project
  364. /// \return Shared pointer to the current MapDataset
  365. std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorOperation>> operations,
  366. std::vector<std::string> input_columns = {},
  367. std::vector<std::string> output_columns = {},
  368. const std::vector<std::string> &project_columns = {});
  369. /// \brief Function to create a Project Dataset
  370. /// \notes Applies project to the dataset
  371. /// \param[in] columns The name of columns to project
  372. /// \return Shared pointer to the current Dataset
  373. std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
  374. /// \brief Function to create a Rename Dataset
  375. /// \notes Renames the columns in the input dataset
  376. /// \param[in] input_columns List of the input columns to rename
  377. /// \param[in] output_columns List of the output columns
  378. /// \return Shared pointer to the current Dataset
  379. std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
  380. const std::vector<std::string> &output_columns);
  381. /// \brief Function to create a RepeatDataset
  382. /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
  383. /// \param[in] count Number of times the dataset should be repeated
  384. /// \return Shared pointer to the current Dataset
  385. /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
  386. /// due to a limitation in the current implementation
  387. std::shared_ptr<Dataset> Repeat(int32_t count = -1);
  388. /// \brief Function to create a Shuffle Dataset
  389. /// \notes Randomly shuffles the rows of this dataset
  390. /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
  391. /// \return Shared pointer to the current ShuffleDataset
  392. std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size);
  393. /// \brief Function to create a SkipDataset
  394. /// \notes Skips count elements in this dataset.
  395. /// \param[in] count Number of elements the dataset to be skipped.
  396. /// \return Shared pointer to the current SkipDataset
  397. std::shared_ptr<SkipDataset> Skip(int32_t count);
  398. /// \brief Function to create a TakeDataset
  399. /// \notes Takes count elements in this dataset.
  400. /// \param[in] count Number of elements the dataset to be taken.
  401. /// \return Shared pointer to the current Dataset
  402. std::shared_ptr<Dataset> Take(int32_t count = -1);
  403. /// \brief Function to create a Zip Dataset
  404. /// \notes Applies zip to the dataset
  405. /// \param[in] datasets A list of shared pointers to the datasets that we want to zip
  406. /// \return Shared pointer to the current Dataset
  407. std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
  408. protected:
  409. std::vector<std::shared_ptr<Dataset>> children;
  410. std::shared_ptr<Dataset> parent;
  411. int32_t num_workers_;
  412. int32_t rows_per_buffer_;
  413. int32_t connector_que_size_;
  414. int32_t worker_connector_size_;
  415. };
  416. class SchemaObj {
  417. public:
  418. /// \brief Constructor
  419. explicit SchemaObj(const std::string &schema_file = "");
  420. /// \brief Destructor
  421. ~SchemaObj() = default;
  422. /// \brief SchemaObj init function
  423. /// \return bool true if schema init success
  424. bool init();
  425. /// \brief Add new column to the schema
  426. /// \param[in] name name of the column.
  427. /// \param[in] de_type data type of the column(TypeId).
  428. /// \param[in] shape shape of the column.
  429. /// \return bool true if schema init success
  430. bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
  431. /// \brief Add new column to the schema
  432. /// \param[in] name name of the column.
  433. /// \param[in] de_type data type of the column(std::string).
  434. /// \param[in] shape shape of the column.
  435. /// \return bool true if schema init success
  436. bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
  437. /// \brief Get a JSON string of the schema
  438. /// \return JSON string of the schema
  439. std::string to_json();
  440. /// \brief Get a JSON string of the schema
  441. std::string to_string() { return to_json(); }
  442. /// \brief set a new value to dataset_type
  443. inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; }
  444. /// \brief set a new value to num_rows
  445. inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
  446. /// \brief get the current num_rows
  447. inline int32_t get_num_rows() { return num_rows_; }
  448. private:
  449. /// \brief Parse the columns and add it to columns
  450. /// \param[in] columns dataset attribution information, decoded from schema file.
  451. /// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
  452. /// \return JSON string of the schema
  453. bool parse_column(nlohmann::json columns);
  454. /// \brief Get schema file from json file
  455. /// \param[in] json_obj object of json parsed.
  456. /// \return bool true if json dump success
  457. bool from_json(nlohmann::json json_obj);
  458. int32_t num_rows_;
  459. std::string dataset_type_;
  460. std::string schema_file_;
  461. nlohmann::json columns_;
  462. };
  463. /* ####################################### Derived Dataset classes ################################# */
  464. // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
  465. // (In alphabetical order)
  466. class AlbumDataset : public Dataset {
  467. public:
  468. /// \brief Constructor
  469. AlbumDataset(const std::string &dataset_dir, const std::string &data_schema,
  470. const std::vector<std::string> &column_names, bool decode, const std::shared_ptr<SamplerObj> &sampler);
  471. /// \brief Destructor
  472. ~AlbumDataset() = default;
  473. /// \brief a base class override function to create a runtime dataset op object from this class
  474. /// \return shared pointer to the newly created DatasetOp
  475. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  476. /// \brief Parameters validation
  477. /// \return bool true if all the params are valid
  478. bool ValidateParams() override;
  479. private:
  480. std::string dataset_dir_;
  481. std::string schema_path_;
  482. std::vector<std::string> column_names_;
  483. bool decode_;
  484. std::shared_ptr<SamplerObj> sampler_;
  485. };
  486. class CelebADataset : public Dataset {
  487. public:
  488. /// \brief Constructor
  489. CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
  490. const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
  491. const std::set<std::string> &extensions);
  492. /// \brief Destructor
  493. ~CelebADataset() = default;
  494. /// \brief a base class override function to create the required runtime dataset op objects for this class
  495. /// \return shared pointer to the list of newly created DatasetOps
  496. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  497. /// \brief Parameters validation
  498. /// \return bool true if all the params are valid
  499. bool ValidateParams() override;
  500. private:
  501. std::string dataset_dir_;
  502. std::string dataset_type_;
  503. bool decode_;
  504. std::set<std::string> extensions_;
  505. std::shared_ptr<SamplerObj> sampler_;
  506. };
  507. // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
  508. // (In alphabetical order)
  509. class Cifar10Dataset : public Dataset {
  510. public:
  511. /// \brief Constructor
  512. Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
  513. /// \brief Destructor
  514. ~Cifar10Dataset() = default;
  515. /// \brief a base class override function to create the required runtime dataset op objects for this class
  516. /// \return The list of shared pointers to the newly created DatasetOps
  517. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  518. /// \brief Parameters validation
  519. /// \return bool true if all the params are valid
  520. bool ValidateParams() override;
  521. private:
  522. std::string dataset_dir_;
  523. std::shared_ptr<SamplerObj> sampler_;
  524. };
  525. class Cifar100Dataset : public Dataset {
  526. public:
  527. /// \brief Constructor
  528. Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
  529. /// \brief Destructor
  530. ~Cifar100Dataset() = default;
  531. /// \brief a base class override function to create the required runtime dataset op objects for this class
  532. /// \return The list of shared pointers to the newly created DatasetOps
  533. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  534. /// \brief Parameters validation
  535. /// \return bool true if all the params are valid
  536. bool ValidateParams() override;
  537. private:
  538. std::string dataset_dir_;
  539. std::shared_ptr<SamplerObj> sampler_;
  540. };
  541. /// \class CLUEDataset
  542. /// \brief A Dataset derived class to represent CLUE dataset
  543. class CLUEDataset : public Dataset {
  544. public:
  545. /// \brief Constructor
  546. CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
  547. ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
  548. /// \brief Destructor
  549. ~CLUEDataset() = default;
  550. /// \brief a base class override function to create the required runtime dataset op objects for this class
  551. /// \return The list of shared pointers to the newly created DatasetOps
  552. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  553. /// \brief Parameters validation
  554. /// \return bool true if all the params are valid
  555. bool ValidateParams() override;
  556. private:
  557. /// \brief Split string based on a character delimiter
  558. /// \return A string vector
  559. std::vector<std::string> split(const std::string &s, char delim);
  560. std::vector<std::string> dataset_files_;
  561. std::string task_;
  562. std::string usage_;
  563. int64_t num_samples_;
  564. ShuffleMode shuffle_;
  565. int32_t num_shards_;
  566. int32_t shard_id_;
  567. };
  568. class CocoDataset : public Dataset {
  569. public:
  570. /// \brief Constructor
  571. CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
  572. const bool &decode, const std::shared_ptr<SamplerObj> &sampler);
  573. /// \brief Destructor
  574. ~CocoDataset() = default;
  575. /// \brief a base class override function to create the required runtime dataset op objects for this class
  576. /// \return shared pointer to the list of newly created DatasetOps
  577. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  578. /// \brief Parameters validation
  579. /// \return bool true if all the params are valid
  580. bool ValidateParams() override;
  581. private:
  582. std::string dataset_dir_;
  583. std::string annotation_file_;
  584. std::string task_;
  585. bool decode_;
  586. std::shared_ptr<SamplerObj> sampler_;
  587. };
  588. /// \brief Record type for CSV
  589. enum CsvType : uint8_t { INT = 0, FLOAT, STRING };
  590. /// \brief Base class of CSV Record
  591. struct CsvBase {
  592. public:
  593. CsvBase() = default;
  594. explicit CsvBase(CsvType t) : type(t) {}
  595. virtual ~CsvBase() {}
  596. CsvType type;
  597. };
  598. /// \brief CSV Record that can represent integer, float and string.
  599. template <typename T>
  600. class CsvRecord : public CsvBase {
  601. public:
  602. CsvRecord() = default;
  603. CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {}
  604. ~CsvRecord() {}
  605. T value;
  606. };
  607. class CSVDataset : public Dataset {
  608. public:
  609. /// \brief Constructor
  610. CSVDataset(const std::vector<std::string> &dataset_files, char field_delim,
  611. const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names,
  612. int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
  613. /// \brief Destructor
  614. ~CSVDataset() = default;
  615. /// \brief a base class override function to create the required runtime dataset op objects for this class
  616. /// \return shared pointer to the list of newly created DatasetOps
  617. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  618. /// \brief Parameters validation
  619. /// \return bool true if all the params are valid
  620. bool ValidateParams() override;
  621. private:
  622. std::vector<std::string> dataset_files_;
  623. char field_delim_;
  624. std::vector<std::shared_ptr<CsvBase>> column_defaults_;
  625. std::vector<std::string> column_names_;
  626. int64_t num_samples_;
  627. ShuffleMode shuffle_;
  628. int32_t num_shards_;
  629. int32_t shard_id_;
  630. };
  631. /// \class ImageFolderDataset
  632. /// \brief A Dataset derived class to represent ImageFolder dataset
  633. class ImageFolderDataset : public Dataset {
  634. public:
  635. /// \brief Constructor
  636. ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,
  637. std::set<std::string> extensions, std::map<std::string, int32_t> class_indexing);
  638. /// \brief Destructor
  639. ~ImageFolderDataset() = default;
  640. /// \brief a base class override function to create the required runtime dataset op objects for this class
  641. /// \return The list of shared pointers to the newly created DatasetOps
  642. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  643. /// \brief Parameters validation
  644. /// \return bool true if all the params are valid
  645. bool ValidateParams() override;
  646. private:
  647. std::string dataset_dir_;
  648. bool decode_;
  649. bool recursive_;
  650. std::shared_ptr<SamplerObj> sampler_;
  651. std::map<std::string, int32_t> class_indexing_;
  652. std::set<std::string> exts_;
  653. };
  654. class ManifestDataset : public Dataset {
  655. public:
  656. /// \brief Constructor
  657. ManifestDataset(std::string dataset_file, std::string usage, std::shared_ptr<SamplerObj> sampler,
  658. const std::map<std::string, int32_t> &class_indexing, bool decode);
  659. /// \brief Destructor
  660. ~ManifestDataset() = default;
  661. /// \brief a base class override function to create the required runtime dataset op objects for this class
  662. /// \return The list of shared pointers to the newly created DatasetOps
  663. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  664. /// \brief Parameters validation
  665. /// \return bool true if all the params are valid
  666. bool ValidateParams() override;
  667. private:
  668. std::string dataset_file_;
  669. std::string usage_;
  670. bool decode_;
  671. std::map<std::string, int32_t> class_index_;
  672. std::shared_ptr<SamplerObj> sampler_;
  673. };
  674. class MnistDataset : public Dataset {
  675. public:
  676. /// \brief Constructor
  677. MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler);
  678. /// \brief Destructor
  679. ~MnistDataset() = default;
  680. /// \brief a base class override function to create the required runtime dataset op objects for this class
  681. /// \return The list of shared pointers to the newly created DatasetOps
  682. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  683. /// \brief Parameters validation
  684. /// \return bool true if all the params are valid
  685. bool ValidateParams() override;
  686. private:
  687. std::string dataset_dir_;
  688. std::shared_ptr<SamplerObj> sampler_;
  689. };
  690. class RandomDataset : public Dataset {
  691. public:
  692. // Some constants to provide limits to random generation.
  693. static constexpr int32_t kMaxNumColumns = 4;
  694. static constexpr int32_t kMaxRank = 4;
  695. static constexpr int32_t kMaxDimValue = 32;
  696. /// \brief Constructor
  697. RandomDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
  698. const std::vector<std::string> &columns_list, std::shared_ptr<SamplerObj> sampler)
  699. : total_rows_(total_rows),
  700. schema_path_(""),
  701. schema_(std::move(schema)),
  702. columns_list_(columns_list),
  703. sampler_(std::move(sampler)) {}
  704. /// \brief Constructor
  705. RandomDataset(const int32_t &total_rows, std::string schema_path, std::vector<std::string> columns_list,
  706. std::shared_ptr<SamplerObj> sampler)
  707. : total_rows_(total_rows), schema_path_(schema_path), columns_list_(columns_list), sampler_(std::move(sampler)) {}
  708. /// \brief Destructor
  709. ~RandomDataset() = default;
  710. /// \brief a base class override function to create the required runtime dataset op objects for this class
  711. /// \return The list of shared pointers to the newly created DatasetOps
  712. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  713. /// \brief Parameters validation
  714. /// \return bool true if all the params are valid
  715. bool ValidateParams() override;
  716. private:
  717. /// \brief A quick inline for producing a random number between (and including) min/max
  718. /// \param[in] min minimum number that can be generated.
  719. /// \param[in] max maximum number that can be generated.
  720. /// \return The generated random number
  721. int32_t GenRandomInt(int32_t min, int32_t max);
  722. int32_t total_rows_;
  723. std::string schema_path_;
  724. std::shared_ptr<SchemaObj> schema_;
  725. std::vector<std::string> columns_list_;
  726. std::shared_ptr<SamplerObj> sampler_;
  727. std::mt19937 rand_gen_;
  728. };
  729. /// \class TextFileDataset
  730. /// \brief A Dataset derived class to represent TextFile dataset
  731. class TextFileDataset : public Dataset {
  732. public:
  733. /// \brief Constructor
  734. TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
  735. int32_t shard_id);
  736. /// \brief Destructor
  737. ~TextFileDataset() = default;
  738. /// \brief a base class override function to create the required runtime dataset op objects for this class
  739. /// \return The list of shared pointers to the newly created DatasetOps
  740. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  741. /// \brief Parameters validation
  742. /// \return bool true if all the params are valid
  743. bool ValidateParams() override;
  744. private:
  745. std::vector<std::string> dataset_files_;
  746. int32_t num_samples_;
  747. int32_t num_shards_;
  748. int32_t shard_id_;
  749. ShuffleMode shuffle_;
  750. };
  751. class VOCDataset : public Dataset {
  752. public:
  753. /// \brief Constructor
  754. VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
  755. const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
  756. /// \brief Destructor
  757. ~VOCDataset() = default;
  758. /// \brief a base class override function to create the required runtime dataset op objects for this class
  759. /// \return shared pointer to the list of newly created DatasetOps
  760. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  761. /// \brief Parameters validation
  762. /// \return bool true if all the params are valid
  763. bool ValidateParams() override;
  764. private:
  765. const std::string kColumnImage = "image";
  766. const std::string kColumnTarget = "target";
  767. const std::string kColumnBbox = "bbox";
  768. const std::string kColumnLabel = "label";
  769. const std::string kColumnDifficult = "difficult";
  770. const std::string kColumnTruncate = "truncate";
  771. std::string dataset_dir_;
  772. std::string task_;
  773. std::string mode_;
  774. std::map<std::string, int32_t> class_index_;
  775. bool decode_;
  776. std::shared_ptr<SamplerObj> sampler_;
  777. };
  778. // DERIVED DATASET CLASSES FOR DATASET OPS
  779. // (In alphabetical order)
  780. class BatchDataset : public Dataset {
  781. public:
  782. /// \brief Constructor
  783. BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
  784. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);
  785. /// \brief Destructor
  786. ~BatchDataset() = default;
  787. /// \brief a base class override function to create the required runtime dataset op objects for this class
  788. /// \return The list of shared pointers to the newly created DatasetOps
  789. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  790. /// \brief Parameters validation
  791. /// \return bool true if all the params are valid
  792. bool ValidateParams() override;
  793. private:
  794. int32_t batch_size_;
  795. bool drop_remainder_;
  796. bool pad_;
  797. std::vector<std::string> cols_to_map_;
  798. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
  799. };
  800. class BuildVocabDataset : public Dataset {
  801. public:
  802. /// \brief Constructor
  803. BuildVocabDataset(std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
  804. const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
  805. const std::vector<std::string> &special_tokens, bool special_first);
  806. /// \brief Destructor
  807. ~BuildVocabDataset() = default;
  808. /// \brief a base class override function to create the required runtime dataset op objects for this class
  809. /// \return The list of shared pointers to the newly created DatasetOps
  810. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  811. /// \brief Parameters validation
  812. /// \return bool true if all the params are valid
  813. bool ValidateParams() override;
  814. private:
  815. std::shared_ptr<Vocab> vocab_;
  816. std::vector<std::string> columns_;
  817. std::pair<int64_t, int64_t> freq_range_;
  818. int64_t top_k_;
  819. std::vector<std::string> special_tokens_;
  820. bool special_first_;
  821. };
  822. class ConcatDataset : public Dataset {
  823. public:
  824. /// \brief Constructor
  825. explicit ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets);
  826. /// \brief Destructor
  827. ~ConcatDataset() = default;
  828. /// \brief a base class override function to create the required runtime dataset op objects for this class
  829. /// \return The list of shared pointers to the newly created DatasetOps
  830. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  831. /// \brief Parameters validation
  832. /// \return bool true if all the params are valid
  833. bool ValidateParams() override;
  834. private:
  835. std::vector<std::shared_ptr<Dataset>> datasets_;
  836. };
  837. class MapDataset : public Dataset {
  838. public:
  839. /// \brief Constructor
  840. MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns = {},
  841. std::vector<std::string> output_columns = {}, const std::vector<std::string> &columns = {});
  842. /// \brief Destructor
  843. ~MapDataset() = default;
  844. /// \brief a base class override function to create the required runtime dataset op objects for this class
  845. /// \return The list of shared pointers to the newly created DatasetOps
  846. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  847. /// \brief Parameters validation
  848. /// \return bool true if all the params are valid
  849. bool ValidateParams() override;
  850. private:
  851. std::vector<std::shared_ptr<TensorOperation>> operations_;
  852. std::vector<std::string> input_columns_;
  853. std::vector<std::string> output_columns_;
  854. std::vector<std::string> project_columns_;
  855. };
  856. class ProjectDataset : public Dataset {
  857. public:
  858. /// \brief Constructor
  859. explicit ProjectDataset(const std::vector<std::string> &columns);
  860. /// \brief Destructor
  861. ~ProjectDataset() = default;
  862. /// \brief a base class override function to create the required runtime dataset op objects for this class
  863. /// \return The list of shared pointers to the newly created DatasetOps
  864. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  865. /// \brief Parameters validation
  866. /// \return bool true if all the params are valid
  867. bool ValidateParams() override;
  868. private:
  869. std::vector<std::string> columns_;
  870. };
  871. class RenameDataset : public Dataset {
  872. public:
  873. /// \brief Constructor
  874. explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns);
  875. /// \brief Destructor
  876. ~RenameDataset() = default;
  877. /// \brief a base class override function to create the required runtime dataset op objects for this class
  878. /// \return The list of shared pointers to the newly created DatasetOps
  879. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  880. /// \brief Parameters validation
  881. /// \return bool true if all the params are valid
  882. bool ValidateParams() override;
  883. private:
  884. std::vector<std::string> input_columns_;
  885. std::vector<std::string> output_columns_;
  886. };
  887. class RepeatDataset : public Dataset {
  888. public:
  889. /// \brief Constructor
  890. explicit RepeatDataset(int32_t count);
  891. /// \brief Destructor
  892. ~RepeatDataset() = default;
  893. /// \brief a base class override function to create the required runtime dataset op objects for this class
  894. /// \return The list of shared pointers to the newly created DatasetOps
  895. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  896. /// \brief Parameters validation
  897. /// \return bool true if all the params are valid
  898. bool ValidateParams() override;
  899. private:
  900. int32_t repeat_count_;
  901. };
  902. class ShuffleDataset : public Dataset {
  903. public:
  904. ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch);
  905. ~ShuffleDataset() = default;
  906. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  907. bool ValidateParams() override;
  908. private:
  909. int32_t shuffle_size_;
  910. uint32_t shuffle_seed_;
  911. bool reset_every_epoch_;
  912. };
  913. class SkipDataset : public Dataset {
  914. public:
  915. /// \brief Constructor
  916. explicit SkipDataset(int32_t count);
  917. /// \brief Destructor
  918. ~SkipDataset() = default;
  919. /// \brief a base class override function to create the required runtime dataset op objects for this class
  920. /// \return The list of shared pointers to the newly created DatasetOps
  921. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  922. /// \brief Parameters validation
  923. /// \return bool true if all the params are valid
  924. bool ValidateParams() override;
  925. private:
  926. int32_t skip_count_;
  927. };
  928. class TakeDataset : public Dataset {
  929. public:
  930. /// \brief Constructor
  931. explicit TakeDataset(int32_t count);
  932. /// \brief Destructor
  933. ~TakeDataset() = default;
  934. /// \brief a base class override function to create the required runtime dataset op objects for this class
  935. /// \return shared pointer to the list of newly created DatasetOps
  936. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  937. /// \brief Parameters validation
  938. /// \return bool true if all the params are valid
  939. bool ValidateParams() override;
  940. private:
  941. int32_t take_count_;
  942. };
  943. class ZipDataset : public Dataset {
  944. public:
  945. /// \brief Constructor
  946. explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets);
  947. /// \brief Destructor
  948. ~ZipDataset() = default;
  949. /// \brief a base class override function to create the required runtime dataset op objects for this class
  950. /// \return The list of shared pointers to the newly created DatasetOps
  951. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  952. /// \brief Parameters validation
  953. /// \return bool true if all the params are valid
  954. bool ValidateParams() override;
  955. private:
  956. std::vector<std::shared_ptr<Dataset>> datasets_;
  957. };
  958. } // namespace api
  959. } // namespace dataset
  960. } // namespace mindspore
  961. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_