|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
#include <utility> |
|
|
|
|
|
|
|
#include "common/utils.h" |
|
|
|
#include "dataset/core/config_manager.h" |
|
|
|
#include "dataset/engine/data_buffer.h" |
|
|
|
#include "dataset/engine/datasetops/take_op.h" |
|
|
|
#include "dataset/engine/db_connector.h" |
|
|
|
@@ -25,7 +26,10 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace dataset { |
|
|
|
// Builder constructor. Creates the builder object. |
|
|
|
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {} |
|
|
|
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) { |
|
|
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); |
|
|
|
builder_op_connector_size_ = cfg->op_connector_size(); |
|
|
|
} |
|
|
|
|
|
|
|
Status TakeOp::Builder::SanityCheck() const { |
|
|
|
if (build_max_takes_ <= 0) { |
|
|
|
@@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const { |
|
|
|
// The builder "build" method creates the final object. |
|
|
|
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) { |
|
|
|
RETURN_IF_NOT_OK(SanityCheck()); |
|
|
|
*ptr = std::make_shared<TakeOp>(build_max_takes_); |
|
|
|
*ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
// Constructor of the TakeOp. |
|
|
|
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {} |
|
|
|
TakeOp::TakeOp(int32_t count, int32_t op_connector_size) |
|
|
|
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {} |
|
|
|
|
|
|
|
// A print method typically used for debugging |
|
|
|
void TakeOp::Print(std::ostream &out, bool show_all) const { |
|
|
|
@@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// This function will be call muti times to returns the buffer, when meet required max take count or meet |
|
|
|
// EOF buffer then this will stop. |
|
|
|
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { |
|
|
|
if (child_.empty()) { |
|
|
|
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node."); |
|
|
|
} |
|
|
|
|
|
|
|
// Main entry point for Take |
|
|
|
Status TakeOp::operator()() { |
|
|
|
TaskManager::FindMe()->Post(); |
|
|
|
std::unique_ptr<DataBuffer> buf; |
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); |
|
|
|
|
|
|
|
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); |
|
|
|
if (take_count_ == max_takes_) { |
|
|
|
if (state_ == OpState::kDeOpRunning) { |
|
|
|
MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer."; |
|
|
|
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); |
|
|
|
*p_buffer = std::move(eoe_buffer); |
|
|
|
state_ = OpState::kDeOpIdle; |
|
|
|
|
|
|
|
// Reset the count and drain |
|
|
|
if (!last_repeat) { |
|
|
|
take_count_ = 0; |
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); |
|
|
|
while (!buf->eoe() && !buf->eof()) { |
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); |
|
|
|
} |
|
|
|
while (buf->eof() == false) { |
|
|
|
if (take_count_ == max_takes_) { |
|
|
|
// Do drain Operation |
|
|
|
while (!buf->eoe() && !buf->eof()) { |
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); |
|
|
|
} |
|
|
|
} else if (state_ == OpState::kDeOpIdle) { |
|
|
|
MS_LOG(DEBUG) << "Meet max count and push-back eof buffer."; |
|
|
|
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); |
|
|
|
*p_buffer = std::move(eof_buffer); |
|
|
|
} |
|
|
|
|
|
|
|
// Loop until non EOE is received |
|
|
|
if (buf->eoe()) { |
|
|
|
take_count_ = 0; |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "Invalid OpState: " << state_; |
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); |
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); |
|
|
|
// Loop until non EOE is received |
|
|
|
if (buf->eoe()) { |
|
|
|
take_count_ = 0; |
|
|
|
*p_buffer = std::move(buf); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
// Check if the last buf is next eof |
|
|
|
if (buf->eof()) { |
|
|
|
*p_buffer = std::move(buf); |
|
|
|
return Status::OK(); |
|
|
|
// Get buffer and push back when take_count is still small |
|
|
|
if (take_count_ < max_takes_) { |
|
|
|
std::unique_ptr<DataBuffer> p_buffer; |
|
|
|
RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer)); |
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer))); |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); |
|
|
|
} |
|
|
|
|
|
|
|
// Get buffer and push back when take_count is still small |
|
|
|
if (take_count_ < max_takes_) { |
|
|
|
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer)); |
|
|
|
} |
|
|
|
take_count_ = 0; |
|
|
|
MS_LOG(DEBUG) << "Meet the end and push-back eof buffer."; |
|
|
|
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); |
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
// Class functor operator () override. |
|
|
|
// Most dataset ops operate by launching a thread (see ExecutionTree). |
|
|
|
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the |
|
|
|
// functor since this op runs inlined inside another operator. The function is overloaded to |
|
|
|
// ensure that it is not called by mistake (it will generate an error). |
|
|
|
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); } |
|
|
|
|
|
|
|
Status TakeOp::PrepareNodePostAction() { |
|
|
|
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); |
|
|
|
tree_->AddToRepeatStack(shared_from_this()); |
|
|
|
|