You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.cc 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include "minddata/dataset/include/datasets.h"
  18. #include "minddata/dataset/include/transforms.h"
  19. #include "minddata/dataset/include/samplers.h"
  20. #include "minddata/dataset/engine/dataset_iterator.h"
  21. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  22. #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
  24. #include "minddata/dataset/engine/datasetops/batch_op.h"
  25. #include "minddata/dataset/engine/datasetops/map_op.h"
  26. #include "minddata/dataset/engine/datasetops/repeat_op.h"
  27. #include "minddata/dataset/engine/datasetops/shuffle_op.h"
  28. #include "minddata/dataset/engine/datasetops/project_op.h"
  29. #include "minddata/dataset/engine/datasetops/zip_op.h"
  30. #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
  31. #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
  32. #include "minddata/dataset/core/config_manager.h"
  33. #include "minddata/dataset/util/random.h"
  34. namespace mindspore {
  35. namespace dataset {
  36. namespace api {
  37. #define RETURN_NULL_IF_ERROR(_s) \
  38. do { \
  39. Status __rc = (_s); \
  40. if (__rc.IsError()) { \
  41. return nullptr; \
  42. } \
  43. } while (false)
  44. // Function to create the iterator, which will build and launch the execution tree.
  45. std::shared_ptr<Iterator> Dataset::CreateIterator() {
  46. std::shared_ptr<Iterator> iter;
  47. try {
  48. iter = std::make_shared<Iterator>();
  49. Status rc = iter->BuildAndLaunchTree(shared_from_this());
  50. if (rc.IsError()) {
  51. MS_LOG(ERROR) << rc;
  52. MS_LOG(ERROR) << "CreateIterator failed.";
  53. return nullptr;
  54. }
  55. return iter;
  56. } catch (const std::exception &err) {
  57. MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what();
  58. return nullptr;
  59. }
  60. return iter;
  61. }
  62. // Constructor
  63. Dataset::Dataset() {
  64. // Fetch some default value from config manager
  65. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  66. num_workers_ = cfg->num_parallel_workers();
  67. rows_per_buffer_ = cfg->rows_per_buffer();
  68. connector_que_size_ = cfg->op_connector_size();
  69. }
  70. // Function to create a ImageFolderDataset.
  71. std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
  72. std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
  73. std::map<std::string, int32_t> class_indexing) {
  74. // This arg is exist in ImageFolderOp, but not externalized (in Python API). The default value is false.
  75. bool recursive = false;
  76. // Create logical representation of ImageFolderDataset.
  77. auto ds = std::make_shared<ImageFolderDataset>(dataset_dir, decode, sampler, recursive, extensions, class_indexing);
  78. // Call derived class validation method.
  79. return ds->ValidateParams() ? ds : nullptr;
  80. }
  81. // Function to create a MnistDataset.
  82. std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) {
  83. auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler);
  84. // Call derived class validation method.
  85. return ds->ValidateParams() ? ds : nullptr;
  86. }
  87. // Function to create a Cifar10Dataset.
  88. std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
  89. std::shared_ptr<SamplerObj> sampler) {
  90. auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, num_samples, sampler);
  91. // Call derived class validation method.
  92. return ds->ValidateParams() ? ds : nullptr;
  93. }
  94. // Function to create a Batch dataset
  95. std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
  96. // Default values
  97. std::vector<std::string> cols_to_map = {};
  98. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map;
  99. bool pad = false;
  100. auto ds = std::make_shared<BatchDataset>(batch_size, drop_remainder, pad, cols_to_map, pad_map);
  101. if (!ds->ValidateParams()) {
  102. return nullptr;
  103. }
  104. ds->children.push_back(shared_from_this());
  105. return ds;
  106. }
  107. // Function to create Repeat dataset.
  108. std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
  109. // Workaround for repeat == 1, do not inject repeat.
  110. if (count == 1) {
  111. return shared_from_this();
  112. }
  113. auto ds = std::make_shared<RepeatDataset>(count);
  114. if (!ds->ValidateParams()) {
  115. return nullptr;
  116. }
  117. ds->children.push_back(shared_from_this());
  118. return ds;
  119. }
  120. // Function to create a Map dataset.
  121. std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
  122. std::vector<std::string> input_columns,
  123. std::vector<std::string> output_columns,
  124. const std::vector<std::string> &project_columns) {
  125. auto ds = std::make_shared<MapDataset>(operations, input_columns, output_columns, project_columns);
  126. if (!ds->ValidateParams()) {
  127. return nullptr;
  128. }
  129. ds->children.push_back(shared_from_this());
  130. return ds;
  131. }
  132. // Function to create a ShuffleOp
  133. std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
  134. // Pass in reshuffle_each_epoch with true
  135. auto ds = std::make_shared<ShuffleDataset>(shuffle_size, true);
  136. if (!ds->ValidateParams()) {
  137. return nullptr;
  138. }
  139. ds->children.push_back(shared_from_this());
  140. return ds;
  141. }
  142. // Function to create a ProjectDataset.
  143. std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
  144. auto ds = std::make_shared<ProjectDataset>(columns);
  145. // Call derived class validation method.
  146. if (!ds->ValidateParams()) {
  147. return nullptr;
  148. }
  149. ds->children.push_back(shared_from_this());
  150. return ds;
  151. }
  152. // Function to create a Zip dataset
  153. std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  154. // Default values
  155. auto ds = std::make_shared<ZipDataset>();
  156. if (!ds->ValidateParams()) {
  157. return nullptr;
  158. }
  159. for (auto dataset : datasets) {
  160. ds->children.push_back(dataset);
  161. }
  162. return ds;
  163. }
  164. // Helper function to create default RandomSampler.
  165. std::shared_ptr<SamplerObj> CreateDefaultSampler() {
  166. const int32_t num_samples = 0; // 0 means to sample all ids.
  167. bool replacement = false;
  168. return std::make_shared<RandomSamplerObj>(replacement, num_samples);
  169. }
  170. /* ####################################### Derived Dataset classes ################################# */
  171. ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
  172. bool recursive, std::set<std::string> extensions,
  173. std::map<std::string, int32_t> class_indexing)
  174. : dataset_dir_(dataset_dir),
  175. decode_(decode),
  176. sampler_(sampler),
  177. recursive_(recursive),
  178. class_indexing_(class_indexing),
  179. exts_(extensions) {}
  180. bool ImageFolderDataset::ValidateParams() {
  181. if (dataset_dir_.empty()) {
  182. MS_LOG(ERROR) << "No dataset path is specified.";
  183. return false;
  184. }
  185. return true;
  186. }
  187. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ImageFolderDataset::Build() {
  188. // A vector containing shared pointer to the Dataset Ops that this object will create
  189. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  190. // If user does not specify Sampler, create a default sampler, i.e., RandomSampler.
  191. if (sampler_ == nullptr) {
  192. sampler_ = CreateDefaultSampler();
  193. }
  194. // Do internal Schema generation.
  195. // This arg is exist in ImageFolderOp, but not externalized (in Python API).
  196. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  197. TensorShape scalar = TensorShape::CreateScalar();
  198. RETURN_NULL_IF_ERROR(
  199. schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  200. RETURN_NULL_IF_ERROR(
  201. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
  202. node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  203. recursive_, decode_, exts_, class_indexing_, std::move(schema),
  204. std::move(sampler_->Build())));
  205. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  206. }
  207. MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
  208. : dataset_dir_(dataset_dir), sampler_(sampler) {}
  209. bool MnistDataset::ValidateParams() {
  210. if (dataset_dir_.empty()) {
  211. MS_LOG(ERROR) << "No dataset path is specified.";
  212. return false;
  213. }
  214. return true;
  215. }
  216. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MnistDataset::Build() {
  217. // A vector containing shared pointer to the Dataset Ops that this object will create
  218. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  219. // If user does not specify Sampler, create a default sampler, i.e., RandomSampler.
  220. if (sampler_ == nullptr) {
  221. sampler_ = CreateDefaultSampler();
  222. }
  223. // Do internal Schema generation.
  224. auto schema = std::make_unique<DataSchema>();
  225. RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  226. TensorShape scalar = TensorShape::CreateScalar();
  227. RETURN_NULL_IF_ERROR(
  228. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  229. node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  230. std::move(schema), std::move(sampler_->Build())));
  231. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  232. }
  233. BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
  234. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
  235. : batch_size_(batch_size),
  236. drop_remainder_(drop_remainder),
  237. pad_(pad),
  238. cols_to_map_(cols_to_map),
  239. pad_map_(pad_map) {}
  240. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> BatchDataset::Build() {
  241. // A vector containing shared pointer to the Dataset Ops that this object will create
  242. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  243. #ifdef ENABLE_PYTHON
  244. py::function noop;
  245. node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
  246. cols_to_map_, noop, noop, pad_map_));
  247. #else
  248. node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
  249. cols_to_map_, pad_map_));
  250. #endif
  251. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  252. }
  253. bool BatchDataset::ValidateParams() {
  254. if (batch_size_ <= 0) {
  255. return false;
  256. }
  257. return true;
  258. }
  259. RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
  260. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> RepeatDataset::Build() {
  261. // A vector containing shared pointer to the Dataset Ops that this object will create
  262. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  263. node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
  264. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  265. }
  266. bool RepeatDataset::ValidateParams() {
  267. if (repeat_count_ <= 0) {
  268. return false;
  269. }
  270. return true;
  271. }
  272. MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns,
  273. std::vector<std::string> output_columns, const std::vector<std::string> &project_columns)
  274. : operations_(operations),
  275. input_columns_(input_columns),
  276. output_columns_(output_columns),
  277. project_columns_(project_columns) {}
  278. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> MapDataset::Build() {
  279. // A vector containing shared pointer to the Dataset Ops that this object will create
  280. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  281. // Currently default is true, and this is not exposed to user.
  282. bool perf_mode = true;
  283. std::vector<std::shared_ptr<TensorOp>> tensor_ops;
  284. // Build tensorOp from tensorOperation vector
  285. // This is to ensure each iterator hold its own copy of the tensorOp objects.
  286. (void)std::transform(
  287. operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
  288. [](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
  289. // This parameter will be removed with next rebase
  290. std::vector<std::string> col_orders;
  291. auto map_op =
  292. std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_, perf_mode);
  293. if (!project_columns_.empty()) {
  294. auto project_op = std::make_shared<ProjectOp>(project_columns_);
  295. node_ops.push_back(project_op);
  296. }
  297. node_ops.push_back(map_op);
  298. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  299. }
  300. bool MapDataset::ValidateParams() {
  301. if (operations_.empty()) {
  302. return false;
  303. }
  304. return true;
  305. }
  306. // Constructor for ShuffleDataset
  307. ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch)
  308. : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {}
  309. // Function to build the ShuffleOp
  310. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ShuffleDataset::Build() {
  311. // A vector containing shared pointer to the Dataset Ops that this object will create
  312. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  313. node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
  314. rows_per_buffer_));
  315. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  316. }
  317. // Function to validate the parameters for ShuffleDataset
  318. bool ShuffleDataset::ValidateParams() {
  319. if (shuffle_size_ <= 1) {
  320. MS_LOG(ERROR) << "ShuffleDataset: Invalid input, shuffle_size: " << shuffle_size_;
  321. return false;
  322. }
  323. return true;
  324. }
  325. // Constructor for Cifar10Dataset
  326. Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
  327. : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
  328. bool Cifar10Dataset::ValidateParams() {
  329. if (dataset_dir_.empty()) {
  330. MS_LOG(ERROR) << "No dataset path is specified.";
  331. return false;
  332. }
  333. if (num_samples_ < 0) {
  334. MS_LOG(ERROR) << "Number of samples cannot be negative";
  335. return false;
  336. }
  337. return true;
  338. }
  339. // Function to build CifarOp
  340. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Cifar10Dataset::Build() {
  341. // A vector containing shared pointer to the Dataset Ops that this object will create
  342. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  343. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  344. if (sampler_ == nullptr) {
  345. sampler_ = CreateDefaultSampler();
  346. }
  347. // Do internal Schema generation.
  348. auto schema = std::make_unique<DataSchema>();
  349. RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  350. TensorShape scalar = TensorShape::CreateScalar();
  351. RETURN_NULL_IF_ERROR(
  352. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  353. node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
  354. dataset_dir_, connector_que_size_, std::move(schema),
  355. std::move(sampler_->Build())));
  356. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  357. }
  358. // Function to build ProjectOp
  359. ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}
  360. bool ProjectDataset::ValidateParams() {
  361. if (columns_.empty()) {
  362. MS_LOG(ERROR) << "No columns are specified.";
  363. return false;
  364. }
  365. return true;
  366. }
  367. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ProjectDataset::Build() {
  368. // A vector containing shared pointer to the Dataset Ops that this object will create
  369. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  370. node_ops.push_back(std::make_shared<ProjectOp>(columns_));
  371. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  372. }
  373. // Function to build ZipOp
  374. ZipDataset::ZipDataset() {}
  375. bool ZipDataset::ValidateParams() { return true; }
  376. std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ZipDataset::Build() {
  377. // A vector containing shared pointer to the Dataset Ops that this object will create
  378. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  379. node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
  380. return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
  381. }
  382. } // namespace api
  383. } // namespace dataset
  384. } // namespace mindspore