diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc index f51c2a1539..7edf1dd288 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc @@ -161,15 +161,18 @@ Status DatasetOp::EofReceived(int32_t worker_id) { return (out_connector_->Add(static_cast(worker_id), std::move(eof_buffer))); } -// During tree prepare phase, operators may have specific operations to perform depending on +// During tree prepare phase, operators may have specific pre-operations to perform depending on // their role. -Status DatasetOp::PrepareNodeAction() { +Status DatasetOp::PrepareNodePreAction() { + if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated); + return Status::OK(); +} +// During tree prepare phase, operators may have specific post-operations to perform depending on +// their role. +Status DatasetOp::PrepareNodePostAction() { // If this op does not have any children and it is in a repeat path of the tree... - if (child_.size() == 0 && BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) { - // Then, flag this operator as a leaf node in a repeat path of tree execution. - BitSet(&op_ctrl_flags_, kDeOpRepeated); - - // Secondly, push ourselves onto the tree repeat stack. Later, the repeat operator + if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { + // push ourselves onto the tree repeat stack. Later, the repeat operator // above us will consume them. tree_->AddToRepeatStack(shared_from_this()); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h index a7d87c3092..0111f5239a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h @@ -150,11 +150,17 @@ class DatasetOp : public std::enable_shared_from_this { return Status::OK(); } - // During tree prepare phase, operators may have specific operations to perform depending on + // During tree prepare phase, operators may have specific pre-operations to perform depending on // their role. // @notes Derived versions of this function should always call it's superclass version first // before providing their own implementations. - virtual Status PrepareNodeAction(); + virtual Status PrepareNodePreAction(); + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + virtual Status PrepareNodePostAction(); // Getter function // @return The operator id diff --git a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h index ceb7f2c4ac..142ec78360 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h @@ -64,14 +64,24 @@ class ParallelOp : public DatasetOp { return out; } - // During tree prepare phase, operators may have specific operations to perform depending on + // During tree prepare phase, operators may have specific pre-operations to perform depending on // their role. // @notes Derived versions of this function should always call it's superclass version first // before providing their own implementations. // @return Status - The error return code - Status PrepareNodeAction() override { + Status PrepareNodePreAction() override { // Run common code from super class before adding ParallelOp specific logic - return (DatasetOp::PrepareNodeAction()); + return (DatasetOp::PrepareNodePreAction()); + } + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + // @return Status - The error return code + Status PrepareNodePostAction() override { + // Run common code from super class before adding ParallelOp specific logic + return (DatasetOp::PrepareNodePostAction()); } // Override base class reset to provide reset actions specific to the ParallelOp class. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h index ee20f1d373..a14279032d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h @@ -64,13 +64,22 @@ class PipelineOp : public DatasetOp { // @return The number of threads that push data to the output connector int32_t num_producers() const override { return 1; } - // During tree prepare phase, operators may have specific operations to perform depending on + // During tree prepare phase, operators may have specific pre-operations to perform depending on // their role. // @notes Derived versions of this function should always call it's superclass version first // before providing their own implementations. - Status PrepareNodeAction() override { + Status PrepareNodePreAction() override { // Run common code from super class before adding PipelineOp specific logic - return (DatasetOp::PrepareNodeAction()); + return (DatasetOp::PrepareNodePreAction()); + } + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override { + // Run common code from super class before adding PipelineOp specific logic + return (DatasetOp::PrepareNodePostAction()); } protected: diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc index 32723a9bd4..33c731c400 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc @@ -58,10 +58,10 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { out << "RepeatOp:" << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ << "\nLeaf Nodes in my execution path:"; - if (!leaf_ops_.empty()) { + if (!eoe_ops_.empty()) { out << "\n"; - for (size_t i = 0; i < leaf_ops_.size(); i++) { - out << " Operator: " << leaf_ops_[i]->id() << "\n"; + for (size_t i = 0; i < eoe_ops_.size(); i++) { + out << " Operator: " << eoe_ops_[i]->id() << "\n"; } } else { out << " kNone."; @@ -71,21 +71,17 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { // Base-class override for executing specific RepeatOp configurations. This code will be called // during the execution tree prepare phase when it is visiting this operator. -Status RepeatOp::PrepareNodeAction() { +Status RepeatOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own specific logic - RETURN_IF_NOT_OK(PipelineOp::PrepareNodeAction()); + RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); std::shared_ptr leaf_op = tree_->PopFromRepeatStack(); while (leaf_op != nullptr) { // Track the leaf operators that are under this repeat op. - leaf_ops_.push_back(leaf_op); - - // Special case. If the repeat count is 1, then pre-flag the leaf nodes - // to tell them they are already at their last op: - if (max_repeats_ == 1) { - leaf_op->set_control_flag(kDeOpLastRepeat); - } + eoe_ops_.push_back(leaf_op); leaf_op = tree_->PopFromRepeatStack(); } + // Push ourselves to the stack in case one of our ascendants is repeat too. + tree_->AddToRepeatStack(shared_from_this()); return Status::OK(); } @@ -127,16 +123,20 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t wo Status RepeatOp::EoeReceived(int32_t worker_id) { repeat_count_++; MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << "."; - - // If we've reached the requested repeat count, then flag the leaf nodes + bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); + bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); + // If we've reached the requested repeat count, then flag the eoe nodes // to tell them they've got one more epoch to perform. When they reach the end - // of the last epoch, they quit rather than loop again. - if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) { - for (size_t i = 0; i < leaf_ops_.size(); i++) { - leaf_ops_[i]->set_control_flag(kDeOpLastRepeat); + // of the last epoch, they quit rather than loop again. This happens in two cases: + // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, + // 2- We are not repeated + if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { + for (auto &eoe_op : eoe_ops_) { + eoe_op->set_control_flag(kDeOpLastRepeat); } } if (repeat_count_ == max_repeats_) { + repeat_count_ = 0; state_ = OpState::kDeOpIdle; return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h index 5cc7ec2efa..8497b4cf3c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h @@ -87,8 +87,8 @@ class RepeatOp : public PipelineOp { uint32_t PrepareFlags() const override; // Base-class override for executing specific RepeatOp configurations. This code will be called - // during the execution tree prepare phase when it is visiting this operator. - Status PrepareNodeAction() override; + // during the execution tree post-prepare phase when it is visiting this operator. + Status PrepareNodePostAction() override; // This function returns the buffer that is at the top of our output connector. The caller is // typically our parent node, when the parent is asking us to provide the next buffer of data. @@ -119,9 +119,9 @@ class RepeatOp : public PipelineOp { int32_t num_producers() const override; private: - int32_t max_repeats_; // The number of repeats that the user requested - int32_t repeat_count_; // A counter for the current number of executed repeats - std::vector> leaf_ops_; // List of leaf operators underneath this repeat. + int32_t max_repeats_; // The number of repeats that the user requested + int32_t repeat_count_; // A counter for the current number of executed repeats + std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc index 20fcb836c5..ebfa532195 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/dataset/engine/execution_tree.cc @@ -162,30 +162,25 @@ Status ExecutionTree::Prepare() { // Recursive function used during prepare phase to visit a node and drive any pre- and post- // node actions during a tree walk. Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) { - int32_t num_children = dataset_op->child_.size(); + // execute PreAction + RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); - // Before going down into children, make any prepare flags updates based on this - // operator. + // Before going down into children, make any prepare flags updates based on this operator. uint32_t op_prep_flags = dataset_op->PrepareFlags(); - // Sanity check. In future we can support nested repeats. for now it's not allowed. - // If somebody above us already set the repeat flag, and now we are another repeat... - if (BitTest(op_prep_flags, kDePrepRepeat) && BitTest(prepare_flags_, kDePrepRepeat)) { - std::string err_msg("Nested RepeatOp detected! This is not supported yet."); - RETURN_STATUS_UNEXPECTED(err_msg); - } BitSet(&prepare_flags_, op_prep_flags); // Now, descend to children - for (int32_t i = 0; i < num_children; ++i) { - RETURN_IF_NOT_OK(this->PrepareNode(dataset_op->child_[i])); + for (const auto &i : dataset_op->child_) { + RETURN_IF_NOT_OK(this->PrepareNode(i)); } - // No more children, now we execute any prepare actions before going back up the - // the tree on recursive function exit - RETURN_IF_NOT_OK(dataset_op->PrepareNodeAction()); - // Then clear the flags from this op now that we have prepared it. BitClear(&prepare_flags_, op_prep_flags); + + // No more children, now we execute any prepare actions before going back up the + // the tree on recursive function + RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); + return Status::OK(); } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3d660d58a8..e40c24c140 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -417,6 +417,8 @@ class Dataset: >>> repeat_and_shuffle = data.repeat(50) >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) """ + if count == 1: + return self return RepeatDataset(self, count) @check_zip_dataset diff --git a/tests/ut/cpp/dataset/repeat_op_test.cc b/tests/ut/cpp/dataset/repeat_op_test.cc index 99e91afe81..e32e98cbd7 100644 --- a/tests/ut/cpp/dataset/repeat_op_test.cc +++ b/tests/ut/cpp/dataset/repeat_op_test.cc @@ -33,18 +33,29 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) { auto my_tree = std::make_shared(); std::shared_ptr parent_op = std::make_shared(32); - - std::shared_ptr leaf_op = std::make_shared(16); + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; +// TFReaderOp + std::shared_ptr my_tfreader_op; + TFReaderOp::Builder builder; + builder.SetDatasetFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetWorkerConnectorSize(16) + .SetNumWorkers(16); + Status rc= builder.Build(&my_tfreader_op); + ASSERT_TRUE(rc.IsOk()); + rc = my_tree->AssociateNode(my_tfreader_op); + ASSERT_TRUE(rc.IsOk()); my_tree->AssociateNode(parent_op); - my_tree->AssociateNode(leaf_op); ASSERT_NE(parent_op, nullptr); - ASSERT_NE(leaf_op, nullptr); - parent_op->AddChild(std::move(leaf_op)); - parent_op->Print(std::cout, false); - parent_op->PrepareNodeAction(); + ASSERT_NE(my_tfreader_op, nullptr); + parent_op->AddChild(std::move(my_tfreader_op)); + MS_LOG(INFO) << parent_op; + my_tree->Prepare(); + RepeatOp RepeatOpOp(); std::shared_ptr repeat_op; - Status rc = RepeatOp::Builder(3).Build(&repeat_op); + rc = RepeatOp::Builder(3).Build(&repeat_op); ASSERT_NE(repeat_op, nullptr); } diff --git a/tests/ut/python/dataset/test_repeat.py b/tests/ut/python/dataset/test_repeat.py index 196a62c315..cb7a80e3d1 100644 --- a/tests/ut/python/dataset/test_repeat.py +++ b/tests/ut/python/dataset/test_repeat.py @@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision from util import save_and_check import mindspore.dataset as ds +import numpy as np from mindspore import log as logger DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] @@ -95,6 +96,141 @@ def test_tf_repeat_03(): assert num_iter == 2 +def generator(): + for i in range(3): + yield np.array([i]), + + +def test_nested_repeat1(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(2) + data = data.repeat(3) + + for i, d in enumerate(data): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data]) == 2 * 3 * 3 + + +def test_nested_repeat2(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(1) + data = data.repeat(1) + + for i, d in enumerate(data): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data]) == 3 + + +def test_nested_repeat3(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(1) + data = data.repeat(2) + + for i, d in enumerate(data): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data]) == 2 * 3 + + +def test_nested_repeat4(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(2) + data = data.repeat(1) + + for i, d in enumerate(data): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data]) == 2 * 3 + + +def test_nested_repeat5(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.batch(3) + data = data.repeat(2) + data = data.repeat(3) + + for i, d in enumerate(data): + assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) + + assert sum([1 for _ in data]) == 6 + + +def test_nested_repeat6(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(2) + data = data.batch(3) + data = data.repeat(3) + + for i, d in enumerate(data): + assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) + + assert sum([1 for _ in data]) == 6 + + +def test_nested_repeat7(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(2) + data = data.repeat(3) + data = data.batch(3) + + for i, d in enumerate(data): + assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) + + assert sum([1 for _ in data]) == 6 + + +def test_nested_repeat8(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.batch(2, drop_remainder=False) + data = data.repeat(2) + data = data.repeat(3) + + for i, d in enumerate(data): + if i % 2 == 0: + assert np.array_equal(d[0], np.asarray([[0], [1]])) + else: + assert np.array_equal(d[0], np.asarray([[2]])) + + assert sum([1 for _ in data]) == 6 * 2 + + +def test_nested_repeat9(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat() + data = data.repeat(3) + + for i, d in enumerate(data): + assert i % 3 == d[0][0] + if i == 10: + break + + +def test_nested_repeat10(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(3) + data = data.repeat() + + for i, d in enumerate(data): + assert i % 3 == d[0][0] + if i == 10: + break + + +def test_nested_repeat11(): + data = ds.GeneratorDataset(generator, ["data"]) + data = data.repeat(2) + data = data.repeat(3) + data = data.repeat(4) + data = data.repeat(5) + + for i, d in enumerate(data): + assert i % 3 == d[0][0] + + assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3 + + if __name__ == "__main__": logger.info("--------test tf repeat 01---------") # test_repeat_01() @@ -104,4 +240,3 @@ if __name__ == "__main__": logger.info("--------test tf repeat 03---------") test_tf_repeat_03() -