You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_op.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
  17. #define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
  18. #include <algorithm>
  19. #include <map>
  20. #include <memory>
  21. #include <queue>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include "dataset/core/config_manager.h"
  28. #include "dataset/core/tensor.h"
  29. #include "dataset/engine/dataset_iterator.h"
  30. #include "dataset/engine/datasetops/parallel_op.h"
  31. #include "dataset/util/status.h"
  32. namespace mindspore {
  33. namespace dataset {
  34. class DataBuffer;
  35. using TensorBatch = TensorRow;
  36. using TensorBatchTable = std::vector<TensorBatch>;
  37. using PadInfo = std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>>;
  38. class BatchOp : public ParallelOp {
  39. public:
  40. class Builder {
  41. public:
  42. // Builder constructor for Batch, batch size needs to be specified
  43. // @param int32_t batch_size
  44. explicit Builder(int32_t batch_size);
  45. // Default destructor
  46. ~Builder() = default;
  47. // set number of parallel Workers on batch
  48. // @param int32_t num_workers
  49. // @return Builder & reference to builder class object
  50. Builder &SetNumWorkers(int32_t num_workers) {
  51. builder_num_workers_ = num_workers;
  52. return *this;
  53. }
  54. // set drop for batch op,default false
  55. // @param bool drop
  56. // @return Builder & reference to builder class object
  57. Builder &SetDrop(bool drop) {
  58. builder_drop_ = drop;
  59. return *this;
  60. }
  61. Builder &SetPaddingMap(const PadInfo &pad_map, bool pad = true) {
  62. builder_pad_ = pad;
  63. builder_pad_map_ = pad_map;
  64. return *this;
  65. }
  66. // set connector size for batch
  67. // @param int32_t op_conn_size
  68. // @return Builder & reference to builder class object
  69. Builder &SetOpConnectorSize(int32_t op_connector_size) {
  70. builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size);
  71. return *this;
  72. }
  73. // set columns to perform map on
  74. // @param const std::vector<std::string> & cols_to_map - name of columns to perform map on
  75. // @return Builder & reference to builder class object
  76. Builder &SetColumnsToMap(const std::vector<std::string> &cols_to_map) {
  77. builder_cols_to_map_ = cols_to_map;
  78. return *this;
  79. }
  80. // set columns to perform map on
  81. // @param const std::vector<std::string> & cols_to_map - name of columns to perform map on
  82. // @return Builder & reference to builder class object
  83. Builder &SetBatchMapFunc(py::function batch_map_func) {
  84. builder_batch_map_func_ = batch_map_func;
  85. return *this;
  86. }
  87. // SetBatchSizeFunc, a function that calls to python after every batch is made
  88. // @param py::function batch_size_func - python function to call, GIL required before calling
  89. // @return Builder & reference to builder class object
  90. Builder &SetBatchSizeFunc(py::function batch_size_func) {
  91. builder_batch_size_func_ = batch_size_func;
  92. return *this;
  93. }
  94. // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg
  95. // @return Status - The error code return
  96. Status Build(std::shared_ptr<BatchOp> *);
  97. private:
  98. // Sanity check for builder class args
  99. // @return Status - The error code return
  100. Status SanityCheck();
  101. bool builder_drop_;
  102. bool builder_pad_;
  103. int32_t builder_batch_size_;
  104. int32_t builder_num_workers_;
  105. int32_t builder_op_connector_size_;
  106. std::vector<std::string> builder_cols_to_map_;
  107. PadInfo builder_pad_map_;
  108. py::function builder_batch_size_func_;
  109. py::function builder_batch_map_func_;
  110. };
  111. enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 };
  112. // Parameters associate with one batch.
  113. // This struct is used for both internal control and python callback.
  114. // This struct is bound to python with read-only access.
  115. struct CBatchInfo {
  116. CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl)
  117. : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {}
  118. CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {}
  119. CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {}
  120. explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {}
  121. int64_t epoch_num_; // i-th epoch. i starts from 0
  122. int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0
  123. int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0
  124. batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3
  125. const int64_t get_batch_num() const { return batch_num_; }
  126. const int64_t get_epoch_num() const { return epoch_num_; }
  127. };
  128. // BatchOp constructor
  129. // @param int32_t batch_size
  130. // @param bool drop
  131. // @param int32_t op_queue_size
  132. // @param int32_t rows_per_buf
  133. // @param int32_t num_workers
  134. BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
  135. const std::vector<std::string> &, py::function batch_size_func, py::function batch_map_func, PadInfo pad_map);
  136. // BatchOp destructor
  137. ~BatchOp() {}
  138. // @param int32_t workerId
  139. // @return Status - The error code return
  140. Status EofReceived(int32_t) override;
  141. // @param int32_t workerId
  142. // @return Status - The error code return
  143. Status EoeReceived(int32_t) override;
  144. // A print method typically used for debugging
  145. // @param out - The output stream to write output to
  146. // @param show_all - A bool to control if you want to show all info or just a summary
  147. void Print(std::ostream &out, bool show_all) const override;
  148. // << Stream output operator overload
  149. // @notes This allows you to write the debug print info using stream operators
  150. // @param out - reference to the output stream being overloaded
  151. // @param sO - reference to the BatchOp to display
  152. // @return - the output stream must be returned
  153. friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) {
  154. bo.Print(out, false);
  155. return out;
  156. }
  157. // Main loop of batch
  158. // @return Status - The error code return
  159. Status operator()() override;
  160. // Base-class override for NodePass visitor acceptor.
  161. // @param p - Pointer to the NodePass to be accepted.
  162. // @param modified - Whether this node visit modified the pipeline.
  163. // @return - Status of the node visit.
  164. Status Accept(NodePass *p, bool *modified) override;
  165. // Op name getter
  166. // @return Name of the current Op
  167. std::string Name() const override { return "BatchOp"; }
  168. // batch the rows in src table then put it to dest table
  169. // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
  170. // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
  171. // @param int32_t size - batch_size
  172. // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
  173. // @return Status - The error code return
  174. static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
  175. dsize_t batch_size);
  176. // @param table
  177. // @param const PadInfo &pad_info pad info
  178. // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
  179. // @return Status - The error code return
  180. static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
  181. const std::unordered_map<std::string, int32_t> &column_name_id_map);
  182. private:
  183. // Worker thread for doing the memcpy of batch
  184. // @param int32_t param workerId
  185. // @return Status - The error code return
  186. Status WorkerEntry(int32_t worker_id) override;
  187. // Generate buffer with batched tensors
  188. // @return Status - The error code return
  189. Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
  190. std::unique_ptr<DataBuffer> *db);
  191. // Function that calls pyfunc to perform map on batch
  192. // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
  193. // @return Status - The error code return
  194. Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
  195. // @param const PadInfo &pad_info pad info to unpack
  196. // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
  197. // @param std::set<int32_t> *cols, col ids to perform pad on
  198. // @param std::vector<float> *vals, default padding value for each column
  199. // @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user
  200. // @return Status - The error code return
  201. static Status UnpackPadInfo(const PadInfo &pad_info,
  202. const std::unordered_map<std::string, int32_t> &column_name_id_map,
  203. std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
  204. std::vector<std::vector<dsize_t>> *pad_shapes);
  205. // the number of thread pulling from the mOutConnector of the Op below
  206. // @return int32_t, 1
  207. int32_t num_consumers() const override { return 1; }
  208. // get the batch size for next batch
  209. // @return Status - The error code return
  210. Status GetBatchSize(int32_t *batch_size, CBatchInfo info);
  211. // Do the initialization of all queues then start all worker threads
  212. // @return Status - The error code return
  213. Status LaunchThreadsAndInitOp();
  214. // Invoke batch size function with current BatchInfo to generate batch size.
  215. // @return Status - The error code return
  216. Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info);
  217. // Invoke batch map function with current BatchInfo to generate tensors to batch.
  218. // @return Status - The error code return
  219. Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
  220. int32_t start_batch_size_;
  221. bool drop_; // bool for whether to drop remainder or not
  222. bool pad_; // bool for whether to perform padding on tensor
  223. std::vector<std::string> pyfunc_column_names_; // Name of the columns to perform map op on
  224. PadInfo pad_info_; // column names to perform padding on
  225. std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
  226. QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_; // internal queue for syncing worker
  227. py::function batch_size_func_; // Function pointer of batch size function
  228. py::function batch_map_func_; // Function pointer of per batch map function
  229. };
  230. } // namespace dataset
  231. } // namespace mindspore
  232. #endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_