Browse Source

Support nested repeat

tags/v0.2.0-alpha
hesham 5 years ago
parent
commit
0fc23eee0f
10 changed files with 233 additions and 62 deletions
  1. +10
    -7
      mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
  2. +8
    -2
      mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
  3. +13
    -3
      mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h
  4. +12
    -3
      mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h
  5. +18
    -18
      mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
  6. +5
    -5
      mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
  7. +10
    -15
      mindspore/ccsrc/dataset/engine/execution_tree.cc
  8. +2
    -0
      mindspore/dataset/engine/datasets.py
  9. +19
    -8
      tests/ut/cpp/dataset/repeat_op_test.cc
  10. +136
    -1
      tests/ut/python/dataset/test_repeat.py

+ 10
- 7
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc View File

@@ -161,15 +161,18 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer)));
}

// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodeAction() {
Status DatasetOp::PrepareNodePreAction() {
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
return Status::OK();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodePostAction() {
// If this op does not have any children and it is in a repeat path of the tree...
if (child_.size() == 0 && BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) {
// Then, flag this operator as a leaf node in a repeat path of tree execution.
BitSet(&op_ctrl_flags_, kDeOpRepeated);

// Secondly, push ourselves onto the tree repeat stack. Later, the repeat operator
if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
// push ourselves onto the tree repeat stack. Later, the repeat operator
// above us will consume them.
tree_->AddToRepeatStack(shared_from_this());
}


+ 8
- 2
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h View File

@@ -150,11 +150,17 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
return Status::OK();
}

// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
virtual Status PrepareNodeAction();
virtual Status PrepareNodePreAction();

// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
virtual Status PrepareNodePostAction();

// Getter function
// @return The operator id


+ 13
- 3
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h View File

@@ -64,14 +64,24 @@ class ParallelOp : public DatasetOp {
return out;
}

// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
// @return Status - The error return code
Status PrepareNodeAction() override {
Status PrepareNodePreAction() override {
// Run common code from super class before adding ParallelOp specific logic
return (DatasetOp::PrepareNodeAction());
return (DatasetOp::PrepareNodePreAction());
}

// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
// @return Status - The error return code
Status PrepareNodePostAction() override {
// Run common code from super class before adding ParallelOp specific logic
return (DatasetOp::PrepareNodePostAction());
}

// Override base class reset to provide reset actions specific to the ParallelOp class.


+ 12
- 3
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h View File

@@ -64,13 +64,22 @@ class PipelineOp : public DatasetOp {
// @return The number of threads that push data to the output connector
int32_t num_producers() const override { return 1; }

// During tree prepare phase, operators may have specific operations to perform depending on
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodeAction() override {
Status PrepareNodePreAction() override {
// Run common code from super class before adding PipelineOp specific logic
return (DatasetOp::PrepareNodeAction());
return (DatasetOp::PrepareNodePreAction());
}

// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override {
// Run common code from super class before adding PipelineOp specific logic
return (DatasetOp::PrepareNodePostAction());
}

protected:


+ 18
- 18
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc View File

@@ -58,10 +58,10 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
out << "RepeatOp:"
<< "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_
<< "\nLeaf Nodes in my execution path:";
if (!leaf_ops_.empty()) {
if (!eoe_ops_.empty()) {
out << "\n";
for (size_t i = 0; i < leaf_ops_.size(); i++) {
out << " Operator: " << leaf_ops_[i]->id() << "\n";
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << " Operator: " << eoe_ops_[i]->id() << "\n";
}
} else {
out << " kNone.";
@@ -71,21 +71,17 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {

// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status RepeatOp::PrepareNodeAction() {
Status RepeatOp::PrepareNodePostAction() {
// Run any common code from super class first before adding our own specific logic
RETURN_IF_NOT_OK(PipelineOp::PrepareNodeAction());
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack();
while (leaf_op != nullptr) {
// Track the leaf operators that are under this repeat op.
leaf_ops_.push_back(leaf_op);

// Special case. If the repeat count is 1, then pre-flag the leaf nodes
// to tell them they are already at their last op:
if (max_repeats_ == 1) {
leaf_op->set_control_flag(kDeOpLastRepeat);
}
eoe_ops_.push_back(leaf_op);
leaf_op = tree_->PopFromRepeatStack();
}
// Push ourselves to the stack in case one of our ascendants is repeat too.
tree_->AddToRepeatStack(shared_from_this());
return Status::OK();
}

@@ -127,16 +123,20 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
Status RepeatOp::EoeReceived(int32_t worker_id) {
repeat_count_++;
MS_LOG(INFO) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";

// If we've reached the requested repeat count, then flag the leaf nodes
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
// If we've reached the requested repeat count, then flag the eoe nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
for (size_t i = 0; i < leaf_ops_.size(); i++) {
leaf_ops_[i]->set_control_flag(kDeOpLastRepeat);
// of the last epoch, they quit rather than loop again. This happens in two cases:
// 1- We are also repeated (by another repeat op) and we are at the last repetition. Or,
// 2- We are not repeated
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) {
for (auto &eoe_op : eoe_ops_) {
eoe_op->set_control_flag(kDeOpLastRepeat);
}
}
if (repeat_count_ == max_repeats_) {
repeat_count_ = 0;
state_ = OpState::kDeOpIdle;
return Status::OK();
}


+ 5
- 5
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h View File

@@ -87,8 +87,8 @@ class RepeatOp : public PipelineOp {
uint32_t PrepareFlags() const override;

// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status PrepareNodeAction() override;
// during the execution tree post-prepare phase when it is visiting this operator.
Status PrepareNodePostAction() override;

// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
@@ -119,9 +119,9 @@ class RepeatOp : public PipelineOp {
int32_t num_producers() const override;

private:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
std::vector<std::shared_ptr<DatasetOp>> leaf_ops_; // List of leaf operators underneath this repeat.
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
};
} // namespace dataset
} // namespace mindspore


+ 10
- 15
mindspore/ccsrc/dataset/engine/execution_tree.cc View File

@@ -162,30 +162,25 @@ Status ExecutionTree::Prepare() {
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) {
int32_t num_children = dataset_op->child_.size();
// execute PreAction
RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction());

// Before going down into children, make any prepare flags updates based on this
// operator.
// Before going down into children, make any prepare flags updates based on this operator.
uint32_t op_prep_flags = dataset_op->PrepareFlags();
// Sanity check. In future we can support nested repeats. for now it's not allowed.
// If somebody above us already set the repeat flag, and now we are another repeat...
if (BitTest(op_prep_flags, kDePrepRepeat) && BitTest(prepare_flags_, kDePrepRepeat)) {
std::string err_msg("Nested RepeatOp detected! This is not supported yet.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
BitSet(&prepare_flags_, op_prep_flags);

// Now, descend to children
for (int32_t i = 0; i < num_children; ++i) {
RETURN_IF_NOT_OK(this->PrepareNode(dataset_op->child_[i]));
for (const auto &i : dataset_op->child_) {
RETURN_IF_NOT_OK(this->PrepareNode(i));
}

// No more children, now we execute any prepare actions before going back up the
// the tree on recursive function exit
RETURN_IF_NOT_OK(dataset_op->PrepareNodeAction());

// Then clear the flags from this op now that we have prepared it.
BitClear(&prepare_flags_, op_prep_flags);

// No more children, now we execute any prepare actions before going back up the
// the tree on recursive function
RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction());

return Status::OK();
}



+ 2
- 0
mindspore/dataset/engine/datasets.py View File

@@ -417,6 +417,8 @@ class Dataset:
>>> repeat_and_shuffle = data.repeat(50)
>>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
"""
if count == 1:
return self
return RepeatDataset(self, count)

@check_zip_dataset


+ 19
- 8
tests/ut/cpp/dataset/repeat_op_test.cc View File

@@ -33,18 +33,29 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
auto my_tree = std::make_shared<ExecutionTree>();

std::shared_ptr<DatasetOp> parent_op = std::make_shared<RepeatOp>(32);

std::shared_ptr<DatasetOp> leaf_op = std::make_shared<RepeatOp>(16);
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
// TFReaderOp
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetWorkerConnectorSize(16)
.SetNumWorkers(16);
Status rc= builder.Build(&my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
my_tree->AssociateNode(parent_op);
my_tree->AssociateNode(leaf_op);
ASSERT_NE(parent_op, nullptr);
ASSERT_NE(leaf_op, nullptr);
parent_op->AddChild(std::move(leaf_op));
parent_op->Print(std::cout, false);
parent_op->PrepareNodeAction();
ASSERT_NE(my_tfreader_op, nullptr);
parent_op->AddChild(std::move(my_tfreader_op));
MS_LOG(INFO) << parent_op;
my_tree->Prepare();

RepeatOp RepeatOpOp();

std::shared_ptr<RepeatOp> repeat_op;
Status rc = RepeatOp::Builder(3).Build(&repeat_op);
rc = RepeatOp::Builder(3).Build(&repeat_op);
ASSERT_NE(repeat_op, nullptr);
}

+ 136
- 1
tests/ut/python/dataset/test_repeat.py View File

@@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
from util import save_and_check

import mindspore.dataset as ds
import numpy as np
from mindspore import log as logger

DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
@@ -95,6 +96,141 @@ def test_tf_repeat_03():
assert num_iter == 2


def generator():
for i in range(3):
yield np.array([i]),


def test_nested_repeat1():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)

for i, d in enumerate(data):
assert i % 3 == d[0][0]

assert sum([1 for _ in data]) == 2 * 3 * 3


def test_nested_repeat2():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(1)

for i, d in enumerate(data):
assert i % 3 == d[0][0]

assert sum([1 for _ in data]) == 3


def test_nested_repeat3():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(2)

for i, d in enumerate(data):
assert i % 3 == d[0][0]

assert sum([1 for _ in data]) == 2 * 3


def test_nested_repeat4():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(1)

for i, d in enumerate(data):
assert i % 3 == d[0][0]

assert sum([1 for _ in data]) == 2 * 3


def test_nested_repeat5():
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(3)
data = data.repeat(2)
data = data.repeat(3)

for i, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))

assert sum([1 for _ in data]) == 6


def test_nested_repeat6():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.batch(3)
data = data.repeat(3)

for i, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))

assert sum([1 for _ in data]) == 6


def test_nested_repeat7():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
data = data.batch(3)

for i, d in enumerate(data):
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))

assert sum([1 for _ in data]) == 6


def test_nested_repeat8():
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(2, drop_remainder=False)
data = data.repeat(2)
data = data.repeat(3)

for i, d in enumerate(data):
if i % 2 == 0:
assert np.array_equal(d[0], np.asarray([[0], [1]]))
else:
assert np.array_equal(d[0], np.asarray([[2]]))

assert sum([1 for _ in data]) == 6 * 2


def test_nested_repeat9():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat()
data = data.repeat(3)

for i, d in enumerate(data):
assert i % 3 == d[0][0]
if i == 10:
break


def test_nested_repeat10():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(3)
data = data.repeat()

for i, d in enumerate(data):
assert i % 3 == d[0][0]
if i == 10:
break


def test_nested_repeat11():
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
data = data.repeat(4)
data = data.repeat(5)

for i, d in enumerate(data):
assert i % 3 == d[0][0]

assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3


if __name__ == "__main__":
logger.info("--------test tf repeat 01---------")
# test_repeat_01()
@@ -104,4 +240,3 @@ if __name__ == "__main__":

logger.info("--------test tf repeat 03---------")
test_tf_repeat_03()


Loading…
Cancel
Save