You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
  18. #include <sys/stat.h>
  19. #include <unistd.h>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <unordered_set>
  26. #include <utility>
  27. #include <vector>
  28. #include "include/iterator.h"
  29. #include "include/samplers.h"
  30. namespace mindspore {
  31. namespace dataset {
  32. class Tensor;
  33. class TensorShape;
  34. class TreeGetters;
  35. class DatasetCache;
  36. class DatasetNode;
  37. class Iterator;
  38. class TensorOperation;
  39. class SchemaObj;
  40. class SamplerObj;
  41. // Dataset classes (in alphabetical order)
  42. class BatchDataset;
  43. class MapDataset;
  44. class ProjectDataset;
  45. class ShuffleDataset;
  46. class DSCallback;
  47. /// \class Dataset datasets.h
  48. /// \brief A base class to represent a dataset in the data pipeline.
  49. class Dataset : public std::enable_shared_from_this<Dataset> {
  50. public:
  51. // need friend class so they can access the children_ field
  52. friend class Iterator;
  53. friend class TransferNode;
  54. /// \brief Constructor
  55. Dataset();
  56. /// \brief Destructor
  57. ~Dataset() = default;
  58. /// \brief Gets the dataset size
  59. /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
  60. /// dataset size at the expense of accuracy.
  61. /// \return dataset size. If failed, return -1
  62. int64_t GetDatasetSize(bool estimate = false);
  63. // /// \brief Gets the output type
  64. // /// \return a vector of DataType. If failed, return an empty vector
  65. // std::vector<DataType> GetOutputTypes();
  66. /// \brief Gets the output shape
  67. /// \return a vector of TensorShape. If failed, return an empty vector
  68. std::vector<TensorShape> GetOutputShapes();
  69. /// \brief Gets the batch size
  70. /// \return int64_t
  71. int64_t GetBatchSize();
  72. /// \brief Gets the repeat count
  73. /// \return int64_t
  74. int64_t GetRepeatCount();
  75. /// \brief Gets the number of classes
  76. /// \return number of classes. If failed, return -1
  77. int64_t GetNumClasses();
  78. /// \brief Gets the column names
  79. /// \return Names of the columns. If failed, return an empty vector
  80. std::vector<std::string> GetColumnNames();
  81. /// \brief Gets the class indexing
  82. /// \return a map of ClassIndexing. If failed, return an empty map
  83. std::vector<std::pair<std::string, std::vector<int32_t>>> GetClassIndexing();
  84. /// \brief Setter function for runtime number of workers
  85. /// \param[in] num_workers The number of threads in this operator
  86. /// \return Shared pointer to the original object
  87. std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
  88. /// \brief Function to create an Iterator over the Dataset pipeline
  89. /// \param[in] columns List of columns to be used to specify the order of columns
  90. /// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
  91. /// An empty row is returned at the end of each epoch
  92. /// \return Shared pointer to the Iterator
  93. std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}, int32_t num_epochs = -1);
  94. /// \brief Function to create a BatchDataset
  95. /// \notes Combines batch_size number of consecutive rows into batches
  96. /// \param[in] batch_size The number of rows each batch is created with
  97. /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
  98. /// batch. If true, and if there are less than batch_size rows
  99. /// available to make the last batch, then those rows will
  100. /// be dropped and not propagated to the next node
  101. /// \return Shared pointer to the current BatchDataset
  102. std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
  103. /// \brief Function to create a MapDataset
  104. /// \notes Applies each operation in operations to this dataset
  105. /// \param[in] operations Vector of operations to be applied on the dataset. Operations are
  106. /// applied in the order they appear in this list
  107. /// \param[in] input_columns Vector of the names of the columns that will be passed to the first
  108. /// operation as input. The size of this list must match the number of
  109. /// input columns expected by the first operator. The default input_columns
  110. /// is the first column
  111. /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
  112. /// This parameter is mandatory if len(input_columns) != len(output_columns)
  113. /// The size of this list must match the number of output columns of the
  114. /// last operation. The default output_columns will have the same
  115. /// name as the input columns, i.e., the columns will be replaced
  116. /// \param[in] project_columns A list of column names to project
  117. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  118. /// \return Shared pointer to the current MapDataset
  119. std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorOperation>> operations,
  120. const std::vector<std::string> &input_columns = {},
  121. const std::vector<std::string> &output_columns = {},
  122. const std::vector<std::string> &project_columns = {},
  123. const std::shared_ptr<DatasetCache> &cache = nullptr,
  124. std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
  125. return std::make_shared<MapDataset>(shared_from_this(), operations, input_columns, output_columns, project_columns,
  126. cache, callbacks);
  127. }
  128. /// \brief Function to create a Project Dataset
  129. /// \notes Applies project to the dataset
  130. /// \param[in] columns The name of columns to project
  131. /// \return Shared pointer to the current Dataset
  132. std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns) {
  133. return std::make_shared<ProjectDataset>(shared_from_this(), columns);
  134. }
  135. /// \brief Function to create a Shuffle Dataset
  136. /// \notes Randomly shuffles the rows of this dataset
  137. /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
  138. /// \return Shared pointer to the current ShuffleDataset
  139. std::shared_ptr<ShuffleDataset> Shuffle(int32_t buffer_size) {
  140. return std::make_shared<ShuffleDataset>(shared_from_this(), buffer_size);
  141. }
  142. std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
  143. protected:
  144. std::shared_ptr<TreeGetters> tree_getters_;
  145. std::shared_ptr<DatasetNode> ir_node_;
  146. };
  147. class BatchDataset : public Dataset {
  148. public:
  149. BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false);
  150. ~BatchDataset() = default;
  151. };
  152. class MapDataset : public Dataset {
  153. public:
  154. MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
  155. const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns,
  156. const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache,
  157. std::vector<std::shared_ptr<DSCallback>> callbacks);
  158. ~MapDataset() = default;
  159. };
  160. class ProjectDataset : public Dataset {
  161. public:
  162. ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &columns);
  163. ~ProjectDataset() = default;
  164. };
  165. class ShuffleDataset : public Dataset {
  166. public:
  167. ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size);
  168. ~ShuffleDataset() = default;
  169. };
  170. /// \brief Function to create a SchemaObj
  171. /// \param[in] schema_file Path of schema file
  172. /// \return Shared pointer to the current schema
  173. std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "");
  174. class AlbumDataset : public Dataset {
  175. public:
  176. AlbumDataset(const std::string &dataset_dir, const std::string &data_schema,
  177. const std::vector<std::string> &column_names = {}, bool decode = false,
  178. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  179. const std::shared_ptr<DatasetCache> &cache = nullptr);
  180. ~AlbumDataset() = default;
  181. };
  182. /// \brief Function to create an AlbumDataset
  183. /// \notes The generated dataset is specified through setting a schema
  184. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  185. /// \param[in] data_schema Path to dataset schema file
  186. /// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
  187. /// (default = {})
  188. /// \param[in] decode the option to decode the images in dataset (default = false)
  189. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  190. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  191. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  192. /// \return Shared pointer to the current Dataset
  193. std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  194. const std::vector<std::string> &column_names = {}, bool decode = false,
  195. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  196. const std::shared_ptr<DatasetCache> &cache = nullptr);
  197. class MnistDataset : public Dataset {
  198. public:
  199. explicit MnistDataset(const std::string &dataset_dir, const std::string &usage = "all",
  200. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  201. const std::shared_ptr<DatasetCache> &cache = nullptr);
  202. ~MnistDataset() = default;
  203. };
  204. /// \brief Function to create a MnistDataset
  205. /// \notes The generated dataset has two columns ["image", "label"]
  206. /// \param[in] dataset_dir Path to the root directory that contains the dataset
  207. /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all").
  208. /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
  209. /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
  210. /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
  211. /// \return Shared pointer to the current MnistDataset
  212. std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = "all",
  213. const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
  214. const std::shared_ptr<DatasetCache> &cache = nullptr);
  215. } // namespace dataset
  216. } // namespace mindspore
  217. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_