You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.h 39 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. /**
  2. * Copyright 2020-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_LITEAPI_INCLUDE_DATASETS_H_
  18. #include <sys/stat.h>
  19. #include <unistd.h>
  20. #include <algorithm>
  21. #include <map>
  22. #include <memory>
  23. #include <set>
  24. #include <string>
  25. #include <unordered_map>
  26. #include <unordered_set>
  27. #include <utility>
  28. #include <vector>
  29. #include "include/api/dual_abi_helper.h"
  30. #include "include/api/types.h"
  31. #include "include/dataset/iterator.h"
  32. #include "include/dataset/samplers.h"
  33. #include "include/dataset/transforms.h"
  34. namespace mindspore {
  35. namespace dataset {
  36. class Tensor;
  37. class TensorShape;
  38. class TreeAdapter;
  39. class TreeAdapterLite;
  40. class TreeGetters;
  41. class DatasetCache;
  42. class DatasetNode;
  43. class Iterator;
  44. class TensorOperation;
  45. class SchemaObj;
  46. class SamplerObj;
  47. // Dataset classes (in alphabetical order)
  48. class BatchDataset;
  49. class MapDataset;
  50. class ProjectDataset;
  51. class ShuffleDataset;
  52. class DSCallback;
  53. /// \class Dataset datasets.h
  54. /// \brief A base class to represent a dataset in the data pipeline.
  55. class MS_API Dataset : public std::enable_shared_from_this<Dataset> {
  56. public:
  57. // need friend class so they can access the children_ field
  58. friend class Iterator;
  59. friend class TransferNode;
  60. /// \brief Constructor
  61. Dataset();
  62. /// \brief Destructor
  63. virtual ~Dataset() = default;
  64. /// \brief Gets the dataset size
  65. /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
  66. /// dataset size at the expense of accuracy.
  67. /// \return dataset size. If failed, return -1
  68. int64_t GetDatasetSize(bool estimate = false);
  69. /// \brief Gets the output type
  70. /// \return a vector of DataType. If failed, return an empty vector
  71. std::vector<mindspore::DataType> GetOutputTypes();
  72. /// \brief Gets the output shape
  73. /// \return a vector of TensorShape. If failed, return an empty vector
  74. std::vector<std::vector<int64_t>> GetOutputShapes();
  75. /// \brief Gets the batch size
  76. /// \return int64_t
  77. int64_t GetBatchSize();
  78. /// \brief Gets the repeat count
  79. /// \return int64_t
  80. int64_t GetRepeatCount();
  81. /// \brief Gets the number of classes
  82. /// \return number of classes. If failed, return -1
  83. int64_t GetNumClasses();
  84. /// \brief Gets the column names
  85. /// \return Names of the columns. If failed, return an empty vector
  86. std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); }
  87. /// \brief Gets the class indexing
  88. /// \return a map of ClassIndexing. If failed, return an empty map
  89. std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() {
  90. return ClassIndexCharToString(GetClassIndexingCharIF());
  91. }
  92. /// \brief Setter function for runtime number of workers
  93. /// \param[in] num_workers The number of threads in this operator
  94. /// \return Shared pointer to the original object
  95. /// \par Example
  96. /// \code
  97. /// /* Set number of workers(threads) to process the dataset in parallel */
  98. /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true);
  99. /// ds = ds->SetNumWorkers(16);
  100. /// \endcode
  101. std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
  102. /// \brief Function to create an PullBasedIterator over the Dataset
  103. /// \param[in] columns List of columns to be used to specify the order of columns
  104. /// \return Shared pointer to the Iterator
  105. /// \par Example
  106. /// \code
  107. /// /* dataset is an instance of Dataset object */
  108. /// std::shared_ptr<Iterator> = dataset->CreatePullBasedIterator();
  109. /// std::unordered_map<std::string, mindspore::MSTensor> row;
  110. /// iter->GetNextRow(&row);
  111. /// \endcode
  112. std::shared_ptr<PullIterator> CreatePullBasedIterator(const std::vector<std::vector<char>> &columns = {});
  113. /// \brief Function to create an Iterator over the Dataset pipeline
  114. /// \param[in] columns List of columns to be used to specify the order of columns
  115. /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
  116. /// An empty row is returned at the end of each epoch
  117. /// \return Shared pointer to the Iterator
  118. /// \par Example
  119. /// \code
  120. /// /* dataset is an instance of Dataset object */
  121. /// std::shared_ptr<Iterator> = dataset->CreateIterator();
  122. /// std::unordered_map<std::string, mindspore::MSTensor> row;
  123. /// iter->GetNextRow(&row);
  124. /// \endcode
  125. std::shared_ptr<Iterator> CreateIterator(const std::vector<std::string> &columns = {}, int32_t num_epochs = -1) {
  126. return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
  127. }
  128. /// \brief Function to transfer data through a device.
  129. /// \notes If device is Ascend, features of data will be transferred one by one. The limitation
  130. /// of data transmission per time is 256M.
  131. /// \param[in] queue_name Channel name (default="", create new unique name).
  132. /// \param[in] device_type Type of device (default="", get from MSContext).
  133. /// \param[in] device_id id of device (default=1, get from MSContext).
  134. /// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
  135. /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
  136. /// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
  137. /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
  138. /// of data or not(default=false).
  139. /// \return Returns true if no error encountered else false.
  140. bool DeviceQueue(const std::string &queue_name = "", const std::string &device_type = "", int32_t device_id = 0,
  141. int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
  142. bool create_data_info_queue = false) {
  143. return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
  144. total_batches, create_data_info_queue);
  145. }
  146. /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
  147. /// \note Usage restrictions:
  148. /// 1. Supported dataset formats: 'mindrecord' only
  149. /// 2. To save the samples in order, set dataset's shuffle to false and num_files to 1.
  150. /// 3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
  151. /// with random attribute in map operator.
  152. /// 4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor
  153. /// multi-dimensional string.
  154. /// \param[in] file_name Path to dataset file
  155. /// \param[in] num_files Number of dataset files (default=1)
  156. /// \param[in] file_type Dataset format (default="mindrecord")
  157. /// \return Returns true if no error encountered else false
  158. /// \par Example
  159. /// \code
  160. /// /* Create a dataset and save its data into MindRecord */
  161. /// std::string folder_path = "/path/to/cifar_dataset";
  162. /// std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all", std::make_shared<SequentialSampler>(0, 10));
  163. /// std::string save_file = "Cifar10Data.mindrecord";
  164. /// bool rc = ds->Save(save_file);
  165. /// \endcode
  166. bool Save(const std::string &dataset_path, int32_t num_files = 1, const std::string &dataset_type = "mindrecord") {
  167. return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
  168. }
  169. /// \brief Function to create a BatchDataset
  170. /// \notes Combines batch_size number of consecutive rows into batches
  171. /// \param[in] batch_size The number of rows each batch is created with
  172. /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
  173. /// batch. If true, and if there are less than batch_size rows
  174. /// available to make the last batch, then those rows will
  175. /// be dropped and not propagated to the next node
  176. /// \return Shared pointer to the current BatchDataset
  177. /// \par Example
  178. /// \code
  179. /// /* Create a dataset where every 100 rows is combined into a batch */
  180. /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true);
  181. /// ds = ds->Batch(100, true);
  182. /// \endcode
  183. std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
  184. /// \brief Function to create a MapDataset
  185. /// \notes Applies each operation in operations to this dataset
  186. /// \param[in] operations Vector of raw pointers to TensorTransform objects to be applied on the dataset. Operations
  187. /// are applied in the order they appear in this list
  188. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  189. /// operation as input. The size of this list must match the number of
  190. /// input columns expected by the first operator. The default input_columns
  191. /// is the first column
  192. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  193. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  194. /// The size of this list must match the number of output columns of the
  195. /// last operation. The default output_columns will have the same
  196. /// name as the input columns, i.e., the columns will be replaced
  197. /// \param[in] project_columns A list of column names to project
  198. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  199. /// \return Shared pointer to the current MapDataset
  200. /// \par Example
  201. /// \code
  202. /// // Create objects for the tensor ops
  203. /// std::shared_ptr<TensorTransform> decode_op = std::make_shared<vision::Decode>(true);
  204. /// std::shared_ptr<TensorTransform> random_color_op = std::make_shared<vision::RandomColor>(0.0, 0.0);
  205. ///
  206. /// /* 1) Simple map example */
  207. /// // Apply decode_op on column "image". This column will be replaced by the outputted
  208. /// // column of decode_op. Since column_order is not provided, both columns "image"
  209. /// // and "label" will be propagated to the child node in their original order.
  210. /// dataset = dataset->Map({decode_op}, {"image"});
  211. ///
  212. /// // Decode and rename column "image" to "decoded_image".
  213. /// dataset = dataset->Map({decode_op}, {"image"}, {"decoded_image"});
  214. ///
  215. /// // Specify the order of the output columns.
  216. /// dataset = dataset->Map({decode_op}, {"image"}, {}, {"label", "image"});
  217. ///
  218. /// // Rename column "image" to "decoded_image" and also specify the order of the output columns.
  219. /// dataset = dataset->Map({decode_op}, {"image"}, {"decoded_image"}, {"label", "decoded_image"});
  220. ///
  221. /// // Rename column "image" to "decoded_image" and keep only this column.
  222. /// dataset = dataset->Map({decode_op}, {"image"}, {"decoded_image"}, {"decoded_image"});
  223. ///
  224. /// /* 2) Map example with more than one operation */
  225. /// // Create a dataset where the images are decoded, then randomly color jittered.
  226. /// // decode_op takes column "image" as input and outputs one column. The column
  227. /// // outputted by decode_op is passed as input to random_jitter_op.
  228. /// // random_jitter_op will output one column. Column "image" will be replaced by
  229. /// // the column outputted by random_jitter_op (the very last operation). All other
  230. /// // columns are unchanged. Since column_order is not specified, the order of the
  231. /// // columns will remain the same.
  232. /// dataset = dataset->Map({decode_op, random_jitter_op}, {"image"})
  233. /// \endcode
  234. std::shared_ptr<MapDataset> Map(const std::vector<TensorTransform *> &operations,
  235. const std::vector<std::string> &input_columns = {},
  236. const std::vector<std::string> &output_columns = {},
  237. const std::vector<std::string> &project_columns = {},
  238. const std::shared_ptr<DatasetCache> &cache = nullptr,
  239. const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
  240. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  241. (void)std::transform(
  242. operations.begin(), operations.end(), std::back_inserter(transform_ops),
  243. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  244. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  245. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  246. callbacks);
  247. }
  248. /// \brief Function to create a MapDataset
  249. /// \notes Applies each operation in operations to this dataset
  250. /// \param[in] operations Vector of shared pointers to TensorTransform objects to be applied on the dataset.
  251. /// Operations are applied in the order they appear in this list
  252. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  253. /// operation as input. The size of this list must match the number of
  254. /// input columns expected by the first operator. The default input_columns
  255. /// is the first column
  256. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  257. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  258. /// The size of this list must match the number of output columns of the
  259. /// last operation. The default output_columns will have the same
  260. /// name as the input columns, i.e., the columns will be replaced
  261. /// \param[in] project_columns A list of column names to project
  262. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  263. /// \return Shared pointer to the current MapDataset
  264. std::shared_ptr<MapDataset> Map(const std::vector<std::shared_ptr<TensorTransform>> &operations,
  265. const std::vector<std::string> &input_columns = {},
  266. const std::vector<std::string> &output_columns = {},
  267. const std::vector<std::string> &project_columns = {},
  268. const std::shared_ptr<DatasetCache> &cache = nullptr,
  269. const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
  270. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  271. (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
  272. [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
  273. return op != nullptr ? op->Parse() : nullptr;
  274. });
  275. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  276. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  277. callbacks);
  278. }
  279. /// \brief Function to create a MapDataset
  280. /// \notes Applies each operation in operations to this dataset
  281. /// \param[in] operations Vector of TensorTransform objects to be applied on the dataset. Operations are applied in
  282. /// the order they appear in this list
  283. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  284. /// operation as input. The size of this list must match the number of
  285. /// input columns expected by the first operator. The default input_columns
  286. /// is the first column
  287. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  288. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  289. /// The size of this list must match the number of output columns of the
  290. /// last operation. The default output_columns will have the same
  291. /// name as the input columns, i.e., the columns will be replaced
  292. /// \param[in] project_columns A list of column names to project
  293. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  294. /// \return Shared pointer to the current MapDataset
  295. std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> &operations,
  296. const std::vector<std::string> &input_columns = {},
  297. const std::vector<std::string> &output_columns = {},
  298. const std::vector<std::string> &project_columns = {},
  299. const std::shared_ptr<DatasetCache> &cache = nullptr,
  300. const std::vector<std::shared_ptr<DSCallback>> &callbacks = {}) {
  301. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  302. (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
  303. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  304. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  305. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  306. callbacks);
  307. }
  308. /// \brief Function to create a Project Dataset
  309. /// \notes Applies project to the dataset
  310. /// \param[in] columns The name of columns to project
  311. /// \return Shared pointer to the current Dataset
  312. /// \par Example
  313. /// \code
  314. /// /* Reorder the original column names in dataset */
  315. /// std::shared_ptr<Dataset> ds = Mnist(folder_path, "all", std::make_shared<RandomSampler>(false, 10));
  316. /// ds = ds->Project({"label", "image"});
  317. /// \endcode
  318. std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
  319. return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
  320. }
  321. /// \brief Function to create a Shuffle Dataset
  322. /// \notes Randomly shuffles the rows of this dataset
  323. /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
  324. /// \return Shared pointer to the current ShuffleDataset
  325. /// \par Example
  326. /// \code
  327. /// /* Rename the original column names in dataset */
  328. /// std::shared_ptr<Dataset> ds = Mnist(folder_path, "all", std::make_shared<RandomSampler>(false, 10));
  329. /// ds = ds->Rename({"image", "label"}, {"image_output", "label_output"});
  330. /// \endcode
  331. std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
  332. return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
  333. }
  334. std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
  335. protected:
  336. std::shared_ptr<TreeGetters> tree_getters_;
  337. std::shared_ptr<DatasetNode> ir_node_;
  338. private:
  339. // Char interface(CharIF) of GetColumnNames
  340. std::vector<std::vector<char>> GetColumnNamesCharIF();
  341. // Char interface(CharIF) of GetClassIndexing
  342. std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
  343. // Char interface(CharIF) of CreateIterator
  344. std::shared_ptr<Iterator> CreateIteratorCharIF(const std::vector<std::vector<char>> &columns, int32_t num_epochs);
  345. // Char interface(CharIF) of DeviceQueue
  346. bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
  347. int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
  348. // Char interface(CharIF) of Save
  349. bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
  350. };
  351. class MS_API SchemaObj {
  352. public:
  353. /// \brief Constructor
  354. explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}
  355. /// \brief Destructor
  356. ~SchemaObj() = default;
  357. /// \brief SchemaObj Init function
  358. /// \return bool true if schema initialization is successful
  359. Status Init();
  360. /// \brief Add new column to the schema with unknown shape of rank 1
  361. /// \param[in] name Name of the column.
  362. /// \param[in] ms_type Data type of the column(mindspore::DataType).
  363. /// \return Status code
  364. Status add_column(const std::string &name, mindspore::DataType ms_type) {
  365. return add_column_char(StringToChar(name), ms_type);
  366. }
  367. /// \brief Add new column to the schema with unknown shape of rank 1
  368. /// \param[in] name Name of the column.
  369. /// \param[in] ms_type Data type of the column(std::string).
  370. /// \param[in] shape Shape of the column.
  371. /// \return Status code
  372. Status add_column(const std::string &name, const std::string &ms_type) {
  373. return add_column_char(StringToChar(name), StringToChar(ms_type));
  374. }
  375. /// \brief Add new column to the schema
  376. /// \param[in] name Name of the column.
  377. /// \param[in] ms_type Data type of the column(mindspore::DataType).
  378. /// \param[in] shape Shape of the column.
  379. /// \return Status code
  380. Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
  381. return add_column_char(StringToChar(name), ms_type, shape);
  382. }
  383. /// \brief Add new column to the schema
  384. /// \param[in] name Name of the column.
  385. /// \param[in] ms_type Data type of the column(std::string).
  386. /// \param[in] shape Shape of the column.
  387. /// \return Status code
  388. Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
  389. return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
  390. }
  391. /// \brief Get a JSON string of the schema
  392. /// \return JSON string of the schema
  393. std::string to_json() { return CharToString(to_json_char()); }
  394. /// \brief Get a JSON string of the schema
  395. std::string to_string() { return to_json(); }
  396. /// \brief Set a new value to dataset_type
  397. void set_dataset_type(const std::string &dataset_type);
  398. /// \brief Set a new value to num_rows
  399. void set_num_rows(int32_t num_rows);
  400. /// \brief Get the current num_rows
  401. int32_t get_num_rows() const;
  402. /// \brief Get schema file from JSON file
  403. /// \param[in] json_string Name of JSON file to be parsed.
  404. /// \return Status code
  405. Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }
  406. /// \brief Parse and add column information
  407. /// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file.
  408. /// \return Status code
  409. Status ParseColumnString(const std::string &json_string) {
  410. return ParseColumnStringCharIF(StringToChar(json_string));
  411. }
  412. private:
  413. // Char constructor of SchemaObj
  414. explicit SchemaObj(const std::vector<char> &schema_file);
  415. // Char interface of add_column
  416. Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);
  417. Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);
  418. Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);
  419. Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
  420. const std::vector<int32_t> &shape);
  421. // Char interface of to_json
  422. const std::vector<char> to_json_char();
  423. // Char interface of FromJSONString
  424. Status FromJSONStringCharIF(const std::vector<char> &json_string);
  425. // Char interface of ParseColumnString
  426. Status ParseColumnStringCharIF(const std::vector<char> &json_string);
  427. struct Data;
  428. std::shared_ptr<Data> data_;
  429. };
  430. class MS_API BatchDataset : public Dataset {
  431. public:
  432. BatchDataset(const std::shared_ptr<Dataset> &input, int32_t batch_size, bool drop_remainder = false);
  433. ~BatchDataset() override = default;
  434. };
  435. class MS_API MapDataset : public Dataset {
  436. public:
  437. MapDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::shared_ptr<TensorOperation>> &operations,
  438. const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
  439. const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
  440. const std::vector<std::shared_ptr<DSCallback>> &callbacks);
  441. ~MapDataset() override = default;
  442. };
  443. class MS_API ProjectDataset : public Dataset {
  444. public:
  445. ProjectDataset(const std::shared_ptr<Dataset> &input, const std::vector<std::vector<char>> &columns);
  446. ~ProjectDataset() override = default;
  447. };
  448. class MS_API ShuffleDataset : public Dataset {
  449. public:
  450. ShuffleDataset(const std::shared_ptr<Dataset> &input, int32_t buffer_size);
  451. ~ShuffleDataset() override = default;
  452. };
  453. /// \brief Function to create a SchemaObj.
  454. /// \param[in] schema_file Path of schema file.
  455. /// \note The reason for using this API is that std::string will be constrained by the
  456. /// compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction.
  457. /// \return Shared pointer to the current schema.
  458. std::shared_ptr<SchemaObj> MS_API SchemaCharIF(const std::vector<char> &schema_file);
  459. /// \brief Function to create a SchemaObj.
  460. /// \param[in] schema_file Path of schema file.
  461. /// \return Shared pointer to the current schema.
  462. inline std::shared_ptr<SchemaObj> MS_API Schema(const std::string &schema_file = "") {
  463. return SchemaCharIF(StringToChar(schema_file));
  464. }
  465. class MS_API AlbumDataset : public Dataset {
  466. public:
  467. /// \brief Constructor of AlbumDataset.
  468. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  469. /// \param[in] data_schema Path to dataset schema file.
  470. /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns
  471. /// (default = {}).
  472. /// \param[in] decode The option to decode the images in dataset (default = false).
  473. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  474. /// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
  475. /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
  476. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  477. const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler,
  478. const std::shared_ptr<DatasetCache> &cache);
  479. /// \brief Constructor of AlbumDataset.
  480. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  481. /// \param[in] data_schema Path to dataset schema file.
  482. /// \param[in] column_names Column names used to specify columns to load.
  483. /// \param[in] decode The option to decode the images in dataset.
  484. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  485. /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
  486. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  487. const std::vector<std::vector<char>> &column_names, bool decode, const Sampler *sampler,
  488. const std::shared_ptr<DatasetCache> &cache);
  489. /// \brief Constructor of AlbumDataset.
  490. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  491. /// \param[in] data_schema Path to dataset schema file.
  492. /// \param[in] column_names Column names used to specify columns to load.
  493. /// \param[in] decode The option to decode the images in dataset.
  494. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  495. /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
  496. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  497. const std::vector<std::vector<char>> &column_names, bool decode,
  498. const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
  499. /// \brief Destructor of AlbumDataset.
  500. ~AlbumDataset() override = default;
  501. };
  502. /// \brief Function to create an AlbumDataset
  503. /// \notes The generated dataset is specified through setting a schema
  504. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  505. /// \param[in] data_schema Path to dataset schema file
  506. /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
  507. /// (default = {})
  508. /// \param[in] decode the option to decode the images in dataset (default = false)
  509. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  510. /// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  511. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  512. /// \return Shared pointer to the current Dataset
  513. /// \par Example
  514. /// \code
  515. /// /* Define dataset path and MindData object */
  516. /// std::string folder_path = "/path/to/album_dataset_directory";
  517. /// std::string schema_file = "/path/to/album_schema_file";
  518. /// std::vector<std::string> column_names = {"image", "label", "id"};
  519. /// std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
  520. ///
  521. /// /* Create iterator to read dataset */
  522. /// std::shared_ptr<Iterator> iter = ds->CreateIterator();
  523. /// std::unordered_map<std::string, mindspore::MSTensor> row;
  524. /// iter->GetNextRow(&row);
  525. ///
  526. /// /* Note: As we defined before, each data dictionary owns keys "image", "label" and "id" */
  527. /// auto image = row["image"];
  528. /// \endcode
  529. inline std::shared_ptr<AlbumDataset> MS_API
  530. Album(const std::string &dataset_dir, const std::string &data_schema, const std::vector<std::string> &column_names = {},
  531. bool decode = false, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
  532. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  533. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  534. VectorStringToChar(column_names), decode, sampler, cache);
  535. }
  536. /// \brief Function to create an AlbumDataset
  537. /// \notes The generated dataset is specified through setting a schema
  538. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  539. /// \param[in] data_schema Path to dataset schema file
  540. /// \param[in] column_names Column names used to specify columns to load
  541. /// \param[in] decode the option to decode the images in dataset
  542. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  543. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  544. /// \return Shared pointer to the current Dataset
  545. inline std::shared_ptr<AlbumDataset> MS_API Album(const std::string &dataset_dir, const std::string &data_schema,
  546. const std::vector<std::string> &column_names, bool decode,
  547. const Sampler *sampler,
  548. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  549. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  550. VectorStringToChar(column_names), decode, sampler, cache);
  551. }
  552. /// \brief Function to create an AlbumDataset
  553. /// \notes The generated dataset is specified through setting a schema
  554. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  555. /// \param[in] data_schema Path to dataset schema file
  556. /// \param[in] column_names Column names used to specify columns to load
  557. /// \param[in] decode the option to decode the images in dataset
  558. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  559. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  560. /// \return Shared pointer to the current Dataset
  561. inline std::shared_ptr<AlbumDataset> MS_API Album(const std::string &dataset_dir, const std::string &data_schema,
  562. const std::vector<std::string> &column_names, bool decode,
  563. const std::reference_wrapper<Sampler> sampler,
  564. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  565. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  566. VectorStringToChar(column_names), decode, sampler, cache);
  567. }
  568. class MS_API MnistDataset : public Dataset {
  569. public:
  570. /// \brief Constructor of MnistDataset.
  571. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  572. /// \param[in] usage Part of dataset of MNIST, can be "train", "test" or "all" (default = "all").
  573. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  574. /// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
  575. /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
  576. MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
  577. const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
  578. /// \brief Constructor of MnistDataset.
  579. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  580. /// \param[in] usage Part of dataset of MNIST, can be "train", "test" or "all".
  581. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  582. /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
  583. MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
  584. const std::shared_ptr<DatasetCache> &cache);
  585. /// \brief Constructor of MnistDataset.
  586. /// \param[in] dataset_dir Path to the root directory that contains the dataset.
  587. /// \param[in] usage Part of dataset of MNIST, can be "train", "test" or "all".
  588. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  589. /// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
  590. MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
  591. const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
  592. /// Destructor of MnistDataset.
  593. ~MnistDataset() override = default;
  594. };
  595. /// \brief Function to create a MnistDataset
  596. /// \notes The generated dataset has two columns ["image", "label"]
  597. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  598. /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all").
  599. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  600. /// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  601. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  602. /// \return Shared pointer to the current MnistDataset
  603. /// \par Example
  604. /// \code
  605. /// /* Define dataset path and MindData object */
  606. /// std::string folder_path = "/path/to/mnist_dataset_directory";
  607. /// std::shared_ptr<Dataset> ds = Mnist(folder_path, "all", std::make_shared<RandomSampler>(false, 20));
  608. ///
  609. /// /* Create iterator to read dataset */
  610. /// std::shared_ptr<Iterator> iter = ds->CreateIterator();
  611. /// std::unordered_map<std::string, mindspore::MSTensor> row;
  612. /// iter->GetNextRow(&row);
  613. ///
  614. /// /* Note: In MNIST dataset, each dictionary has keys "image" and "label" */
  615. /// auto image = row["image"];
  616. /// \endcode
  617. inline std::shared_ptr<MnistDataset> MS_API
  618. Mnist(const std::string &dataset_dir, const std::string &usage = "all",
  619. const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
  620. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  621. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  622. }
  623. /// \brief Function to create a MnistDataset
  624. /// \notes The generated dataset has two columns ["image", "label"]
  625. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  626. /// \param[in] usage of MNIST, can be "train", "test" or "all"
  627. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  628. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  629. /// \return Shared pointer to the current MnistDataset
  630. inline std::shared_ptr<MnistDataset> MS_API Mnist(const std::string &dataset_dir, const std::string &usage,
  631. const Sampler *sampler,
  632. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  633. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  634. }
  635. /// \brief Function to create a MnistDataset
  636. /// \notes The generated dataset has two columns ["image", "label"]
  637. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  638. /// \param[in] usage of MNIST, can be "train", "test" or "all"
  639. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  640. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  641. /// \return Shared pointer to the current MnistDataset
  642. inline std::shared_ptr<MnistDataset> MS_API Mnist(const std::string &dataset_dir, const std::string &usage,
  643. const std::reference_wrapper<Sampler> sampler,
  644. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  645. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  646. }
  647. } // namespace dataset
  648. } // namespace mindspore
  649. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_