You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.h 33 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  18. #include <unistd.h>
  19. #include <vector>
  20. #include <memory>
  21. #include <set>
  22. #include <map>
  23. #include <utility>
  24. #include <string>
  25. #include "minddata/dataset/core/constants.h"
  26. #include "minddata/dataset/include/tensor.h"
  27. #include "minddata/dataset/include/iterator.h"
  28. #include "minddata/dataset/include/samplers.h"
  29. namespace mindspore {
  30. namespace dataset {
  31. // Forward declare
  32. class DatasetOp;
  33. class DataSchema;
  34. class Tensor;
  35. class TensorShape;
  36. namespace api {
  37. class TensorOperation;
  38. class SamplerObj;
  39. // Datasets classes (in alphabetical order)
  40. class CelebADataset;
  41. class Cifar10Dataset;
  42. class Cifar100Dataset;
  43. class CLUEDataset;
  44. class CocoDataset;
  45. class ImageFolderDataset;
  46. class MnistDataset;
  47. class TextFileDataset;
  48. class VOCDataset;
  49. // Dataset Op classes (in alphabetical order)
  50. class BatchDataset;
  51. class ConcatDataset;
  52. class MapDataset;
  53. class ProjectDataset;
  54. class RenameDataset;
  55. class RepeatDataset;
  56. class ShuffleDataset;
  57. class SkipDataset;
  58. class TakeDataset;
  59. class ZipDataset;
  60. /// \brief Function to create a CelebADataset
  61. /// \notes The generated dataset has two columns ['image', 'attr'].
  62. // The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
  63. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  64. /// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'.
  65. /// \param[in] decode Decode the images after reading (default=False).
  66. /// \param[in] extensions List of file extensions to be included in the dataset (default=None).
  67. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  68. /// will be used to randomly iterate the entire dataset
  69. /// \return Shared pointer to the current Dataset
  70. std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
  71. const std::shared_ptr<SamplerObj> &sampler = nullptr, bool decode = false,
  72. const std::set<std::string> &extensions = {});
  73. /// \brief Function to create a Cifar10 Dataset
  74. /// \notes The generated dataset has two columns ['image', 'label']
  75. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  76. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  77. /// will be used to randomly iterate the entire dataset
  78. /// \return Shared pointer to the current Dataset
  79. std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
  80. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  81. /// \brief Function to create a Cifar100 Dataset
  82. /// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
  83. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  84. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  85. /// will be used to randomly iterate the entire dataset
  86. /// \return Shared pointer to the current Dataset
  87. std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
  88. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  89. /// \brief Function to create a CLUEDataset
  90. /// \notes The generated dataset has a variable number of columns depending on the task and usage
  91. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  92. /// will be sorted in a lexicographical order.
  93. /// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC").
  94. /// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
  95. /// \param[in] num_samples The number of samples to be included in the dataset.
  96. /// (Default = 0 means all samples.)
  97. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
  98. /// Can be any of:
  99. /// ShuffleMode.kFalse - No shuffling is performed.
  100. /// ShuffleMode.kFiles - Shuffle files only.
  101. /// ShuffleMode.kGlobal - Shuffle both the files and samples.
  102. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  103. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  104. /// specified only when num_shards is also specified. (Default = 0)
  105. /// \return Shared pointer to the current CLUEDataset
  106. std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
  107. const std::string &usage = "train", int64_t num_samples = 0,
  108. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
  109. int32_t shard_id = 0);
  110. /// \brief Function to create a CocoDataset
  111. /// \notes The generated dataset has multi-columns :
  112. /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  113. /// ['iscrowd', dtype=uint32]].
  114. /// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]].
  115. /// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
  116. /// ['num_keypoints', dtype=uint32]].
  117. /// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  118. /// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]].
  119. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  120. /// \param[in] annotation_file Path to the annotation json
  121. /// \param[in] task Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint'
  122. /// \param[in] decode Decode the images after reading
  123. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  124. /// will be used to randomly iterate the entire dataset
  125. /// \return Shared pointer to the current Dataset
  126. std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
  127. const std::string &task = "Detection", const bool &decode = false,
  128. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  129. /// \brief Function to create an ImageFolderDataset
  130. /// \notes A source dataset that reads images from a tree of directories
  131. /// All images within one folder have the same label
  132. /// The generated dataset has two columns ['image', 'label']
  133. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  134. /// \param[in] decode A flag to decode in ImageFolder
  135. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
  136. /// A `RandomSampler` will be used to randomly iterate the entire dataset
  137. /// \param[in] extensions File extensions to be read
  138. /// \param[in] class_indexing a class name to label map
  139. /// \return Shared pointer to the current ImageFolderDataset
  140. std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false,
  141. const std::shared_ptr<SamplerObj> &sampler = nullptr,
  142. const std::set<std::string> &extensions = {},
  143. const std::map<std::string, int32_t> &class_indexing = {});
  144. /// \brief Function to create a MnistDataset
  145. /// \notes The generated dataset has two columns ['image', 'label']
  146. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  147. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
  148. /// A `RandomSampler` will be used to randomly iterate the entire dataset
  149. /// \return Shared pointer to the current MnistDataset
  150. std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
  151. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  152. /// \brief Function to create a ConcatDataset
  153. /// \notes Reload "+" operator to concat two datasets
  154. /// \param[in] datasets1 Shared pointer to the first dataset to be concatenated
  155. /// \param[in] datasets2 Shared pointer to the second dataset to be concatenated
  156. /// \return Shared pointer to the current ConcatDataset
  157. std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
  158. const std::shared_ptr<Dataset> &datasets2);
  159. /// \brief Function to create a TextFileDataset
  160. /// \notes The generated dataset has one column ['text']
  161. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  162. /// will be sorted in a lexicographical order.
  163. /// \param[in] num_samples The number of samples to be included in the dataset.
  164. /// (Default = 0 means all samples.)
  165. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
  166. /// Can be any of:
  167. /// ShuffleMode.kFalse - No shuffling is performed.
  168. /// ShuffleMode.kFiles - Shuffle files only.
  169. /// ShuffleMode.kGlobal - Shuffle both the files and samples.
  170. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  171. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  172. /// specified only when num_shards is also specified. (Default = 0)
  173. /// \return Shared pointer to the current TextFileDataset
  174. std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples = 0,
  175. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
  176. int32_t shard_id = 0);
  177. /// \brief Function to create a VOCDataset
  178. /// \notes The generated dataset has multi-columns :
  179. /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
  180. /// ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
  181. /// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
  182. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  183. /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
  184. /// \param[in] mode Set the data list txt file to be readed
  185. /// \param[in] class_indexing A str-to-int mapping from label name to index
  186. /// \param[in] decode Decode the images after reading
  187. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
  188. /// will be used to randomly iterate the entire dataset
  189. /// \return Shared pointer to the current Dataset
  190. std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
  191. const std::string &mode = "train",
  192. const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
  193. const std::shared_ptr<SamplerObj> &sampler = nullptr);
  194. /// \brief Function to create a ZipDataset
  195. /// \notes Applies zip to the dataset
  196. /// \param[in] datasets List of shared pointers to the datasets that we want to zip
  197. /// \return Shared pointer to the current Dataset
  198. std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
  199. /// \class Dataset datasets.h
  200. /// \brief A base class to represent a dataset in the data pipeline.
  201. class Dataset : public std::enable_shared_from_this<Dataset> {
  202. public:
  203. friend class Iterator;
  204. /// \brief Constructor
  205. Dataset();
  206. /// \brief Destructor
  207. ~Dataset() = default;
  208. /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object
  209. /// \return The list of shared pointers to the newly created DatasetOps
  210. virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
  211. /// \brief Pure virtual function for derived class to implement parameters validation
  212. /// \return bool True if all the params are valid
  213. virtual bool ValidateParams() = 0;
  214. /// \brief Setter function for runtime number of workers
  215. /// \param[in] num_workers The number of threads in this operator
  216. /// \return Shared pointer to the original object
  217. std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers) {
  218. num_workers_ = num_workers;
  219. return shared_from_this();
  220. }
  221. /// \brief Function to create an Iterator over the Dataset pipeline
  222. /// \param[in] columns List of columns to be used to specify the order of columns
  223. /// \return Shared pointer to the Iterator
  224. std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
  225. /// \brief Function to create a BatchDataset
  226. /// \notes Combines batch_size number of consecutive rows into batches
  227. /// \param[in] batch_size Path to the root directory that contains the dataset
  228. /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
  229. /// batch. If true, and if there are less than batch_size rows
  230. /// available to make the last batch, then those rows will
  231. /// be dropped and not propagated to the next node
  232. /// \return Shared pointer to the current BatchDataset
  233. std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
  234. /// \brief Function to create a ConcatDataset
  235. /// \notes Concat the datasets in the input
  236. /// \param[in] datasets List of shared pointers to the dataset that should be concatenated together
  237. /// \return Shared pointer to the current ConcatDataset
  238. std::shared_ptr<ConcatDataset> Concat(const std::vector<std::shared_ptr<Dataset>> &datasets);
  239. /// \brief Function to create a MapDataset
  240. /// \notes Applies each operation in operations to this dataset
  241. /// \param[in] operations Vector of operations to be applied on the dataset. Operations are
  242. /// applied in the order they appear in this list
  243. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  244. /// operation as input. The size of this list must match the number of
  245. /// input columns expected by the first operator. The default input_columns
  246. /// is the first column
  247. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  248. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  249. /// The size of this list must match the number of output columns of the
  250. /// last operation. The default output_columns will have the same
  251. /// name as the input columns, i.e., the columns will be replaced
  252. /// \param[in] project_columns A list of column names to project
  253. /// \return Shared pointer to the current MapDataset
  254. std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorOperation>> operations,
  255. std::vector<std::string> input_columns = {},
  256. std::vector<std::string> output_columns = {},
  257. const std::vector<std::string> &project_columns = {});
  258. /// \brief Function to create a Project Dataset
  259. /// \notes Applies project to the dataset
  260. /// \param[in] columns The name of columns to project
  261. /// \return Shared pointer to the current Dataset
  262. std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
  263. /// \brief Function to create a Rename Dataset
  264. /// \notes Renames the columns in the input dataset
  265. /// \param[in] input_columns List of the input columns to rename
  266. /// \param[in] output_columns List of the output columns
  267. /// \return Shared pointer to the current Dataset
  268. std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
  269. const std::vector<std::string> &output_columns);
  270. /// \brief Function to create a RepeatDataset
  271. /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
  272. /// \param[in] count Number of times the dataset should be repeated
  273. /// \return Shared pointer to the current Dataset
  274. /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
  275. /// due to a limitation in the current implementation
  276. std::shared_ptr<Dataset> Repeat(int32_t count = -1);
  277. /// \brief Function to create a Shuffle Dataset
  278. /// \notes Randomly shuffles the rows of this dataset
  279. /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
  280. /// \return Shared pointer to the current ShuffleDataset
  281. std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size);
  282. /// \brief Function to create a SkipDataset
  283. /// \notes Skips count elements in this dataset.
  284. /// \param[in] count Number of elements the dataset to be skipped.
  285. /// \return Shared pointer to the current SkipDataset
  286. std::shared_ptr<SkipDataset> Skip(int32_t count);
  287. /// \brief Function to create a TakeDataset
  288. /// \notes Takes count elements in this dataset.
  289. /// \param[in] count Number of elements the dataset to be taken.
  290. /// \return Shared pointer to the current Dataset
  291. std::shared_ptr<Dataset> Take(int32_t count = -1);
  292. /// \brief Function to create a Zip Dataset
  293. /// \notes Applies zip to the dataset
  294. /// \param[in] datasets A list of shared pointers to the datasets that we want to zip
  295. /// \return Shared pointer to the current Dataset
  296. std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
  297. protected:
  298. std::vector<std::shared_ptr<Dataset>> children;
  299. std::shared_ptr<Dataset> parent;
  300. int32_t num_workers_;
  301. int32_t rows_per_buffer_;
  302. int32_t connector_que_size_;
  303. int32_t worker_connector_size_;
  304. };
  305. /* ####################################### Derived Dataset classes ################################# */
  306. // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
  307. // (In alphabetical order)
  308. class CelebADataset : public Dataset {
  309. public:
  310. /// \brief Constructor
  311. CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
  312. const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
  313. const std::set<std::string> &extensions);
  314. /// \brief Destructor
  315. ~CelebADataset() = default;
  316. /// \brief a base class override function to create the required runtime dataset op objects for this class
  317. /// \return shared pointer to the list of newly created DatasetOps
  318. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  319. /// \brief Parameters validation
  320. /// \return bool true if all the params are valid
  321. bool ValidateParams() override;
  322. private:
  323. std::string dataset_dir_;
  324. std::string dataset_type_;
  325. bool decode_;
  326. std::set<std::string> extensions_;
  327. std::shared_ptr<SamplerObj> sampler_;
  328. };
  329. // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
  330. // (In alphabetical order)
  331. class Cifar10Dataset : public Dataset {
  332. public:
  333. /// \brief Constructor
  334. Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
  335. /// \brief Destructor
  336. ~Cifar10Dataset() = default;
  337. /// \brief a base class override function to create the required runtime dataset op objects for this class
  338. /// \return The list of shared pointers to the newly created DatasetOps
  339. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  340. /// \brief Parameters validation
  341. /// \return bool true if all the params are valid
  342. bool ValidateParams() override;
  343. private:
  344. std::string dataset_dir_;
  345. std::shared_ptr<SamplerObj> sampler_;
  346. };
  347. class Cifar100Dataset : public Dataset {
  348. public:
  349. /// \brief Constructor
  350. Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
  351. /// \brief Destructor
  352. ~Cifar100Dataset() = default;
  353. /// \brief a base class override function to create the required runtime dataset op objects for this class
  354. /// \return The list of shared pointers to the newly created DatasetOps
  355. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  356. /// \brief Parameters validation
  357. /// \return bool true if all the params are valid
  358. bool ValidateParams() override;
  359. private:
  360. std::string dataset_dir_;
  361. std::shared_ptr<SamplerObj> sampler_;
  362. };
  363. /// \class CLUEDataset
  364. /// \brief A Dataset derived class to represent CLUE dataset
  365. class CLUEDataset : public Dataset {
  366. public:
  367. /// \brief Constructor
  368. CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
  369. ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
  370. /// \brief Destructor
  371. ~CLUEDataset() = default;
  372. /// \brief a base class override function to create the required runtime dataset op objects for this class
  373. /// \return The list of shared pointers to the newly created DatasetOps
  374. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  375. /// \brief Parameters validation
  376. /// \return bool true if all the params are valid
  377. bool ValidateParams() override;
  378. private:
  379. /// \brief Split string based on a character delimiter
  380. /// \return A string vector
  381. std::vector<std::string> split(const std::string &s, char delim);
  382. std::vector<std::string> dataset_files_;
  383. std::string task_;
  384. std::string usage_;
  385. int64_t num_samples_;
  386. ShuffleMode shuffle_;
  387. int32_t num_shards_;
  388. int32_t shard_id_;
  389. };
  390. class CocoDataset : public Dataset {
  391. public:
  392. /// \brief Constructor
  393. CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
  394. const bool &decode, const std::shared_ptr<SamplerObj> &sampler);
  395. /// \brief Destructor
  396. ~CocoDataset() = default;
  397. /// \brief a base class override function to create the required runtime dataset op objects for this class
  398. /// \return shared pointer to the list of newly created DatasetOps
  399. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  400. /// \brief Parameters validation
  401. /// \return bool true if all the params are valid
  402. bool ValidateParams() override;
  403. private:
  404. std::string dataset_dir_;
  405. std::string annotation_file_;
  406. std::string task_;
  407. bool decode_;
  408. std::shared_ptr<SamplerObj> sampler_;
  409. };
  410. /// \class ImageFolderDataset
  411. /// \brief A Dataset derived class to represent ImageFolder dataset
  412. class ImageFolderDataset : public Dataset {
  413. public:
  414. /// \brief Constructor
  415. ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,
  416. std::set<std::string> extensions, std::map<std::string, int32_t> class_indexing);
  417. /// \brief Destructor
  418. ~ImageFolderDataset() = default;
  419. /// \brief a base class override function to create the required runtime dataset op objects for this class
  420. /// \return The list of shared pointers to the newly created DatasetOps
  421. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  422. /// \brief Parameters validation
  423. /// \return bool true if all the params are valid
  424. bool ValidateParams() override;
  425. private:
  426. std::string dataset_dir_;
  427. bool decode_;
  428. bool recursive_;
  429. std::shared_ptr<SamplerObj> sampler_;
  430. std::map<std::string, int32_t> class_indexing_;
  431. std::set<std::string> exts_;
  432. };
  433. class MnistDataset : public Dataset {
  434. public:
  435. /// \brief Constructor
  436. MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler);
  437. /// \brief Destructor
  438. ~MnistDataset() = default;
  439. /// \brief a base class override function to create the required runtime dataset op objects for this class
  440. /// \return The list of shared pointers to the newly created DatasetOps
  441. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  442. /// \brief Parameters validation
  443. /// \return bool true if all the params are valid
  444. bool ValidateParams() override;
  445. private:
  446. std::string dataset_dir_;
  447. std::shared_ptr<SamplerObj> sampler_;
  448. };
  449. /// \class TextFileDataset
  450. /// \brief A Dataset derived class to represent TextFile dataset
  451. class TextFileDataset : public Dataset {
  452. public:
  453. /// \brief Constructor
  454. TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
  455. int32_t shard_id);
  456. /// \brief Destructor
  457. ~TextFileDataset() = default;
  458. /// \brief a base class override function to create the required runtime dataset op objects for this class
  459. /// \return The list of shared pointers to the newly created DatasetOps
  460. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  461. /// \brief Parameters validation
  462. /// \return bool true if all the params are valid
  463. bool ValidateParams() override;
  464. private:
  465. std::vector<std::string> dataset_files_;
  466. int32_t num_samples_;
  467. int32_t num_shards_;
  468. int32_t shard_id_;
  469. ShuffleMode shuffle_;
  470. };
  471. class VOCDataset : public Dataset {
  472. public:
  473. /// \brief Constructor
  474. VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
  475. const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
  476. /// \brief Destructor
  477. ~VOCDataset() = default;
  478. /// \brief a base class override function to create the required runtime dataset op objects for this class
  479. /// \return shared pointer to the list of newly created DatasetOps
  480. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  481. /// \brief Parameters validation
  482. /// \return bool true if all the params are valid
  483. bool ValidateParams() override;
  484. private:
  485. const std::string kColumnImage = "image";
  486. const std::string kColumnTarget = "target";
  487. const std::string kColumnBbox = "bbox";
  488. const std::string kColumnLabel = "label";
  489. const std::string kColumnDifficult = "difficult";
  490. const std::string kColumnTruncate = "truncate";
  491. std::string dataset_dir_;
  492. std::string task_;
  493. std::string mode_;
  494. std::map<std::string, int32_t> class_index_;
  495. bool decode_;
  496. std::shared_ptr<SamplerObj> sampler_;
  497. };
  498. // DERIVED DATASET CLASSES FOR DATASET OPS
  499. // (In alphabetical order)
  500. class BatchDataset : public Dataset {
  501. public:
  502. /// \brief Constructor
  503. BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
  504. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);
  505. /// \brief Destructor
  506. ~BatchDataset() = default;
  507. /// \brief a base class override function to create the required runtime dataset op objects for this class
  508. /// \return The list of shared pointers to the newly created DatasetOps
  509. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  510. /// \brief Parameters validation
  511. /// \return bool true if all the params are valid
  512. bool ValidateParams() override;
  513. private:
  514. int32_t batch_size_;
  515. bool drop_remainder_;
  516. bool pad_;
  517. std::vector<std::string> cols_to_map_;
  518. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
  519. };
  520. class ConcatDataset : public Dataset {
  521. public:
  522. /// \brief Constructor
  523. explicit ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets);
  524. /// \brief Destructor
  525. ~ConcatDataset() = default;
  526. /// \brief a base class override function to create the required runtime dataset op objects for this class
  527. /// \return The list of shared pointers to the newly created DatasetOps
  528. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  529. /// \brief Parameters validation
  530. /// \return bool true if all the params are valid
  531. bool ValidateParams() override;
  532. private:
  533. std::vector<std::shared_ptr<Dataset>> datasets_;
  534. };
  535. class MapDataset : public Dataset {
  536. public:
  537. /// \brief Constructor
  538. MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns = {},
  539. std::vector<std::string> output_columns = {}, const std::vector<std::string> &columns = {});
  540. /// \brief Destructor
  541. ~MapDataset() = default;
  542. /// \brief a base class override function to create the required runtime dataset op objects for this class
  543. /// \return The list of shared pointers to the newly created DatasetOps
  544. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  545. /// \brief Parameters validation
  546. /// \return bool true if all the params are valid
  547. bool ValidateParams() override;
  548. private:
  549. std::vector<std::shared_ptr<TensorOperation>> operations_;
  550. std::vector<std::string> input_columns_;
  551. std::vector<std::string> output_columns_;
  552. std::vector<std::string> project_columns_;
  553. };
  554. class ProjectDataset : public Dataset {
  555. public:
  556. /// \brief Constructor
  557. explicit ProjectDataset(const std::vector<std::string> &columns);
  558. /// \brief Destructor
  559. ~ProjectDataset() = default;
  560. /// \brief a base class override function to create the required runtime dataset op objects for this class
  561. /// \return The list of shared pointers to the newly created DatasetOps
  562. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  563. /// \brief Parameters validation
  564. /// \return bool true if all the params are valid
  565. bool ValidateParams() override;
  566. private:
  567. std::vector<std::string> columns_;
  568. };
  569. class RenameDataset : public Dataset {
  570. public:
  571. /// \brief Constructor
  572. explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns);
  573. /// \brief Destructor
  574. ~RenameDataset() = default;
  575. /// \brief a base class override function to create the required runtime dataset op objects for this class
  576. /// \return The list of shared pointers to the newly created DatasetOps
  577. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  578. /// \brief Parameters validation
  579. /// \return bool true if all the params are valid
  580. bool ValidateParams() override;
  581. private:
  582. std::vector<std::string> input_columns_;
  583. std::vector<std::string> output_columns_;
  584. };
  585. class RepeatDataset : public Dataset {
  586. public:
  587. /// \brief Constructor
  588. explicit RepeatDataset(int32_t count);
  589. /// \brief Destructor
  590. ~RepeatDataset() = default;
  591. /// \brief a base class override function to create the required runtime dataset op objects for this class
  592. /// \return The list of shared pointers to the newly created DatasetOps
  593. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  594. /// \brief Parameters validation
  595. /// \return bool true if all the params are valid
  596. bool ValidateParams() override;
  597. private:
  598. int32_t repeat_count_;
  599. };
  600. class ShuffleDataset : public Dataset {
  601. public:
  602. ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch);
  603. ~ShuffleDataset() = default;
  604. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  605. bool ValidateParams() override;
  606. private:
  607. int32_t shuffle_size_;
  608. uint32_t shuffle_seed_;
  609. bool reset_every_epoch_;
  610. };
  611. class SkipDataset : public Dataset {
  612. public:
  613. /// \brief Constructor
  614. explicit SkipDataset(int32_t count);
  615. /// \brief Destructor
  616. ~SkipDataset() = default;
  617. /// \brief a base class override function to create the required runtime dataset op objects for this class
  618. /// \return The list of shared pointers to the newly created DatasetOps
  619. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  620. /// \brief Parameters validation
  621. /// \return bool true if all the params are valid
  622. bool ValidateParams() override;
  623. private:
  624. int32_t skip_count_;
  625. };
  626. class TakeDataset : public Dataset {
  627. public:
  628. /// \brief Constructor
  629. explicit TakeDataset(int32_t count);
  630. /// \brief Destructor
  631. ~TakeDataset() = default;
  632. /// \brief a base class override function to create the required runtime dataset op objects for this class
  633. /// \return shared pointer to the list of newly created DatasetOps
  634. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  635. /// \brief Parameters validation
  636. /// \return bool true if all the params are valid
  637. bool ValidateParams() override;
  638. private:
  639. int32_t take_count_;
  640. };
  641. class ZipDataset : public Dataset {
  642. public:
  643. /// \brief Constructor
  644. explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets);
  645. /// \brief Destructor
  646. ~ZipDataset() = default;
  647. /// \brief a base class override function to create the required runtime dataset op objects for this class
  648. /// \return The list of shared pointers to the newly created DatasetOps
  649. std::vector<std::shared_ptr<DatasetOp>> Build() override;
  650. /// \brief Parameters validation
  651. /// \return bool true if all the params are valid
  652. bool ValidateParams() override;
  653. private:
  654. std::vector<std::shared_ptr<Dataset>> datasets_;
  655. };
  656. } // namespace api
  657. } // namespace dataset
  658. } // namespace mindspore
  659. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_