Merge pull request !856 from jiangzhiwen/dataset/skip_threadtags/v0.3.0-alpha
| @@ -16,6 +16,7 @@ | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| @@ -26,7 +27,10 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Builder constructor. Creates the builder object. | |||
| SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {} | |||
| SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| Status SkipOp::Builder::SanityCheck() const { | |||
| if (build_max_skips_ < 0) { | |||
| @@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const { | |||
| // The builder "build" method creates the final object. | |||
| Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<SkipOp>(build_max_skips_); | |||
| *ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the SkipOp. | |||
| SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} | |||
| SkipOp::SkipOp(int32_t count, int32_t op_connector_size) | |||
| : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} | |||
| // Destructor | |||
| SkipOp::~SkipOp() {} | |||
| @@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const { | |||
| << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; | |||
| } | |||
| // Since the buffer may contain multi rows, this function will drop the rows | |||
| // that need to skip in it, and then return the buffer. | |||
| Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { | |||
| if (child_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node."); | |||
| } | |||
| std::unique_ptr<DataBuffer> buf; | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| // Drop first max_skips_ rows | |||
| while (skip_count_ < max_skips_) { | |||
| if (buf->eoe() || buf->eof()) { | |||
| break; | |||
| } | |||
| // Consider the rows of buffer more than 1 | |||
| TensorRow drop_row; | |||
| int row_num = buf->NumRows(); | |||
| int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; | |||
| skip_count_ += drop_num; | |||
| for (int i = 0; i < drop_num; i++) { | |||
| RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); | |||
| } | |||
| if (buf->NumRows() == 0) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| } | |||
| // Handling eoe | |||
| if (buf->eoe()) { | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| } | |||
| // Handling eof | |||
| if (buf->eof()) { | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| } | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for handling cases when an eoe is received. | |||
| Status SkipOp::EoeReceived(int32_t worker_id) { | |||
| skip_count_ = 0; | |||
| @@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to | |||
| // launch the functor since this op runs inlined inside another operator. The | |||
| // function is overloaded to ensure that it is not called by mistake (it will | |||
| // generate an error). | |||
| Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||
| // main entry point for skip | |||
| Status SkipOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<DataBuffer> curr_buffer; | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| while (curr_buffer->eof() == false) { | |||
| // Reset count | |||
| skip_count_ = 0; | |||
| while (curr_buffer->eoe() == false) { | |||
| // Drop first count rows | |||
| while (skip_count_ < max_skips_) { | |||
| if (curr_buffer->eoe() || curr_buffer->eof()) { | |||
| break; | |||
| } | |||
| // Consider the rows of buffer more than one | |||
| TensorRow drop_row; | |||
| int row_num = curr_buffer->NumRows(); | |||
| int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; | |||
| skip_count_ += drop_num; | |||
| for (int i = 0; i < drop_num; i++) { | |||
| RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); | |||
| } | |||
| if (curr_buffer->NumRows() == 0) { | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| } | |||
| // we got eoe, now try again until we got eof | |||
| MS_LOG(DEBUG) << "Skip operator EOE Received."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| } | |||
| MS_LOG(DEBUG) << "Skip operator EOF Received."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for handling cases when an eof is received. | |||
| Status SkipOp::EofReceived(int32_t worker_id) { | |||
| @@ -42,6 +42,7 @@ class SkipOp : public PipelineOp { | |||
| private: | |||
| int32_t build_max_skips_; | |||
| int32_t builder_op_connector_size_; | |||
| Status SanityCheck() const; | |||
| }; | |||
| @@ -49,7 +50,7 @@ class SkipOp : public PipelineOp { | |||
| // Constructor of the SkipOp. | |||
| // @note The builder class should be used to call it | |||
| // @param count - The number of skips to do | |||
| explicit SkipOp(int32_t count); | |||
| explicit SkipOp(int32_t count, int32_t op_connector_size); | |||
| // Destructor | |||
| ~SkipOp(); | |||
| @@ -60,23 +61,11 @@ class SkipOp : public PipelineOp { | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to launch the | |||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||
| // ensure that it is not called by mistake (it will generate an error). | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||
| // Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get | |||
| // a buffer from our child. | |||
| // @param p_buffer - output pointer to the buffer that it will fetch. | |||
| // @param worker_id - The worker id | |||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||
| // @return Status - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||
| // Base-class override for handling cases when an eoe is received. | |||
| // @param worker_id - The worker id | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| @@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // SkipOp | |||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5); | |||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2); | |||
| rc = my_tree->AssociateNode(skip_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -12,7 +12,6 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| @@ -51,7 +50,7 @@ def generator_md(): | |||
| def test_generator_skip(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4) | |||
| # Here ds1 should be [3, 4] | |||
| ds1 = ds1.skip(3) | |||
| @@ -60,6 +59,7 @@ def test_generator_skip(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 2 | |||
| assert buf == [3, 4] | |||
| def test_skip_1(): | |||
| @@ -72,6 +72,7 @@ def test_skip_1(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 0 | |||
| assert buf == [] | |||
| def test_skip_2(): | |||
| @@ -84,6 +85,7 @@ def test_skip_2(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 5 | |||
| assert buf == [0, 1, 2, 3, 4] | |||
| def test_skip_repeat_1(): | |||
| @@ -99,6 +101,7 @@ def test_skip_repeat_1(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 7 | |||
| assert buf == [3, 4, 0, 1, 2, 3, 4] | |||
| def test_skip_repeat_2(): | |||
| @@ -114,6 +117,7 @@ def test_skip_repeat_2(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 4 | |||
| assert buf == [3, 4, 3, 4] | |||
| def test_skip_repeat_3(): | |||
| @@ -132,6 +136,62 @@ def test_skip_repeat_3(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 6 | |||
| assert buf == [3, 4, 3, 4, 3, 4] | |||
| def test_skip_take_1(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [0, 1, 2, 3] | |||
| ds1 = ds1.take(4) | |||
| # Here ds1 should be [2, 3] | |||
| ds1 = ds1.skip(2) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 2 | |||
| assert buf == [2, 3] | |||
| def test_skip_take_2(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [2, 3, 4] | |||
| ds1 = ds1.skip(2) | |||
| # Here ds1 should be [2, 3] | |||
| ds1 = ds1.take(2) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 2 | |||
| assert buf == [2, 3] | |||
| def generator_1d(): | |||
| for i in range(64): | |||
| yield (np.array([i]), ) | |||
| def test_skip_filter_1(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ['data']) | |||
| dataset = dataset.skip(5) | |||
| dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) | |||
| buf = [] | |||
| for item in dataset: | |||
| buf.append(item[0][0]) | |||
| assert buf == [5, 6, 7, 8, 9, 10] | |||
| def test_skip_filter_2(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ['data']) | |||
| dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) | |||
| dataset = dataset.skip(5) | |||
| buf = [] | |||
| for item in dataset: | |||
| buf.append(item[0][0]) | |||
| assert buf == [5, 6, 7, 8, 9, 10] | |||
| if __name__ == "__main__": | |||
| @@ -142,3 +202,7 @@ if __name__ == "__main__": | |||
| test_skip_repeat_1() | |||
| test_skip_repeat_2() | |||
| test_skip_repeat_3() | |||
| test_skip_take_1() | |||
| test_skip_take_2() | |||
| test_skip_filter_1() | |||
| test_skip_filter_2() | |||