| @@ -161,15 +161,18 @@ Status DatasetOp::EofReceived(int32_t worker_id) { | |||||
| return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer))); | return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer))); | ||||
| } | } | ||||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||||
| // their role. | // their role. | ||||
| Status DatasetOp::PrepareNodeAction() { | |||||
| Status DatasetOp::PrepareNodePreAction() { | |||||
| if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated); | |||||
| return Status::OK(); | |||||
| } | |||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| Status DatasetOp::PrepareNodePostAction() { | |||||
| // If this op does not have any children and it is in a repeat path of the tree... | // If this op does not have any children and it is in a repeat path of the tree... | ||||
| if (child_.size() == 0 && BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) { | |||||
| // Then, flag this operator as a leaf node in a repeat path of tree execution. | |||||
| BitSet(&op_ctrl_flags_, kDeOpRepeated); | |||||
| // Secondly, push ourselves onto the tree repeat stack. Later, the repeat operator | |||||
| if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { | |||||
| // push ourselves onto the tree repeat stack. Later, the repeat operator | |||||
| // above us will consume them. | // above us will consume them. | ||||
| tree_->AddToRepeatStack(shared_from_this()); | tree_->AddToRepeatStack(shared_from_this()); | ||||
| } | } | ||||
| @@ -150,11 +150,17 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||||
| // their role. | // their role. | ||||
| // @notes Derived versions of this function should always call it's superclass version first | // @notes Derived versions of this function should always call it's superclass version first | ||||
| // before providing their own implementations. | // before providing their own implementations. | ||||
| virtual Status PrepareNodeAction(); | |||||
| virtual Status PrepareNodePreAction(); | |||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| // @notes Derived versions of this function should always call it's superclass version first | |||||
| // before providing their own implementations. | |||||
| virtual Status PrepareNodePostAction(); | |||||
| // Getter function | // Getter function | ||||
| // @return The operator id | // @return The operator id | ||||
| @@ -64,14 +64,24 @@ class ParallelOp : public DatasetOp { | |||||
| return out; | return out; | ||||
| } | } | ||||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||||
| // their role. | // their role. | ||||
| // @notes Derived versions of this function should always call it's superclass version first | // @notes Derived versions of this function should always call it's superclass version first | ||||
| // before providing their own implementations. | // before providing their own implementations. | ||||
| // @return Status - The error return code | // @return Status - The error return code | ||||
| Status PrepareNodeAction() override { | |||||
| Status PrepareNodePreAction() override { | |||||
| // Run common code from super class before adding ParallelOp specific logic | // Run common code from super class before adding ParallelOp specific logic | ||||
| return (DatasetOp::PrepareNodeAction()); | |||||
| return (DatasetOp::PrepareNodePreAction()); | |||||
| } | |||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| // @notes Derived versions of this function should always call it's superclass version first | |||||
| // before providing their own implementations. | |||||
| // @return Status - The error return code | |||||
| Status PrepareNodePostAction() override { | |||||
| // Run common code from super class before adding ParallelOp specific logic | |||||
| return (DatasetOp::PrepareNodePostAction()); | |||||
| } | } | ||||
| // Override base class reset to provide reset actions specific to the ParallelOp class. | // Override base class reset to provide reset actions specific to the ParallelOp class. | ||||
| @@ -64,13 +64,22 @@ class PipelineOp : public DatasetOp { | |||||
| // @return The number of threads that push data to the output connector | // @return The number of threads that push data to the output connector | ||||
| int32_t num_producers() const override { return 1; } | int32_t num_producers() const override { return 1; } | ||||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||||
| // their role. | // their role. | ||||
| // @notes Derived versions of this function should always call it's superclass version first | // @notes Derived versions of this function should always call it's superclass version first | ||||
| // before providing their own implementations. | // before providing their own implementations. | ||||
| Status PrepareNodeAction() override { | |||||
| Status PrepareNodePreAction() override { | |||||
| // Run common code from super class before adding PipelineOp specific logic | // Run common code from super class before adding PipelineOp specific logic | ||||
| return (DatasetOp::PrepareNodeAction()); | |||||
| return (DatasetOp::PrepareNodePreAction()); | |||||
| } | |||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| // @notes Derived versions of this function should always call it's superclass version first | |||||
| // before providing their own implementations. | |||||
| Status PrepareNodePostAction() override { | |||||
| // Run common code from super class before adding PipelineOp specific logic | |||||
| return (DatasetOp::PrepareNodePostAction()); | |||||
| } | } | ||||
| protected: | protected: | ||||
| @@ -58,10 +58,10 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||||
| out << "RepeatOp:" | out << "RepeatOp:" | ||||
| << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ | << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ | ||||
| << "\nLeaf Nodes in my execution path:"; | << "\nLeaf Nodes in my execution path:"; | ||||
| if (!leaf_ops_.empty()) { | |||||
| if (!eoe_ops_.empty()) { | |||||
| out << "\n"; | out << "\n"; | ||||
| for (size_t i = 0; i < leaf_ops_.size(); i++) { | |||||
| out << " Operator: " << leaf_ops_[i]->id() << "\n"; | |||||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||||
| out << " Operator: " << eoe_ops_[i]->id() << "\n"; | |||||
| } | } | ||||
| } else { | } else { | ||||
| out << " kNone."; | out << " kNone."; | ||||
| @@ -71,21 +71,17 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Base-class override for executing specific RepeatOp configurations. This code will be called | // Base-class override for executing specific RepeatOp configurations. This code will be called | ||||
| // during the execution tree prepare phase when it is visiting this operator. | // during the execution tree prepare phase when it is visiting this operator. | ||||
| Status RepeatOp::PrepareNodeAction() { | |||||
| Status RepeatOp::PrepareNodePostAction() { | |||||
| // Run any common code from super class first before adding our own specific logic | // Run any common code from super class first before adding our own specific logic | ||||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodeAction()); | |||||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||||
| std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack(); | std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack(); | ||||
| while (leaf_op != nullptr) { | while (leaf_op != nullptr) { | ||||
| // Track the leaf operators that are under this repeat op. | // Track the leaf operators that are under this repeat op. | ||||
| leaf_ops_.push_back(leaf_op); | |||||
| // Special case. If the repeat count is 1, then pre-flag the leaf nodes | |||||
| // to tell them they are already at their last op: | |||||
| if (max_repeats_ == 1) { | |||||
| leaf_op->set_control_flag(kDeOpLastRepeat); | |||||
| } | |||||
| eoe_ops_.push_back(leaf_op); | |||||
| leaf_op = tree_->PopFromRepeatStack(); | leaf_op = tree_->PopFromRepeatStack(); | ||||
| } | } | ||||
| // Push ourselves to the stack in case one of our ascendants is repeat too. | |||||
| tree_->AddToRepeatStack(shared_from_this()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -127,16 +123,20 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||||
| Status RepeatOp::EoeReceived(int32_t worker_id) { | Status RepeatOp::EoeReceived(int32_t worker_id) { | ||||
| repeat_count_++; | repeat_count_++; | ||||
| MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | ||||
| // If we've reached the requested repeat count, then flag the leaf nodes | |||||
| bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); | |||||
| bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); | |||||
| // If we've reached the requested repeat count, then flag the eoe nodes | |||||
| // to tell them they've got one more epoch to perform. When they reach the end | // to tell them they've got one more epoch to perform. When they reach the end | ||||
| // of the last epoch, they quit rather than loop again. | |||||
| if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) { | |||||
| for (size_t i = 0; i < leaf_ops_.size(); i++) { | |||||
| leaf_ops_[i]->set_control_flag(kDeOpLastRepeat); | |||||
| // of the last epoch, they quit rather than loop again. This happens in two cases: | |||||
| // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, | |||||
| // 2- We are not repeated | |||||
| if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { | |||||
| for (auto &eoe_op : eoe_ops_) { | |||||
| eoe_op->set_control_flag(kDeOpLastRepeat); | |||||
| } | } | ||||
| } | } | ||||
| if (repeat_count_ == max_repeats_) { | if (repeat_count_ == max_repeats_) { | ||||
| repeat_count_ = 0; | |||||
| state_ = OpState::kDeOpIdle; | state_ = OpState::kDeOpIdle; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -87,8 +87,8 @@ class RepeatOp : public PipelineOp { | |||||
| uint32_t PrepareFlags() const override; | uint32_t PrepareFlags() const override; | ||||
| // Base-class override for executing specific RepeatOp configurations. This code will be called | // Base-class override for executing specific RepeatOp configurations. This code will be called | ||||
| // during the execution tree prepare phase when it is visiting this operator. | |||||
| Status PrepareNodeAction() override; | |||||
| // during the execution tree post-prepare phase when it is visiting this operator. | |||||
| Status PrepareNodePostAction() override; | |||||
| // This function returns the buffer that is at the top of our output connector. The caller is | // This function returns the buffer that is at the top of our output connector. The caller is | ||||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | // typically our parent node, when the parent is asking us to provide the next buffer of data. | ||||
| @@ -119,9 +119,9 @@ class RepeatOp : public PipelineOp { | |||||
| int32_t num_producers() const override; | int32_t num_producers() const override; | ||||
| private: | private: | ||||
| int32_t max_repeats_; // The number of repeats that the user requested | |||||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||||
| std::vector<std::shared_ptr<DatasetOp>> leaf_ops_; // List of leaf operators underneath this repeat. | |||||
| int32_t max_repeats_; // The number of repeats that the user requested | |||||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||||
| std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -162,30 +162,25 @@ Status ExecutionTree::Prepare() { | |||||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | // Recursive function used during prepare phase to visit a node and drive any pre- and post- | ||||
| // node actions during a tree walk. | // node actions during a tree walk. | ||||
| Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) { | Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) { | ||||
| int32_t num_children = dataset_op->child_.size(); | |||||
| // execute PreAction | |||||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); | |||||
| // Before going down into children, make any prepare flags updates based on this | |||||
| // operator. | |||||
| // Before going down into children, make any prepare flags updates based on this operator. | |||||
| uint32_t op_prep_flags = dataset_op->PrepareFlags(); | uint32_t op_prep_flags = dataset_op->PrepareFlags(); | ||||
| // Sanity check. In future we can support nested repeats. for now it's not allowed. | |||||
| // If somebody above us already set the repeat flag, and now we are another repeat... | |||||
| if (BitTest(op_prep_flags, kDePrepRepeat) && BitTest(prepare_flags_, kDePrepRepeat)) { | |||||
| std::string err_msg("Nested RepeatOp detected! This is not supported yet."); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| BitSet(&prepare_flags_, op_prep_flags); | BitSet(&prepare_flags_, op_prep_flags); | ||||
| // Now, descend to children | // Now, descend to children | ||||
| for (int32_t i = 0; i < num_children; ++i) { | |||||
| RETURN_IF_NOT_OK(this->PrepareNode(dataset_op->child_[i])); | |||||
| for (const auto &i : dataset_op->child_) { | |||||
| RETURN_IF_NOT_OK(this->PrepareNode(i)); | |||||
| } | } | ||||
| // No more children, now we execute any prepare actions before going back up the | |||||
| // the tree on recursive function exit | |||||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodeAction()); | |||||
| // Then clear the flags from this op now that we have prepared it. | // Then clear the flags from this op now that we have prepared it. | ||||
| BitClear(&prepare_flags_, op_prep_flags); | BitClear(&prepare_flags_, op_prep_flags); | ||||
| // No more children, now we execute any prepare actions before going back up the | |||||
| // the tree on recursive function | |||||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -417,6 +417,8 @@ class Dataset: | |||||
| >>> repeat_and_shuffle = data.repeat(50) | >>> repeat_and_shuffle = data.repeat(50) | ||||
| >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) | >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) | ||||
| """ | """ | ||||
| if count == 1: | |||||
| return self | |||||
| return RepeatDataset(self, count) | return RepeatDataset(self, count) | ||||
| @check_zip_dataset | @check_zip_dataset | ||||
| @@ -33,18 +33,29 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) { | |||||
| auto my_tree = std::make_shared<ExecutionTree>(); | auto my_tree = std::make_shared<ExecutionTree>(); | ||||
| std::shared_ptr<DatasetOp> parent_op = std::make_shared<RepeatOp>(32); | std::shared_ptr<DatasetOp> parent_op = std::make_shared<RepeatOp>(32); | ||||
| std::shared_ptr<DatasetOp> leaf_op = std::make_shared<RepeatOp>(16); | |||||
| std::string dataset_path; | |||||
| dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||||
| // TFReaderOp | |||||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||||
| TFReaderOp::Builder builder; | |||||
| builder.SetDatasetFilesList({dataset_path}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetWorkerConnectorSize(16) | |||||
| .SetNumWorkers(16); | |||||
| Status rc= builder.Build(&my_tfreader_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssociateNode(my_tfreader_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| my_tree->AssociateNode(parent_op); | my_tree->AssociateNode(parent_op); | ||||
| my_tree->AssociateNode(leaf_op); | |||||
| ASSERT_NE(parent_op, nullptr); | ASSERT_NE(parent_op, nullptr); | ||||
| ASSERT_NE(leaf_op, nullptr); | |||||
| parent_op->AddChild(std::move(leaf_op)); | |||||
| parent_op->Print(std::cout, false); | |||||
| parent_op->PrepareNodeAction(); | |||||
| ASSERT_NE(my_tfreader_op, nullptr); | |||||
| parent_op->AddChild(std::move(my_tfreader_op)); | |||||
| MS_LOG(INFO) << parent_op; | |||||
| my_tree->Prepare(); | |||||
| RepeatOp RepeatOpOp(); | RepeatOp RepeatOpOp(); | ||||
| std::shared_ptr<RepeatOp> repeat_op; | std::shared_ptr<RepeatOp> repeat_op; | ||||
| Status rc = RepeatOp::Builder(3).Build(&repeat_op); | |||||
| rc = RepeatOp::Builder(3).Build(&repeat_op); | |||||
| ASSERT_NE(repeat_op, nullptr); | ASSERT_NE(repeat_op, nullptr); | ||||
| } | } | ||||
| @@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision | |||||
| from util import save_and_check | from util import save_and_check | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import numpy as np | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] | DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] | ||||
| @@ -95,6 +96,141 @@ def test_tf_repeat_03(): | |||||
| assert num_iter == 2 | assert num_iter == 2 | ||||
| def generator(): | |||||
| for i in range(3): | |||||
| yield np.array([i]), | |||||
| def test_nested_repeat1(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(2) | |||||
| data = data.repeat(3) | |||||
| for i, d in enumerate(data): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data]) == 2 * 3 * 3 | |||||
| def test_nested_repeat2(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(1) | |||||
| data = data.repeat(1) | |||||
| for i, d in enumerate(data): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data]) == 3 | |||||
| def test_nested_repeat3(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(1) | |||||
| data = data.repeat(2) | |||||
| for i, d in enumerate(data): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data]) == 2 * 3 | |||||
| def test_nested_repeat4(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(2) | |||||
| data = data.repeat(1) | |||||
| for i, d in enumerate(data): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data]) == 2 * 3 | |||||
| def test_nested_repeat5(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.batch(3) | |||||
| data = data.repeat(2) | |||||
| data = data.repeat(3) | |||||
| for i, d in enumerate(data): | |||||
| assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) | |||||
| assert sum([1 for _ in data]) == 6 | |||||
| def test_nested_repeat6(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(2) | |||||
| data = data.batch(3) | |||||
| data = data.repeat(3) | |||||
| for i, d in enumerate(data): | |||||
| assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) | |||||
| assert sum([1 for _ in data]) == 6 | |||||
| def test_nested_repeat7(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(2) | |||||
| data = data.repeat(3) | |||||
| data = data.batch(3) | |||||
| for i, d in enumerate(data): | |||||
| assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) | |||||
| assert sum([1 for _ in data]) == 6 | |||||
| def test_nested_repeat8(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.batch(2, drop_remainder=False) | |||||
| data = data.repeat(2) | |||||
| data = data.repeat(3) | |||||
| for i, d in enumerate(data): | |||||
| if i % 2 == 0: | |||||
| assert np.array_equal(d[0], np.asarray([[0], [1]])) | |||||
| else: | |||||
| assert np.array_equal(d[0], np.asarray([[2]])) | |||||
| assert sum([1 for _ in data]) == 6 * 2 | |||||
| def test_nested_repeat9(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat() | |||||
| data = data.repeat(3) | |||||
| for i, d in enumerate(data): | |||||
| assert i % 3 == d[0][0] | |||||
| if i == 10: | |||||
| break | |||||
| def test_nested_repeat10(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(3) | |||||
| data = data.repeat() | |||||
| for i, d in enumerate(data): | |||||
| assert i % 3 == d[0][0] | |||||
| if i == 10: | |||||
| break | |||||
| def test_nested_repeat11(): | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| data = data.repeat(2) | |||||
| data = data.repeat(3) | |||||
| data = data.repeat(4) | |||||
| data = data.repeat(5) | |||||
| for i, d in enumerate(data): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3 | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| logger.info("--------test tf repeat 01---------") | logger.info("--------test tf repeat 01---------") | ||||
| # test_repeat_01() | # test_repeat_01() | ||||
| @@ -104,4 +240,3 @@ if __name__ == "__main__": | |||||
| logger.info("--------test tf repeat 03---------") | logger.info("--------test tf repeat 03---------") | ||||
| test_tf_repeat_03() | test_tf_repeat_03() | ||||