From 34bfa2f7c9199c078665107e80f7ab9a2a5d4e48 Mon Sep 17 00:00:00 2001 From: jiangzhiwen Date: Wed, 29 Apr 2020 17:18:12 +0800 Subject: [PATCH] fix skip --- .../dataset/engine/datasetops/skip_op.cc | 100 ++++++++---------- .../ccsrc/dataset/engine/datasetops/skip_op.h | 19 +--- tests/ut/cpp/dataset/skip_op_test.cc | 2 +- tests/ut/python/dataset/test_skip.py | 68 +++++++++++- 4 files changed, 118 insertions(+), 71 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc index d851f2c699..a7b642d9d1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc @@ -16,6 +16,7 @@ #include #include +#include "dataset/core/config_manager.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/db_connector.h" @@ -26,7 +27,10 @@ namespace mindspore { namespace dataset { // Builder constructor. Creates the builder object. -SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {} +SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} Status SkipOp::Builder::SanityCheck() const { if (build_max_skips_ < 0) { @@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const { // The builder "build" method creates the final object. Status SkipOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_skips_); + *ptr = std::make_shared(build_max_skips_, builder_op_connector_size_); return Status::OK(); } // Constructor of the SkipOp. -SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} +SkipOp::SkipOp(int32_t count, int32_t op_connector_size) + : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} // Destructor SkipOp::~SkipOp() {} @@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const { << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; } -// Since the buffer may contain multi rows, this function will drop the rows -// that need to skip in it, and then return the buffer. -Status SkipOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { - if (child_.empty()) { - RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node."); - } - - std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - - // Drop first max_skips_ rows - while (skip_count_ < max_skips_) { - if (buf->eoe() || buf->eof()) { - break; - } - - // Consider the rows of buffer more than 1 - TensorRow drop_row; - int row_num = buf->NumRows(); - int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; - skip_count_ += drop_num; - for (int i = 0; i < drop_num; i++) { - RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); - } - if (buf->NumRows() == 0) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - } - } - - // Handling eoe - if (buf->eoe()) { - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - } - - // Handling eof - if (buf->eof()) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - - *p_buffer = std::move(buf); - return Status::OK(); -} - // Base-class override for handling cases when an eoe is received. Status SkipOp::EoeReceived(int32_t worker_id) { skip_count_ = 0; @@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) { return Status::OK(); } -// Class functor operator () override. -// Most dataset ops operate by launching a thread (see ExecutionTree). -// However, the SkipOp is defined as a inlined operator, so it is invalid to -// launch the functor since this op runs inlined inside another operator. The -// function is overloaded to ensure that it is not called by mistake (it will -// generate an error). -Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } +// main entry point for skip +Status SkipOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr curr_buffer; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + while (curr_buffer->eof() == false) { + // Reset count + skip_count_ = 0; + while (curr_buffer->eoe() == false) { + // Drop first count rows + while (skip_count_ < max_skips_) { + if (curr_buffer->eoe() || curr_buffer->eof()) { + break; + } + // Consider the rows of buffer more than one + TensorRow drop_row; + int row_num = curr_buffer->NumRows(); + int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; + skip_count_ += drop_num; + for (int i = 0; i < drop_num; i++) { + RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); + } + if (curr_buffer->NumRows() == 0) { + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + } + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + // we got eoe, now try again until we got eof + MS_LOG(DEBUG) << "Skip operator EOE Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + + MS_LOG(DEBUG) << "Skip operator EOF Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} // Base-class override for handling cases when an eof is received. Status SkipOp::EofReceived(int32_t worker_id) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h index 0ae520c3ad..a16b82ed21 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h @@ -42,6 +42,7 @@ class SkipOp : public PipelineOp { private: int32_t build_max_skips_; + int32_t builder_op_connector_size_; Status SanityCheck() const; }; @@ -49,7 +50,7 @@ class SkipOp : public PipelineOp { // Constructor of the SkipOp. // @note The builder class should be used to call it // @param count - The number of skips to do - explicit SkipOp(int32_t count); + explicit SkipOp(int32_t count, int32_t op_connector_size); // Destructor ~SkipOp(); @@ -60,23 +61,11 @@ class SkipOp : public PipelineOp { void Print(std::ostream &out, bool show_all) const override; // Class functor operator () override. - // Most dataset ops operate by launching a thread (see ExecutionTree). - // However, the SkipOp is defined as a inlined operator, so it is invalid to launch the - // functor since this op runs inlined inside another operator. The function is overloaded to - // ensure that it is not called by mistake (it will generate an error). + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work // @return Status - The error code return Status operator()() override; - // This function returns the buffer that is at the top of our output connector. The caller is - // typically our parent node, when the parent is asking us to provide the next buffer of data. - // Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get - // a buffer from our child. - // @param p_buffer - output pointer to the buffer that it will fetch. - // @param worker_id - The worker id - // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. - // @return Status - The error code return - Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; - // Base-class override for handling cases when an eoe is received. // @param worker_id - The worker id Status EoeReceived(int32_t worker_id) override; diff --git a/tests/ut/cpp/dataset/skip_op_test.cc b/tests/ut/cpp/dataset/skip_op_test.cc index c2168b24d4..697745512d 100644 --- a/tests/ut/cpp/dataset/skip_op_test.cc +++ b/tests/ut/cpp/dataset/skip_op_test.cc @@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) { ASSERT_TRUE(rc.IsOk()); // SkipOp - std::shared_ptr skip_op = std::make_shared(5); + std::shared_ptr skip_op = std::make_shared(5, 2); rc = my_tree->AssociateNode(skip_op); ASSERT_TRUE(rc.IsOk()); diff --git a/tests/ut/python/dataset/test_skip.py b/tests/ut/python/dataset/test_skip.py index 59893f6ded..ccbf40a55b 100644 --- a/tests/ut/python/dataset/test_skip.py +++ b/tests/ut/python/dataset/test_skip.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - import numpy as np import mindspore.dataset.transforms.vision.c_transforms as vision @@ -51,7 +50,7 @@ def generator_md(): def test_generator_skip(): - ds1 = ds.GeneratorDataset(generator_md, ["data"]) + ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4) # Here ds1 should be [3, 4] ds1 = ds1.skip(3) @@ -60,6 +59,7 @@ def test_generator_skip(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 2 + assert buf == [3, 4] def test_skip_1(): @@ -72,6 +72,7 @@ def test_skip_1(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 0 + assert buf == [] def test_skip_2(): @@ -84,6 +85,7 @@ def test_skip_2(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 5 + assert buf == [0, 1, 2, 3, 4] def test_skip_repeat_1(): @@ -99,6 +101,7 @@ def test_skip_repeat_1(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 7 + assert buf == [3, 4, 0, 1, 2, 3, 4] def test_skip_repeat_2(): @@ -114,6 +117,7 @@ def test_skip_repeat_2(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 4 + assert buf == [3, 4, 3, 4] def test_skip_repeat_3(): @@ -132,6 +136,62 @@ def test_skip_repeat_3(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 6 + assert buf == [3, 4, 3, 4, 3, 4] + +def test_skip_take_1(): + ds1 = ds.GeneratorDataset(generator_md, ["data"]) + + # Here ds1 should be [0, 1, 2, 3] + ds1 = ds1.take(4) + + # Here ds1 should be [2, 3] + ds1 = ds1.skip(2) + + buf = [] + for data in ds1: + buf.append(data[0][0]) + assert len(buf) == 2 + assert buf == [2, 3] + +def test_skip_take_2(): + ds1 = ds.GeneratorDataset(generator_md, ["data"]) + + # Here ds1 should be [2, 3, 4] + ds1 = ds1.skip(2) + + # Here ds1 should be [2, 3] + ds1 = ds1.take(2) + + buf = [] + for data in ds1: + buf.append(data[0][0]) + assert len(buf) == 2 + assert buf == [2, 3] + + +def generator_1d(): + for i in range(64): + yield (np.array([i]), ) + +def test_skip_filter_1(): + dataset = ds.GeneratorDataset(generator_1d, ['data']) + dataset = dataset.skip(5) + dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) + + buf = [] + for item in dataset: + buf.append(item[0][0]) + assert buf == [5, 6, 7, 8, 9, 10] + +def test_skip_filter_2(): + dataset = ds.GeneratorDataset(generator_1d, ['data']) + dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) + dataset = dataset.skip(5) + + buf = [] + for item in dataset: + buf.append(item[0][0]) + assert buf == [5, 6, 7, 8, 9, 10] if __name__ == "__main__": @@ -142,3 +202,7 @@ if __name__ == "__main__": test_skip_repeat_1() test_skip_repeat_2() test_skip_repeat_3() + test_skip_take_1() + test_skip_take_2() + test_skip_filter_1() + test_skip_filter_2()