Browse Source

!871 dataset: repair take op problem when there exist muti-thread in next node

Merge pull request !871 from ms_yan/take_operator
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
67f3d0eb5d
3 changed files with 85 additions and 70 deletions
  1. +35
    -55
      mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
  2. +4
    -15
      mindspore/ccsrc/dataset/engine/datasetops/take_op.h
  3. +46
    -0
      tests/ut/python/dataset/test_take.py

+ 35
- 55
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc View File

@@ -17,6 +17,7 @@
#include <utility>

#include "common/utils.h"
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
@@ -25,7 +26,10 @@
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {}
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
}

Status TakeOp::Builder::SanityCheck() const {
if (build_max_takes_ <= 0) {
@@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<TakeOp>(build_max_takes_);
*ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_);
return Status::OK();
}

// Constructor of the TakeOp.
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
TakeOp::TakeOp(int32_t count, int32_t op_connector_size)
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {}

// A print method typically used for debugging
void TakeOp::Print(std::ostream &out, bool show_all) const {
@@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
}
}

// This function will be call muti times to returns the buffer, when meet required max take count or meet
// EOF buffer then this will stop.
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
}

// Main entry point for Take
Status TakeOp::operator()() {
TaskManager::FindMe()->Post();
std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));

bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
if (take_count_ == max_takes_) {
if (state_ == OpState::kDeOpRunning) {
MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer.";
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
*p_buffer = std::move(eoe_buffer);
state_ = OpState::kDeOpIdle;

// Reset the count and drain
if (!last_repeat) {
take_count_ = 0;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
while (buf->eof() == false) {
if (take_count_ == max_takes_) {
// Do drain Operation
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
}
} else if (state_ == OpState::kDeOpIdle) {
MS_LOG(DEBUG) << "Meet max count and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
*p_buffer = std::move(eof_buffer);
}

// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
} else {
MS_LOG(WARNING) << "Invalid OpState: " << state_;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
continue;
}
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
*p_buffer = std::move(buf);
return Status::OK();
}

// Check if the last buf is next eof
if (buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
std::unique_ptr<DataBuffer> p_buffer;
RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer)));
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
}

// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer));
}
take_count_ = 0;
MS_LOG(DEBUG) << "Meet the end and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
return Status::OK();
}

@@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return Status::OK();
}

// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }

Status TakeOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToRepeatStack(shared_from_this());


+ 4
- 15
mindspore/ccsrc/dataset/engine/datasetops/take_op.h View File

@@ -45,6 +45,7 @@ class TakeOp : public PipelineOp {

private:
int32_t build_max_takes_;
int32_t builder_op_connector_size_;

Status SanityCheck() const;
};
@@ -52,7 +53,7 @@ class TakeOp : public PipelineOp {
// Constructor of the TakeOp.
// @note The builder class should be used to call it
// @param count - The number of takes to do
explicit TakeOp(int32_t count);
explicit TakeOp(int32_t count, int32_t op_connector_size);

// Destructor
~TakeOp() = default;
@@ -72,23 +73,11 @@ class TakeOp : public PipelineOp {
return out;
}

// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status operator()() override;

// Gets a buffer from the child node. The caller is typically our parent node.
// @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
// this function will retry to pop the connector again and will get the non-EOE buffer if any.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;

// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first


+ 46
- 0
tests/ut/python/dataset/test_take.py View File

@@ -30,6 +30,12 @@ def generator_10():
yield np.array([i]),


def filter_func_ge(data):
if data > 3:
return False
return True


def test_take_01():
"""
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
@@ -297,6 +303,44 @@ def test_take_16():
assert sum([1 for _ in data1]) == 5


def test_take_17():
"""
Test take: take first, then do fiter operation
"""
logger.info("test_take_17")
data1 = ds.GeneratorDataset(generator_10, ["data"])

data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)

# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i == d[0][0]

assert sum([1 for _ in data1]) == 4


def test_take_18():
"""
Test take: take first, then do fiter, skip, batch and repeat operation
"""
logger.info("test_take_18")
data1 = ds.GeneratorDataset(generator_10, ["data"])

data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
data1 = data1.skip(2)

data1 = data1.batch(2)
data1 = data1.repeat(2)

# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 == d[0][0]

assert sum([1 for _ in data1]) == 2


if __name__ == '__main__':
test_take_01()
test_take_02()
@@ -314,4 +358,6 @@ if __name__ == '__main__':
test_take_14()
test_take_15()
test_take_16()
test_take_17()
test_take_18()
logger.info('== test take operation finished ==')

Loading…
Cancel
Save