Merge pull request !856 from jiangzhiwen/dataset/skip_threadtags/v0.3.0-alpha
| @@ -16,6 +16,7 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <utility> | #include <utility> | ||||
| #include "dataset/core/config_manager.h" | |||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/datasetops/skip_op.h" | #include "dataset/engine/datasetops/skip_op.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| @@ -26,7 +27,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Builder constructor. Creates the builder object. | // Builder constructor. Creates the builder object. | ||||
| SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {} | |||||
| SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||||
| } | |||||
| Status SkipOp::Builder::SanityCheck() const { | Status SkipOp::Builder::SanityCheck() const { | ||||
| if (build_max_skips_ < 0) { | if (build_max_skips_ < 0) { | ||||
| @@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const { | |||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) { | Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) { | ||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| *ptr = std::make_shared<SkipOp>(build_max_skips_); | |||||
| *ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Constructor of the SkipOp. | // Constructor of the SkipOp. | ||||
| SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} | |||||
| SkipOp::SkipOp(int32_t count, int32_t op_connector_size) | |||||
| : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} | |||||
| // Destructor | // Destructor | ||||
| SkipOp::~SkipOp() {} | SkipOp::~SkipOp() {} | ||||
| @@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const { | |||||
| << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; | << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; | ||||
| } | } | ||||
| // Since the buffer may contain multi rows, this function will drop the rows | |||||
| // that need to skip in it, and then return the buffer. | |||||
| Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { | |||||
| if (child_.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node."); | |||||
| } | |||||
| std::unique_ptr<DataBuffer> buf; | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||||
| // Drop first max_skips_ rows | |||||
| while (skip_count_ < max_skips_) { | |||||
| if (buf->eoe() || buf->eof()) { | |||||
| break; | |||||
| } | |||||
| // Consider the rows of buffer more than 1 | |||||
| TensorRow drop_row; | |||||
| int row_num = buf->NumRows(); | |||||
| int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; | |||||
| skip_count_ += drop_num; | |||||
| for (int i = 0; i < drop_num; i++) { | |||||
| RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); | |||||
| } | |||||
| if (buf->NumRows() == 0) { | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||||
| } | |||||
| } | |||||
| // Handling eoe | |||||
| if (buf->eoe()) { | |||||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||||
| } | |||||
| // Handling eof | |||||
| if (buf->eof()) { | |||||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||||
| } | |||||
| *p_buffer = std::move(buf); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Base-class override for handling cases when an eoe is received. | // Base-class override for handling cases when an eoe is received. | ||||
| Status SkipOp::EoeReceived(int32_t worker_id) { | Status SkipOp::EoeReceived(int32_t worker_id) { | ||||
| skip_count_ = 0; | skip_count_ = 0; | ||||
| @@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Class functor operator () override. | |||||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to | |||||
| // launch the functor since this op runs inlined inside another operator. The | |||||
| // function is overloaded to ensure that it is not called by mistake (it will | |||||
| // generate an error). | |||||
| Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||||
| // main entry point for skip | |||||
| Status SkipOp::operator()() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<DataBuffer> curr_buffer; | |||||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||||
| while (curr_buffer->eof() == false) { | |||||
| // Reset count | |||||
| skip_count_ = 0; | |||||
| while (curr_buffer->eoe() == false) { | |||||
| // Drop first count rows | |||||
| while (skip_count_ < max_skips_) { | |||||
| if (curr_buffer->eoe() || curr_buffer->eof()) { | |||||
| break; | |||||
| } | |||||
| // Consider the rows of buffer more than one | |||||
| TensorRow drop_row; | |||||
| int row_num = curr_buffer->NumRows(); | |||||
| int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; | |||||
| skip_count_ += drop_num; | |||||
| for (int i = 0; i < drop_num; i++) { | |||||
| RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); | |||||
| } | |||||
| if (curr_buffer->NumRows() == 0) { | |||||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); | |||||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||||
| } | |||||
| // we got eoe, now try again until we got eof | |||||
| MS_LOG(DEBUG) << "Skip operator EOE Received."; | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||||
| } | |||||
| MS_LOG(DEBUG) << "Skip operator EOF Received."; | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Base-class override for handling cases when an eof is received. | // Base-class override for handling cases when an eof is received. | ||||
| Status SkipOp::EofReceived(int32_t worker_id) { | Status SkipOp::EofReceived(int32_t worker_id) { | ||||
| @@ -42,6 +42,7 @@ class SkipOp : public PipelineOp { | |||||
| private: | private: | ||||
| int32_t build_max_skips_; | int32_t build_max_skips_; | ||||
| int32_t builder_op_connector_size_; | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| }; | }; | ||||
| @@ -49,7 +50,7 @@ class SkipOp : public PipelineOp { | |||||
| // Constructor of the SkipOp. | // Constructor of the SkipOp. | ||||
| // @note The builder class should be used to call it | // @note The builder class should be used to call it | ||||
| // @param count - The number of skips to do | // @param count - The number of skips to do | ||||
| explicit SkipOp(int32_t count); | |||||
| explicit SkipOp(int32_t count, int32_t op_connector_size); | |||||
| // Destructor | // Destructor | ||||
| ~SkipOp(); | ~SkipOp(); | ||||
| @@ -60,23 +61,11 @@ class SkipOp : public PipelineOp { | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to launch the | |||||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||||
| // ensure that it is not called by mistake (it will generate an error). | |||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status operator()() override; | Status operator()() override; | ||||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||||
| // Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get | |||||
| // a buffer from our child. | |||||
| // @param p_buffer - output pointer to the buffer that it will fetch. | |||||
| // @param worker_id - The worker id | |||||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||||
| // @return Status - The error code return | |||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||||
| // Base-class override for handling cases when an eoe is received. | // Base-class override for handling cases when an eoe is received. | ||||
| // @param worker_id - The worker id | // @param worker_id - The worker id | ||||
| Status EoeReceived(int32_t worker_id) override; | Status EoeReceived(int32_t worker_id) override; | ||||
| @@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) { | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| // SkipOp | // SkipOp | ||||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5); | |||||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2); | |||||
| rc = my_tree->AssociateNode(skip_op); | rc = my_tree->AssociateNode(skip_op); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -12,7 +12,6 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | import mindspore.dataset.transforms.vision.c_transforms as vision | ||||
| @@ -51,7 +50,7 @@ def generator_md(): | |||||
| def test_generator_skip(): | def test_generator_skip(): | ||||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||||
| ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4) | |||||
| # Here ds1 should be [3, 4] | # Here ds1 should be [3, 4] | ||||
| ds1 = ds1.skip(3) | ds1 = ds1.skip(3) | ||||
| @@ -60,6 +59,7 @@ def test_generator_skip(): | |||||
| for data in ds1: | for data in ds1: | ||||
| buf.append(data[0][0]) | buf.append(data[0][0]) | ||||
| assert len(buf) == 2 | assert len(buf) == 2 | ||||
| assert buf == [3, 4] | |||||
| def test_skip_1(): | def test_skip_1(): | ||||
| @@ -72,6 +72,7 @@ def test_skip_1(): | |||||
| for data in ds1: | for data in ds1: | ||||
| buf.append(data[0][0]) | buf.append(data[0][0]) | ||||
| assert len(buf) == 0 | assert len(buf) == 0 | ||||
| assert buf == [] | |||||
| def test_skip_2(): | def test_skip_2(): | ||||
| @@ -84,6 +85,7 @@ def test_skip_2(): | |||||
| for data in ds1: | for data in ds1: | ||||
| buf.append(data[0][0]) | buf.append(data[0][0]) | ||||
| assert len(buf) == 5 | assert len(buf) == 5 | ||||
| assert buf == [0, 1, 2, 3, 4] | |||||
| def test_skip_repeat_1(): | def test_skip_repeat_1(): | ||||
| @@ -99,6 +101,7 @@ def test_skip_repeat_1(): | |||||
| for data in ds1: | for data in ds1: | ||||
| buf.append(data[0][0]) | buf.append(data[0][0]) | ||||
| assert len(buf) == 7 | assert len(buf) == 7 | ||||
| assert buf == [3, 4, 0, 1, 2, 3, 4] | |||||
| def test_skip_repeat_2(): | def test_skip_repeat_2(): | ||||
| @@ -114,6 +117,7 @@ def test_skip_repeat_2(): | |||||
| for data in ds1: | for data in ds1: | ||||
| buf.append(data[0][0]) | buf.append(data[0][0]) | ||||
| assert len(buf) == 4 | assert len(buf) == 4 | ||||
| assert buf == [3, 4, 3, 4] | |||||
| def test_skip_repeat_3(): | def test_skip_repeat_3(): | ||||
| @@ -132,6 +136,62 @@ def test_skip_repeat_3(): | |||||
| for data in ds1: | for data in ds1: | ||||
| buf.append(data[0][0]) | buf.append(data[0][0]) | ||||
| assert len(buf) == 6 | assert len(buf) == 6 | ||||
| assert buf == [3, 4, 3, 4, 3, 4] | |||||
| def test_skip_take_1(): | |||||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||||
| # Here ds1 should be [0, 1, 2, 3] | |||||
| ds1 = ds1.take(4) | |||||
| # Here ds1 should be [2, 3] | |||||
| ds1 = ds1.skip(2) | |||||
| buf = [] | |||||
| for data in ds1: | |||||
| buf.append(data[0][0]) | |||||
| assert len(buf) == 2 | |||||
| assert buf == [2, 3] | |||||
| def test_skip_take_2(): | |||||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||||
| # Here ds1 should be [2, 3, 4] | |||||
| ds1 = ds1.skip(2) | |||||
| # Here ds1 should be [2, 3] | |||||
| ds1 = ds1.take(2) | |||||
| buf = [] | |||||
| for data in ds1: | |||||
| buf.append(data[0][0]) | |||||
| assert len(buf) == 2 | |||||
| assert buf == [2, 3] | |||||
| def generator_1d(): | |||||
| for i in range(64): | |||||
| yield (np.array([i]), ) | |||||
| def test_skip_filter_1(): | |||||
| dataset = ds.GeneratorDataset(generator_1d, ['data']) | |||||
| dataset = dataset.skip(5) | |||||
| dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) | |||||
| buf = [] | |||||
| for item in dataset: | |||||
| buf.append(item[0][0]) | |||||
| assert buf == [5, 6, 7, 8, 9, 10] | |||||
| def test_skip_filter_2(): | |||||
| dataset = ds.GeneratorDataset(generator_1d, ['data']) | |||||
| dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) | |||||
| dataset = dataset.skip(5) | |||||
| buf = [] | |||||
| for item in dataset: | |||||
| buf.append(item[0][0]) | |||||
| assert buf == [5, 6, 7, 8, 9, 10] | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -142,3 +202,7 @@ if __name__ == "__main__": | |||||
| test_skip_repeat_1() | test_skip_repeat_1() | ||||
| test_skip_repeat_2() | test_skip_repeat_2() | ||||
| test_skip_repeat_3() | test_skip_repeat_3() | ||||
| test_skip_take_1() | |||||
| test_skip_take_2() | |||||
| test_skip_filter_1() | |||||
| test_skip_filter_2() | |||||