You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

de_pipeline.h 7.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_API_DE_PIPELINE_H_
  17. #define DATASET_API_DE_PIPELINE_H_
  18. #include <iostream>
  19. #include <memory>
  20. #include <stack>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <vector>
  25. #include "dataset/core/client.h" // DE client
  26. #include "dataset/engine/dataset_iterator.h"
  27. #include "dataset/util/status.h"
  28. #include "pybind11/numpy.h"
  29. #include "pybind11/pybind11.h"
  30. #include "pybind11/stl.h"
  31. namespace py = pybind11;
  32. namespace mindspore {
  33. namespace dataset {
  34. using DsOpPtr = std::shared_ptr<DatasetOp>;
  35. // enum for the dataset operator names
  36. enum OpName {
  37. kShuffle,
  38. kMindrecord,
  39. kBatch,
  40. kBucketBatch,
  41. kBarrier,
  42. kCache,
  43. kRepeat,
  44. kSkip,
  45. kTake,
  46. kZip,
  47. kConcat,
  48. kMap,
  49. kFilter,
  50. kDeviceQueue,
  51. kGenerator,
  52. kRename,
  53. kTfReader,
  54. kProject,
  55. kImageFolder,
  56. kMnist,
  57. kManifest,
  58. kVoc,
  59. kCoco,
  60. kCifar10,
  61. kCifar100,
  62. kCelebA,
  63. kRandomData,
  64. kTextFile,
  65. kBuildVocab,
  66. kClue
  67. };
  68. // The C++ binder class that we expose to the python script.
  69. class DEPipeline {
  70. public:
  71. DEPipeline();
  72. ~DEPipeline();
  73. // Function to add a Node to the Execution Tree.
  74. Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output);
  75. // Function to add a child and parent relationship.
  76. static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op);
  77. // Function to assign the node as root.
  78. Status AssignRootNode(const DsOpPtr &dataset_op);
  79. // Function to launch the tree execution.
  80. Status LaunchTreeExec();
  81. // Get a row of data as dictionary of column name to the value.
  82. Status GetNextAsMap(py::dict *output);
  83. // Get a row of data as list.
  84. Status GetNextAsList(py::list *output);
  85. Status GetOutputShapes(py::list *output);
  86. Status GetOutputTypes(py::list *output);
  87. int GetDatasetSize() const;
  88. int GetBatchSize() const;
  89. int GetRepeatCount() const;
  90. Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  91. Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  92. Status BuildMindrecordSamplerChain(const py::handle &handle,
  93. std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
  94. int num_padded);
  95. Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  96. Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  97. Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  98. Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  99. Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  100. Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
  101. std::shared_ptr<DatasetOp> *bottom);
  102. Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  103. Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  104. Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  105. Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  106. Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  107. Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  108. Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  109. Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  110. Status ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  111. Status ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  112. Status ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  113. Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  114. Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  115. Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  116. Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  117. Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  118. void PrintTree();
  119. int32_t GetNumClasses() const;
  120. Status ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  121. Status SetBatchParameters(const py::dict &args);
  122. Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  123. Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  124. Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  125. Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
  126. private:
  127. // Execution tree that links the dataset operators.
  128. std::shared_ptr<ExecutionTree> tree_;
  129. std::unique_ptr<DatasetIterator> iterator_;
  130. static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
  131. /// \brief Helper function to inject a shuffle operator over top of the current operation being built.
  132. /// \param[in] shuffle_size The size to use in the shuffle buffer
  133. /// \param[in] input_op The operator to build shuffle on top of
  134. /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be
  135. /// the shuffle operator
  136. /// \return Status return code
  137. Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
  138. std::shared_ptr<DatasetOp> *shuffle_op);
  139. /// \brief Helper function to compute the shuffle size
  140. /// \param[in] num_files The number of files in the dataset
  141. /// \param[in] num_devices The number of devices in the dataset
  142. /// \param[in] num_rows The number of rows in the dataset
  143. /// \param[in] total_rows An upper bound on the total rows in the dataset
  144. /// \param[out] shuffle_size The resultant computed shuffle size
  145. /// \return Status return code
  146. Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
  147. int64_t *shuffle_size);
  148. int batch_size_;
  149. int repeat_num_;
  150. int num_rows_;
  151. int num_classes_;
  152. int temp_batch_size_;
  153. bool temp_drop_remainder_;
  154. };
  155. } // namespace dataset
  156. } // namespace mindspore
  157. #endif // DATASET_API_DE_PIPELINE_H_