You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_op.h 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
  17. #define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_
  18. #include <memory>
  19. #include <queue>
  20. #include <string>
  21. #include <unordered_map>
  22. #include <utility>
  23. #include <vector>
  24. #include "dataset/core/config_manager.h"
  25. #include "dataset/core/tensor.h"
  26. #include "dataset/engine/dataset_iterator.h"
  27. #include "dataset/engine/datasetops/parallel_op.h"
  28. #include "dataset/util/status.h"
  29. namespace mindspore {
  30. namespace dataset {
  31. class DataBuffer;
  32. using TensorBatch = std::vector<std::shared_ptr<Tensor>>;
  33. using TensorBatchTable = std::vector<TensorBatch>;
  34. class BatchOp : public ParallelOp {
  35. public:
  36. class Builder {
  37. public:
  38. // Builder constructor for Batch, batch size needs to be specified
  39. // @param int32_t batch_size
  40. explicit Builder(int32_t batch_size);
  41. // Builder constructor for Batch, batch size function needs to be specified
  42. // @param py::function batch_size_func
  43. explicit Builder(py::function batch_size_func);
  44. // Default destructor
  45. ~Builder() = default;
  46. // set number of parallel Workers on batch
  47. // @param int32_t num_workers
  48. // @return Builder & reference to builder class object
  49. Builder &SetNumWorkers(int32_t num_workers) {
  50. builder_num_workers_ = num_workers;
  51. return *this;
  52. }
  53. // set drop for batch op,default false
  54. // @param bool drop
  55. // @return Builder & reference to builder class object
  56. Builder &SetDrop(bool drop) {
  57. builder_drop_ = drop;
  58. return *this;
  59. }
  60. // set connector size for batch
  61. // @param int32_t op_conn_size
  62. // @return Builder & reference to builder class object
  63. Builder &SetOpConnectorSize(int32_t op_connector_size) {
  64. builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size);
  65. return *this;
  66. }
  67. // set columns to perform map on
  68. // @param const std::vector<std::string> & cols_to_map - name of columns to perform map on
  69. // @return Builder & reference to builder class object
  70. Builder &SetColumnsToMap(const std::vector<std::string> &cols_to_map) {
  71. builder_cols_to_map_ = cols_to_map;
  72. return *this;
  73. }
  74. // set columns to perform map on
  75. // @param const std::vector<std::string> & cols_to_map - name of columns to perform map on
  76. // @return Builder & reference to builder class object
  77. Builder &SetBatchMapFunc(py::function batch_map_func) {
  78. builder_batch_map_func_ = batch_map_func;
  79. return *this;
  80. }
  81. // SetBatchSizeFunc, a function that calls to python after every batch is made
  82. // @param py::function batch_size_func - python function to call, GIL required before calling
  83. // @return Builder & reference to builder class object
  84. Builder &SetBatchSizeFunc(py::function batch_size_func) {
  85. builder_batch_size_func_ = batch_size_func;
  86. return *this;
  87. }
  88. // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg
  89. // @return Status - The error code return
  90. Status Build(std::shared_ptr<BatchOp> *);
  91. private:
  92. // Sanity check for builder class args
  93. // @return Status - The error code return
  94. Status SanityCheck();
  95. bool builder_drop_;
  96. int32_t builder_batch_size_;
  97. int32_t builder_num_workers_;
  98. int32_t builder_op_connector_size_;
  99. std::vector<std::string> builder_cols_to_map_;
  100. py::function builder_batch_size_func_;
  101. py::function builder_batch_map_func_;
  102. };
  103. enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 };
  104. // Parameters associate with one batch.
  105. // This struct is used for both internal control and python callback.
  106. // This struct is bound to python with read-only access.
  107. struct CBatchInfo {
  108. CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl)
  109. : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {}
  110. CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {}
  111. CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {}
  112. explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {}
  113. int64_t epoch_num_; // i-th epoch. i starts from 0
  114. int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0
  115. int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0
  116. batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3
  117. const int64_t get_batch_num() const { return batch_num_; }
  118. const int64_t get_epoch_num() const { return epoch_num_; }
  119. };
  120. // BatchOp constructor
  121. // @param int32_t batch_size
  122. // @param bool drop
  123. // @param int32_t op_queue_size
  124. // @param int32_t rows_per_buf
  125. // @param int32_t num_workers
  126. BatchOp(int32_t batch_size, bool drop, int32_t op_queue_size, int32_t num_workers, const std::vector<std::string> &,
  127. py::function batch_size_func, py::function batch_map_func);
  128. // BatchOp destructor
  129. ~BatchOp() {}
  130. // @param int32_t workerId
  131. // @return Status - The error code return
  132. Status EofReceived(int32_t) override;
  133. // @param int32_t workerId
  134. // @return Status - The error code return
  135. Status EoeReceived(int32_t) override;
  136. // A print method typically used for debugging
  137. // @param out - The output stream to write output to
  138. // @param show_all - A bool to control if you want to show all info or just a summary
  139. void Print(std::ostream &out, bool show_all) const override;
  140. // << Stream output operator overload
  141. // @notes This allows you to write the debug print info using stream operators
  142. // @param out - reference to the output stream being overloaded
  143. // @param sO - reference to the BatchOp to display
  144. // @return - the output stream must be returned
  145. friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) {
  146. bo.Print(out, false);
  147. return out;
  148. }
  149. // Main loop of batch
  150. // @return Status - The error code return
  151. Status operator()() override;
  152. private:
  153. // Worker thread for doing the memcpy of batch
  154. // @param int32_t param workerId
  155. // @return Status - The error code return
  156. Status WorkerEntry(int32_t worker_id) override;
  157. // Generate buffer with batched tensors
  158. // @return Status - The error code return
  159. Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
  160. std::unique_ptr<DataBuffer> *db);
  161. // batch the rows in src table then put it to dest table
  162. // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
  163. // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
  164. // @param int32_t size - batch_size
  165. // @return Status - The error code return
  166. Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, size_t size);
  167. // Function that calls pyfunc to perform map on batch
  168. // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
  169. // @return Status - The error code return
  170. Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
  171. // the number of thread pulling from the mOutConnector of the Op below
  172. // @return int32_t, 1
  173. int32_t num_consumers() const override { return 1; }
  174. // get the batch size for next batch
  175. // @return Status - The error code return
  176. Status GetBatchSize(int32_t *batch_size, CBatchInfo info);
  177. // Do the initialization of all queues then start all worker threads
  178. // @return Status - The error code return
  179. Status LaunchThreadsAndInitOp();
  180. // Invoke batch size function with current BatchInfo to generate batch size.
  181. // @return Status - The error code return
  182. Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info);
  183. // Invoke batch map function with current BatchInfo to generate tensors to batch.
  184. // @return Status - The error code return
  185. Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
  186. int32_t start_batch_size_;
  187. bool drop_;
  188. // Name of the columns to perform map op on
  189. std::vector<std::string> input_column_names_;
  190. // Iterator for fetching
  191. std::unique_ptr<ChildIterator> child_iterator_;
  192. // Map of column_name: column_index
  193. std::unordered_map<std::string, int32_t> column_name_map_;
  194. // Internal queue for task distribution
  195. QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_;
  196. // Function pointer of batch size function
  197. py::function batch_size_func_;
  198. // Function pointer of per batch map function
  199. py::function batch_map_func_;
  200. };
  201. } // namespace dataset
  202. } // namespace mindspore
  203. #endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_