You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.h 48 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  18. #include <unistd.h>
  19. #include <map>
  20. #include <memory>
  21. #include <set>
  22. #include <string>
  23. #include <unordered_set>
  24. #include <utility>
  25. #include <vector>
  26. #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
  27. #include "minddata/dataset/core/constants.h"
  28. #include "minddata/dataset/engine/consumers/tree_consumer.h"
  29. #include "minddata/dataset/engine/data_schema.h"
  30. #include "minddata/dataset/include/iterator.h"
  31. #include "minddata/dataset/include/samplers.h"
  32. #include "minddata/dataset/include/tensor.h"
  33. #include "minddata/dataset/include/type_id.h"
  34. #include "minddata/dataset/kernels/c_func_op.h"
  35. #include "minddata/dataset/kernels/tensor_op.h"
  36. #include "minddata/dataset/util/path.h"
  37. #ifndef ENABLE_ANDROID
  38. #include "minddata/dataset/text/vocab.h"
  39. #endif
  40. namespace mindspore {
  41. namespace dataset {
  42. // Forward declare
  43. class DatasetOp;
  44. class DataSchema;
  45. class Tensor;
  46. class TensorShape;
  47. class TreeAdapter;
  48. class TreeGetters;
  49. #ifndef ENABLE_ANDROID
  50. class Vocab;
  51. #endif
  52. namespace api {
  53. class Dataset;
  54. class Iterator;
  55. class TensorOperation;
  56. class SchemaObj;
  57. class SamplerObj;
  58. // Datasets classes (in alphabetical order)
  59. class AlbumNode;
  60. class CelebANode;
  61. class Cifar10Node;
  62. class Cifar100Node;
  63. class CLUENode;
  64. class CocoNode;
  65. class CSVNode;
  66. class CsvBase;
  67. class ImageFolderNode;
  68. class BatchNode;
  69. #ifndef ENABLE_ANDROID
  70. class ManifestNode;
  71. class MindDataNode;
  72. #endif
  73. class MnistNode;
  74. class RandomNode;
  75. class TextFileNode;
  76. #ifndef ENABLE_ANDROID
  77. class TFRecordNode;
  78. class VOCNode;
  79. #endif
  80. // Dataset Op classes (in alphabetical order)
  81. #ifndef ENABLE_ANDROID
  82. class BucketBatchByLengthNode;
  83. class BuildVocabNode;
  84. #endif
  85. class ConcatNode;
  86. class MapNode;
  87. class ProjectNode;
  88. class RenameNode;
  89. class RepeatNode;
  90. class ShuffleNode;
  91. class SkipNode;
  92. class TakeNode;
  93. class TransferNode;
  94. class ZipNode;
  95. #define RETURN_EMPTY_IF_ERROR(_s) \
  96. do { \
  97. Status __rc = (_s); \
  98. if (__rc.IsError()) { \
  99. MS_LOG(ERROR) << __rc; \
  100. return {}; \
  101. } \
  102. } while (false)
  103. Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
  104. int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op);
  105. // Helper function to validate dataset files parameter
  106. Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files);
  107. // Helper function to validate dataset num_shards and shard_id parameters
  108. Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id);
  109. // Helper function to validate dataset sampler parameter
  110. Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler);
  111. Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
  112. const std::unordered_set<std::string> &valid_strings);
  113. // Helper function to validate dataset input/output column parameterCD -
  114. Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
  115. const std::vector<std::string> &columns);
  116. // Helper function to validate dataset directory parameter
  117. Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir);
  118. /// \brief Function to create a SchemaObj
  119. /// \param[in] schema_file Path of schema file
  120. /// \return Shared pointer to the current schema
  121. std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "");
  122. /// \brief Function to create an AlbumNode
  123. /// \notes The generated dataset is specified through setting a schema
  124. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  125. /// \param[in] data_schema Path to dataset schema file
  126. /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
  127. /// (default = {})
  128. /// \param[in] decode the option to decode the images in dataset (default = false)
  129. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  130. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  131. /// \return Shared pointer to the current Dataset
  132. std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::string &data_schema,
  133. const std::vector<std::string> &column_names = {}, bool decode = false,
  134. const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
  135. /// \brief Function to create a CelebANode
  136. /// \notes The generated dataset has two columns ['image', 'attr'].
  137. /// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
  138. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  139. /// \param[in] usage One of "all", "train", "valid" or "test" (default = "all").
  140. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  141. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  142. /// \param[in] decode Decode the images after reading (default=false).
  143. /// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
  144. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  145. /// The cache feature is under development and is not recommended.
  146. /// \return Shared pointer to the current Dataset
  147. std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage = "all",
  148. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false,
  149. const std::set<std::string> &extensions = {},
  150. const std::shared_ptr<DatasetCache> &cache = nullptr);
  151. /// \brief Function to create a Cifar10 Dataset
  152. /// \notes The generated dataset has two columns ["image", "label"]
  153. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  154. /// \param[in] usage of CIFAR10, can be "train", "test" or "all" (default = "all").
  155. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  156. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  157. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  158. /// The cache feature is under development and is not recommended.
  159. /// \return Shared pointer to the current Dataset
  160. std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage = "all",
  161. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  162. const std::shared_ptr<DatasetCache> &cache = nullptr);
  163. /// \brief Function to create a Cifar100 Dataset
  164. /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"]
  165. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  166. /// \param[in] usage of CIFAR100, can be "train", "test" or "all" (default = "all").
  167. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  168. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  169. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  170. /// The cache feature is under development and is not recommended.
  171. /// \return Shared pointer to the current Dataset
  172. std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage = "all",
  173. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  174. const std::shared_ptr<DatasetCache> &cache = nullptr);
  175. /// \brief Function to create a CLUENode
  176. /// \notes The generated dataset has a variable number of columns depending on the task and usage
  177. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  178. /// will be sorted in a lexicographical order.
  179. /// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC").
  180. /// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
  181. /// \param[in] num_samples The number of samples to be included in the dataset.
  182. /// (Default = 0 means all samples.)
  183. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
  184. /// Can be any of:
  185. /// ShuffleMode::kFalse - No shuffling is performed.
  186. /// ShuffleMode::kFiles - Shuffle files only.
  187. /// ShuffleMode::kGlobal - Shuffle both the files and samples.
  188. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  189. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  190. /// specified only when num_shards is also specified. (Default = 0)
  191. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  192. /// The cache feature is under development and is not recommended.
  193. /// \return Shared pointer to the current CLUENode
  194. std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
  195. const std::string &usage = "train", int64_t num_samples = 0,
  196. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0,
  197. const std::shared_ptr<DatasetCache> &cache = nullptr);
  198. /// \brief Function to create a CocoNode
  199. /// \notes The generated dataset has multi-columns :
  200. /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  201. /// ['iscrowd', dtype=uint32]].
  202. /// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]].
  203. /// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
  204. /// ['num_keypoints', dtype=uint32]].
  205. /// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  206. /// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]].
  207. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  208. /// \param[in] annotation_file Path to the annotation json
  209. /// \param[in] task Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint'
  210. /// \param[in] decode Decode the images after reading
  211. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  212. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  213. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  214. /// The cache feature is under development and is not recommended.
  215. /// \return Shared pointer to the current Dataset
  216. std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file,
  217. const std::string &task = "Detection", const bool &decode = false,
  218. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  219. const std::shared_ptr<DatasetCache> &cache = nullptr);
  220. /// \brief Function to create a CSVNode
  221. /// \notes The generated dataset has a variable number of columns
  222. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  223. /// will be sorted in a lexicographical order.
  224. /// \param[in] field_delim A char that indicates the delimiter to separate fields (default=',').
  225. /// \param[in] column_defaults List of default values for the CSV field (default={}). Each item in the list is
  226. /// either a valid type (float, int, or string). If this is not provided, treats all columns as string type.
  227. /// \param[in] column_names List of column names of the dataset (default={}). If this is not provided, infers the
  228. /// column_names from the first row of CSV file.
  229. /// \param[in] num_samples The number of samples to be included in the dataset.
  230. /// (Default = 0 means all samples.)
  231. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
  232. /// Can be any of:
  233. /// ShuffleMode::kFalse - No shuffling is performed.
  234. /// ShuffleMode::kFiles - Shuffle files only.
  235. /// ShuffleMode::kGlobal - Shuffle both the files and samples.
  236. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  237. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  238. /// specified only when num_shards is also specified. (Default = 0)
  239. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  240. /// The cache feature is under development and is not recommended.
  241. /// \return Shared pointer to the current Dataset
  242. std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',',
  243. const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {},
  244. const std::vector<std::string> &column_names = {}, int64_t num_samples = 0,
  245. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0,
  246. const std::shared_ptr<DatasetCache> &cache = nullptr);
  247. /// \brief Function to create an ImageFolderNode
  248. /// \notes A source dataset that reads images from a tree of directories
  249. /// All images within one folder have the same label
  250. /// The generated dataset has two columns ["image", "label"]
  251. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  252. /// \param[in] decode A flag to decode in ImageFolder
  253. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  254. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  255. /// \param[in] extensions File extensions to be read
  256. /// \param[in] class_indexing a class name to label map
  257. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  258. /// The cache feature is under development and is not recommended.
  259. /// \return Shared pointer to the current ImageFolderNode
  260. std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode = false,
  261. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  262. const std::set<std::string> &extensions = {},
  263. const std::map<std::string, int32_t> &class_indexing = {},
  264. const std::shared_ptr<DatasetCache> &cache = nullptr);
  265. #ifndef ENABLE_ANDROID
  266. /// \brief Function to create a ManifestNode
  267. /// \notes The generated dataset has two columns ["image", "label"]
  268. /// \param[in] dataset_file The dataset file to be read
  269. /// \param[in] usage Need "train", "eval" or "inference" data (default="train")
  270. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  271. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  272. /// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
  273. /// names will be sorted alphabetically and each class will be given a unique index starting from 0).
  274. /// \param[in] decode Decode the images after reading (default=false).
  275. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  276. /// The cache feature is under development and is not recommended.
  277. /// \return Shared pointer to the current ManifestNode
  278. std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage = "train",
  279. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  280. const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
  281. const std::shared_ptr<DatasetCache> &cache = nullptr);
  282. #endif
  283. #ifndef ENABLE_ANDROID
  284. /// \brief Function to create a MindDataNode
  285. /// \param[in] dataset_file File name of one component of a mindrecord source. Other files with identical source
  286. /// in the same path will be found and loaded automatically.
  287. /// \param[in] columns_list List of columns to be read (default={})
  288. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  289. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()),
  290. /// supported sampler list: SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
  291. /// \param[in] padded_sample Samples will be appended to dataset, where keys are the same as column_list.
  292. /// \param[in] num_padded Number of padding samples. Dataset size plus num_padded should be divisible by num_shards.
  293. /// \return Shared pointer to the current MindDataNode
  294. std::shared_ptr<MindDataNode> MindData(const std::string &dataset_file,
  295. const std::vector<std::string> &columns_list = {},
  296. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  297. nlohmann::json padded_sample = nullptr, int64_t num_padded = 0);
  298. /// \brief Function to create a MindDataNode
  299. /// \param[in] dataset_files List of dataset files to be read directly.
  300. /// \param[in] columns_list List of columns to be read (default={})
  301. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  302. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()),
  303. /// supported sampler list: SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
  304. /// \param[in] padded_sample Samples will be appended to dataset, where keys are the same as column_list.
  305. /// \param[in] num_padded Number of padding samples. Dataset size plus num_padded should be divisible by num_shards.
  306. /// \return Shared pointer to the current MindDataNode
  307. std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_files,
  308. const std::vector<std::string> &columns_list = {},
  309. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  310. nlohmann::json padded_sample = nullptr, int64_t num_padded = 0);
  311. #endif
  312. /// \brief Function to create a MnistNode
  313. /// \notes The generated dataset has two columns ["image", "label"]
  314. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  315. /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all").
  316. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  317. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  318. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  319. /// The cache feature is under development and is not recommended.
  320. /// \return Shared pointer to the current MnistNode
  321. std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage = "all",
  322. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  323. const std::shared_ptr<DatasetCache> &cache = nullptr);
  324. /// \brief Function to create a ConcatNode
  325. /// \notes Reload "+" operator to concat two datasets
  326. /// \param[in] datasets1 Shared pointer to the first dataset to be concatenated
  327. /// \param[in] datasets2 Shared pointer to the second dataset to be concatenated
  328. /// \return Shared pointer to the current ConcatNode
  329. std::shared_ptr<ConcatNode> operator+(const std::shared_ptr<Dataset> &datasets1,
  330. const std::shared_ptr<Dataset> &datasets2);
  331. /// \brief Function to create a RandomNode
  332. /// \param[in] total_rows Number of rows for the dataset to generate (default=0, number of rows is random)
  333. /// \param[in] schema SchemaObj to set column type, data type and data shape
  334. /// \param[in] columns_list List of columns to be read (default={}, read all columns)
  335. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  336. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  337. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  338. /// The cache feature is under development and is not recommended.
  339. /// \return Shared pointer to the current Dataset
  340. template <typename T = std::shared_ptr<SchemaObj>>
  341. std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
  342. const std::vector<std::string> &columns_list = {},
  343. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  344. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  345. if (total_rows < 0) {
  346. MS_LOG(ERROR) << "RandomNode: total_rows must be greater than or equal 0, now get " << total_rows;
  347. return nullptr;
  348. }
  349. if (sampler == nullptr) {
  350. MS_LOG(ERROR) << "RandomNode: Sampler is not constructed correctly, sampler: nullptr";
  351. return nullptr;
  352. }
  353. if (!columns_list.empty()) {
  354. for (uint32_t i = 0; i < columns_list.size(); ++i) {
  355. if (columns_list[i].empty()) {
  356. MS_LOG(ERROR) << "RandomNode:columns_list"
  357. << "[" << i << "] should not be empty";
  358. return nullptr;
  359. }
  360. }
  361. std::set<std::string> columns_set(columns_list.begin(), columns_list.end());
  362. if (columns_set.size() != columns_list.size()) {
  363. MS_LOG(ERROR) << "RandomNode:columns_list: Every column name should not be same with others";
  364. return nullptr;
  365. }
  366. }
  367. std::shared_ptr<RandomNode> ds;
  368. if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
  369. std::shared_ptr<SchemaObj> schema_obj = schema;
  370. ds = std::make_shared<RandomNode>(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler),
  371. cache);
  372. } else {
  373. ds =
  374. std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache);
  375. }
  376. return ds;
  377. }
  378. /// \brief Function to create a TextFileNode
  379. /// \notes The generated dataset has one column ['text']
  380. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  381. /// will be sorted in a lexicographical order.
  382. /// \param[in] num_samples The number of samples to be included in the dataset.
  383. /// (Default = 0 means all samples.)
  384. /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
  385. /// Can be any of:
  386. /// ShuffleMode.kFalse - No shuffling is performed.
  387. /// ShuffleMode.kFiles - Shuffle files only.
  388. /// ShuffleMode.kGlobal - Shuffle both the files and samples.
  389. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  390. /// \param[in] shard_id The shard ID within num_shards. This argument should be
  391. /// specified only when num_shards is also specified. (Default = 0)
  392. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  393. /// The cache feature is under development and is not recommended.
  394. /// \return Shared pointer to the current TextFileNode
  395. std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0,
  396. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
  397. int32_t shard_id = 0, const std::shared_ptr<DatasetCache> &cache = nullptr);
  398. #ifndef ENABLE_ANDROID
  399. /// \brief Function to create a TFRecordNode
  400. /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
  401. /// will be sorted in a lexicographical order.
  402. /// \param[in] schema SchemaObj or string to schema path. (Default = nullptr, which means that the
  403. /// meta data from the TFData file is considered the schema.)
  404. /// \param[in] columns_list List of columns to be read. (Default = {}, read all columns)
  405. /// \param[in] num_samples The number of samples to be included in the dataset.
  406. /// (Default = 0 means all samples.)
  407. /// If num_samples is 0 and numRows(parsed from schema) does not exist, read the full dataset;
  408. /// If num_samples is 0 and numRows(parsed from schema) is greater than 0, read numRows rows;
  409. /// If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
  410. /// \param[in] shuffle The mode for shuffling data every epoch. (Default = ShuffleMode::kGlobal)
  411. /// Can be any of:
  412. /// ShuffleMode::kFalse - No shuffling is performed.
  413. /// ShuffleMode::kFiles - Shuffle files only.
  414. /// ShuffleMode::kGlobal - Shuffle both the files and samples.
  415. /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
  416. /// \param[in] shard_id The shard ID within num_shards. This argument should be specified only
  417. /// when num_shards is also specified. (Default = 0)
  418. /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of
  419. /// each shard may be not equal)
  420. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  421. /// The cache feature is under development and is not recommended.
  422. /// \return Shared pointer to the current TFRecordNode
  423. template <typename T = std::shared_ptr<SchemaObj>>
  424. std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr,
  425. const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0,
  426. ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
  427. int32_t shard_id = 0, bool shard_equal_rows = false,
  428. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  429. if (dataset_files.empty()) {
  430. MS_LOG(ERROR) << "TFRecordNode: dataset_files is not specified.";
  431. return nullptr;
  432. }
  433. for (auto f : dataset_files) {
  434. Path dataset_file(f);
  435. if (!dataset_file.Exists()) {
  436. MS_LOG(ERROR) << "TFRecordNode: dataset file: [" << f << "] is invalid or does not exist.";
  437. return nullptr;
  438. }
  439. }
  440. if (num_samples < 0) {
  441. MS_LOG(ERROR) << "TFRecordNode: Invalid number of samples: " << num_samples;
  442. return nullptr;
  443. }
  444. if (num_shards <= 0) {
  445. MS_LOG(ERROR) << "TFRecordNode: Invalid num_shards: " << num_shards;
  446. return nullptr;
  447. }
  448. if (shard_id < 0 || shard_id >= num_shards) {
  449. MS_LOG(ERROR) << "TFRecordNode: Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
  450. return nullptr;
  451. }
  452. if (cache == nullptr && !shard_equal_rows && dataset_files.size() < num_shards) {
  453. // This check only makes sense in a non-cache path. We should make sure there is at least one file per
  454. // shard in file-based sharding
  455. MS_LOG(ERROR) << "TFRecordNode: Invalid number of dataset files, should at least be " << std::to_string(num_shards);
  456. return nullptr;
  457. }
  458. std::shared_ptr<TFRecordNode> ds = nullptr;
  459. if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
  460. std::shared_ptr<SchemaObj> schema_obj = schema;
  461. ds = std::make_shared<TFRecordNode>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
  462. shard_id, shard_equal_rows, cache);
  463. } else {
  464. std::string schema_path = schema;
  465. if (!schema_path.empty()) {
  466. Path schema_file(schema_path);
  467. if (!schema_file.Exists()) {
  468. MS_LOG(ERROR) << "TFRecordNode: schema path [" << schema_path << "] is invalid or does not exist.";
  469. return nullptr;
  470. }
  471. }
  472. ds = std::make_shared<TFRecordNode>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards,
  473. shard_id, shard_equal_rows, cache);
  474. }
  475. return ds;
  476. }
  477. /// \brief Function to create a VOCNode
  478. /// \notes The generated dataset has multi-columns :
  479. /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
  480. /// ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
  481. /// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
  482. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  483. /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
  484. /// \param[in] usage The type of data list text file to be read (default = "train").
  485. /// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task
  486. /// \param[in] decode Decode the images after reading
  487. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  488. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  489. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  490. /// The cache feature is under development and is not recommended.
  491. /// \return Shared pointer to the current Dataset
  492. std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
  493. const std::string &usage = "train",
  494. const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
  495. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  496. const std::shared_ptr<DatasetCache> &cache = nullptr);
  497. /// \brief Function the create a cache to be attached to a dataset
  498. /// \param id A user assigned session id for the current pipeline
  499. /// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited
  500. /// \param spill Spill to disk if out of memory
  501. /// \param hostname optional host name
  502. /// \param port optional port
  503. /// \param num_connections optional number of connections
  504. /// \param prefetch_sz optional prefetch size
  505. /// \return Shared pointer to DatasetCache. If error, nullptr is returned.
  506. std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill,
  507. std::optional<std::string> hostname, std::optional<int32_t> port,
  508. std::optional<int32_t> num_connections,
  509. std::optional<int32_t> prefetch_sz);
  510. #endif
  511. /// \brief Function to create a ZipNode
  512. /// \notes Applies zip to the dataset
  513. /// \param[in] datasets List of shared pointers to the datasets that we want to zip
  514. /// \return Shared pointer to the current Dataset
  515. std::shared_ptr<ZipNode> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
  516. /// \class Dataset datasets.h
  517. /// \brief A base class to represent a dataset in the data pipeline.
  518. class Dataset : public std::enable_shared_from_this<Dataset> {
  519. public:
  520. // need friend class so they can access the children_ field
  521. friend class Iterator;
  522. friend class TransferNode;
  523. friend class mindspore::dataset::TreeAdapter;
  524. /// \brief Constructor
  525. Dataset();
  526. /// \brief Constructor that initializes the cache
  527. /// \param dataset_cache DatasetCache
  528. explicit Dataset(const std::shared_ptr<DatasetCache> &dataset_cache);
  529. /// \brief Destructor
  530. ~Dataset() = default;
  531. /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object
  532. /// \return The list of shared pointers to the newly created DatasetOps
  533. virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
  534. /// \brief Pure virtual function for derived class to implement parameters validation
  535. /// \return Status Status::OK() if all the parameters are valid
  536. virtual Status ValidateParams() = 0;
  537. /// \brief Pure virtual function for derived class to get the shard id of specific node
  538. /// \return Status Status::OK() if get shard id successfully
  539. virtual Status GetShardId(int32_t *shard_id) {
  540. return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
  541. }
  542. /// \brief Gets the dataset size
  543. /// \return int64_t
  544. int64_t GetDatasetSize();
  545. /// \brief Gets the output type
  546. /// \return vector of DataType
  547. std::vector<DataType> GetOutputTypes();
  548. /// \brief Gets the output shape
  549. /// \return vector of TensorShapes
  550. std::vector<TensorShape> GetOutputShapes();
  551. /// \brief Gets the batch size
  552. /// \return int64_t
  553. int64_t GetBatchSize();
  554. /// \brief Gets the the repeat count
  555. /// \return int64_t
  556. int64_t GetRepeatCount();
  557. /// \brief Setter function for runtime number of workers
  558. /// \param[in] num_workers The number of threads in this operator
  559. /// \return Shared pointer to the original object
  560. std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers) {
  561. #if !defined(_WIN32) && !defined(_WIN64)
  562. #ifndef ENABLE_ANDROID
  563. int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF);
  564. if (cpu_count < 0 || cpu_count > INT32_MAX) {
  565. MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count;
  566. return nullptr;
  567. }
  568. if (num_workers < 1 || num_workers > cpu_count) {
  569. MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count;
  570. return nullptr;
  571. }
  572. #endif
  573. #endif
  574. num_workers_ = num_workers;
  575. return shared_from_this();
  576. }
  577. /// \brief Function to create an Iterator over the Dataset pipeline
  578. /// \param[in] columns List of columns to be used to specify the order of columns
  579. /// \return Shared pointer to the Iterator
  580. std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
  581. /// \brief Function to transfer data through a device.
  582. /// \notes If device is Ascend, features of data will be transferred one by one. The limitation
  583. /// of data transmission per time is 256M.
  584. /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True).
  585. /// \return Returns true if no error encountered else false.
  586. bool DeviceQueue(bool send_epoch_end = true);
  587. #ifndef ENABLE_ANDROID
  588. /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
  589. /// \note Usage restrictions:
  590. /// 1. Supported dataset formats: 'mindrecord' only
  591. /// 2. To save the samples in order, set dataset's shuffle to false and num_files to 1.
  592. /// 3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
  593. /// with random attribute in map operator.
  594. /// 4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor
  595. /// multi-dimensional string.
  596. /// \param[in] file_name Path to dataset file
  597. /// \param[in] num_files Number of dataset files (default=1)
  598. /// \param[in] file_type Dataset format (default="mindrecord")
  599. /// \return Returns true if no error encountered else false
  600. bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord");
  601. #endif
  602. /// \brief Function to create a BatchNode
  603. /// \notes Combines batch_size number of consecutive rows into batches
  604. /// \param[in] batch_size Path to the root directory that contains the dataset
  605. /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
  606. /// batch. If true, and if there are less than batch_size rows
  607. /// available to make the last batch, then those rows will
  608. /// be dropped and not propagated to the next node
  609. /// \return Shared pointer to the current BatchNode
  610. std::shared_ptr<BatchNode> Batch(int32_t batch_size, bool drop_remainder = false);
  611. #ifndef ENABLE_ANDROID
  612. /// \brief Function to create a BucketBatchByLengthNode
  613. /// \notes Combines batch_size number of consecutive rows into batches
  614. /// \param[in] column_names Columns passed to element_length_function
  615. /// \param[in] bucket_boundaries A list consisting of the upper boundaries of the buckets.
  616. /// Must be strictly increasing. If there are n boundaries, n+1 buckets are created: One bucket for
  617. /// [0, bucket_boundaries[0]), one bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
  618. /// 0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
  619. /// \param[in] bucket_batch_sizes A list consisting of the batch sizes for each bucket.
  620. /// Must contain elements equal to the size of bucket_boundaries + 1.
  621. /// \param[in] element_length_function A function pointer that takes in TensorRow and outputs a TensorRow. The
  622. /// output
  623. /// must contain a single tensor containing a single int32_t. If no value is provided, then size of column_names
  624. /// must be 1, and the size of the first dimension of that column will be taken as the length (default=nullptr)
  625. /// \param[in] pad_info Represents how to batch each column. The key corresponds to the column name, the value must
  626. /// be a tuple of 2 elements. The first element corresponds to the shape to pad to, and the second element
  627. /// corresponds to the value to pad with. If a column is not specified, then that column will be padded to the
  628. /// longest in the current batch, and 0 will be used as the padding value. Any unspecified dimensions will be
  629. /// padded to the longest in the current batch, unless if pad_to_bucket_boundary is true. If no padding is
  630. /// wanted, set pad_info to None (default=empty dictionary).
  631. /// \param[in] pad_to_bucket_boundary If true, will pad each unspecified dimension in pad_info to the
  632. /// bucket_boundary
  633. /// minus 1. If there are any elements that fall into the last bucket, an error will occur (default=false).
  634. /// \param[in] drop_remainder If true, will drop the last batch for each bucket if it is not a full batch
  635. /// (default=false).
  636. /// \return Shared pointer to the current BucketBatchByLengthNode
  637. std::shared_ptr<BucketBatchByLengthNode> BucketBatchByLength(
  638. const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries,
  639. const std::vector<int32_t> &bucket_batch_sizes,
  640. std::function<TensorRow(TensorRow)> element_length_function = nullptr,
  641. const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
  642. bool pad_to_bucket_boundary = false, bool drop_remainder = false);
  643. /// \brief Function to create a Vocab from source dataset
  644. /// \notes Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
  645. /// which contains top_k most frequent words (if top_k is specified)
  646. /// \param[in] columns Column names to get words from. It can be a vector of column names
  647. /// \param[in] freq_range A tuple of integers (min_frequency, max_frequency). Words within the frequency
  648. /// range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
  649. /// can be set to default, which corresponds to 0/total_words separately
  650. /// \param[in] top_k Number of words to be built into vocab. top_k most frequent words are
  651. /// taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken
  652. /// \param[in] special_tokens A list of strings, each one is a special token
  653. /// \param[in] special_first Whether special_tokens will be prepended/appended to vocab, If special_tokens
  654. /// is specified and special_first is set to default, special_tokens will be prepended
  655. /// \return Shared pointer to the current Vocab
  656. std::shared_ptr<Vocab> BuildVocab(const std::vector<std::string> &columns = {},
  657. const std::pair<int64_t, int64_t> &freq_range = {0, kDeMaxFreq},
  658. int64_t top_k = kDeMaxTopk, const std::vector<std::string> &special_tokens = {},
  659. bool special_first = true);
  660. #endif
  661. /// \brief Function to create a ConcatNode
  662. /// \notes Concat the datasets in the input
  663. /// \param[in] datasets List of shared pointers to the dataset that should be concatenated together
  664. /// \return Shared pointer to the current ConcatNode
  665. std::shared_ptr<ConcatNode> Concat(const std::vector<std::shared_ptr<Dataset>> &datasets);
  666. /// \brief Function to create a MapNode
  667. /// \notes Applies each operation in operations to this dataset
  668. /// \param[in] operations Vector of operations to be applied on the dataset. Operations are
  669. /// applied in the order they appear in this list
  670. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  671. /// operation as input. The size of this list must match the number of
  672. /// input columns expected by the first operator. The default input_columns
  673. /// is the first column
  674. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  675. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  676. /// The size of this list must match the number of output columns of the
  677. /// last operation. The default output_columns will have the same
  678. /// name as the input columns, i.e., the columns will be replaced
  679. /// \param[in] project_columns A list of column names to project
  680. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  681. /// The cache feature is under development and is not recommended.
  682. /// \return Shared pointer to the current MapNode
  683. std::shared_ptr<MapNode> Map(std::vector<std::shared_ptr<TensorOperation>> operations,
  684. std::vector<std::string> input_columns = {},
  685. std::vector<std::string> output_columns = {},
  686. const std::vector<std::string> &project_columns = {},
  687. const std::shared_ptr<DatasetCache> &cache = nullptr);
  688. /// \brief Function to create a Project Dataset
  689. /// \notes Applies project to the dataset
  690. /// \param[in] columns The name of columns to project
  691. /// \return Shared pointer to the current Dataset
  692. std::shared_ptr<ProjectNode> Project(const std::vector<std::string> &columns);
  693. /// \brief Function to create a Rename Dataset
  694. /// \notes Renames the columns in the input dataset
  695. /// \param[in] input_columns List of the input columns to rename
  696. /// \param[in] output_columns List of the output columns
  697. /// \return Shared pointer to the current Dataset
  698. std::shared_ptr<RenameNode> Rename(const std::vector<std::string> &input_columns,
  699. const std::vector<std::string> &output_columns);
  700. /// \brief Function to create a RepeatNode
  701. /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
  702. /// \param[in] count Number of times the dataset should be repeated
  703. /// \return Shared pointer to the current Dataset
  704. /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatNode`
  705. /// due to a limitation in the current implementation
  706. std::shared_ptr<Dataset> Repeat(int32_t count = -1);
  707. /// \brief Function to create a Shuffle Dataset
  708. /// \notes Randomly shuffles the rows of this dataset
  709. /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
  710. /// \return Shared pointer to the current ShuffleNode
  711. std::shared_ptr<ShuffleNode> Shuffle(int32_t buffer_size);
  712. /// \brief Function to create a SkipNode
  713. /// \notes Skips count elements in this dataset.
  714. /// \param[in] count Number of elements the dataset to be skipped.
  715. /// \return Shared pointer to the current SkipNode
  716. std::shared_ptr<SkipNode> Skip(int32_t count);
  717. /// \brief Function to create a TakeNode
  718. /// \notes Takes count elements in this dataset.
  719. /// \param[in] count Number of elements the dataset to be taken.
  720. /// \return Shared pointer to the current Dataset
  721. std::shared_ptr<Dataset> Take(int32_t count = -1);
  722. /// \brief Function to create a Zip Dataset
  723. /// \notes Applies zip to the dataset
  724. /// \param[in] datasets A list of shared pointers to the datasets that we want to zip
  725. /// \return Shared pointer to the current Dataset
  726. std::shared_ptr<ZipNode> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
  727. protected:
  728. std::vector<std::shared_ptr<Dataset>> children;
  729. std::shared_ptr<Dataset> parent;
  730. std::shared_ptr<TreeGetters> tree_getters_;
  731. int32_t num_workers_;
  732. int32_t rows_per_buffer_;
  733. int32_t connector_que_size_;
  734. int32_t worker_connector_size_;
  735. std::shared_ptr<DatasetCache> cache_;
  736. Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
  737. };
  738. class SchemaObj {
  739. public:
  740. /// \brief Constructor
  741. explicit SchemaObj(const std::string &schema_file = "");
  742. /// \brief Destructor
  743. ~SchemaObj() = default;
  744. /// \brief SchemaObj init function
  745. /// \return bool true if schema init success
  746. bool init();
  747. /// \brief Add new column to the schema
  748. /// \param[in] name name of the column.
  749. /// \param[in] de_type data type of the column(TypeId).
  750. /// \param[in] shape shape of the column.
  751. /// \return bool true if schema init success
  752. bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
  753. /// \brief Add new column to the schema
  754. /// \param[in] name name of the column.
  755. /// \param[in] de_type data type of the column(std::string).
  756. /// \param[in] shape shape of the column.
  757. /// \return bool true if schema init success
  758. bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
  759. /// \brief Get a JSON string of the schema
  760. /// \return JSON string of the schema
  761. std::string to_json();
  762. /// \brief Get a JSON string of the schema
  763. std::string to_string() { return to_json(); }
  764. /// \brief set a new value to dataset_type
  765. inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; }
  766. /// \brief set a new value to num_rows
  767. inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
  768. /// \brief get the current num_rows
  769. inline int32_t get_num_rows() { return num_rows_; }
  770. private:
  771. /// \brief Parse the columns and add it to columns
  772. /// \param[in] columns dataset attribution information, decoded from schema file.
  773. /// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
  774. /// \return JSON string of the schema
  775. bool parse_column(nlohmann::json columns);
  776. /// \brief Get schema file from json file
  777. /// \param[in] json_obj object of json parsed.
  778. /// \return bool true if json dump success
  779. bool from_json(nlohmann::json json_obj);
  780. int32_t num_rows_;
  781. std::string dataset_type_;
  782. std::string schema_file_;
  783. nlohmann::json columns_;
  784. };
  785. /* ####################################### Derived Dataset classes ################################# */
  786. } // namespace api
  787. } // namespace dataset
  788. } // namespace mindspore
  789. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_