Browse Source

Migrate GeneratorOp resetting logic to IR optimizer (incl. debug code)

with comments from the peer view addressed
tags/v1.2.0-rc1
Nat Sutyanyong 5 years ago
parent
commit
686772cce7
17 changed files with 602 additions and 131 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  2. +20
    -6
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc
  3. +18
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h
  4. +16
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  5. +37
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  6. +28
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc
  7. +24
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h
  8. +32
    -16
      mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
  9. +16
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
  10. +8
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.h
  11. +108
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/post/generator_node_pass.cc
  12. +75
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/post/generator_node_pass.h
  13. +1
    -82
      mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc
  14. +0
    -13
      mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
  15. +6
    -0
      mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc
  16. +4
    -0
      mindspore/lite/minddata/CMakeLists.txt
  17. +208
    -0
      tests/ut/python/dataset/test_generator_reset_pass.py

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -258,7 +258,7 @@ void DatasetNode::PrintTree(std::ostream &out) const {


void DatasetNode::PrintNode(std::ostream &out, int *level) const { void DatasetNode::PrintNode(std::ostream &out, int *level) const {
const std::string prefix = "+-"; const std::string prefix = "+-";
const std::string indent = " ";
const std::string indent = "| ";
out << prefix; out << prefix;
Print(out); Print(out);
for (const auto &c : this->Children()) { for (const auto &c : this->Children()) {


+ 20
- 6
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc View File

@@ -28,30 +28,33 @@ namespace mindspore {
namespace dataset { namespace dataset {


// Constructor for EpochCtrlNode // Constructor for EpochCtrlNode
EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) {
EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : RepeatNode() {
// The root node's parent must set to null pointer. // The root node's parent must set to null pointer.
this->AddChild(child); this->AddChild(child);
repeat_count_ = num_epochs;
} }


std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() { std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
auto node = std::make_shared<EpochCtrlNode>(num_epochs_);
auto node = std::make_shared<EpochCtrlNode>(repeat_count_);
return node; return node;
} }


void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; }
void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(repeat_count_) + ")"; }


// Function to build the EpochCtrlOp // Function to build the EpochCtrlOp
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<EpochCtrlOp>(num_epochs_));
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
node_ops->push_back(new_op_);
op_ = new_op_;
return Status::OK(); return Status::OK();
} }


// Function to validate the parameters for EpochCtrlNode // Function to validate the parameters for EpochCtrlNode
Status EpochCtrlNode::ValidateParams() { Status EpochCtrlNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
if (num_epochs_ <= 0 && num_epochs_ != -1) {
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg = std::string err_msg =
"EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
"EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(repeat_count_);
MS_LOG(ERROR) << err_msg; MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }
@@ -63,5 +66,16 @@ Status EpochCtrlNode::ValidateParams() {
return Status::OK(); return Status::OK();
} }


// Visitor accepting method for IRNodePass
Status EpochCtrlNode::Accept(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<EpochCtrlNode>(), modified);
}

// Visitor accepting method for IRNodePass
Status EpochCtrlNode::AcceptAfter(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<EpochCtrlNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 18
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h View File

@@ -21,15 +21,20 @@
#include <string> #include <string>
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class EpochCtrlNode : public DatasetNode {
class EpochCtrlNode : public RepeatNode {
// Allow GeneratorNode to access internal members
friend class GeneratorNode;

public: public:
/// \brief Constructor /// \brief Constructor
explicit EpochCtrlNode(int32_t num_epochs) : num_epochs_(num_epochs) {}
explicit EpochCtrlNode(int32_t num_epochs) : RepeatNode() { repeat_count_ = num_epochs; }


/// \brief Constructor /// \brief Constructor
EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
@@ -58,8 +63,17 @@ class EpochCtrlNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


private:
int32_t num_epochs_;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *const modified) override;

/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *const modified) override;
}; };


} // namespace dataset } // namespace dataset


+ 16
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -26,7 +26,8 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) {
RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count)
: repeat_count_(count), reset_ancestor_(nullptr), op_(nullptr) {
this->AddChild(child); this->AddChild(child);
} }


@@ -35,10 +36,22 @@ std::shared_ptr<DatasetNode> RepeatNode::Copy() {
return node; return node;
} }


void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; }
void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ") "; }


Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<RepeatOp>(repeat_count_));
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
node_ops->push_back(new_op);
op_ = new_op;

// Add this RepeatOp to its RepeatOp/EpochCtrlOp ancestor's EOE list.
// When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list.
// The ancestor is updated by GeneratorNodePass post pass.
// Assumption:
// We build the run-time ops from IR nodes from top to bottom. Hence Repeat/EpochCtrl ancestor ops are built
// before this leaf Generator op is built.
if (reset_ancestor_ != nullptr) {
reset_ancestor_->op_->AddToEoeList(new_op);
}
return Status::OK(); return Status::OK();
} }




+ 37
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -28,10 +28,18 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {


class RepeatOp;

class RepeatNode : public DatasetNode { class RepeatNode : public DatasetNode {
// Allow GeneratorNode to access internal members
friend class GeneratorNode;

public: public:
/// \brief Constructor /// \brief Constructor
explicit RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count);
RepeatNode() : op_(nullptr), reset_ancestor_(nullptr), repeat_count_(-1) {}

/// \brief Constructor
RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count);


/// \brief Destructor /// \brief Destructor
~RepeatNode() = default; ~RepeatNode() = default;
@@ -82,7 +90,34 @@ class RepeatNode : public DatasetNode {
/// \return Status of the node visit /// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override; Status AcceptAfter(IRNodePass *const p, bool *const modified) override;


private:
/// \brief Record the Repeat/EpochCtrl node that is the closest ancestor of this node
/// \param[in] the ancestor node
/// \return Status of the function
Status AddResetAncestor(const std::shared_ptr<RepeatNode> &src) {
/*
* This check is to ensure we don't overwrite an existing value of its ancestor.
* It is okay to assign to the same value more than once in RepeatNode (but not in GeneratorNode).
* Consider the following scenario
* EpochCtrl(-1)
* |
* Repeat
* |
* Concat
* / \
* GenData1 GenData2
*
* We will record the ancestor relationship of (Repeat, EpochCtrl) twice, one at Visit(GenData1), the other at
* Vist(GenData2).
*/
CHECK_FAIL_RETURN_UNEXPECTED(reset_ancestor_ == nullptr || reset_ancestor_ == src,
"Internal error: Overwriting an existing value");
reset_ancestor_ = src;
return Status::OK();
}

protected:
std::shared_ptr<RepeatOp> op_; // keep its corresponding run-time op of EpochCtrlNode and RepeatNode
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass
int32_t repeat_count_; int32_t repeat_count_;
}; };




+ 28
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc View File

@@ -20,7 +20,9 @@
#include <string> #include <string>
#include <vector> #include <vector>


#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h" #include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"


namespace mindspore { namespace mindspore {
@@ -31,10 +33,11 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
: MappableSourceNode(), : MappableSourceNode(),
generator_function_(generator_function), generator_function_(generator_function),
column_names_(column_names), column_names_(column_names),
column_types_(column_types) {}
column_types_(column_types),
reset_ancestor_(nullptr) {}


GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
: generator_function_(generator_function), schema_(schema) {}
: MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {}


std::shared_ptr<DatasetNode> GeneratorNode::Copy() { std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
std::shared_ptr<GeneratorNode> node; std::shared_ptr<GeneratorNode> node;
@@ -47,7 +50,7 @@ std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
} }


void GeneratorNode::Print(std::ostream &out) const { void GeneratorNode::Print(std::ostream &out) const {
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)";
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>) ";
} }


Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
@@ -77,8 +80,17 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
// best be delivered when the test cases for this api is ready. // best be delivered when the test cases for this api is ready.
RETURN_IF_NOT_OK(op->Init()); RETURN_IF_NOT_OK(op->Init());


node_ops->push_back(op);
// Add this GeneratorOp to its RepeatOp/EpochCtrlOp ancestor's EOE list.
// When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list.
// The ancestor is updated by GeneratorNodePass post pass.
// Assumption:
// We build the run-time ops from IR nodes from top to bottom. Hence Repeat/EpochCtrl ancestor ops are built
// before this leaf Generator op is built.
if (reset_ancestor_ != nullptr) {
reset_ancestor_->op_->AddToEoeList(op);
}


node_ops->push_back(op);
return Status::OK(); return Status::OK();
} }


@@ -93,5 +105,17 @@ Status GeneratorNode::GetShardId(int32_t *shard_id) {
*shard_id = 0; *shard_id = 0;
return Status::OK(); return Status::OK();
} }

// Visitor accepting method for IRNodePass
Status GeneratorNode::Accept(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<GeneratorNode>(), modified);
}

// Visitor accepting method for IRNodePass
Status GeneratorNode::AcceptAfter(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<GeneratorNode>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 24
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h View File

@@ -22,6 +22,8 @@
#include <vector> #include <vector>


#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"


namespace mindspore { namespace mindspore {
@@ -72,11 +74,33 @@ class GeneratorNode : public MappableSourceNode {


bool IsSizeDefined() override { return false; } bool IsSizeDefined() override { return false; }


/// \brief Record the vector of Repeat/EpochCtrl nodes that are ancestors of this node
/// \param[in] the ancestor node
/// \return Status of the function
Status AddResetAncestor(const std::shared_ptr<RepeatNode> &src) {
CHECK_FAIL_RETURN_UNEXPECTED(reset_ancestor_ == nullptr, "Internal error: Overwriting an existing value");
reset_ancestor_ = src;
return Status::OK();
}

private: private:
py::function generator_function_; py::function generator_function_;
std::vector<std::string> column_names_; std::vector<std::string> column_names_;
std::vector<DataType> column_types_; std::vector<DataType> column_types_;
std::shared_ptr<SchemaObj> schema_; std::shared_ptr<SchemaObj> schema_;
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass

/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *const modified) override;

/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *const modified) override;
}; };


} // namespace dataset } // namespace dataset


+ 32
- 16
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt View File

@@ -1,19 +1,35 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-opt OBJECT
optional/tensor_op_fusion_pass.cc
pass.cc
post/auto_worker_pass.cc
post/repeat_pass.cc
pre/cache_error_pass.cc
pre/cache_transform_pass.cc
pre/cache_validation_pass.cc
pre/deep_copy_pass.cc
pre/epoch_ctrl_pass.cc
pre/epoch_injection_pass.cc
pre/getter_pass.cc
pre/input_validation_pass.cc
pre/node_removal_pass.cc
pre/removal_pass.cc
util/printer_pass.cc

set(DATASET_ENGINE_OPT_SRC_FILES
pass.cc
post/auto_worker_pass.cc
pre/cache_validation_pass.cc
pre/deep_copy_pass.cc
pre/getter_pass.cc
pre/input_validation_pass.cc
pre/epoch_ctrl_pass.cc
pre/node_removal_pass.cc
)

# This set of files is for ExecTree's optimizer. It is being migrated to IR's optimizer.
# When the migration is complete, we will remove these files.
set(DATASET_ENGINE_OPT_SRC_FILES
${DATASET_ENGINE_OPT_SRC_FILES}
optional/tensor_op_fusion_pass.cc
pre/cache_error_pass.cc
post/repeat_pass.cc
pre/cache_transform_pass.cc
pre/epoch_injection_pass.cc
util/printer_pass.cc
pre/removal_pass.cc
)

if (ENABLE_PYTHON)
set(DATASET_ENGINE_OPT_SRC_FILES
${DATASET_ENGINE_OPT_SRC_FILES}
post/generator_node_pass.cc
) )
endif()

add_library(engine-opt OBJECT ${DATASET_ENGINE_OPT_SRC_FILES})

+ 16
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc View File

@@ -22,6 +22,7 @@
#endif #endif
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h" #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h" #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h" #include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h" #include "minddata/dataset/engine/ir/datasetops/project_node.h"
@@ -31,6 +32,7 @@
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#endif #endif
#include "minddata/dataset/engine/ir/datasetops/take_node.h" #include "minddata/dataset/engine/ir/datasetops/take_node.h"
@@ -179,12 +181,26 @@ Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified)
Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) { Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) { Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }
Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified) { Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
} }
#ifdef ENABLE_PYTHON
Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) { Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified); return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
} }


+ 8
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h View File

@@ -31,6 +31,7 @@ class BatchNode;
class BucketBatchByLengthNode; class BucketBatchByLengthNode;
class BuildVocabNode; class BuildVocabNode;
class ConcatNode; class ConcatNode;
class EpochCtrlNode;
class FilterNode; class FilterNode;
class MapNode; class MapNode;
class ProjectNode; class ProjectNode;
@@ -43,6 +44,7 @@ class TakeNode;
class TransferNode; class TransferNode;
class ZipNode; class ZipNode;
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
class GeneratorNode;
class SyncWaitNode; class SyncWaitNode;
#endif #endif
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@@ -198,8 +200,14 @@ class IRNodePass : public IRPass {
virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified); virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified); virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified); virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified); virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified); virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified);
#ifdef ENABLE_PYTHON
virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified);
#endif
virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified); virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified); virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified); virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified);


+ 108
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/post/generator_node_pass.cc View File

@@ -0,0 +1,108 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"

namespace mindspore {
namespace dataset {

GeneratorNodePass::GeneratorNodePass() : repeat_ancestors_({}) {}
/*
* A diagram shows how the code work:
* With the tree below as an input
*
* EpochCtrl(-1)
* / \
* Repeat1 \
* / Repeat3
* .. \
* / Generator2
* Repeat2 Add: Gen2-Rep3
* /
* Generator1
* Add: Gen1-Rep2
*
* The sequence of the DFS walk of the tree looks like this:
* 1) Visit(EpochCtrl): push EpochCtrl, repeat_ancestor_ = { EpochCtrl }
* 2) Visit(Repeat1): push Repeat1, repeat_ancestors_ = { EpochCtrl, Repeat1 }
* 3) Visit(Repeat2): push Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1, Repeat2 }
* 4) Visit(Generator1): record Repeat2 as its ancestor
* record Repeat1 as Repeat2's ancestor
* record EpochCtrl as Repeat1's ancestor
* 5) VisitAfter(Repeat2): pop Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1 }
* 6) VisitAfter(Repeat1): pop Repeat1, repeat_ancestors_ = { EpochCtrl }
* 7) Visit(Repeat3): push Repeat3, repeat_ancestors_ = { EpochCtrl, Repeat3 }
* 8) Visit(Generator2): record Repeat3 as its ancestors
* record EpochCtrl as Repeat3's ancestor
* 9) VisitAfter(Repeat3): pop Repeat3, repeat_ancestors_ = { EpochCtrl }
* 10) VisitAfter(EpochCtrl): don't care. We could pop EpochCtrl.
*/

Status GeneratorNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
// Add this EpochCtrl node as an ancestor of its descendant
repeat_ancestors_.push_back(node);
return Status::OK();
}

Status GeneratorNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
// Add this Repeat node as an ancestor of its descendant
repeat_ancestors_.push_back(node);
return Status::OK();
}

Status GeneratorNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
// Form a reset relationship with the immediate Repeat/EpochCtrl ancestor node of this leaf Generator Node
// only when any of its ancestors is an infinite repeat.
if (repeat_ancestors_.size() > 0) {
bool infinite_repeat = false;
for (auto &repeat_ancestor : repeat_ancestors_) {
if (repeat_ancestor->Count() < 0) {
infinite_repeat = true;
break;
}
}
if (infinite_repeat) {
// Form a pair-wise relationship between this leaf Generator node and its immediate Repeat/EpochCtrl
// ancestor node, and between the next adjacent pairs in the vector. For example,
// if we have GeneratorNode -> Repeat1 -> Repeat2 -> EpochCtrl(-1), the pair-wise relationships are:
// (GeneratorNode, Repeat1), (Repeat1, Repeat2), and (Repeat2, EpochCtrl)
for (auto i = repeat_ancestors_.size() - 1; i > 0; --i) {
auto ancestor = repeat_ancestors_[i - 1];
RETURN_IF_NOT_OK(repeat_ancestors_[i]->AddResetAncestor(ancestor));
}
RETURN_IF_NOT_OK(node->AddResetAncestor(repeat_ancestors_.back()));
}
}
return Status::OK();
}

Status GeneratorNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
// When we backtrack from the same Repeat node, we pop it out from the list of ancestors.
repeat_ancestors_.pop_back();
return Status::OK();
}

Status GeneratorNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
// As EpochCtrl node is a terminal node, the process stops here.
// Popping it back out of the reset ancestors is unnecessary.
// This function becomes a no-op function and can be deleted completely.
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 75
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/post/generator_node_pass.h View File

@@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H

#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"

namespace mindspore {
namespace dataset {

/// \class GeneratorNodePass repeat_pass.h
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
/// to the eoe-producing (typically leaf) nodes underneath it.
class GeneratorNodePass : public IRNodePass {
public:
/// \brief Constructor
GeneratorNodePass();

/// \brief Destructor
~GeneratorNodePass() = default;

/// \brief Record the starting point to collect the Generator node
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override;

/// \brief Record the starting point to collect the Generator node
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;

/// \brief Add the Generator node to the set
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override;

/// \brief Add the Generator node(s) from the set to this Repeat node for run-time processing
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override;

/// \brief Add the Generator node(s) from the set to this EpochCtrl node for run-time processing
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;

private:
std::vector<std::shared_ptr<RepeatNode>> repeat_ancestors_;
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H

+ 1
- 82
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc View File

@@ -28,25 +28,10 @@ namespace mindspore {
namespace dataset { namespace dataset {


RepeatPass::RepeatPass() RepeatPass::RepeatPass()
: is_repeated_(false),
nested_repeats_(0),
num_repeats_(1),
num_epochs_(1),
is_merge_(false),
is_cached_(false),
cache_lookup_(nullptr) {}
: num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {}


// Identifies the subtree below this node as being in a repeated path of the tree. // Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
// If we are already repeated, then this is a nested repeat.
if (is_repeated_) {
nested_repeats_++;
}
is_repeated_ = true;

// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
@@ -70,13 +55,6 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modi


// Identifies the subtree below this node as being in a repeated path of the tree. // Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) {
// EpochCtrl is derived from RepeatOp. Generally it should do the identical setup
// that RepeatOp does. However, epoch control is actually simpler because it can
// only exist as the root node so it doesn't need all the nested code.
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
is_repeated_ = true;
// Get the total number of epochs from the EpochCtrlOp parameter // Get the total number of epochs from the EpochCtrlOp parameter
num_epochs_ = node->num_repeats(); num_epochs_ = node->num_repeats();
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times. // Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
@@ -103,22 +81,6 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modif


// Hooks up any identified eoe nodes under this repeat. // Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();

while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}

// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
// at this time, so we can pop it to get rid of it.
op_stack *current_stack = eoe_op_stacks_.top().get();
if (!current_stack->empty()) {
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
}
eoe_op_stacks_.pop();

// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and set its total repeats. It is important that the op is removed from the save area, // and set its total repeats. It is important that the op is removed from the save area,
// because the merge op above us may also take action on it later for a different case when // because the merge op above us may also take action on it later for a different case when
@@ -129,18 +91,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modifie
cache_lookup_.reset(); cache_lookup_.reset();
} }


// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
if (nested_repeats_ > 0) {
AddToEOEOpStack(node);
nested_repeats_--;
} else {
// If we are not nested, or we were the top-most repeat, now we clear the flag
if (nested_repeats_ != 0) {
RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!");
}
is_repeated_ = false;
}
if (is_cached_) { if (is_cached_) {
AddToCachedOpStack(node); AddToCachedOpStack(node);
} }
@@ -156,13 +106,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modifie


// Hooks up any identified eoe nodes under this repeat. // Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) {
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
is_repeated_ = false;
node->set_total_repeats(num_repeats_); node->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
// We finish the walk of this EpochCtrl's descendent nodes. // We finish the walk of this EpochCtrl's descendent nodes.
@@ -192,13 +135,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified
} }


Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) { Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
// If we are in a repeat path, then set our repeated flag
if (is_repeated_) {
// if infinite repeat save ourself in a stack for the repeat operator above us
if (num_repeats_ < 0) {
AddToEOEOpStack(node);
}
}
// If we are under a cache op, then save ourselves to the cached op stack. // If we are under a cache op, then save ourselves to the cached op stack.
if (is_cached_) { if (is_cached_) {
AddToCachedOpStack(node); AddToCachedOpStack(node);
@@ -260,23 +196,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const mo
return Status::OK(); return Status::OK();
} }


// Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
op_stack *current_stack = eoe_op_stacks_.top().get();
current_stack->push(dataset_op);
}

// Pops an operator from the eoe operator stack save area
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
op_stack *current_stack = eoe_op_stacks_.top().get();
if (current_stack != nullptr && !current_stack->empty()) {
top_op = current_stack->top();
current_stack->pop();
}
return top_op;
}

// Adds an operator to the cached operator stack save area // Adds an operator to the cached operator stack save area
void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); }




+ 0
- 13
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h View File

@@ -112,19 +112,6 @@ class RepeatPass : public NodePass {
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override; Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override;


private: private:
/// \brief Adds an operator to the eoe operator stack save area
/// \param op - The dataset op to work add to eoe stack
/// \return Status The status code returned
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);

/// \brief Pops an operator from the eoe operator stack save area
/// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromEOEOpStack();

bool is_repeated_; // T/F if we are processing under a repeat
int32_t nested_repeats_; // A counter for nested repeats
std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting)

/// \brief Adds an operator to the cached operator stack save area /// \brief Adds an operator to the cached operator stack save area
/// \param op - The dataset op to work add to cached stack /// \param op - The dataset op to work add to cached stack
/// \return Status The status code returned /// \return Status The status code returned


+ 6
- 0
mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc View File

@@ -21,6 +21,9 @@
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h" #include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
#endif
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
@@ -86,6 +89,9 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
// skip this for getter pass // skip this for getter pass
actions.emplace_back(std::make_unique<AutoWorkerPass>()); actions.emplace_back(std::make_unique<AutoWorkerPass>());
} }
#ifdef ENABLE_PYTHON
actions.emplace_back(std::make_unique<GeneratorNodePass>());
#endif


// We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here.




+ 4
- 0
mindspore/lite/minddata/CMakeLists.txt View File

@@ -148,6 +148,10 @@ if (BUILD_MINDDATA STREQUAL "full")
"${MINDDATA_DIR}/engine/datasetops/source/sampler/python_sampler.cc" "${MINDDATA_DIR}/engine/datasetops/source/sampler/python_sampler.cc"
) )


list(REMOVE_ITEM MINDDATA_ENGINE_OPT_POST_SRC_FILES
"${MINDDATA_DIR}/engine/opt/post/generator_node_pass.cc"
)

list(REMOVE_ITEM MINDDATA_ENGINE_OPT_POST_SRC_FILES list(REMOVE_ITEM MINDDATA_ENGINE_OPT_POST_SRC_FILES
"${MINDDATA_DIR}/engine/opt/post/repeat_pass.cc" "${MINDDATA_DIR}/engine/opt/post/repeat_pass.cc"
) )


+ 208
- 0
tests/ut/python/dataset/test_generator_reset_pass.py View File

@@ -0,0 +1,208 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np

import mindspore.dataset as ds
from mindspore import log as logger

# Generate 2 rows of data (1, 2)
def generator_1to2():
for i in np.array([1, 2]):
yield (np.array(i),)

# Generate 3 rows of data (10, 11, 12)
def generator_10to12():
for i in np.array([10, 11, 12]):
yield (np.array(i),)

# Generate 3 rows of data (22, 23, 24)
def generator_22to24():
for i in np.array([22, 23, 24]):
yield (np.array(i),)

def test_simple_repeat():

# Since numer of epoch is 1, the GeneratorPass logic will not add the reset logic.
logger.info("test_simple_repeat")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
branch1 = data1.repeat(2)
branch1 = branch1.skip(1) # Skip the first row

output = np.array([0])
for item in branch1.create_dict_iterator(num_epochs=1, output_numpy=True):
output = np.append(output, item["data"])

golden = np.array([0, 2, 1, 2])

np.testing.assert_array_equal(output, golden)

def test_generator_reset_1():
"""
Test (Generator -> Repeat) + (Generator -> Repeat) + (Generator -> Repeat)
"""
logger.info("test_generator_reset_1")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
branch1 = data1.repeat(4)
data2 = ds.GeneratorDataset(generator_10to12, ["data"])
branch2 = data2.repeat(2)
branch2 = branch2.take(10) # Meaningless opearation, just want to insert an op in between
data3 = ds.GeneratorDataset(generator_22to24, ["data"])
branch3 = data3.repeat(3)
branch3 = branch3.skip(1) # Skip the first row

concat1 = branch1 + branch2
concat2 = concat1 + branch3

output = np.array([0])
for item in concat2.create_dict_iterator(num_epochs=1, output_numpy=True):
output = np.append(output, item["data"])

golden = np.array([0, 1, 2, 1, 2, 1, 2, 1, 2, 10, 11, 12, 10, 11, 12, 23, 24, 22, 23, 24, 22, 23, 24])

np.testing.assert_array_equal(output, golden)

def test_generator_reset_2():
"""
Test ((Generator -> Repeat) + (Generator -> Repeat) -> Repeat) + (Generator)
"""
logger.info("test_generator_reset_2")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
data1 = data1.skip(1)
branch1 = data1.repeat(3)
data2 = ds.GeneratorDataset(generator_10to12, ["data"])
branch2 = data2.repeat(2)
branch2 = branch2.take(10) # Meaningless opearation, just want to insert an op in between
data3 = ds.GeneratorDataset(generator_22to24, ["data"])
branch3 = data3.skip(2) # Skip the first row

concat1 = branch1 + branch2
concat2 = concat1.repeat(2).take(11) + branch3

output = np.array([0])
for item in concat2.create_dict_iterator(num_epochs=1, output_numpy=True):
output = np.append(output, item["data"])

golden = np.array([0, 2, 2, 2, 10, 11, 12, 10, 11, 12, 2, 2, 24])

np.testing.assert_array_equal(output, golden)

def test_generator_reset_3():
"""
Test (Generator -> Repeat -> Repeat) + ((Generator -> Repeat) + (Generator)) -> Repeat) -> EpochCtrl
"""
logger.info("test_generator_reset_3")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
branch1 = data1.repeat(2)
branch1 = branch1.skip(1)
branch1 = branch1.take(2)
branch1 = branch1.repeat(2)
data2 = ds.GeneratorDataset(generator_10to12, ["data"])
branch2 = data2.repeat(2)
data3 = ds.GeneratorDataset(generator_22to24, ["data"])
branch3 = data3.take(2)
branch3 = branch3

concat1 = branch2 + branch3
concat2 = branch1 + concat1.repeat(3).skip(5).take(15)

itr = concat2.create_dict_iterator(output_numpy=True)

num_epochs = 5
output = np.array([0])
golden = np.array([0])
expected = np.array([2, 1, 2, 1, 12, 22, 23, 10, 11, 12, 10, 11, 12, 22, 23, 10, 11, 12, 10])
for _ in range(num_epochs):
golden = np.append(golden, expected)
for item in itr:
output = np.append(output, item["data"])

np.testing.assert_array_equal(output, golden)

itr.stop()

def test_generator_reset_4():
"""
Test Generator -> Repeat -> Repeat
"""
logger.info("test_generator_reset_4")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
branch1 = data1.repeat(4).repeat(2)

output = np.array([0])
for item in branch1.create_dict_iterator(num_epochs=1, output_numpy=True):
output = np.append(output, item["data"])

golden = np.array([0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2])

np.testing.assert_array_equal(output, golden)

def test_generator_reset_5():
"""
Test Generator -> Repeat -> Repeat -> EpochCtrl
"""
logger.info("test_generator_reset_5")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
branch1 = data1.repeat(3).take(3).repeat(2)

num_epochs = 2
output = np.array([0])
itr = branch1.create_dict_iterator(output_numpy=True)

for _ in range(num_epochs):
for item in itr:
output = np.append(output, item["data"])

golden = np.array([0, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1])

np.testing.assert_array_equal(output, golden)

itr.stop()

def test_generator_reset_6():
"""
Test Generator -> Repeat -> Repeat -> EpochCtrl
"""
logger.info("test_generator_reset_6")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_10to12, ["data"])
branch1 = data1.repeat(2).take(5).repeat(2).skip(2)
iter1 = branch1.create_dict_iterator(num_epochs=3, output_numpy=True)

output = np.array([0])
for _ in range(2):
for item in iter1:
output = np.append(output, item["data"])

golden = np.array([0, 12, 10, 11, 10, 11, 12, 10, 11, 12, 10, 11, 10, 11, 12, 10, 11])

np.testing.assert_array_equal(output, golden)

# intentionally not adding itr.stop() to trigger the self-termination when itr is out of scope


if __name__ == '__main__':
test_generator_reset_1()
test_generator_reset_2()
test_generator_reset_3()
test_generator_reset_4()
test_generator_reset_5()
test_generator_reset_6()
logger.info('\n')

Loading…
Cancel
Save