You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_op.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
  17. #define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
  18. #include <algorithm>
  19. #include <map>
  20. #include <memory>
  21. #include <queue>
  22. #include <set>
  23. #include <string>
  24. #include <unordered_map>
  25. #include <utility>
  26. #include <vector>
  27. #include "dataset/core/config_manager.h"
  28. #include "dataset/core/tensor.h"
  29. #include "dataset/engine/dataset_iterator.h"
  30. #include "dataset/engine/datasetops/parallel_op.h"
  31. #include "dataset/util/status.h"
  32. namespace mindspore {
  33. namespace dataset {
  34. class DataBuffer;
  35. using TensorBatch = std::vector<std::shared_ptr<Tensor>>;
  36. using TensorBatchTable = std::vector<TensorBatch>;
  37. class BatchOp : public ParallelOp {
  38. public:
  39. class Builder {
  40. public:
  41. // Builder constructor for Batch, batch size needs to be specified
  42. // @param int32_t batch_size
  43. explicit Builder(int32_t batch_size);
  44. // Default destructor
  45. ~Builder() = default;
  46. // set number of parallel Workers on batch
  47. // @param int32_t num_workers
  48. // @return Builder & reference to builder class object
  49. Builder &SetNumWorkers(int32_t num_workers) {
  50. builder_num_workers_ = num_workers;
  51. return *this;
  52. }
  53. // set drop for batch op,default false
  54. // @param bool drop
  55. // @return Builder & reference to builder class object
  56. Builder &SetDrop(bool drop) {
  57. builder_drop_ = drop;
  58. return *this;
  59. }
  60. Builder &SetPaddingMap(const std::map<std::string, std::pair<TensorShape, float>> &pad_map, bool pad = true) {
  61. builder_pad_ = pad;
  62. builder_pad_map_ = pad_map;
  63. return *this;
  64. }
  65. // set connector size for batch
  66. // @param int32_t op_conn_size
  67. // @return Builder & reference to builder class object
  68. Builder &SetOpConnectorSize(int32_t op_connector_size) {
  69. builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size);
  70. return *this;
  71. }
  72. // set columns to perform map on
  73. // @param const std::vector<std::string> & cols_to_map - name of columns to perform map on
  74. // @return Builder & reference to builder class object
  75. Builder &SetColumnsToMap(const std::vector<std::string> &cols_to_map) {
  76. builder_cols_to_map_ = cols_to_map;
  77. return *this;
  78. }
  79. // set columns to perform map on
  80. // @param const std::vector<std::string> & cols_to_map - name of columns to perform map on
  81. // @return Builder & reference to builder class object
  82. Builder &SetBatchMapFunc(py::function batch_map_func) {
  83. builder_batch_map_func_ = batch_map_func;
  84. return *this;
  85. }
  86. // SetBatchSizeFunc, a function that calls to python after every batch is made
  87. // @param py::function batch_size_func - python function to call, GIL required before calling
  88. // @return Builder & reference to builder class object
  89. Builder &SetBatchSizeFunc(py::function batch_size_func) {
  90. builder_batch_size_func_ = batch_size_func;
  91. return *this;
  92. }
  93. // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg
  94. // @return Status - The error code return
  95. Status Build(std::shared_ptr<BatchOp> *);
  96. private:
  97. // Sanity check for builder class args
  98. // @return Status - The error code return
  99. Status SanityCheck();
  100. bool builder_drop_;
  101. bool builder_pad_;
  102. int32_t builder_batch_size_;
  103. int32_t builder_num_workers_;
  104. int32_t builder_op_connector_size_;
  105. std::vector<std::string> builder_cols_to_map_;
  106. std::map<std::string, std::pair<TensorShape, float>> builder_pad_map_;
  107. py::function builder_batch_size_func_;
  108. py::function builder_batch_map_func_;
  109. };
  110. enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 };
  111. // Parameters associate with one batch.
  112. // This struct is used for both internal control and python callback.
  113. // This struct is bound to python with read-only access.
  114. struct CBatchInfo {
  115. CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl)
  116. : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {}
  117. CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {}
  118. CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {}
  119. explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {}
  120. int64_t epoch_num_; // i-th epoch. i starts from 0
  121. int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0
  122. int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0
  123. batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3
  124. const int64_t get_batch_num() const { return batch_num_; }
  125. const int64_t get_epoch_num() const { return epoch_num_; }
  126. };
  127. // BatchOp constructor
  128. // @param int32_t batch_size
  129. // @param bool drop
  130. // @param int32_t op_queue_size
  131. // @param int32_t rows_per_buf
  132. // @param int32_t num_workers
  133. BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
  134. const std::vector<std::string> &, py::function batch_size_func, py::function batch_map_func,
  135. std::map<std::string, std::pair<TensorShape, float>> pad_map);
  136. // BatchOp destructor
  137. ~BatchOp() {}
  138. // @param int32_t workerId
  139. // @return Status - The error code return
  140. Status EofReceived(int32_t) override;
  141. // @param int32_t workerId
  142. // @return Status - The error code return
  143. Status EoeReceived(int32_t) override;
  144. // A print method typically used for debugging
  145. // @param out - The output stream to write output to
  146. // @param show_all - A bool to control if you want to show all info or just a summary
  147. void Print(std::ostream &out, bool show_all) const override;
  148. // << Stream output operator overload
  149. // @notes This allows you to write the debug print info using stream operators
  150. // @param out - reference to the output stream being overloaded
  151. // @param sO - reference to the BatchOp to display
  152. // @return - the output stream must be returned
  153. friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) {
  154. bo.Print(out, false);
  155. return out;
  156. }
  157. // Main loop of batch
  158. // @return Status - The error code return
  159. Status operator()() override;
  160. // Pad input tensor according pad_shape, need to have same rank.
  161. // @param std::shared_ptr<Tensor> src - tensor to pad from
  162. // @param std::shared_ptr<Tensor> *dst - return tensor padded
  163. // @param std::vector<dsize_t> pad_shape - shape to pad to
  164. // @param float pad_val - value to pad with
  165. // @return - The error code return
  166. Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
  167. float pad_val);
  168. // Base-class override for NodePass visitor acceptor.
  169. // @param p - Pointer to the NodePass to be accepted.
  170. // @param modified - Whether this node visit modified the pipeline.
  171. // @return - Status of the node visit.
  172. Status Accept(NodePass *p, bool *modified) override;
  173. private:
  174. // recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
  175. // it is only meant to be called by PadTensor.
  176. // @tparam T - type of tensor and fill value
  177. // @param std::shared_ptr<Tensor> src - Tensor to pad from
  178. // @param std::shared_ptr<Tensor>* dst - Tensor to pad to, return value
  179. // @param std::vector<dsize_t> cur_ind - recursion helper
  180. // @param T pad_val - value to pad tensor with
  181. // @param size_t cur_dim - recursion helper
  182. // @return Status - The error code return
  183. Status PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> dst, std::vector<dsize_t> cur_ind,
  184. const std::vector<dsize_t> &src_s, const std::vector<dsize_t> &dst_s, size_t cur_dim = 0);
  185. // Worker thread for doing the memcpy of batch
  186. // @param int32_t param workerId
  187. // @return Status - The error code return
  188. Status WorkerEntry(int32_t worker_id) override;
  189. // Generate buffer with batched tensors
  190. // @return Status - The error code return
  191. Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
  192. std::unique_ptr<DataBuffer> *db);
  193. // batch the rows in src table then put it to dest table
  194. // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
  195. // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
  196. // @param int32_t size - batch_size
  197. // @return Status - The error code return
  198. Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, size_t size);
  199. // Function that calls pyfunc to perform map on batch
  200. // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
  201. // @return Status - The error code return
  202. Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
  203. // @param std::set<int32_t> *cols, col ids to perform pad on
  204. // @param std::vector<float> *vals, default padding value for each column
  205. // @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user
  206. // @return Status - The error code return
  207. Status UnpackPadInfo(std::set<int32_t> *cols, std::vector<float> *vals, std::vector<std::vector<dsize_t>> *shapes);
  208. // @param table_pair
  209. // @return Status - The error code return
  210. Status PadColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
  211. // the number of thread pulling from the mOutConnector of the Op below
  212. // @return int32_t, 1
  213. int32_t num_consumers() const override { return 1; }
  214. // get the batch size for next batch
  215. // @return Status - The error code return
  216. Status GetBatchSize(int32_t *batch_size, CBatchInfo info);
  217. // Do the initialization of all queues then start all worker threads
  218. // @return Status - The error code return
  219. Status LaunchThreadsAndInitOp();
  220. // Invoke batch size function with current BatchInfo to generate batch size.
  221. // @return Status - The error code return
  222. Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info);
  223. // Invoke batch map function with current BatchInfo to generate tensors to batch.
  224. // @return Status - The error code return
  225. Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
  226. int32_t start_batch_size_;
  227. bool drop_; // bool for whether to drop remainder or not
  228. bool pad_; // bool for whether to perform padding on tensor
  229. std::vector<std::string> pyfunc_column_names_; // Name of the columns to perform map op on
  230. std::map<std::string, std::pair<TensorShape, float>> pad_info_; // column names to perform padding on
  231. std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
  232. QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_; // internal queue for syncing worker
  233. py::function batch_size_func_; // Function pointer of batch size function
  234. py::function batch_map_func_; // Function pointer of per batch map function
  235. };
  236. } // namespace dataset
  237. } // namespace mindspore
  238. #endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_