/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef DATASET_ENGINE_DB_CONNECTOR_H_ #define DATASET_ENGINE_DB_CONNECTOR_H_ #include #include #include "dataset/engine/connector.h" #include "dataset/engine/data_buffer.h" #include "dataset/core/constants.h" namespace mindspore { namespace dataset { // DbConnector is a derived class from Connector with added logic to handle EOE and EOF. // The Connector class itself is responsible to ensure deterministic order on every run. class DbConnector : public Connector> { public: // Constructor of DbConnector // @note DbConnector will create internal N number of blocking queues, where N = nProducers. // See Connector.h for more details. // @param n_producers The number of threads producing data into this DbConnector. // @param n_consumers The number of thread consuming data from this DbConnector. // @param queue_capacity The number of element (DataBuffer) for each internal queue. DbConnector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) : Connector>(n_producers, n_consumers, queue_capacity), end_of_file_(false) {} // Destructor of DbConnector ~DbConnector() = default; // Add a unique_ptr into the DbConnector. // @note The caller of this add method should use std::move to pass the ownership to DbConnector. // @param worker_id The id of a worker thread calling this method. // @param el A rvalue reference to an element to be passed/added/pushed. Status Add(int32_t worker_id, std::unique_ptr &&el) noexcept { return (Connector>::Push(worker_id, std::move(el))); } // Get a unique_ptr from the DbConnector. // @note After the first EOF Buffer is encountered, subsequent pop()s will return EOF Buffer. // This will provide/propagate the EOF to all consumer threads of this Connector. // Thus, When the num_consumers < num_producers, there will be extra EOF messages in some of the internal queues // and reset() must be called before reusing DbConnector. // @param worker_id The id of a worker thread calling this method. // @param result The address of a unique_ptr where the popped element will be placed. // @param retry_if_eoe A flag to allow the same thread invoke pop() again if the current pop returns eoe buffer. Status PopWithRetry(int32_t worker_id, std::unique_ptr *result, bool retry_if_eoe = false) noexcept { if (result == nullptr) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "[ERROR] nullptr detected when getting data from db connector"); } else { std::unique_lock lk(m_); RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; })); // Once an EOF message is encountered this flag will be set and we can return early. if (end_of_file_) { *result = std::make_unique(0, DataBuffer::kDeBFlagEOF); } else { RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); if (*result == nullptr) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "[ERROR] nullptr detected when getting data from db connector"); } // Setting the internal flag once the first EOF is encountered. if ((*result)->eof()) { end_of_file_ = true; } pop_from_ = (pop_from_ + 1) % num_producers_; } // Do not increment expect_consumer_ when result is eoe and retry_if_eoe is set. if (!((*result)->eoe() && retry_if_eoe)) { expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; } } cv_.NotifyAll(); return Status::OK(); } private: // A flag to indicate the end of stream has been encountered. bool end_of_file_; }; } // namespace dataset } // namespace mindspore #endif // DATASET_ENGINE_DB_CONNECTOR_H_