| @@ -161,15 +161,18 @@ Status DatasetOp::EofReceived(int32_t worker_id) { | |||
| return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer))); | |||
| } | |||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||
| // their role. | |||
| Status DatasetOp::PrepareNodeAction() { | |||
| Status DatasetOp::PrepareNodePreAction() { | |||
| if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated); | |||
| return Status::OK(); | |||
| } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| Status DatasetOp::PrepareNodePostAction() { | |||
| // If this op does not have any children and it is in a repeat path of the tree... | |||
| if (child_.size() == 0 && BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) { | |||
| // Then, flag this operator as a leaf node in a repeat path of tree execution. | |||
| BitSet(&op_ctrl_flags_, kDeOpRepeated); | |||
| // Secondly, push ourselves onto the tree repeat stack. Later, the repeat operator | |||
| if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { | |||
| // push ourselves onto the tree repeat stack. Later, the repeat operator | |||
| // above us will consume them. | |||
| tree_->AddToRepeatStack(shared_from_this()); | |||
| } | |||
| @@ -150,11 +150,17 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| return Status::OK(); | |||
| } | |||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| virtual Status PrepareNodeAction(); | |||
| virtual Status PrepareNodePreAction(); | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| virtual Status PrepareNodePostAction(); | |||
| // Getter function | |||
| // @return The operator id | |||
| @@ -64,14 +64,24 @@ class ParallelOp : public DatasetOp { | |||
| return out; | |||
| } | |||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| // @return Status - The error return code | |||
| Status PrepareNodeAction() override { | |||
| Status PrepareNodePreAction() override { | |||
| // Run common code from super class before adding ParallelOp specific logic | |||
| return (DatasetOp::PrepareNodeAction()); | |||
| return (DatasetOp::PrepareNodePreAction()); | |||
| } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| // @return Status - The error return code | |||
| Status PrepareNodePostAction() override { | |||
| // Run common code from super class before adding ParallelOp specific logic | |||
| return (DatasetOp::PrepareNodePostAction()); | |||
| } | |||
| // Override base class reset to provide reset actions specific to the ParallelOp class. | |||
| @@ -64,13 +64,22 @@ class PipelineOp : public DatasetOp { | |||
| // @return The number of threads that push data to the output connector | |||
| int32_t num_producers() const override { return 1; } | |||
| // During tree prepare phase, operators may have specific operations to perform depending on | |||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| Status PrepareNodeAction() override { | |||
| Status PrepareNodePreAction() override { | |||
| // Run common code from super class before adding PipelineOp specific logic | |||
| return (DatasetOp::PrepareNodeAction()); | |||
| return (DatasetOp::PrepareNodePreAction()); | |||
| } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override { | |||
| // Run common code from super class before adding PipelineOp specific logic | |||
| return (DatasetOp::PrepareNodePostAction()); | |||
| } | |||
| protected: | |||
| @@ -58,10 +58,10 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||
| out << "RepeatOp:" | |||
| << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ | |||
| << "\nLeaf Nodes in my execution path:"; | |||
| if (!leaf_ops_.empty()) { | |||
| if (!eoe_ops_.empty()) { | |||
| out << "\n"; | |||
| for (size_t i = 0; i < leaf_ops_.size(); i++) { | |||
| out << " Operator: " << leaf_ops_[i]->id() << "\n"; | |||
| for (size_t i = 0; i < eoe_ops_.size(); i++) { | |||
| out << " Operator: " << eoe_ops_[i]->id() << "\n"; | |||
| } | |||
| } else { | |||
| out << " kNone."; | |||
| @@ -71,21 +71,17 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||
| // Base-class override for executing specific RepeatOp configurations. This code will be called | |||
| // during the execution tree prepare phase when it is visiting this operator. | |||
| Status RepeatOp::PrepareNodeAction() { | |||
| Status RepeatOp::PrepareNodePostAction() { | |||
| // Run any common code from super class first before adding our own specific logic | |||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodeAction()); | |||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||
| std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack(); | |||
| while (leaf_op != nullptr) { | |||
| // Track the leaf operators that are under this repeat op. | |||
| leaf_ops_.push_back(leaf_op); | |||
| // Special case. If the repeat count is 1, then pre-flag the leaf nodes | |||
| // to tell them they are already at their last op: | |||
| if (max_repeats_ == 1) { | |||
| leaf_op->set_control_flag(kDeOpLastRepeat); | |||
| } | |||
| eoe_ops_.push_back(leaf_op); | |||
| leaf_op = tree_->PopFromRepeatStack(); | |||
| } | |||
| // Push ourselves to the stack in case one of our ascendants is repeat too. | |||
| tree_->AddToRepeatStack(shared_from_this()); | |||
| return Status::OK(); | |||
| } | |||
| @@ -127,16 +123,20 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||
| Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| repeat_count_++; | |||
| MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | |||
| // If we've reached the requested repeat count, then flag the leaf nodes | |||
| bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); | |||
| bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); | |||
| // If we've reached the requested repeat count, then flag the eoe nodes | |||
| // to tell them they've got one more epoch to perform. When they reach the end | |||
| // of the last epoch, they quit rather than loop again. | |||
| if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) { | |||
| for (size_t i = 0; i < leaf_ops_.size(); i++) { | |||
| leaf_ops_[i]->set_control_flag(kDeOpLastRepeat); | |||
| // of the last epoch, they quit rather than loop again. This happens in two cases: | |||
| // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, | |||
| // 2- We are not repeated | |||
| if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| eoe_op->set_control_flag(kDeOpLastRepeat); | |||
| } | |||
| } | |||
| if (repeat_count_ == max_repeats_) { | |||
| repeat_count_ = 0; | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| @@ -87,8 +87,8 @@ class RepeatOp : public PipelineOp { | |||
| uint32_t PrepareFlags() const override; | |||
| // Base-class override for executing specific RepeatOp configurations. This code will be called | |||
| // during the execution tree prepare phase when it is visiting this operator. | |||
| Status PrepareNodeAction() override; | |||
| // during the execution tree post-prepare phase when it is visiting this operator. | |||
| Status PrepareNodePostAction() override; | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||
| @@ -119,9 +119,9 @@ class RepeatOp : public PipelineOp { | |||
| int32_t num_producers() const override; | |||
| private: | |||
| int32_t max_repeats_; // The number of repeats that the user requested | |||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||
| std::vector<std::shared_ptr<DatasetOp>> leaf_ops_; // List of leaf operators underneath this repeat. | |||
| int32_t max_repeats_; // The number of repeats that the user requested | |||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||
| std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat. | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -162,30 +162,25 @@ Status ExecutionTree::Prepare() { | |||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | |||
| // node actions during a tree walk. | |||
| Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) { | |||
| int32_t num_children = dataset_op->child_.size(); | |||
| // execute PreAction | |||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); | |||
| // Before going down into children, make any prepare flags updates based on this | |||
| // operator. | |||
| // Before going down into children, make any prepare flags updates based on this operator. | |||
| uint32_t op_prep_flags = dataset_op->PrepareFlags(); | |||
| // Sanity check. In future we can support nested repeats. for now it's not allowed. | |||
| // If somebody above us already set the repeat flag, and now we are another repeat... | |||
| if (BitTest(op_prep_flags, kDePrepRepeat) && BitTest(prepare_flags_, kDePrepRepeat)) { | |||
| std::string err_msg("Nested RepeatOp detected! This is not supported yet."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| BitSet(&prepare_flags_, op_prep_flags); | |||
| // Now, descend to children | |||
| for (int32_t i = 0; i < num_children; ++i) { | |||
| RETURN_IF_NOT_OK(this->PrepareNode(dataset_op->child_[i])); | |||
| for (const auto &i : dataset_op->child_) { | |||
| RETURN_IF_NOT_OK(this->PrepareNode(i)); | |||
| } | |||
| // No more children, now we execute any prepare actions before going back up the | |||
| // the tree on recursive function exit | |||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodeAction()); | |||
| // Then clear the flags from this op now that we have prepared it. | |||
| BitClear(&prepare_flags_, op_prep_flags); | |||
| // No more children, now we execute any prepare actions before going back up the | |||
| // the tree on recursive function | |||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); | |||
| return Status::OK(); | |||
| } | |||
| @@ -417,6 +417,8 @@ class Dataset: | |||
| >>> repeat_and_shuffle = data.repeat(50) | |||
| >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) | |||
| """ | |||
| if count == 1: | |||
| return self | |||
| return RepeatDataset(self, count) | |||
| @check_zip_dataset | |||
| @@ -33,18 +33,29 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) { | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| std::shared_ptr<DatasetOp> parent_op = std::make_shared<RepeatOp>(32); | |||
| std::shared_ptr<DatasetOp> leaf_op = std::make_shared<RepeatOp>(16); | |||
| std::string dataset_path; | |||
| dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||
| // TFReaderOp | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| TFReaderOp::Builder builder; | |||
| builder.SetDatasetFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetWorkerConnectorSize(16) | |||
| .SetNumWorkers(16); | |||
| Status rc= builder.Build(&my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| my_tree->AssociateNode(parent_op); | |||
| my_tree->AssociateNode(leaf_op); | |||
| ASSERT_NE(parent_op, nullptr); | |||
| ASSERT_NE(leaf_op, nullptr); | |||
| parent_op->AddChild(std::move(leaf_op)); | |||
| parent_op->Print(std::cout, false); | |||
| parent_op->PrepareNodeAction(); | |||
| ASSERT_NE(my_tfreader_op, nullptr); | |||
| parent_op->AddChild(std::move(my_tfreader_op)); | |||
| MS_LOG(INFO) << parent_op; | |||
| my_tree->Prepare(); | |||
| RepeatOp RepeatOpOp(); | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| Status rc = RepeatOp::Builder(3).Build(&repeat_op); | |||
| rc = RepeatOp::Builder(3).Build(&repeat_op); | |||
| ASSERT_NE(repeat_op, nullptr); | |||
| } | |||
| @@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from util import save_and_check | |||
| import mindspore.dataset as ds | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] | |||
| @@ -95,6 +96,141 @@ def test_tf_repeat_03(): | |||
| assert num_iter == 2 | |||
| def generator(): | |||
| for i in range(3): | |||
| yield np.array([i]), | |||
| def test_nested_repeat1(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| for i, d in enumerate(data): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data]) == 2 * 3 * 3 | |||
| def test_nested_repeat2(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(1) | |||
| data = data.repeat(1) | |||
| for i, d in enumerate(data): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data]) == 3 | |||
| def test_nested_repeat3(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(1) | |||
| data = data.repeat(2) | |||
| for i, d in enumerate(data): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data]) == 2 * 3 | |||
| def test_nested_repeat4(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(1) | |||
| for i, d in enumerate(data): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data]) == 2 * 3 | |||
| def test_nested_repeat5(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.batch(3) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| for i, d in enumerate(data): | |||
| assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) | |||
| assert sum([1 for _ in data]) == 6 | |||
| def test_nested_repeat6(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.batch(3) | |||
| data = data.repeat(3) | |||
| for i, d in enumerate(data): | |||
| assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) | |||
| assert sum([1 for _ in data]) == 6 | |||
| def test_nested_repeat7(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| data = data.batch(3) | |||
| for i, d in enumerate(data): | |||
| assert np.array_equal(d[0], np.asarray([[0], [1], [2]])) | |||
| assert sum([1 for _ in data]) == 6 | |||
| def test_nested_repeat8(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.batch(2, drop_remainder=False) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| for i, d in enumerate(data): | |||
| if i % 2 == 0: | |||
| assert np.array_equal(d[0], np.asarray([[0], [1]])) | |||
| else: | |||
| assert np.array_equal(d[0], np.asarray([[2]])) | |||
| assert sum([1 for _ in data]) == 6 * 2 | |||
| def test_nested_repeat9(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat() | |||
| data = data.repeat(3) | |||
| for i, d in enumerate(data): | |||
| assert i % 3 == d[0][0] | |||
| if i == 10: | |||
| break | |||
| def test_nested_repeat10(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(3) | |||
| data = data.repeat() | |||
| for i, d in enumerate(data): | |||
| assert i % 3 == d[0][0] | |||
| if i == 10: | |||
| break | |||
| def test_nested_repeat11(): | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| data = data.repeat(4) | |||
| data = data.repeat(5) | |||
| for i, d in enumerate(data): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3 | |||
| if __name__ == "__main__": | |||
| logger.info("--------test tf repeat 01---------") | |||
| # test_repeat_01() | |||
| @@ -104,4 +240,3 @@ if __name__ == "__main__": | |||
| logger.info("--------test tf repeat 03---------") | |||
| test_tf_repeat_03() | |||