with comments from the peer view addressedtags/v1.2.0-rc1
| @@ -258,7 +258,7 @@ void DatasetNode::PrintTree(std::ostream &out) const { | |||||
| void DatasetNode::PrintNode(std::ostream &out, int *level) const { | void DatasetNode::PrintNode(std::ostream &out, int *level) const { | ||||
| const std::string prefix = "+-"; | const std::string prefix = "+-"; | ||||
| const std::string indent = " "; | |||||
| const std::string indent = "| "; | |||||
| out << prefix; | out << prefix; | ||||
| Print(out); | Print(out); | ||||
| for (const auto &c : this->Children()) { | for (const auto &c : this->Children()) { | ||||
| @@ -28,30 +28,33 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor for EpochCtrlNode | // Constructor for EpochCtrlNode | ||||
| EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) { | |||||
| EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : RepeatNode() { | |||||
| // The root node's parent must set to null pointer. | // The root node's parent must set to null pointer. | ||||
| this->AddChild(child); | this->AddChild(child); | ||||
| repeat_count_ = num_epochs; | |||||
| } | } | ||||
| std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() { | std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() { | ||||
| auto node = std::make_shared<EpochCtrlNode>(num_epochs_); | |||||
| auto node = std::make_shared<EpochCtrlNode>(repeat_count_); | |||||
| return node; | return node; | ||||
| } | } | ||||
| void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; } | |||||
| void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(repeat_count_) + ")"; } | |||||
| // Function to build the EpochCtrlOp | // Function to build the EpochCtrlOp | ||||
| Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | ||||
| node_ops->push_back(std::make_shared<EpochCtrlOp>(num_epochs_)); | |||||
| auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_); | |||||
| node_ops->push_back(new_op_); | |||||
| op_ = new_op_; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Function to validate the parameters for EpochCtrlNode | // Function to validate the parameters for EpochCtrlNode | ||||
| Status EpochCtrlNode::ValidateParams() { | Status EpochCtrlNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | ||||
| if (num_epochs_ <= 0 && num_epochs_ != -1) { | |||||
| if (repeat_count_ <= 0 && repeat_count_ != -1) { | |||||
| std::string err_msg = | std::string err_msg = | ||||
| "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_); | |||||
| "EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(repeat_count_); | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| @@ -63,5 +66,16 @@ Status EpochCtrlNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accepting method for IRNodePass | |||||
| Status EpochCtrlNode::Accept(IRNodePass *p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<EpochCtrlNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for IRNodePass | |||||
| Status EpochCtrlNode::AcceptAfter(IRNodePass *p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<EpochCtrlNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,15 +21,20 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class EpochCtrlNode : public DatasetNode { | |||||
| class EpochCtrlNode : public RepeatNode { | |||||
| // Allow GeneratorNode to access internal members | |||||
| friend class GeneratorNode; | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit EpochCtrlNode(int32_t num_epochs) : num_epochs_(num_epochs) {} | |||||
| explicit EpochCtrlNode(int32_t num_epochs) : RepeatNode() { repeat_count_ = num_epochs; } | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | ||||
| @@ -58,8 +63,17 @@ class EpochCtrlNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| private: | |||||
| int32_t num_epochs_; | |||||
| /// \brief Base-class override for accepting IRNodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(IRNodePass *p, bool *const modified) override; | |||||
| /// \brief Base-class override for accepting IRNodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(IRNodePass *p, bool *const modified) override; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -26,7 +26,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | |||||
| RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) | |||||
| : repeat_count_(count), reset_ancestor_(nullptr), op_(nullptr) { | |||||
| this->AddChild(child); | this->AddChild(child); | ||||
| } | } | ||||
| @@ -35,10 +36,22 @@ std::shared_ptr<DatasetNode> RepeatNode::Copy() { | |||||
| return node; | return node; | ||||
| } | } | ||||
| void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; } | |||||
| void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ") "; } | |||||
| Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | ||||
| node_ops->push_back(std::make_shared<RepeatOp>(repeat_count_)); | |||||
| auto new_op = std::make_shared<RepeatOp>(repeat_count_); | |||||
| node_ops->push_back(new_op); | |||||
| op_ = new_op; | |||||
| // Add this RepeatOp to its RepeatOp/EpochCtrlOp ancestor's EOE list. | |||||
| // When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list. | |||||
| // The ancestor is updated by GeneratorNodePass post pass. | |||||
| // Assumption: | |||||
| // We build the run-time ops from IR nodes from top to bottom. Hence Repeat/EpochCtrl ancestor ops are built | |||||
| // before this leaf Generator op is built. | |||||
| if (reset_ancestor_ != nullptr) { | |||||
| reset_ancestor_->op_->AddToEoeList(new_op); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -28,10 +28,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class RepeatOp; | |||||
| class RepeatNode : public DatasetNode { | class RepeatNode : public DatasetNode { | ||||
| // Allow GeneratorNode to access internal members | |||||
| friend class GeneratorNode; | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count); | |||||
| RepeatNode() : op_(nullptr), reset_ancestor_(nullptr), repeat_count_(-1) {} | |||||
| /// \brief Constructor | |||||
| RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~RepeatNode() = default; | ~RepeatNode() = default; | ||||
| @@ -82,7 +90,34 @@ class RepeatNode : public DatasetNode { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | Status AcceptAfter(IRNodePass *const p, bool *const modified) override; | ||||
| private: | |||||
| /// \brief Record the Repeat/EpochCtrl node that is the closest ancestor of this node | |||||
| /// \param[in] the ancestor node | |||||
| /// \return Status of the function | |||||
| Status AddResetAncestor(const std::shared_ptr<RepeatNode> &src) { | |||||
| /* | |||||
| * This check is to ensure we don't overwrite an existing value of its ancestor. | |||||
| * It is okay to assign to the same value more than once in RepeatNode (but not in GeneratorNode). | |||||
| * Consider the following scenario | |||||
| * EpochCtrl(-1) | |||||
| * | | |||||
| * Repeat | |||||
| * | | |||||
| * Concat | |||||
| * / \ | |||||
| * GenData1 GenData2 | |||||
| * | |||||
| * We will record the ancestor relationship of (Repeat, EpochCtrl) twice, one at Visit(GenData1), the other at | |||||
| * Vist(GenData2). | |||||
| */ | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(reset_ancestor_ == nullptr || reset_ancestor_ == src, | |||||
| "Internal error: Overwriting an existing value"); | |||||
| reset_ancestor_ = src; | |||||
| return Status::OK(); | |||||
| } | |||||
| protected: | |||||
| std::shared_ptr<RepeatOp> op_; // keep its corresponding run-time op of EpochCtrlNode and RepeatNode | |||||
| std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass | |||||
| int32_t repeat_count_; | int32_t repeat_count_; | ||||
| }; | }; | ||||
| @@ -20,7 +20,9 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | #include "minddata/dataset/engine/datasetops/source/generator_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -31,10 +33,11 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector< | |||||
| : MappableSourceNode(), | : MappableSourceNode(), | ||||
| generator_function_(generator_function), | generator_function_(generator_function), | ||||
| column_names_(column_names), | column_names_(column_names), | ||||
| column_types_(column_types) {} | |||||
| column_types_(column_types), | |||||
| reset_ancestor_(nullptr) {} | |||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | ||||
| : generator_function_(generator_function), schema_(schema) {} | |||||
| : MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {} | |||||
| std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | ||||
| std::shared_ptr<GeneratorNode> node; | std::shared_ptr<GeneratorNode> node; | ||||
| @@ -47,7 +50,7 @@ std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | |||||
| } | } | ||||
| void GeneratorNode::Print(std::ostream &out) const { | void GeneratorNode::Print(std::ostream &out) const { | ||||
| out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)"; | |||||
| out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>) "; | |||||
| } | } | ||||
| Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | ||||
| @@ -77,8 +80,17 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ | |||||
| // best be delivered when the test cases for this api is ready. | // best be delivered when the test cases for this api is ready. | ||||
| RETURN_IF_NOT_OK(op->Init()); | RETURN_IF_NOT_OK(op->Init()); | ||||
| node_ops->push_back(op); | |||||
| // Add this GeneratorOp to its RepeatOp/EpochCtrlOp ancestor's EOE list. | |||||
| // When the ancestor reaches an end-of-epoch boundary, it will send a "reset" signal to all the ops in the EOE list. | |||||
| // The ancestor is updated by GeneratorNodePass post pass. | |||||
| // Assumption: | |||||
| // We build the run-time ops from IR nodes from top to bottom. Hence Repeat/EpochCtrl ancestor ops are built | |||||
| // before this leaf Generator op is built. | |||||
| if (reset_ancestor_ != nullptr) { | |||||
| reset_ancestor_->op_->AddToEoeList(op); | |||||
| } | |||||
| node_ops->push_back(op); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -93,5 +105,17 @@ Status GeneratorNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = 0; | *shard_id = 0; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor accepting method for IRNodePass | |||||
| Status GeneratorNode::Accept(IRNodePass *p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->Visit(shared_from_base<GeneratorNode>(), modified); | |||||
| } | |||||
| // Visitor accepting method for IRNodePass | |||||
| Status GeneratorNode::AcceptAfter(IRNodePass *p, bool *const modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->VisitAfter(shared_from_base<GeneratorNode>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,6 +22,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -72,11 +74,33 @@ class GeneratorNode : public MappableSourceNode { | |||||
| bool IsSizeDefined() override { return false; } | bool IsSizeDefined() override { return false; } | ||||
| /// \brief Record the vector of Repeat/EpochCtrl nodes that are ancestors of this node | |||||
| /// \param[in] the ancestor node | |||||
| /// \return Status of the function | |||||
| Status AddResetAncestor(const std::shared_ptr<RepeatNode> &src) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(reset_ancestor_ == nullptr, "Internal error: Overwriting an existing value"); | |||||
| reset_ancestor_ = src; | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | private: | ||||
| py::function generator_function_; | py::function generator_function_; | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| std::vector<DataType> column_types_; | std::vector<DataType> column_types_; | ||||
| std::shared_ptr<SchemaObj> schema_; | std::shared_ptr<SchemaObj> schema_; | ||||
| std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass | |||||
| /// \brief Base-class override for accepting IRNodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status Accept(IRNodePass *p, bool *const modified) override; | |||||
| /// \brief Base-class override for accepting IRNodePass visitor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status AcceptAfter(IRNodePass *p, bool *const modified) override; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -1,19 +1,35 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(engine-opt OBJECT | |||||
| optional/tensor_op_fusion_pass.cc | |||||
| pass.cc | |||||
| post/auto_worker_pass.cc | |||||
| post/repeat_pass.cc | |||||
| pre/cache_error_pass.cc | |||||
| pre/cache_transform_pass.cc | |||||
| pre/cache_validation_pass.cc | |||||
| pre/deep_copy_pass.cc | |||||
| pre/epoch_ctrl_pass.cc | |||||
| pre/epoch_injection_pass.cc | |||||
| pre/getter_pass.cc | |||||
| pre/input_validation_pass.cc | |||||
| pre/node_removal_pass.cc | |||||
| pre/removal_pass.cc | |||||
| util/printer_pass.cc | |||||
| set(DATASET_ENGINE_OPT_SRC_FILES | |||||
| pass.cc | |||||
| post/auto_worker_pass.cc | |||||
| pre/cache_validation_pass.cc | |||||
| pre/deep_copy_pass.cc | |||||
| pre/getter_pass.cc | |||||
| pre/input_validation_pass.cc | |||||
| pre/epoch_ctrl_pass.cc | |||||
| pre/node_removal_pass.cc | |||||
| ) | |||||
| # This set of files is for ExecTree's optimizer. It is being migrated to IR's optimizer. | |||||
| # When the migration is complete, we will remove these files. | |||||
| set(DATASET_ENGINE_OPT_SRC_FILES | |||||
| ${DATASET_ENGINE_OPT_SRC_FILES} | |||||
| optional/tensor_op_fusion_pass.cc | |||||
| pre/cache_error_pass.cc | |||||
| post/repeat_pass.cc | |||||
| pre/cache_transform_pass.cc | |||||
| pre/epoch_injection_pass.cc | |||||
| util/printer_pass.cc | |||||
| pre/removal_pass.cc | |||||
| ) | |||||
| if (ENABLE_PYTHON) | |||||
| set(DATASET_ENGINE_OPT_SRC_FILES | |||||
| ${DATASET_ENGINE_OPT_SRC_FILES} | |||||
| post/generator_node_pass.cc | |||||
| ) | ) | ||||
| endif() | |||||
| add_library(engine-opt OBJECT ${DATASET_ENGINE_OPT_SRC_FILES}) | |||||
| @@ -22,6 +22,7 @@ | |||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | #include "minddata/dataset/engine/ir/datasetops/map_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| @@ -31,6 +32,7 @@ | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" | #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" | ||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | #include "minddata/dataset/engine/ir/datasetops/take_node.h" | ||||
| @@ -179,12 +181,26 @@ Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified) | |||||
| Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) { | Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) { | ||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status IRNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) { | Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) { | ||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified) { | Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified) { | ||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| #ifdef ENABLE_PYTHON | |||||
| Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) { | |||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) { | |||||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||||
| } | |||||
| #endif | |||||
| Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) { | Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) { | ||||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | ||||
| } | } | ||||
| @@ -31,6 +31,7 @@ class BatchNode; | |||||
| class BucketBatchByLengthNode; | class BucketBatchByLengthNode; | ||||
| class BuildVocabNode; | class BuildVocabNode; | ||||
| class ConcatNode; | class ConcatNode; | ||||
| class EpochCtrlNode; | |||||
| class FilterNode; | class FilterNode; | ||||
| class MapNode; | class MapNode; | ||||
| class ProjectNode; | class ProjectNode; | ||||
| @@ -43,6 +44,7 @@ class TakeNode; | |||||
| class TransferNode; | class TransferNode; | ||||
| class ZipNode; | class ZipNode; | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| class GeneratorNode; | |||||
| class SyncWaitNode; | class SyncWaitNode; | ||||
| #endif | #endif | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -198,8 +200,14 @@ class IRNodePass : public IRPass { | |||||
| virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified); | virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified); | ||||
| virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified); | virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified); | virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified); | ||||
| virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified); | |||||
| virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified); | virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified); | virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified); | ||||
| #ifdef ENABLE_PYTHON | |||||
| virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified); | |||||
| virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified); | |||||
| #endif | |||||
| virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified); | virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified); | ||||
| virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified); | virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified); | ||||
| virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified); | virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified); | ||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/engine/opt/post/generator_node_pass.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| GeneratorNodePass::GeneratorNodePass() : repeat_ancestors_({}) {} | |||||
| /* | |||||
| * A diagram shows how the code work: | |||||
| * With the tree below as an input | |||||
| * | |||||
| * EpochCtrl(-1) | |||||
| * / \ | |||||
| * Repeat1 \ | |||||
| * / Repeat3 | |||||
| * .. \ | |||||
| * / Generator2 | |||||
| * Repeat2 Add: Gen2-Rep3 | |||||
| * / | |||||
| * Generator1 | |||||
| * Add: Gen1-Rep2 | |||||
| * | |||||
| * The sequence of the DFS walk of the tree looks like this: | |||||
| * 1) Visit(EpochCtrl): push EpochCtrl, repeat_ancestor_ = { EpochCtrl } | |||||
| * 2) Visit(Repeat1): push Repeat1, repeat_ancestors_ = { EpochCtrl, Repeat1 } | |||||
| * 3) Visit(Repeat2): push Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1, Repeat2 } | |||||
| * 4) Visit(Generator1): record Repeat2 as its ancestor | |||||
| * record Repeat1 as Repeat2's ancestor | |||||
| * record EpochCtrl as Repeat1's ancestor | |||||
| * 5) VisitAfter(Repeat2): pop Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1 } | |||||
| * 6) VisitAfter(Repeat1): pop Repeat1, repeat_ancestors_ = { EpochCtrl } | |||||
| * 7) Visit(Repeat3): push Repeat3, repeat_ancestors_ = { EpochCtrl, Repeat3 } | |||||
| * 8) Visit(Generator2): record Repeat3 as its ancestors | |||||
| * record EpochCtrl as Repeat3's ancestor | |||||
| * 9) VisitAfter(Repeat3): pop Repeat3, repeat_ancestors_ = { EpochCtrl } | |||||
| * 10) VisitAfter(EpochCtrl): don't care. We could pop EpochCtrl. | |||||
| */ | |||||
| Status GeneratorNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) { | |||||
| // Add this EpochCtrl node as an ancestor of its descendant | |||||
| repeat_ancestors_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GeneratorNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) { | |||||
| // Add this Repeat node as an ancestor of its descendant | |||||
| repeat_ancestors_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GeneratorNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) { | |||||
| // Form a reset relationship with the immediate Repeat/EpochCtrl ancestor node of this leaf Generator Node | |||||
| // only when any of its ancestors is an infinite repeat. | |||||
| if (repeat_ancestors_.size() > 0) { | |||||
| bool infinite_repeat = false; | |||||
| for (auto &repeat_ancestor : repeat_ancestors_) { | |||||
| if (repeat_ancestor->Count() < 0) { | |||||
| infinite_repeat = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (infinite_repeat) { | |||||
| // Form a pair-wise relationship between this leaf Generator node and its immediate Repeat/EpochCtrl | |||||
| // ancestor node, and between the next adjacent pairs in the vector. For example, | |||||
| // if we have GeneratorNode -> Repeat1 -> Repeat2 -> EpochCtrl(-1), the pair-wise relationships are: | |||||
| // (GeneratorNode, Repeat1), (Repeat1, Repeat2), and (Repeat2, EpochCtrl) | |||||
| for (auto i = repeat_ancestors_.size() - 1; i > 0; --i) { | |||||
| auto ancestor = repeat_ancestors_[i - 1]; | |||||
| RETURN_IF_NOT_OK(repeat_ancestors_[i]->AddResetAncestor(ancestor)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(node->AddResetAncestor(repeat_ancestors_.back())); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GeneratorNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) { | |||||
| // When we backtrack from the same Repeat node, we pop it out from the list of ancestors. | |||||
| repeat_ancestors_.pop_back(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GeneratorNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) { | |||||
| // As EpochCtrl node is a terminal node, the process stops here. | |||||
| // Popping it back out of the reset ancestors is unnecessary. | |||||
| // This function becomes a no-op function and can be deleted completely. | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \class GeneratorNodePass repeat_pass.h | |||||
| /// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references | |||||
| /// to the eoe-producing (typically leaf) nodes underneath it. | |||||
| class GeneratorNodePass : public IRNodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| GeneratorNodePass(); | |||||
| /// \brief Destructor | |||||
| ~GeneratorNodePass() = default; | |||||
| /// \brief Record the starting point to collect the Generator node | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override; | |||||
| /// \brief Record the starting point to collect the Generator node | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; | |||||
| /// \brief Add the Generator node to the set | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override; | |||||
| /// \brief Add the Generator node(s) from the set to this Repeat node for run-time processing | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override; | |||||
| /// \brief Add the Generator node(s) from the set to this EpochCtrl node for run-time processing | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override; | |||||
| private: | |||||
| std::vector<std::shared_ptr<RepeatNode>> repeat_ancestors_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H | |||||
| @@ -28,25 +28,10 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| RepeatPass::RepeatPass() | RepeatPass::RepeatPass() | ||||
| : is_repeated_(false), | |||||
| nested_repeats_(0), | |||||
| num_repeats_(1), | |||||
| num_epochs_(1), | |||||
| is_merge_(false), | |||||
| is_cached_(false), | |||||
| cache_lookup_(nullptr) {} | |||||
| : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} | |||||
| // Identifies the subtree below this node as being in a repeated path of the tree. | // Identifies the subtree below this node as being in a repeated path of the tree. | ||||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { | Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { | ||||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||||
| eoe_op_stacks_.push(std::move(new_stack)); | |||||
| // If we are already repeated, then this is a nested repeat. | |||||
| if (is_repeated_) { | |||||
| nested_repeats_++; | |||||
| } | |||||
| is_repeated_ = true; | |||||
| // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. | ||||
| // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. | ||||
| if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { | ||||
| @@ -70,13 +55,6 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modi | |||||
| // Identifies the subtree below this node as being in a repeated path of the tree. | // Identifies the subtree below this node as being in a repeated path of the tree. | ||||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { | Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { | ||||
| // EpochCtrl is derived from RepeatOp. Generally it should do the identical setup | |||||
| // that RepeatOp does. However, epoch control is actually simpler because it can | |||||
| // only exist as the root node so it doesn't need all the nested code. | |||||
| // Create a new stack for eoe operators and push onto our stack of stacks. | |||||
| std::unique_ptr<op_stack> new_stack = std::make_unique<op_stack>(); | |||||
| eoe_op_stacks_.push(std::move(new_stack)); | |||||
| is_repeated_ = true; | |||||
| // Get the total number of epochs from the EpochCtrlOp parameter | // Get the total number of epochs from the EpochCtrlOp parameter | ||||
| num_epochs_ = node->num_repeats(); | num_epochs_ = node->num_repeats(); | ||||
| // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. | ||||
| @@ -103,22 +81,6 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modif | |||||
| // Hooks up any identified eoe nodes under this repeat. | // Hooks up any identified eoe nodes under this repeat. | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) { | ||||
| // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | |||||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||||
| while (leaf_op != nullptr) { | |||||
| node->AddToEoeList(leaf_op); | |||||
| leaf_op = PopFromEOEOpStack(); | |||||
| } | |||||
| // At this point, we are done with the save area stack. It's a unique pointer to an empty stack | |||||
| // at this time, so we can pop it to get rid of it. | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| if (!current_stack->empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); | |||||
| } | |||||
| eoe_op_stacks_.pop(); | |||||
| // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | ||||
| // and set its total repeats. It is important that the op is removed from the save area, | // and set its total repeats. It is important that the op is removed from the save area, | ||||
| // because the merge op above us may also take action on it later for a different case when | // because the merge op above us may also take action on it later for a different case when | ||||
| @@ -129,18 +91,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modifie | |||||
| cache_lookup_.reset(); | cache_lookup_.reset(); | ||||
| } | } | ||||
| // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. | |||||
| // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. | |||||
| if (nested_repeats_ > 0) { | |||||
| AddToEOEOpStack(node); | |||||
| nested_repeats_--; | |||||
| } else { | |||||
| // If we are not nested, or we were the top-most repeat, now we clear the flag | |||||
| if (nested_repeats_ != 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); | |||||
| } | |||||
| is_repeated_ = false; | |||||
| } | |||||
| if (is_cached_) { | if (is_cached_) { | ||||
| AddToCachedOpStack(node); | AddToCachedOpStack(node); | ||||
| } | } | ||||
| @@ -156,13 +106,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modifie | |||||
| // Hooks up any identified eoe nodes under this repeat. | // Hooks up any identified eoe nodes under this repeat. | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) { | ||||
| // Pop the leaf ops from the save-area stack and add them to the eoe node tracking | |||||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||||
| while (leaf_op != nullptr) { | |||||
| node->AddToEoeList(leaf_op); | |||||
| leaf_op = PopFromEOEOpStack(); | |||||
| } | |||||
| is_repeated_ = false; | |||||
| node->set_total_repeats(num_repeats_); | node->set_total_repeats(num_repeats_); | ||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | ||||
| // We finish the walk of this EpochCtrl's descendent nodes. | // We finish the walk of this EpochCtrl's descendent nodes. | ||||
| @@ -192,13 +135,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified | |||||
| } | } | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) { | ||||
| // If we are in a repeat path, then set our repeated flag | |||||
| if (is_repeated_) { | |||||
| // if infinite repeat save ourself in a stack for the repeat operator above us | |||||
| if (num_repeats_ < 0) { | |||||
| AddToEOEOpStack(node); | |||||
| } | |||||
| } | |||||
| // If we are under a cache op, then save ourselves to the cached op stack. | // If we are under a cache op, then save ourselves to the cached op stack. | ||||
| if (is_cached_) { | if (is_cached_) { | ||||
| AddToCachedOpStack(node); | AddToCachedOpStack(node); | ||||
| @@ -260,23 +196,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const mo | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Adds an operator to the eoe operator stack save area | |||||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| current_stack->push(dataset_op); | |||||
| } | |||||
| // Pops an operator from the eoe operator stack save area | |||||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | |||||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||||
| op_stack *current_stack = eoe_op_stacks_.top().get(); | |||||
| if (current_stack != nullptr && !current_stack->empty()) { | |||||
| top_op = current_stack->top(); | |||||
| current_stack->pop(); | |||||
| } | |||||
| return top_op; | |||||
| } | |||||
| // Adds an operator to the cached operator stack save area | // Adds an operator to the cached operator stack save area | ||||
| void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); } | ||||
| @@ -112,19 +112,6 @@ class RepeatPass : public NodePass { | |||||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override; | Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override; | ||||
| private: | private: | ||||
| /// \brief Adds an operator to the eoe operator stack save area | |||||
| /// \param op - The dataset op to work add to eoe stack | |||||
| /// \return Status The status code returned | |||||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||||
| /// \brief Pops an operator from the eoe operator stack save area | |||||
| /// \return shared_ptr to the popped operator | |||||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||||
| bool is_repeated_; // T/F if we are processing under a repeat | |||||
| int32_t nested_repeats_; // A counter for nested repeats | |||||
| std::stack<std::unique_ptr<op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) | |||||
| /// \brief Adds an operator to the cached operator stack save area | /// \brief Adds an operator to the cached operator stack save area | ||||
| /// \param op - The dataset op to work add to cached stack | /// \param op - The dataset op to work add to cached stack | ||||
| /// \return Status The status code returned | /// \return Status The status code returned | ||||
| @@ -21,6 +21,9 @@ | |||||
| #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | #include "minddata/dataset/engine/opt/post/auto_worker_pass.h" | ||||
| #ifdef ENABLE_PYTHON | |||||
| #include "minddata/dataset/engine/opt/post/generator_node_pass.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" | #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" | ||||
| #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | ||||
| @@ -86,6 +89,9 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||||
| // skip this for getter pass | // skip this for getter pass | ||||
| actions.emplace_back(std::make_unique<AutoWorkerPass>()); | actions.emplace_back(std::make_unique<AutoWorkerPass>()); | ||||
| } | } | ||||
| #ifdef ENABLE_PYTHON | |||||
| actions.emplace_back(std::make_unique<GeneratorNodePass>()); | |||||
| #endif | |||||
| // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. | // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. | ||||
| @@ -148,6 +148,10 @@ if (BUILD_MINDDATA STREQUAL "full") | |||||
| "${MINDDATA_DIR}/engine/datasetops/source/sampler/python_sampler.cc" | "${MINDDATA_DIR}/engine/datasetops/source/sampler/python_sampler.cc" | ||||
| ) | ) | ||||
| list(REMOVE_ITEM MINDDATA_ENGINE_OPT_POST_SRC_FILES | |||||
| "${MINDDATA_DIR}/engine/opt/post/generator_node_pass.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_OPT_POST_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_OPT_POST_SRC_FILES | ||||
| "${MINDDATA_DIR}/engine/opt/post/repeat_pass.cc" | "${MINDDATA_DIR}/engine/opt/post/repeat_pass.cc" | ||||
| ) | ) | ||||
| @@ -0,0 +1,208 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| from mindspore import log as logger | |||||
| # Generate 2 rows of data (1, 2) | |||||
| def generator_1to2(): | |||||
| for i in np.array([1, 2]): | |||||
| yield (np.array(i),) | |||||
| # Generate 3 rows of data (10, 11, 12) | |||||
| def generator_10to12(): | |||||
| for i in np.array([10, 11, 12]): | |||||
| yield (np.array(i),) | |||||
| # Generate 3 rows of data (22, 23, 24) | |||||
| def generator_22to24(): | |||||
| for i in np.array([22, 23, 24]): | |||||
| yield (np.array(i),) | |||||
| def test_simple_repeat(): | |||||
| # Since numer of epoch is 1, the GeneratorPass logic will not add the reset logic. | |||||
| logger.info("test_simple_repeat") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1to2, ["data"]) | |||||
| branch1 = data1.repeat(2) | |||||
| branch1 = branch1.skip(1) # Skip the first row | |||||
| output = np.array([0]) | |||||
| for item in branch1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| output = np.append(output, item["data"]) | |||||
| golden = np.array([0, 2, 1, 2]) | |||||
| np.testing.assert_array_equal(output, golden) | |||||
| def test_generator_reset_1(): | |||||
| """ | |||||
| Test (Generator -> Repeat) + (Generator -> Repeat) + (Generator -> Repeat) | |||||
| """ | |||||
| logger.info("test_generator_reset_1") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1to2, ["data"]) | |||||
| branch1 = data1.repeat(4) | |||||
| data2 = ds.GeneratorDataset(generator_10to12, ["data"]) | |||||
| branch2 = data2.repeat(2) | |||||
| branch2 = branch2.take(10) # Meaningless opearation, just want to insert an op in between | |||||
| data3 = ds.GeneratorDataset(generator_22to24, ["data"]) | |||||
| branch3 = data3.repeat(3) | |||||
| branch3 = branch3.skip(1) # Skip the first row | |||||
| concat1 = branch1 + branch2 | |||||
| concat2 = concat1 + branch3 | |||||
| output = np.array([0]) | |||||
| for item in concat2.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| output = np.append(output, item["data"]) | |||||
| golden = np.array([0, 1, 2, 1, 2, 1, 2, 1, 2, 10, 11, 12, 10, 11, 12, 23, 24, 22, 23, 24, 22, 23, 24]) | |||||
| np.testing.assert_array_equal(output, golden) | |||||
| def test_generator_reset_2(): | |||||
| """ | |||||
| Test ((Generator -> Repeat) + (Generator -> Repeat) -> Repeat) + (Generator) | |||||
| """ | |||||
| logger.info("test_generator_reset_2") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1to2, ["data"]) | |||||
| data1 = data1.skip(1) | |||||
| branch1 = data1.repeat(3) | |||||
| data2 = ds.GeneratorDataset(generator_10to12, ["data"]) | |||||
| branch2 = data2.repeat(2) | |||||
| branch2 = branch2.take(10) # Meaningless opearation, just want to insert an op in between | |||||
| data3 = ds.GeneratorDataset(generator_22to24, ["data"]) | |||||
| branch3 = data3.skip(2) # Skip the first row | |||||
| concat1 = branch1 + branch2 | |||||
| concat2 = concat1.repeat(2).take(11) + branch3 | |||||
| output = np.array([0]) | |||||
| for item in concat2.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| output = np.append(output, item["data"]) | |||||
| golden = np.array([0, 2, 2, 2, 10, 11, 12, 10, 11, 12, 2, 2, 24]) | |||||
| np.testing.assert_array_equal(output, golden) | |||||
| def test_generator_reset_3(): | |||||
| """ | |||||
| Test (Generator -> Repeat -> Repeat) + ((Generator -> Repeat) + (Generator)) -> Repeat) -> EpochCtrl | |||||
| """ | |||||
| logger.info("test_generator_reset_3") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1to2, ["data"]) | |||||
| branch1 = data1.repeat(2) | |||||
| branch1 = branch1.skip(1) | |||||
| branch1 = branch1.take(2) | |||||
| branch1 = branch1.repeat(2) | |||||
| data2 = ds.GeneratorDataset(generator_10to12, ["data"]) | |||||
| branch2 = data2.repeat(2) | |||||
| data3 = ds.GeneratorDataset(generator_22to24, ["data"]) | |||||
| branch3 = data3.take(2) | |||||
| branch3 = branch3 | |||||
| concat1 = branch2 + branch3 | |||||
| concat2 = branch1 + concat1.repeat(3).skip(5).take(15) | |||||
| itr = concat2.create_dict_iterator(output_numpy=True) | |||||
| num_epochs = 5 | |||||
| output = np.array([0]) | |||||
| golden = np.array([0]) | |||||
| expected = np.array([2, 1, 2, 1, 12, 22, 23, 10, 11, 12, 10, 11, 12, 22, 23, 10, 11, 12, 10]) | |||||
| for _ in range(num_epochs): | |||||
| golden = np.append(golden, expected) | |||||
| for item in itr: | |||||
| output = np.append(output, item["data"]) | |||||
| np.testing.assert_array_equal(output, golden) | |||||
| itr.stop() | |||||
| def test_generator_reset_4(): | |||||
| """ | |||||
| Test Generator -> Repeat -> Repeat | |||||
| """ | |||||
| logger.info("test_generator_reset_4") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1to2, ["data"]) | |||||
| branch1 = data1.repeat(4).repeat(2) | |||||
| output = np.array([0]) | |||||
| for item in branch1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| output = np.append(output, item["data"]) | |||||
| golden = np.array([0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) | |||||
| np.testing.assert_array_equal(output, golden) | |||||
| def test_generator_reset_5(): | |||||
| """ | |||||
| Test Generator -> Repeat -> Repeat -> EpochCtrl | |||||
| """ | |||||
| logger.info("test_generator_reset_5") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_1to2, ["data"]) | |||||
| branch1 = data1.repeat(3).take(3).repeat(2) | |||||
| num_epochs = 2 | |||||
| output = np.array([0]) | |||||
| itr = branch1.create_dict_iterator(output_numpy=True) | |||||
| for _ in range(num_epochs): | |||||
| for item in itr: | |||||
| output = np.append(output, item["data"]) | |||||
| golden = np.array([0, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1]) | |||||
| np.testing.assert_array_equal(output, golden) | |||||
| itr.stop() | |||||
| def test_generator_reset_6(): | |||||
| """ | |||||
| Test Generator -> Repeat -> Repeat -> EpochCtrl | |||||
| """ | |||||
| logger.info("test_generator_reset_6") | |||||
| # apply dataset operations | |||||
| data1 = ds.GeneratorDataset(generator_10to12, ["data"]) | |||||
| branch1 = data1.repeat(2).take(5).repeat(2).skip(2) | |||||
| iter1 = branch1.create_dict_iterator(num_epochs=3, output_numpy=True) | |||||
| output = np.array([0]) | |||||
| for _ in range(2): | |||||
| for item in iter1: | |||||
| output = np.append(output, item["data"]) | |||||
| golden = np.array([0, 12, 10, 11, 10, 11, 12, 10, 11, 12, 10, 11, 10, 11, 12, 10, 11]) | |||||
| np.testing.assert_array_equal(output, golden) | |||||
| # intentionally not adding itr.stop() to trigger the self-termination when itr is out of scope | |||||
| if __name__ == '__main__': | |||||
| test_generator_reset_1() | |||||
| test_generator_reset_2() | |||||
| test_generator_reset_3() | |||||
| test_generator_reset_4() | |||||
| test_generator_reset_5() | |||||
| test_generator_reset_6() | |||||
| logger.info('\n') | |||||