You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.h 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  18. #include <sys/stat.h>
  19. #include <unistd.h>
  20. #include <algorithm>
  21. #include <map>
  22. #include <memory>
  23. #include <set>
  24. #include <string>
  25. #include <unordered_map>
  26. #include <unordered_set>
  27. #include <utility>
  28. #include <vector>
  29. #include "include/api/dual_abi_helper.h"
  30. #include "include/iterator.h"
  31. #include "include/samplers.h"
  32. #include "include/transforms.h"
  33. namespace mindspore {
  34. namespace dataset {
  35. class Tensor;
  36. class TensorShape;
  37. class TreeGetters;
  38. class DatasetCache;
  39. class DatasetNode;
  40. class Iterator;
  41. class TensorOperation;
  42. class SchemaObj;
  43. class SamplerObj;
  44. // Dataset classes (in alphabetical order)
  45. class BatchDataset;
  46. class MapDataset;
  47. class ProjectDataset;
  48. class ShuffleDataset;
  49. class DSCallback;
  50. /// \class Dataset datasets.h
  51. /// \brief A base class to represent a dataset in the data pipeline.
  52. class Dataset : public std::enable_shared_from_this<Dataset> {
  53. public:
  54. // need friend class so they can access the children_ field
  55. friend class Iterator;
  56. friend class TransferNode;
  57. /// \brief Constructor
  58. Dataset();
  59. /// \brief Destructor
  60. ~Dataset() = default;
  61. /// \brief Gets the dataset size
  62. /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
  63. /// dataset size at the expense of accuracy.
  64. /// \return dataset size. If failed, return -1
  65. int64_t GetDatasetSize(bool estimate = false);
  66. // /// \brief Gets the output type
  67. // /// \return a vector of DataType. If failed, return an empty vector
  68. // std::vector<DataType> GetOutputTypes();
  69. /// \brief Gets the output shape
  70. /// \return a vector of TensorShape. If failed, return an empty vector
  71. std::vector<TensorShape> GetOutputShapes();
  72. /// \brief Gets the batch size
  73. /// \return int64_t
  74. int64_t GetBatchSize();
  75. /// \brief Gets the repeat count
  76. /// \return int64_t
  77. int64_t GetRepeatCount();
  78. /// \brief Gets the number of classes
  79. /// \return number of classes. If failed, return -1
  80. int64_t GetNumClasses();
  81. /// \brief Gets the column names
  82. /// \return Names of the columns. If failed, return an empty vector
  83. std::vector<std::string> GetColumnNames() { return VectorCharToString(GetColumnNamesCharIF()); }
  84. /// \brief Gets the class indexing
  85. /// \return a map of ClassIndexing. If failed, return an empty map
  86. std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing() {
  87. return ClassIndexCharToString(GetClassIndexingCharIF());
  88. }
  89. /// \brief Setter function for runtime number of workers
  90. /// \param[in] num_workers The number of threads in this operator
  91. /// \return Shared pointer to the original object
  92. std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
  93. /// \brief Function to create an Iterator over the Dataset pipeline
  94. /// \param[in] columns List of columns to be used to specify the order of columns
  95. /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
  96. /// An empty row is returned at the end of each epoch
  97. /// \return Shared pointer to the Iterator
  98. std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1) {
  99. return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
  100. }
  101. /// \brief Function to create a BatchDataset
  102. /// \notes Combines batch_size number of consecutive rows into batches
  103. /// \param[in] batch_size The number of rows each batch is created with
  104. /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
  105. /// batch. If true, and if there are less than batch_size rows
  106. /// available to make the last batch, then those rows will
  107. /// be dropped and not propagated to the next node
  108. /// \return Shared pointer to the current BatchDataset
  109. std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
  110. /// \brief Function to create a MapDataset
  111. /// \notes Applies each operation in operations to this dataset
  112. /// \param[in] operations Vector of operations to be applied on the dataset. Operations are
  113. /// applied in the order they appear in this list
  114. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  115. /// operation as input. The size of this list must match the number of
  116. /// input columns expected by the first operator. The default input_columns
  117. /// is the first column
  118. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  119. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  120. /// The size of this list must match the number of output columns of the
  121. /// last operation. The default output_columns will have the same
  122. /// name as the input columns, i.e., the columns will be replaced
  123. /// \param[in] project_columns A list of column names to project
  124. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  125. /// \return Shared pointer to the current MapDataset
  126. std::shared_ptr<MapDataset> Map(std::vector<TensorTransform *> operations,
  127. const std::vector<std::string> &input_columns = {},
  128. const std::vector<std::string> &output_columns = {},
  129. const std::vector<std::string> &project_columns = {},
  130. const std::shared_ptr<DatasetCache> &cache = nullptr,
  131. std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
  132. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  133. (void)std::transform(
  134. operations.begin(), operations.end(), std::back_inserter(transform_ops),
  135. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  136. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  137. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  138. callbacks);
  139. }
  140. std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations,
  141. const std::vector<std::string> &input_columns = {},
  142. const std::vector<std::string> &output_columns = {},
  143. const std::vector<std::string> &project_columns = {},
  144. const std::shared_ptr<DatasetCache> &cache = nullptr,
  145. std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
  146. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  147. (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
  148. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  149. return op != nullptr ? op->Parse() : nullptr;
  150. });
  151. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  152. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  153. callbacks);
  154. }
  155. std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations,
  156. const std::vector<std::string> &input_columns = {},
  157. const std::vector<std::string> &output_columns = {},
  158. const std::vector<std::string> &project_columns = {},
  159. const std::shared_ptr<DatasetCache> &cache = nullptr,
  160. std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
  161. std::vector<std::shared_ptr<TensorOperation>> transform_ops;
  162. (void)std::transform(operations.begin(), operations.end(), std::back_inserter(transform_ops),
  163. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  164. return std::make_shared<MapDataset>(shared_from_this(), transform_ops, VectorStringToChar(input_columns),
  165. VectorStringToChar(output_columns), VectorStringToChar(project_columns), cache,
  166. callbacks);
  167. }
  168. /// \brief Function to create a Project Dataset
  169. /// \notes Applies project to the dataset
  170. /// \param[in] columns The name of columns to project
  171. /// \return Shared pointer to the current Dataset
  172. std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
  173. return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
  174. }
  175. /// \brief Function to create a Shuffle Dataset
  176. /// \notes Randomly shuffles the rows of this dataset
  177. /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
  178. /// \return Shared pointer to the current ShuffleDataset
  179. std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
  180. return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
  181. }
  182. std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
  183. protected:
  184. std::shared_ptr<TreeGetters> tree_getters_;
  185. std::shared_ptr<DatasetNode> ir_node_;
  186. private:
  187. // Char interface(CharIF) of GetColumnNames
  188. std::vector<std::vector<char>> GetColumnNamesCharIF();
  189. // Char interface(CharIF) of GetClassIndexing
  190. std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> GetClassIndexingCharIF();
  191. // Char interface(CharIF) of CreateIterator
  192. std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);
  193. };
  194. class BatchDataset : public Dataset {
  195. public:
  196. BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false);
  197. ~BatchDataset() = default;
  198. };
  199. class MapDataset : public Dataset {
  200. public:
  201. MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
  202. const std::vector<std::vector<char>> &input_columns, const std::vector<std::vector<char>> &output_columns,
  203. const std::vector<std::vector<char>> &project_columns, const std::shared_ptr<DatasetCache> &cache,
  204. std::vector<std::shared_ptr<DSCallback>> callbacks);
  205. ~MapDataset() = default;
  206. };
  207. class ProjectDataset : public Dataset {
  208. public:
  209. ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &columns);
  210. ~ProjectDataset() = default;
  211. };
  212. class ShuffleDataset : public Dataset {
  213. public:
  214. ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size);
  215. ~ShuffleDataset() = default;
  216. };
  217. /// \brief Function to create a SchemaObj
  218. /// \param[in] schema_file Path of schema file
  219. /// \return Shared pointer to the current schema
  220. std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file);
  221. inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
  222. return SchemaCharIF(StringToChar(schema_file));
  223. }
  224. class AlbumDataset : public Dataset {
  225. public:
  226. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  227. const std::vector<std::vector<char>> &column_names, bool decode, const std::shared_ptr<Sampler> &sampler,
  228. const std::shared_ptr<DatasetCache> &cache);
  229. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  230. const std::vector<std::vector<char>> &column_names, bool decode, Sampler *sampler,
  231. const std::shared_ptr<DatasetCache> &cache);
  232. AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
  233. const std::vector<std::vector<char>> &column_names, bool decode,
  234. const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
  235. ~AlbumDataset() = default;
  236. };
  237. /// \brief Function to create an AlbumDataset
  238. /// \notes The generated dataset is specified through setting a schema
  239. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  240. /// \param[in] data_schema Path to dataset schema file
  241. /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
  242. /// (default = {})
  243. /// \param[in] decode the option to decode the images in dataset (default = false)
  244. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  245. /// given,
  246. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  247. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  248. /// \return Shared pointer to the current Dataset
  249. inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  250. const std::vector<std::string> &column_names = {}, bool decode = false,
  251. const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
  252. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  253. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  254. VectorStringToChar(column_names), decode, sampler, cache);
  255. }
  256. /// \brief Function to create an AlbumDataset
  257. /// \notes The generated dataset is specified through setting a schema
  258. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  259. /// \param[in] data_schema Path to dataset schema file
  260. /// \param[in] column_names Column names used to specify columns to load
  261. /// \param[in] decode the option to decode the images in dataset
  262. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  263. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  264. /// \return Shared pointer to the current Dataset
  265. inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  266. const std::vector<std::string> &column_names, bool decode, Sampler *sampler,
  267. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  268. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  269. VectorStringToChar(column_names), decode, sampler, cache);
  270. }
  271. /// \brief Function to create an AlbumDataset
  272. /// \notes The generated dataset is specified through setting a schema
  273. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  274. /// \param[in] data_schema Path to dataset schema file
  275. /// \param[in] column_names Column names used to specify columns to load
  276. /// \param[in] decode the option to decode the images in dataset
  277. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  278. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  279. /// \return Shared pointer to the current Dataset
  280. inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  281. const std::vector<std::string> &column_names, bool decode,
  282. const std::reference_wrapper<Sampler> sampler,
  283. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  284. return std::make_shared<AlbumDataset>(StringToChar(dataset_dir), StringToChar(data_schema),
  285. VectorStringToChar(column_names), decode, sampler, cache);
  286. }
  287. class MnistDataset : public Dataset {
  288. public:
  289. explicit MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
  290. const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
  291. explicit MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, Sampler *sampler,
  292. const std::shared_ptr<DatasetCache> &cache);
  293. explicit MnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
  294. const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache);
  295. ~MnistDataset() = default;
  296. };
  297. /// \brief Function to create a MnistDataset
  298. /// \notes The generated dataset has two columns ["image", "label"]
  299. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  300. /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all").
  301. /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
  302. /// given,
  303. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  304. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  305. /// \return Shared pointer to the current MnistDataset
  306. inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = "all",
  307. const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
  308. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  309. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  310. }
  311. /// \brief Function to create a MnistDataset
  312. /// \notes The generated dataset has two columns ["image", "label"]
  313. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  314. /// \param[in] usage of MNIST, can be "train", "test" or "all"
  315. /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
  316. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  317. /// \return Shared pointer to the current MnistDataset
  318. inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage, Sampler *sampler,
  319. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  320. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  321. }
  322. /// \brief Function to create a MnistDataset
  323. /// \notes The generated dataset has two columns ["image", "label"]
  324. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  325. /// \param[in] usage of MNIST, can be "train", "test" or "all"
  326. /// \param[in] sampler Sampler object used to choose samples from the dataset.
  327. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  328. /// \return Shared pointer to the current MnistDataset
  329. inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
  330. const std::reference_wrapper<Sampler> sampler,
  331. const std::shared_ptr<DatasetCache> &cache = nullptr) {
  332. return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
  333. }
  334. } // namespace dataset
  335. } // namespace mindspore
  336. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_