You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

connector.h 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONNECTOR_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONNECTOR_H_
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include <vector>
  22. #include "minddata/dataset/util/task_manager.h"
  23. #include "minddata/dataset/util/queue.h"
  24. #include "minddata/dataset/util/services.h"
  25. #include "minddata/dataset/util/cond_var.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Connector is a communication data structure between two group of threads that
  29. // preserve the order.
  30. //
  31. // Example use case:
  32. // An initial tasks-list of [1,2,3,4,5,6,7,8,9] with 5 threads getting/processing elements from that list,
  33. // and pushing the processed elements to a Connector in any order whoever finishes processing first.
  34. // If the consumer of the Connector is single threaded, when the consumer pop() the
  35. // element from the Connector one by one, it will get [1,2,3,4,5,6,7,8,9].
  36. //
  37. // Requirements:
  38. // 1. Each thread in the group of consumer or producer threads must be assigned ids starting from 0.
  39. // 2. If your multi-threads program is not reading from a Connector class but
  40. // want to push to a Connector class, you must follow roundrobin element distribution,
  41. // i.e., the thread-id0 must have the first element, thread-id1 has the second element,
  42. // and so on; then each of this worker can push to the Connector class async in parallel.
  43. //
  44. // Blocking conditions:
  45. // 1. Connector.push(int, T) can block when the internal queue it's trying to push is full.
  46. // 2. Connector.pop(int) can block when
  47. // - The internal queue it's trying to pop is empty.
  48. // - The caller thread of pop() is not equal to the _expectConsumer. This is to enforce
  49. // the ordering.
  50. //
  51. // Future improvement:
  52. // 1. Fault tolerant: Right now, if one of the worker dies, the Connector will not work
  53. // properly.
  54. template <class T>
  55. class Connector {
  56. public:
  57. // Name: Constructor
  58. // Description: Initializing private members with the given input arguments.
  59. // expect_consumer_ and pop_from_ is initialized to 0 as part of
  60. // our requirements. We instantiate nProducers number of internal
  61. // queues so that each producer thread can push to its queue without
  62. // any sync overhead.
  63. // Constructor of Connector
  64. // Initializing private members with the given input arguments.
  65. // _expectConsumer and _popFrom is initialized to 0 as part of
  66. // our requirements. We instantiate nProducers number of internal
  67. // queues so that each producer thread can push to its queue without
  68. // any sync overhead.
  69. // @param n_producers The number of threads producing data into this DbConnector.
  70. // @param n_consumers The number of thread consuming data from this DbConnector.
  71. // @param queue_capacity The number of element (DataBuffer) for each queue.
  72. Connector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity)
  73. : num_producers_(n_producers), num_consumers_(n_consumers) {
  74. MS_LOG(DEBUG) << "A connector is created with " << n_producers << " producers and " << n_consumers << " consumers.";
  75. my_name_ = Services::GetUniqueID();
  76. // We require the consumers to have ids sequentially from 0 to the num_consumers_-1,
  77. // Otherwise a ordered list of consumer ids have to be passed here. (not implemented yet)
  78. expect_consumer_ = 0;
  79. // Roundrobin pop starts from index 0 of the queues_.
  80. pop_from_ = 0;
  81. // Initialize the queues_ to have num_producers_ number of queues.
  82. // Each queue is a blocking queue and has the same queue_capacity.
  83. queues_.Init(num_producers_, queue_capacity);
  84. }
  85. // Destructor of Connector
  86. virtual ~Connector() = default;
  87. // Get an element from the Connector.
  88. // @not Call to pop() can block the caller thread, see the blocking condition at the top of this file.
  89. // @param worker_id The id of a worker thread calling this method.
  90. // @param result The address of an object where the popped element will be placed.
  91. virtual Status Pop(int32_t worker_id, // The worker-id of the caller. See the requirement at the top of this file.
  92. T *result) noexcept {
  93. {
  94. MS_ASSERT(worker_id < num_consumers_);
  95. std::unique_lock<std::mutex> lk(m_);
  96. RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; }));
  97. RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result));
  98. pop_from_ = (pop_from_ + 1) % num_producers_;
  99. out_buffers_count_++;
  100. expect_consumer_ = (expect_consumer_ + 1) % num_consumers_;
  101. }
  102. cv_.NotifyAll();
  103. return Status::OK();
  104. }
  105. // Add an element into the DbConnector without the overhead of synchronization.
  106. // It may block when the internal queue is full.
  107. // The element passed to this function will be copied into the internal queue.
  108. // @param worker_id The id of a worker thread calling this method.
  109. // @param el A const lvalue element to be passed/added/pushed.
  110. Status Push(int32_t worker_id, const T &el) noexcept {
  111. MS_ASSERT(worker_id < static_cast<int32_t>(queues_.size()));
  112. MS_ASSERT(queues_[worker_id] != nullptr);
  113. return (queues_[worker_id]->Add(el));
  114. }
  115. auto out_buffers_count() const { return out_buffers_count_.load(); }
  116. // Add an element into the DbConnector without the overhead of synchronization.
  117. // It may block when the internal queue is full.
  118. // The element passed to this function will be forwarded into the internal queue.
  119. // @param worker_id The id of a worker thread calling this method.
  120. // @param el An element to be passed/added/pushed.
  121. virtual Status Push(int32_t worker_id, T &&el) noexcept {
  122. MS_ASSERT(worker_id < static_cast<int32_t>(queues_.size()));
  123. MS_ASSERT(queues_[worker_id] != nullptr);
  124. return (queues_[worker_id]->Add(std::forward<T>(el)));
  125. }
  126. // Resets the internal index tracking of the queue so that it can be used again with new inputs,
  127. // starting from the beginning.
  128. void Reset() {
  129. for (int i = 0; i < queues_.size(); ++i) {
  130. queues_[i]->ResetQue();
  131. }
  132. expect_consumer_ = 0;
  133. pop_from_ = 0;
  134. out_buffers_count_ = 0;
  135. MS_LOG(DEBUG) << "Connector counters reset.";
  136. }
  137. void Print(std::ostream &out, bool showAll) const {
  138. out << "\n--------- Connector ------------"
  139. << "\nConnector Name : " << my_name_ << "\nNumber of consumers : " << num_consumers_
  140. << "\nNumber of producers : " << num_producers_ << "\n";
  141. }
  142. friend std::ostream &operator<<(std::ostream &out, const Connector &con) {
  143. con.print(out, false);
  144. return out;
  145. }
  146. // Get current size of connector.
  147. int32_t size() const {
  148. int32_t size = 0;
  149. for (int32_t i = 0; i < queues_.size(); ++i) {
  150. size += queues_[i]->size();
  151. }
  152. return size;
  153. }
  154. int32_t capacity() const {
  155. int32_t capacity = 0;
  156. for (int32_t i = 0; i < queues_.size(); ++i) {
  157. capacity += queues_[i]->capacity();
  158. }
  159. return capacity;
  160. }
  161. // Register the internal resources with Task group for interruption service.
  162. // @param vg
  163. // @return
  164. Status Register(TaskGroup *vg) {
  165. Status rc = queues_.Register(vg);
  166. if (rc.IsOk()) {
  167. rc = cv_.Register(vg->GetIntrpService());
  168. }
  169. return rc;
  170. }
  171. protected:
  172. std::string my_name_;
  173. // A list of Queues that are thread safe.
  174. QueueList<T> queues_;
  175. // The consumer that we allow to get the next data from pop()
  176. int32_t expect_consumer_;
  177. // The index to the queues_ where the next data should be popped.
  178. int32_t pop_from_;
  179. int32_t num_producers_;
  180. int32_t num_consumers_;
  181. // Used in the Pop(), when a thread call pop() but it is not the expect_consumer_.
  182. std::mutex m_;
  183. CondVar cv_;
  184. std::atomic<std::int64_t> out_buffers_count_ = 0;
  185. };
  186. } // namespace dataset
  187. } // namespace mindspore
  188. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONNECTOR_H_