You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

connector.h 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_CONNECTOR_H_
  17. #define DATASET_ENGINE_CONNECTOR_H_
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. #include <vector>
  22. #include "dataset/util/task_manager.h"
  23. #include "dataset/util/queue.h"
  24. #include "dataset/util/services.h"
  25. #include "dataset/util/cond_var.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Connector is a communication data structure between two group of threads that
  29. // preserve the order.
  30. //
  31. // Example use case:
  32. // An initial tasks-list of [1,2,3,4,5,6,7,8,9] with 5 threads getting/processing elements from that list,
  33. // and pushing the processed elements to a Connector in any order whoever finishes processing first.
  34. // If the consumer of the Connector is single threaded, when the consumer pop() the
  35. // element from the Connector one by one, it will get [1,2,3,4,5,6,7,8,9].
  36. //
  37. // Requirements:
  38. // 1. Each thread in the group of consumer or producer threads must be assigned ids starting from 0.
  39. // 2. If your multi-threads program is not reading from a Connector class but
  40. // want to push to a Connector class, you must follow roundrobin element distribution,
  41. // i.e., the thread-id0 must have the first element, thread-id1 has the second element,
  42. // and so on; then each of this worker can push to the Connector class async in parallel.
  43. //
  44. // Blocking conditions:
  45. // 1. Connector.push(int, T) can block when the internal queue it's trying to push is full.
  46. // 2. Connector.pop(int) can block when
  47. // - The internal queue it's trying to pop is empty.
  48. // - The caller thread of pop() is not equal to the _expectConsumer. This is to enforce
  49. // the ordering.
  50. //
  51. // Future improvement:
  52. // 1. Fault tolerant: Right now, if one of the worker dies, the Connector will not work
  53. // properly.
  54. template <class T>
  55. class Connector {
  56. public:
  57. // Name: Constructor
  58. // Description: Initializing private members with the given input arguments.
  59. // expect_consumer_ and pop_from_ is initialized to 0 as part of
  60. // our requirements. We instantiate nProducers number of internal
  61. // queues so that each producer thread can push to its queue without
  62. // any sync overhead.
  63. // Constructor of Connector
  64. // Initializing private members with the given input arguments.
  65. // _expectConsumer and _popFrom is initialized to 0 as part of
  66. // our requirements. We instantiate nProducers number of internal
  67. // queues so that each producer thread can push to its queue without
  68. // any sync overhead.
  69. // @param n_producers The number of threads producing data into this DbConnector.
  70. // @param n_consumers The number of thread consuming data from this DbConnector.
  71. // @param queue_capacity The number of element (DataBuffer) for each queue.
  72. Connector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity)
  73. : num_producers_(n_producers), num_consumers_(n_consumers) {
  74. MS_LOG(INFO) << "A connector is created with " << n_producers << " producers and " << n_consumers << " consumers.";
  75. my_name_ = Services::GetUniqueID();
  76. // We require the consumers to have ids sequentially from 0 to the num_consumers_-1,
  77. // Otherwise a ordered list of consumer ids have to be passed here. (not implemented yet)
  78. expect_consumer_ = 0;
  79. // Roundrobin pop starts from index 0 of the queues_.
  80. pop_from_ = 0;
  81. // Initialize the queues_ to have num_producers_ number of queues.
  82. // Each queue is a blocking queue and has the same queue_capacity.
  83. queues_.Init(num_producers_, queue_capacity);
  84. }
  85. // Destructor of Connector
  86. virtual ~Connector() = default;
  87. // Get an element from the Connector.
  88. // @not Call to pop() can block the caller thread, see the blocking condition at the top of this file.
  89. // @param worker_id The id of a worker thread calling this method.
  90. // @param result The address of an object where the popped element will be placed.
  91. virtual Status Pop(int32_t worker_id, // The worker-id of the caller. See the requirement at the top of this file.
  92. T *result) noexcept {
  93. {
  94. DS_ASSERT(worker_id < num_consumers_);
  95. std::unique_lock<std::mutex> lk(m_);
  96. RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; }));
  97. RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result));
  98. pop_from_ = (pop_from_ + 1) % num_producers_;
  99. expect_consumer_ = (expect_consumer_ + 1) % num_consumers_;
  100. }
  101. cv_.NotifyAll();
  102. return Status::OK();
  103. }
  104. // Add an element into the DbConnector without the overhead of synchronization.
  105. // It may block when the internal queue is full.
  106. // The element passed to this function will be copied into the internal queue.
  107. // @param worker_id The id of a worker thread calling this method.
  108. // @param el A const lvalue element to be passed/added/pushed.
  109. Status Push(int32_t worker_id, const T &el) noexcept {
  110. DS_ASSERT(worker_id < static_cast<int32_t>(queues_.size()));
  111. DS_ASSERT(queues_[worker_id] != nullptr);
  112. return (queues_[worker_id]->Add(el));
  113. }
  114. // Add an element into the DbConnector without the overhead of synchronization.
  115. // It may block when the internal queue is full.
  116. // The element passed to this function will be forwarded into the internal queue.
  117. // @param worker_id The id of a worker thread calling this method.
  118. // @param el An element to be passed/added/pushed.
  119. virtual Status Push(int32_t worker_id, T &&el) noexcept {
  120. DS_ASSERT(worker_id < static_cast<int32_t>(queues_.size()));
  121. DS_ASSERT(queues_[worker_id] != nullptr);
  122. return (queues_[worker_id]->Add(std::forward<T>(el)));
  123. }
  124. // Resets the internal index tracking of the queue so that it can be used again with new inputs,
  125. // starting from the beginning.
  126. void Reset() {
  127. for (int i = 0; i < queues_.size(); ++i) {
  128. queues_[i]->ResetQue();
  129. }
  130. expect_consumer_ = 0;
  131. pop_from_ = 0;
  132. MS_LOG(INFO) << "Connector counters reset.";
  133. }
  134. void Print(std::ostream &out, bool showAll) const {
  135. out << "\n--------- Connector ------------"
  136. << "\nConnector Name : " << my_name_ << "\nNumber of consumers : " << num_consumers_
  137. << "\nNumber of producers : " << num_producers_ << "\n";
  138. }
  139. friend std::ostream &operator<<(std::ostream &out, const Connector &con) {
  140. con.print(out, false);
  141. return out;
  142. }
  143. // Register the internal resources with Task group for interruption service.
  144. // @param vg
  145. // @return
  146. Status Register(TaskGroup *vg) {
  147. Status rc = queues_.Register(vg);
  148. if (rc.IsOk()) {
  149. rc = cv_.Register(vg->GetIntrpService());
  150. }
  151. return rc;
  152. }
  153. protected:
  154. std::string my_name_;
  155. // A list of Queues that are thread safe.
  156. QueueList<T> queues_;
  157. // The consumer that we allow to get the next data from pop()
  158. int32_t expect_consumer_;
  159. // The index to the queues_ where the next data should be popped.
  160. int32_t pop_from_;
  161. int32_t num_producers_;
  162. int32_t num_consumers_;
  163. // Used in the Pop(), when a thread call pop() but it is not the expect_consumer_.
  164. std::mutex m_;
  165. CondVar cv_;
  166. };
  167. } // namespace dataset
  168. } // namespace mindspore
  169. #endif // DATASET_ENGINE_CONNECTOR_H_