You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

db_connector.h 4.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DB_CONNECTOR_H_
  17. #define DATASET_ENGINE_DB_CONNECTOR_H_
  18. #include <memory>
  19. #include <utility>
  20. #include "dataset/engine/connector.h"
  21. #include "dataset/engine/data_buffer.h"
  22. #include "dataset/core/constants.h"
  23. namespace mindspore {
  24. namespace dataset {
  25. // DbConnector is a derived class from Connector with added logic to handle EOE and EOF.
  26. // The Connector class itself is responsible to ensure deterministic order on every run.
  27. class DbConnector : public Connector<std::unique_ptr<DataBuffer>> {
  28. public:
  29. // Constructor of DbConnector
  30. // @note DbConnector will create internal N number of blocking queues, where N = nProducers.
  31. // See Connector.h for more details.
  32. // @param n_producers The number of threads producing data into this DbConnector.
  33. // @param n_consumers The number of thread consuming data from this DbConnector.
  34. // @param queue_capacity The number of element (DataBuffer) for each internal queue.
  35. DbConnector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity)
  36. : Connector<std::unique_ptr<DataBuffer>>(n_producers, n_consumers, queue_capacity), end_of_file_(false) {}
  37. // Destructor of DbConnector
  38. ~DbConnector() = default;
  39. // Add a unique_ptr<DataBuffer> into the DbConnector.
  40. // @note The caller of this add method should use std::move to pass the ownership to DbConnector.
  41. // @param worker_id The id of a worker thread calling this method.
  42. // @param el A rvalue reference to an element to be passed/added/pushed.
  43. Status Add(int32_t worker_id, std::unique_ptr<DataBuffer> &&el) noexcept {
  44. return (Connector<std::unique_ptr<DataBuffer>>::Push(worker_id, std::move(el)));
  45. }
  46. // Get a unique_ptr<DataBuffer> from the DbConnector.
  47. // @note After the first EOF Buffer is encountered, subsequent pop()s will return EOF Buffer.
  48. // This will provide/propagate the EOF to all consumer threads of this Connector.
  49. // Thus, When the num_consumers < num_producers, there will be extra EOF messages in some of the internal queues
  50. // and reset() must be called before reusing DbConnector.
  51. // @param worker_id The id of a worker thread calling this method.
  52. // @param result The address of a unique_ptr<DataBuffer> where the popped element will be placed.
  53. // @param retry_if_eoe A flag to allow the same thread invoke pop() again if the current pop returns eoe buffer.
  54. Status PopWithRetry(int32_t worker_id, std::unique_ptr<DataBuffer> *result, bool retry_if_eoe = false) noexcept {
  55. if (result == nullptr) {
  56. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  57. "[ERROR] nullptr detected when getting data from db connector");
  58. } else {
  59. std::unique_lock<std::mutex> lk(m_);
  60. RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; }));
  61. // Once an EOF message is encountered this flag will be set and we can return early.
  62. if (end_of_file_) {
  63. *result = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
  64. } else {
  65. RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result));
  66. if (*result == nullptr) {
  67. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  68. "[ERROR] nullptr detected when getting data from db connector");
  69. }
  70. // Setting the internal flag once the first EOF is encountered.
  71. if ((*result)->eof()) {
  72. end_of_file_ = true;
  73. }
  74. pop_from_ = (pop_from_ + 1) % num_producers_;
  75. }
  76. // Do not increment expect_consumer_ when result is eoe and retry_if_eoe is set.
  77. if (!((*result)->eoe() && retry_if_eoe)) {
  78. expect_consumer_ = (expect_consumer_ + 1) % num_consumers_;
  79. }
  80. }
  81. cv_.NotifyAll();
  82. return Status::OK();
  83. }
  84. private:
  85. // A flag to indicate the end of stream has been encountered.
  86. bool end_of_file_;
  87. };
  88. } // namespace dataset
  89. } // namespace mindspore
  90. #endif // DATASET_ENGINE_DB_CONNECTOR_H_