You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.h 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  18. #include <sys/stat.h>
  19. #include <unistd.h>
  20. #include <algorithm>
  21. #include <map>
  22. #include <memory>
  23. #include <set>
  24. #include <string>
  25. #include <unordered_map>
  26. #include <unordered_set>
  27. #include <utility>
  28. #include <vector>
  29. #include "include/api/dual_abi_helper.h"
  30. #include "include/api/types.h"
  31. #include "include/iterator.h"
  32. #include "include/samplers.h"
  33. #include "include/transforms.h"
  34. namespace mindspore {
  35. namespace dataset {
  36. class Tensor;
  37. class TensorShape;
  38. class TreeAdapter;
  39. class TreeAdapterLite;
  40. class TreeGetters;
  41. class DatasetCache;
  42. class DatasetNode;
  43. class Iterator;
  44. class TensorOperation;
  45. class SchemaObj;
  46. class SamplerObj;
  47. // Dataset classes (in alphabetical order)
  48. class BatchDataset;
  49. class MapDataset;
  50. class ProjectDataset;
  51. class ShuffleDataset;
  52. class DSCallback;
  53. /// \class Dataset datasets.h
  54. /// \brief A base class to represent a dataset in the data pipeline.
  55. class Dataset : public std::enable_shared_from_this<Dataset> {
  56. public:
  57. // need friend class so they can access the children_ field
  58. friend class Iterator;
  59. friend class TransferNode;
  60. /// \brief Constructor
  61. Dataset();
  62. /// \brief Destructor
  63. ~Dataset() = default;
  64. /// \brief Gets the dataset size
  65. /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
  66. /// dataset size at the expense of accuracy.
  67. /// \return dataset size. If failed, return -1
  68. int64_t GetDatasetSize(bool estimate = false);
  69. /// \brief Gets the output type
  70. /// \return a vector of DataType. If failed, return an empty vector
  71. std::vector<mindspore::DataType> GetOutputTypes();
  72. /// \brief Gets the output shape
  73. /// \return a vector of TensorShape. If failed, return an empty vector
  74. std::vector<std::vector<int64_t>> GetOutputShapes();
  75. /// \brief Gets the batch size
  76. /// \return int64_t
  77. int64_t GetBatchSize();
  78. /// \brief Gets the repeat count
  79. /// \return int64_t
  80. int64_t GetRepeatCount();
  81. /// \brief Gets the number of classes
  82. /// \return number of classes. If failed, return -1
  83. int64_t GetNumClasses();
  84. /// \brief Gets the column names
  85. /// \return Names of the columns. If failed, return an empty vector
  86. std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); }
  87. /// \brief Gets the class indexing
  88. /// \return a map of ClassIndexing. If failed, return an empty map
  89. std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() {
  90. return ClassIndexCharToString(GetClassIndexingCharIF());
  91. }
  92. /// \brief Setter function for runtime number of workers
  93. /// \param[in] num_workers The number of threads in this operator
  94. /// \return Shared pointer to the original object
  95. std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
  96. /// \brief Function to create an PullBasedIterator over the Dataset
  97. /// \param[in] columns List of columns to be used to specify the order of columns
  98. /// \return Shared pointer to the Iterator
  99. std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {});
  100. /// \brief Function to create an Iterator over the Dataset pipeline
  101. /// \param[in] columns List of columns to be used to specify the order of columns
  102. /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
  103. /// An empty row is returned at the end of each epoch
  104. /// \return Shared pointer to the Iterator
  105. std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1) {
  106. return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
  107. }
  108. /// \brief Function to transfer data through a device.
  109. /// \notes If device is Ascend, features of data will be transferred one by one. The limitation
  110. /// of data transmission per time is 256M.
  111. /// \param[in] queue_name Channel name (default="", create new unique name).
  112. /// \param[in] device_type Type of device (default="", get from MSContext).
  113. /// \param[in] device_id id of device (default=1, get from MSContext).
  114. /// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
  115. /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
  116. /// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
  117. /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
  118. /// of data or not(default=false).
  119. /// \return Returns true if no error encountered else false.
  120. bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0,
  121. int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
  122. bool create_data_info_queue = false) {
  123. return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
  124. total_batches, create_data_info_queue);
  125. }
  126. /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
  127. /// \note Usage restrictions:
  128. /// 1. Supported dataset formats: 'mindrecord' only
  129. /// 2. To save the samples in order, set dataset's shuffle to false and num_files to 1.
  130. /// 3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
  131. /// with random attribute in map operator.
  132. /// 4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor
  133. /// multi-dimensional string.
  134. /// \param[in] file_name Path to dataset file
  135. /// \param[in] num_files Number of dataset files (default=1)
  136. /// \param[in] file_type Dataset format (default="mindrecord")
  137. /// \return Returns true if no error encountered else false
  138. bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
  139. return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
  140. }
  141. /// \brief Function to create a BatchDataset
  142. /// \notes Combines batch_size number of consecutive rows into batches
  143. /// \param[in] batch_size The number of rows each batch is created with
  144. /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
  145. /// batch. If true, and if there are less than batch_size rows
  146. /// available to make the last batch, then those rows will
  147. /// be dropped and not propagated to the next node
  148. /// \return Shared pointer to the current BatchDataset
  149. std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
  150. /// \brief Function to create a MapDataset
  151. /// \notes Applies each operation in operations to this dataset
  152. /// \param[in] operations Vector of raw pointers to TensorTransform objects to be applied on the dataset. Operations
  153. /// are applied in the order they appear in this list
  154. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  155. /// operation as input. The size of this list must match the number of
  156. /// input columns expected by the first operator. The default input_columns
  157. /// is the first column
  158. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  159. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  160. /// The size of this list must match the number of output columns of the
  161. /// last operation. The default output_columns will have the same
  162. /// name as the input columns, i.e., the columns will be replaced
  163. /// \param[in] project_columns A list of column names to project
  164. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  165. /// \return Shared pointer to the current MapDataset
  166. std::shared_ptr<MapDataset> Map(std::vector<TensorTransform *> operations,
  167. const std::vector<std::string> &input_columns = {},
  168. const std::vector<std::string> &output_columns = {},
  169. const std::vector<std::string> &project_columns = {},
  170. const std::shared_ptr<DatasetCache> &cache = nullptr,
  171. std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
  172. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  173. (void)std::transform(
  174. operations.begin(), operations.end(), std::back_inserter(transform_ops),
  175. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  176. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  177. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  178. callbacks);
  179. }
  180. /// \brief Function to create a MapDataset
  181. /// \notes Applies each operation in operations to this dataset
  182. /// \param[in] operations Vector of shared pointers to TensorTransform objects to be applied on the dataset.
  183. /// Operations are applied in the order they appear in this list
  184. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  185. /// operation as input. The size of this list must match the number of
  186. /// input columns expected by the first operator. The default input_columns
  187. /// is the first column
  188. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  189. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  190. /// The size of this list must match the number of output columns of the
  191. /// last operation. The default output_columns will have the same
  192. /// name as the input columns, i.e., the columns will be replaced
  193. /// \param[in] project_columns A list of column names to project
  194. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  195. /// \return Shared pointer to the current MapDataset
  196. std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations,
  197. const std::vector<std::string> &input_columns = {},
  198. const std::vector<std::string> &output_columns = {},
  199. const std::vector<std::string> &project_columns = {},
  200. const std::shared_ptr<DatasetCache> &cache = nullptr,
  201. std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
  202. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  203. (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
  204. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  205. return op != nullptr ? op->Parse() : nullptr;
  206. });
  207. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  208. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  209. callbacks);
  210. }
  211. /// \brief Function to create a MapDataset
  212. /// \notes Applies each operation in operations to this dataset
  213. /// \param[in] operations Vector of TensorTransform objects to be applied on the dataset. Operations are applied in
  214. /// the order they appear in this list
  215. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  216. /// operation as input. The size of this list must match the number of
  217. /// input columns expected by the first operator. The default input_columns
  218. /// is the first column
  219. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  220. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  221. /// The size of this list must match the number of output columns of the
  222. /// last operation. The default output_columns will have the same
  223. /// name as the input columns, i.e., the columns will be replaced
  224. /// \param[in] project_columns A list of column names to project
  225. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  226. /// \return Shared pointer to the current MapDataset
  227. std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations,
  228. const std::vector<std::string> &input_columns = {},
  229. const std::vector<std::string> &output_columns = {},
  230. const std::vector<std::string> &project_columns = {},
  231. const std::shared_ptr<DatasetCache> &cache = nullptr,
  232. std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
  233. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  234. (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
  235. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  236. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  237. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  238. callbacks);
  239. }
  240. /// \brief Function to create a Project Dataset
  241. /// \notes Applies project to the dataset
  242. /// \param[in] columns The name of columns to project
  243. /// \return Shared pointer to the current Dataset
  244. std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
  245. return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
  246. }
  247. /// \brief Function to create a Shuffle Dataset
  248. /// \notes Randomly shuffles the rows of this dataset
  249. /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
  250. /// \return Shared pointer to the current ShuffleDataset
  251. std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
  252. return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
  253. }
  254. std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
  255. protected:
  256. std::shared_ptr<TreeGetters> tree_getters_;
  257. std::shared_ptr<DatasetNode> ir_node_;
  258. private:
  259. // Char interface(CharIF) of GetColumnNames
  260. std::vector<std::vector<char>> GetColumnNamesCharIF();
  261. // Char interface(CharIF) of GetClassIndexing
  262. std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
  263. // Char interface(CharIF) of CreateIterator
  264. std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);
  265. // Char interface(CharIF) of DeviceQueue
  266. bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
  267. int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
  268. // Char interface(CharIF) of Save
  269. bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
  270. };
  271. class SchemaObj {
  272. public:
  273. /// \brief Constructor
  274. explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}
  275. /// \brief Destructor
  276. ~SchemaObj() = default;
  277. /// \brief SchemaObj Init function
  278. /// \return bool true if schema initialization is successful
  279. Status Init();
  280. /// \brief Add new column to the schema with unknown shape of rank 1
  281. /// \param[in] name Name of the column.
  282. /// \param[in] ms_type Data type of the column(mindspore::DataType).
  283. /// \return Status code
  284. Status add_column(const std::string &name, mindspore::DataType ms_type) {
  285. return add_column_char(StringToChar(name), ms_type);
  286. }
  287. /// \brief Add new column to the schema with unknown shape of rank 1
  288. /// \param[in] name Name of the column.
  289. /// \param[in] ms_type Data type of the column(std::string).
  290. /// \param[in] shape Shape of the column.
  291. /// \return Status code
  292. Status add_column(const std::string &name, const std::string &ms_type) {
  293. return add_column_char(StringToChar(name), StringToChar(ms_type));
  294. }
  295. /// \brief Add new column to the schema
  296. /// \param[in] name Name of the column.
  297. /// \param[in] ms_type Data type of the column(mindspore::DataType).
  298. /// \param[in] shape Shape of the column.
  299. /// \return Status code
  300. Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
  301. return add_column_char(StringToChar(name), ms_type, shape);
  302. }
  303. /// \brief Add new column to the schema
  304. /// \param[in] name Name of the column.
  305. /// \param[in] ms_type Data type of the column(std::string).
  306. /// \param[in] shape Shape of the column.
  307. /// \return Status code
  308. Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
  309. return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
  310. }
  311. /// \brief Get a JSON string of the schema
  312. /// \return JSON string of the schema
  313. std::string to_json() { return CharToString(to_json_char()); }
  314. /// \brief Get a JSON string of the schema
  315. std::string to_string() { return to_json(); }
  316. /// \brief Set a new value to dataset_type
  317. void set_dataset_type(std::string dataset_type);
  318. /// \brief Set a new value to num_rows
  319. void set_num_rows(int32_t num_rows);
  320. /// \brief Get the current num_rows
  321. int32_t get_num_rows() const;
  322. /// \brief Get schema file from JSON file
  323. /// \param[in] json_string Name of JSON file to be parsed.
  324. /// \return Status code
  325. Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }
  326. /// \brief Parse and add column information
  327. /// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file.
  328. /// \return Status code
  329. Status ParseColumnString(const std::string &json_string) {
  330. return ParseColumnStringCharIF(StringToChar(json_string));
  331. }
  332. private:
  333. // Char constructor of SchemaObj
  334. explicit SchemaObj(const std::vector<char> &schema_file);
  335. // Char interface of add_column
  336. Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);
  337. Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);
  338. Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);
  339. Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
  340. const std::vector<int32_t> &shape);
  341. // Char interface of to_json
  342. const std::vector<char> to_json_char();
  343. // Char interface of FromJSONString
  344. Status FromJSONStringCharIF(const std::vector<char> &json_string);
  345. // Char interface of ParseColumnString
  346. Status ParseColumnStringCharIF(const std::vector<char> &json_string);
  347. struct Data;
  348. std::shared_ptr<Data> data_;
  349. };
  350. class BatchDataset : public Dataset {
  351. public:
  352. BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false);
  353. ~BatchDataset() = default;
  354. };
  355. class MapDataset : public Dataset {
  356. public:
  357. MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
  358. const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
  359. const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
  360. std::vector<std::shared_ptr<DSCallback>> callbacks);
  361. ~MapDataset() = default;
  362. };
  363. class ProjectDataset : public Dataset {
  364. public:
  365. ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns);
  366. ~ProjectDataset() = default;
  367. };
  368. class ShuffleDataset : public Dataset {
  369. public:
  370. ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size);
  371. ~ShuffleDataset() = default;
  372. };
  373. /// \brief Function to create a SchemaObj
  374. /// \param[in] schema_file Path of schema file
  375. /// \note This api exists because std::string will constrained by ABI compile macro but char don't.
  376. /// \return Shared pointer to the current schema
  377. std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file);
  378. /// \brief Function to create a SchemaObj
  379. /// \param[in] schema_file Path of schema file
  380. /// \return Shared pointer to the current schema
  381. inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
  382. return SchemaCharIF(StringToChar(schema_file));
  383. }
  384. class AlbumDataset : public Dataset {
  385. public:
  386. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  387. const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler,
  388. const std::shared_ptr<DatasetCache> &cache);
  389. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  390. const std::vector<std::vector<char>> &column_names, bool decode, Sampler *sampler,
  391. const std::shared_ptr<DatasetCache> &cache);
  392. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  393. const std::vector<std::vector<char>> &column_names, bool decode,
  394. const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
  395. ~AlbumDataset() = default;
  396. };
  397. /// \brief Function to create an AlbumDataset
  398. /// \notes The generated dataset is specified through setting a schema
  399. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  400. /// \param[in] data_schema Path to dataset schema file
  401. /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
  402. /// (default = {})
  403. /// \param[in] decode the option to decode the images in dataset (default = false)
  404. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  405. /// given,
  406. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  407. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  408. /// \return Shared pointer to the current Dataset
  409. inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  410. const std::vector<std::string> &column_names = {}, bool decode = false,
  411. const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
  412. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  413. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  414. VectorStringToChar(column_names), decode, sampler, cache);
  415. }
  416. /// \brief Function to create an AlbumDataset
  417. /// \notes The generated dataset is specified through setting a schema
  418. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  419. /// \param[in] data_schema Path to dataset schema file
  420. /// \param[in] column_names Column names used to specify columns to load
  421. /// \param[in] decode the option to decode the images in dataset
  422. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  423. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  424. /// \return Shared pointer to the current Dataset
  425. inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  426. const std::vector<std::string> &column_names, bool decode, Sampler *sampler,
  427. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  428. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  429. VectorStringToChar(column_names), decode, sampler, cache);
  430. }
  431. /// \brief Function to create an AlbumDataset
  432. /// \notes The generated dataset is specified through setting a schema
  433. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  434. /// \param[in] data_schema Path to dataset schema file
  435. /// \param[in] column_names Column names used to specify columns to load
  436. /// \param[in] decode the option to decode the images in dataset
  437. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  438. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  439. /// \return Shared pointer to the current Dataset
  440. inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  441. const std::vector<std::string> &column_names, bool decode,
  442. const std::reference_wrapper<Sampler> sampler,
  443. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  444. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  445. VectorStringToChar(column_names), decode, sampler, cache);
  446. }
  447. class MnistDataset : public Dataset {
  448. public:
  449. explicit MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
  450. const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
  451. explicit MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, Sampler *sampler,
  452. const std::shared_ptr<DatasetCache> &cache);
  453. explicit MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
  454. const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
  455. ~MnistDataset() = default;
  456. };
  457. /// \brief Function to create a MnistDataset
  458. /// \notes The generated dataset has two columns ["image", "label"]
  459. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  460. /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all").
  461. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  462. /// given,
  463. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  464. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  465. /// \return Shared pointer to the current MnistDataset
  466. inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = "all",
  467. const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
  468. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  469. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  470. }
  471. /// \brief Function to create a MnistDataset
  472. /// \notes The generated dataset has two columns ["image", "label"]
  473. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  474. /// \param[in] usage of MNIST, can be "train", "test" or "all"
  475. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  476. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  477. /// \return Shared pointer to the current MnistDataset
  478. inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage, Sampler *sampler,
  479. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  480. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  481. }
  482. /// \brief Function to create a MnistDataset
  483. /// \notes The generated dataset has two columns ["image", "label"]
  484. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  485. /// \param[in] usage of MNIST, can be "train", "test" or "all"
  486. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  487. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  488. /// \return Shared pointer to the current MnistDataset
  489. inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
  490. const std::reference_wrapper<Sampler> sampler,
  491. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  492. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  493. }
  494. } // namespace dataset
  495. } // namespace mindspore
  496. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_