| @@ -574,7 +574,7 @@ Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t * | |||
| std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize))); | |||
| return pre; | |||
| }); | |||
| RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1)); | |||
| RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1)); | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); | |||
| int64_t row_cnt = 0; | |||
| @@ -214,7 +214,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||
| // The driver of the prepare phase of the execution tree. | |||
| // Prepare phase consists of three sub phases | |||
| // | |||
| // 1. PrepareTreePreAction() | |||
| // 1. PreAction() | |||
| // Compulsory transformation/action pre optimization. | |||
| // For example, CacheOp Insertion | |||
| // | |||
| @@ -222,41 +222,44 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||
| // Optimization transformation/action, optional | |||
| // For example, MapOp Fusion | |||
| // | |||
| // 3. PrepareTreePostAction() | |||
| // 3. PostAction() | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| Status ExecutionTree::Prepare(int32_t num_epochs) { | |||
| Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) { | |||
| num_epochs_ = num_epochs; | |||
| partially_prepare_ = partial; | |||
| // Pre optimization compulsory transformation | |||
| RETURN_IF_NOT_OK(this->PrepareTreePreAction()); | |||
| RETURN_IF_NOT_OK(this->PreAction()); | |||
| // If optional optimizations are enabled | |||
| if (optimize_) { | |||
| RETURN_IF_NOT_OK(this->Optimize()); | |||
| } | |||
| // Post optimization compulsory transformation | |||
| RETURN_IF_NOT_OK(this->PrepareTreePostAction()); | |||
| RETURN_IF_NOT_OK(this->PostAction()); | |||
| // The tree is ready to be prepared. | |||
| tree_state_ = kDeTStatePrepare; | |||
| // Existing transformation implementation, will be removed later | |||
| RETURN_IF_NOT_OK(this->PrepareDeprecated()); | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PrepareTreePreAction() { | |||
| Status ExecutionTree::PreAction() { | |||
| bool modified = false; | |||
| std::vector<std::unique_ptr<Pass>> pre_actions; | |||
| // Construct pre actions | |||
| if (!partially_prepare_) { | |||
| #ifndef ENABLE_ANDROID | |||
| pre_actions.push_back(std::make_unique<CacheErrorPass>()); | |||
| #endif | |||
| pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | |||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | |||
| #ifndef ENABLE_ANDROID | |||
| pre_actions.push_back(std::make_unique<CacheTransformPass>()); | |||
| pre_actions.push_back(std::make_unique<CacheErrorPass>()); | |||
| #endif | |||
| pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | |||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | |||
| } | |||
| // this offers a way to override the preset optimization pass with customized ones | |||
| // this is used when certain nodes are removed for tree getters | |||
| @@ -276,15 +279,17 @@ Status ExecutionTree::PrepareTreePreAction() { | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PrepareTreePostAction() { | |||
| // The tree is ready to be prepared. | |||
| tree_state_ = kDeTStatePrepare; | |||
| Status ExecutionTree::PostAction() { | |||
| bool modified = false; | |||
| OptPass post_actions; | |||
| // Construct pre actions | |||
| MS_LOG(INFO) << "Running post pass loops."; | |||
| #ifndef ENABLE_ANDROID | |||
| // Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind. | |||
| // The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API. | |||
| // This is because Python API binding to TensorOperation is still in progress. | |||
| post_actions.push_back(std::make_unique<CacheErrorPass>()); | |||
| post_actions.push_back(std::make_unique<CacheTransformPass>()); | |||
| post_actions.push_back(std::make_unique<RepeatPass>()); | |||
| #endif | |||
| @@ -340,9 +345,6 @@ Status ExecutionTree::PrepareDeprecated() { | |||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | |||
| // node actions during a tree walk. | |||
| Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) { | |||
| // execute PreAction | |||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); | |||
| // Before going down into children, make any prepare flags updates based on this operator. | |||
| uint32_t op_prep_flags = dataset_op->PrepareFlags(); | |||
| BitSet(&prepare_flags_, op_prep_flags); | |||
| @@ -169,7 +169,7 @@ class ExecutionTree { | |||
| // The driver of the prepare phase of the execution tree. | |||
| // Prepare phase consists of three sub phases | |||
| // | |||
| // 1. PrepareTreePreAction() | |||
| // 1. PreAction() | |||
| // Compulsory transformation/action pre optimization. | |||
| // For example, CacheOp Insertion | |||
| // | |||
| @@ -177,20 +177,20 @@ class ExecutionTree { | |||
| // Optimization transformation/action, optional | |||
| // For example, MapOp Fusion | |||
| // | |||
| // 3. PrepareTreePostAction() | |||
| // 3. PostAction() | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| Status Prepare(int num_epochs = -1); | |||
| Status Prepare(int num_epochs = -1, bool partial = false); | |||
| // Compulsory transformation/action pre optimization. | |||
| // @return Status - The error code return | |||
| Status PrepareTreePreAction(); | |||
| Status PreAction(); | |||
| // Compulsory transformation/action post optimization. | |||
| // @return Status - The error code return | |||
| Status PrepareTreePostAction(); | |||
| Status PostAction(); | |||
| // Optimization transformation/action, optional. | |||
| // @return Status - The error code return | |||
| @@ -281,6 +281,7 @@ class ExecutionTree { | |||
| std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | |||
| bool optimize_; // Flag to enable optional optimizations | |||
| std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() | |||
| bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes. | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/batch_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -139,5 +140,16 @@ Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status BatchNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<BatchNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status BatchNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<BatchNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -74,6 +74,18 @@ class BatchNode : public DatasetNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t batch_size_; | |||
| bool drop_remainder_; | |||
| @@ -46,12 +46,40 @@ BucketBatchByLengthNode::BucketBatchByLengthNode( | |||
| std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() { | |||
| auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||
| element_length_function_, pad_info_, pad_to_bucket_boundary_); | |||
| element_length_function_, pad_info_, pad_to_bucket_boundary_, | |||
| drop_remainder_); | |||
| return node; | |||
| } | |||
| void BucketBatchByLengthNode::Print(std::ostream &out) const { | |||
| out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)"; | |||
| out << Name() + "(columns:" + PrintColumns(column_names_); | |||
| int i = 0; | |||
| for (auto it : bucket_boundaries_) { | |||
| if (i == 0) { | |||
| out << ",bucket_boundaries:{"; | |||
| } | |||
| out << it; | |||
| if (i < bucket_boundaries_.size() - 1) { | |||
| out << ","; | |||
| } else { | |||
| out << "}"; | |||
| } | |||
| i++; | |||
| } | |||
| i = 0; | |||
| for (auto it : bucket_batch_sizes_) { | |||
| if (i == 0) { | |||
| out << ",bucket_batch_sizes:{"; | |||
| } | |||
| out << it; | |||
| if (i < bucket_batch_sizes_.size() - 1) { | |||
| out << ","; | |||
| } else { | |||
| out << "}"; | |||
| } | |||
| i++; | |||
| } | |||
| out << ")"; | |||
| } | |||
| Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| @@ -90,14 +90,14 @@ Status BuildSentenceVocabNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status BuildSentenceVocabNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status BuildSentenceVocabNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified); | |||
| } | |||
| @@ -59,17 +59,17 @@ class BuildSentenceVocabNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| std::shared_ptr<SentencePieceVocab> vocab_; | |||
| @@ -85,14 +85,14 @@ Status BuildVocabNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildVocabNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status BuildVocabNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<BuildVocabNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status BuildVocabNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified); | |||
| } | |||
| @@ -58,17 +58,17 @@ class BuildVocabNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| std::shared_ptr<Vocab> vocab_; | |||
| @@ -39,8 +39,10 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets | |||
| } | |||
| std::shared_ptr<DatasetNode> ConcatNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| // create an empty vector to copy a concat | |||
| auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>()); | |||
| auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler, | |||
| children_flag_and_nums_, children_start_end_index_); | |||
| return node; | |||
| } | |||
| @@ -80,14 +82,14 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ConcatNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status ConcatNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<ConcatNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status ConcatNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<ConcatNode>(), modified); | |||
| } | |||
| @@ -66,17 +66,17 @@ class ConcatNode : public DatasetNode { | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| std::vector<std::pair<int, int>> children_start_end_index_; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -242,9 +242,27 @@ DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) { | |||
| worker_connector_size_ = cfg->worker_connector_size(); | |||
| } | |||
| const bool DatasetNode::IsTree() const { | |||
| bool is_tree = true; | |||
| if (this->parent_.size() > 1) { | |||
| MS_LOG(WARNING) << Name() << " has more than one parent."; | |||
| return false; | |||
| } | |||
| for (const auto &child : children_) { | |||
| is_tree = child->IsTree(); | |||
| if (!is_tree) { | |||
| MS_LOG(WARNING) << Name() << " has more than one parent."; | |||
| break; | |||
| } | |||
| } | |||
| return is_tree; | |||
| } | |||
| // this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied | |||
| std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() { | |||
| std::shared_ptr<DatasetNode> new_node = this->Copy(); | |||
| // temporary fix to set the num_workers to the new node. | |||
| new_node->SetNumWorkers(this->num_workers_); | |||
| for (const auto &child : children_) { | |||
| new_node->AddChild(child->DeepCopy()); | |||
| } | |||
| @@ -298,12 +316,31 @@ void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) { | |||
| children_.push_back(child); | |||
| child->parent_.push_back(this); | |||
| } else if (child != nullptr) { | |||
| MS_LOG(WARNING) << "DatasetNode::AddChild() failed: " + child->Name() + "'s parent isn't a nullptr."; | |||
| MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent"; | |||
| children_.push_back(child); | |||
| child->parent_.push_back(this); | |||
| } | |||
| } | |||
| // Insert a node as a child of this node. This node's children becomes the children of the inserted node. | |||
| Status DatasetNode::InsertBelow(std::shared_ptr<DatasetNode> node) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent."); | |||
| for (auto child : children_) { | |||
| node->children_.push_back(child); | |||
| child->parent_.clear(); | |||
| child->parent_.push_back(node.get()); | |||
| } | |||
| // Then establish the new parent-child relationship with the new parent. | |||
| children_.clear(); | |||
| children_.push_back(node); | |||
| node->parent_.clear(); | |||
| node->parent_.push_back(this); | |||
| return Status::OK(); | |||
| } | |||
| // Remove this node from its parent. Add the child of this node to its parent. | |||
| // for now, this remove is limited to node with a single child or no child | |||
| Status DatasetNode::Remove() { | |||
| @@ -325,14 +362,14 @@ Status DatasetNode::Remove() { | |||
| } | |||
| // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. | |||
| Status DatasetNode::Accept(NodePass *p, bool *modified) { | |||
| Status DatasetNode::Accept(IRNodePass *p, bool *modified) { | |||
| // This method will only be called if its derived class does not implement one. | |||
| return p->Visit(shared_from_this(), modified); | |||
| } | |||
| // In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit | |||
| // after all child nodes are visited. | |||
| Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| Status DatasetNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // This method will only be called if its derived class does not implement one. | |||
| return p->VisitAfter(shared_from_this(), modified); | |||
| } | |||
| @@ -369,17 +406,5 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | |||
| } | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status SourceNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<SourceNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status SourceNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<SourceNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -32,7 +32,7 @@ namespace dataset { | |||
| class Dataset; | |||
| class SamplerObj; | |||
| class NodePass; | |||
| class IRNodePass; | |||
| class DatasetSizeGetter; | |||
| // Names for non-leaf IR node | |||
| @@ -182,6 +182,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \brief Establish the parent-child relationship between this node and its child. | |||
| void AddChild(std::shared_ptr<DatasetNode> child); | |||
| /// \brief Insert the input node below this node. This node's children becomes the children of the inserted node. | |||
| Status InsertBelow(std::shared_ptr<DatasetNode> node); | |||
| /// \brief detach this node from its parent, add its child (if any) to its parent | |||
| /// \return error code, return error if node has more than 1 children | |||
| Status Remove(); | |||
| @@ -190,6 +193,25 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \return True if the data of this node will be cached | |||
| const bool IsCached() const { return (cache_ != nullptr); } | |||
| /// \brief Check if this node is a tree | |||
| /// \return True if the structure is indeed a tree, i.e., no node has more than one parent | |||
| const bool IsTree() const; | |||
| /// \brief Check if this node is a leaf node. | |||
| /// \return True if this is a leaf node. | |||
| const bool IsLeaf() const { return children_.empty(); } | |||
| /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes | |||
| /// \return True if the dataset represented by this node is a mappable dataset | |||
| const bool IsMappable() const { return mappable_; } | |||
| /// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes | |||
| /// \return True if a cache-enabled operator is an ancestor of this node | |||
| const bool IsDescendantOfCache() const { return descendant_of_cache_; } | |||
| /// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes | |||
| void HasCacheAbove() { descendant_of_cache_ = true; } | |||
| /// \brief Setter function for runtime number of workers | |||
| /// \param[in] num_workers The number of threads in this operator | |||
| /// \return Shared pointer to the original object | |||
| @@ -203,7 +225,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| return std::static_pointer_cast<Derived>(shared_from_this()); | |||
| } | |||
| /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up | |||
| /// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up | |||
| /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node | |||
| /// visit on the way back up the tree after its descendants are visited. | |||
| /// \notes Subclass needs to override this if it requires special node visit access. | |||
| @@ -211,15 +233,15 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| virtual Status Accept(NodePass *p, bool *modified); | |||
| virtual Status Accept(IRNodePass *p, bool *modified); | |||
| /// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited. | |||
| /// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited. | |||
| /// \notes Subclass needs to override this if it requires special node visit access. | |||
| /// Check "dataset/engine/opt/pass.h" for more details. | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| virtual Status AcceptAfter(NodePass *p, bool *modified); | |||
| virtual Status AcceptAfter(IRNodePass *p, bool *modified); | |||
| virtual bool IsSizeDefined() { return true; } | |||
| @@ -235,55 +257,22 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| std::string PrintColumns(const std::vector<std::string> &columns) const; | |||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||
| void PrintNode(std::ostream &out, int *level) const; | |||
| }; | |||
| // SourceNode represents the leaf nodes of a pipeline where the data is pulled into. | |||
| class SourceNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| SourceNode() : DatasetNode() {} | |||
| /// \brief Constructor that initializes the cache | |||
| /// \param dataset_cache DatasetCache | |||
| explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {} | |||
| /// \brief Destructor | |||
| ~SourceNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes | |||
| /// \return True if the dataset represented by this node is a mappable dataset | |||
| const bool IsMappable() const { return mappable_; } | |||
| protected: | |||
| bool mappable_; | |||
| bool descendant_of_cache_; | |||
| }; | |||
| // MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes. | |||
| class MappableSourceNode : public SourceNode { | |||
| class MappableSourceNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| MappableSourceNode() : SourceNode() { mappable_ = true; } | |||
| MappableSourceNode() : DatasetNode() { mappable_ = true; } | |||
| /// \brief Constructor that initializes the cache | |||
| /// \param dataset_cache DatasetCache | |||
| explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { | |||
| explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) { | |||
| mappable_ = true; | |||
| // Initially set to false, and set to true by the optimizer when conditions are met. | |||
| descendant_of_cache_ = false; | |||
| } | |||
| /// \brief Destructor | |||
| @@ -295,15 +284,17 @@ class MappableSourceNode : public SourceNode { | |||
| }; | |||
| // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. | |||
| class NonMappableSourceNode : public SourceNode { | |||
| class NonMappableSourceNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| NonMappableSourceNode() : SourceNode() { mappable_ = false; } | |||
| NonMappableSourceNode() : DatasetNode() { mappable_ = false; } | |||
| /// \brief Constructor that initializes the cache | |||
| /// \param dataset_cache DatasetCache | |||
| explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) { | |||
| explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) { | |||
| mappable_ = false; | |||
| // Initially set to false, and set to true by the optimizer when conditions are met. | |||
| descendant_of_cache_ = false; | |||
| } | |||
| /// \brief Destructor | |||
| @@ -313,34 +304,6 @@ class NonMappableSourceNode : public SourceNode { | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| }; | |||
| // NonLeafNode represents operations over data in a pipeline. | |||
| class NonLeafNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| NonLeafNode() = default; | |||
| /// \brief Destructor | |||
| ~NonLeafNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| }; | |||
| // SinkNode represents the end node of a pipeline where the data is pushed out | |||
| class SinkNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| SinkNode() = default; | |||
| /// \brief Destructor | |||
| ~SinkNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | |||
| @@ -32,8 +32,9 @@ EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epo | |||
| // The root node's parent must set to null pointer. | |||
| this->AddChild(child); | |||
| } | |||
| std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() { | |||
| auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_); | |||
| auto node = std::make_shared<EpochCtrlNode>(num_epochs_); | |||
| return node; | |||
| } | |||
| @@ -29,7 +29,10 @@ namespace dataset { | |||
| class EpochCtrlNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | |||
| explicit EpochCtrlNode(int32_t num_epochs) : num_epochs_(num_epochs) {} | |||
| /// \brief Constructor | |||
| EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs); | |||
| /// \brief Destructor | |||
| ~EpochCtrlNode() = default; | |||
| @@ -60,14 +60,14 @@ Status FilterNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status FilterNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status FilterNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<FilterNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status FilterNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status FilterNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<FilterNode>(), modified); | |||
| } | |||
| @@ -58,17 +58,17 @@ class FilterNode : public DatasetNode { | |||
| bool IsSizeDefined() override { return false; }; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| std::shared_ptr<TensorOp> predicate_; | |||
| @@ -42,14 +42,16 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr | |||
| } | |||
| std::shared_ptr<DatasetNode> MapNode::Copy() { | |||
| auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_, | |||
| std::vector<std::shared_ptr<TensorOperation>> operations = operations_; | |||
| auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_, | |||
| callbacks_); | |||
| return node; | |||
| } | |||
| void MapNode::Print(std::ostream &out) const { | |||
| out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + | |||
| ",<project_cols>" + ",...)"; | |||
| ",<project_cols>" + ",num_tensor_ops:" | |||
| << operations_.size() << ",...)"; | |||
| } | |||
| Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| @@ -101,14 +103,14 @@ Status MapNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status MapNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status MapNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<MapNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status MapNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status MapNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<MapNode>(), modified); | |||
| } | |||
| @@ -63,17 +63,17 @@ class MapNode : public DatasetNode { | |||
| const auto &TensorOperations() const { return operations_; } | |||
| auto &TensorOperations() { return operations_; } | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| std::vector<std::shared_ptr<TensorOperation>> operations_; | |||
| @@ -70,14 +70,14 @@ Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RepeatNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status RepeatNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<RepeatNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status RepeatNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<RepeatNode>(), modified); | |||
| } | |||
| @@ -66,17 +66,17 @@ class RepeatNode : public DatasetNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t repeat_count_; | |||
| @@ -72,14 +72,14 @@ Status RootNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RootNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status RootNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<RootNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RootNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status RootNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<RootNode>(), modified); | |||
| } | |||
| @@ -58,17 +58,17 @@ class RootNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t num_epochs_; | |||
| @@ -21,6 +21,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -70,5 +71,16 @@ Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status SkipNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<SkipNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status SkipNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<SkipNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -64,6 +64,18 @@ class SkipNode : public DatasetNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t skip_count_; | |||
| }; | |||
| @@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> AlbumNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); | |||
| return node; | |||
| } | |||
| @@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | |||
| extensions_(extensions) {} | |||
| std::shared_ptr<DatasetNode> CelebANode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); | |||
| return node; | |||
| } | |||
| @@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> Cifar100Node::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); | |||
| return node; | |||
| } | |||
| @@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> Cifar10Node::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); | |||
| return node; | |||
| } | |||
| @@ -208,7 +208,7 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| RETURN_IF_NOT_OK(clue_op->Init()); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { | |||
| // Inject ShuffleOp | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| @@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> CocoNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); | |||
| return node; | |||
| } | |||
| @@ -119,7 +119,7 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| RETURN_IF_NOT_OK(csv_op->Init()); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { | |||
| // Inject ShuffleOp | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| @@ -33,8 +33,16 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector< | |||
| column_names_(column_names), | |||
| column_types_(column_types) {} | |||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | |||
| : generator_function_(generator_function), schema_(schema) {} | |||
| std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | |||
| auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_); | |||
| std::shared_ptr<GeneratorNode> node; | |||
| if (schema_ == nullptr) { | |||
| node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_); | |||
| } else { | |||
| node = std::make_shared<GeneratorNode>(generator_function_, schema_); | |||
| } | |||
| return node; | |||
| } | |||
| @@ -42,9 +50,6 @@ void GeneratorNode::Print(std::ostream &out) const { | |||
| out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)"; | |||
| } | |||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | |||
| : generator_function_(generator_function), schema_(schema) {} | |||
| Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>(); | |||
| @@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar | |||
| exts_(extensions) {} | |||
| std::shared_ptr<DatasetNode> ImageFolderNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = | |||
| std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); | |||
| return node; | |||
| @@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> ManifestNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); | |||
| return node; | |||
| } | |||
| @@ -54,12 +54,13 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st | |||
| std::shared_ptr<DatasetNode> MindDataNode::Copy() { | |||
| std::shared_ptr<MindDataNode> node; | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| if (dataset_files_.empty()) { | |||
| node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); | |||
| } else { | |||
| node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_); | |||
| } | |||
| node->SetSampleBytes(&sample_bytes_); | |||
| return node; | |||
| } | |||
| @@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> MnistNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); | |||
| return node; | |||
| } | |||
| @@ -86,7 +86,7 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | |||
| RETURN_IF_NOT_OK(text_file_op->Init()); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { | |||
| // Inject ShuffleOp | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| @@ -134,7 +134,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| RETURN_IF_NOT_OK(tf_reader_op->Init()); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { | |||
| // Inject ShuffleOp | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| @@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> VOCNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||
| auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); | |||
| return node; | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/datasetops/take_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| @@ -68,5 +69,16 @@ Status TakeNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status TakeNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<TakeNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status TakeNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<TakeNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -64,6 +64,18 @@ class TakeNode : public DatasetNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| int32_t take_count_; | |||
| }; | |||
| @@ -104,14 +104,14 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status TransferNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status TransferNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<TransferNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status TransferNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status TransferNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<TransferNode>(), modified); | |||
| } | |||
| @@ -58,17 +58,17 @@ class TransferNode : public DatasetNode { | |||
| static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id); | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| std::string queue_name_; | |||
| @@ -79,14 +79,14 @@ Status ZipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ZipNode::Accept(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status ZipNode::Accept(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<ZipNode>(), modified); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status ZipNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // Visitor accepting method for IRNodePass | |||
| Status ZipNode::AcceptAfter(IRNodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<ZipNode>(), modified); | |||
| } | |||
| @@ -64,19 +64,20 @@ class ZipNode : public DatasetNode { | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::vector<std::shared_ptr<DatasetNode>> datasets_; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| Status Accept(IRNodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(NodePass *p, bool *modified) override; | |||
| Status AcceptAfter(IRNodePass *p, bool *modified) override; | |||
| private: | |||
| std::vector<std::shared_ptr<DatasetNode>> datasets_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -6,9 +6,12 @@ add_library(engine-opt OBJECT | |||
| post/repeat_pass.cc | |||
| pre/cache_error_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| pre/cache_validation_pass.cc | |||
| pre/epoch_ctrl_pass.cc | |||
| pre/epoch_injection_pass.cc | |||
| pre/getter_pass.cc | |||
| pre/input_validation_pass.cc | |||
| pre/node_removal_pass.cc | |||
| pre/removal_pass.cc | |||
| util/printer_pass.cc | |||
| ) | |||
| @@ -87,7 +87,7 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Driver method for TreePass | |||
| Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| Status IRTreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| if (root_ir == nullptr || modified == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); | |||
| } | |||
| @@ -95,7 +95,7 @@ Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| } | |||
| // Driver method for NodePass | |||
| Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| Status IRNodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| if (root_ir == nullptr || modified == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); | |||
| } | |||
| @@ -110,7 +110,7 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| } | |||
| // Helper function to perform DFS visit | |||
| Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | |||
| Status IRNodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | |||
| bool m = false; | |||
| RETURN_IF_NOT_OK(node_ir->Accept(this, &m)); | |||
| @@ -125,7 +125,7 @@ Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi | |||
| } | |||
| // Helper function to perform BFS visit | |||
| Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | |||
| Status IRNodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) { | |||
| bool m = false; | |||
| // Initialize bfs queue with root | |||
| @@ -151,121 +151,113 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi | |||
| } | |||
| // For non-leaf IR node | |||
| Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| Status IRNodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #endif | |||
| // For leaf IR Node | |||
| Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| // Driver method for TreePass | |||
| @@ -113,26 +113,18 @@ class GeneratorOp; | |||
| // The base class Pass is the basic unit of tree transformation. | |||
| // The actual implementation of the passes will be derived from here. | |||
| class Pass : public std::enable_shared_from_this<Pass> { | |||
| class IRPass : public std::enable_shared_from_this<IRPass> { | |||
| public: | |||
| // Run the transformation pass against the IR tree. | |||
| // @param root_ir - Pointer to the IR tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) = 0; | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| // Run the transformation pass against the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| virtual Status Run(ExecutionTree *tree, bool *modified) = 0; | |||
| ////////////////////////////////// | |||
| virtual ~Pass() = default; | |||
| virtual ~IRPass() = default; | |||
| }; | |||
| // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | |||
| class TreePass : public Pass { | |||
| // IRTreePass is a basic Pass class which performs transformation on IR tree directly. | |||
| class IRTreePass : public IRPass { | |||
| public: | |||
| /// \brief Run the transformation pass against the IR tree. | |||
| /// \param[inout] root_ir Pointer to the IR tree to be transformed. | |||
| @@ -145,44 +137,29 @@ class TreePass : public Pass { | |||
| /// \param[inout] Indicate if the tree was modified. | |||
| /// \return Status The error code return | |||
| virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| /// \brief Run the transformation pass against the execution tree. | |||
| /// \param[inout] tree Pointer to the execution tree to be transformed. | |||
| /// \param[inout] modified Indicate if the tree was modified | |||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||
| ////////////////////////////////// | |||
| }; | |||
| // NodePass is a base Pass class which performs transformation on node visiting. | |||
| // NodePass implements Visitor design pattern. | |||
| // IRNodePass is a base Pass class which performs transformation on node visiting. | |||
| // IRNodePass implements Visitor design pattern. | |||
| // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, | |||
| // and the other when all the descending nodes are visited. | |||
| // Actual transformation is done by implementing a new derived class of NodePass. | |||
| // Actual transformation is done by implementing a new derived class of IRNodePass. | |||
| // The derived class will implement the method Visit()/VisitAfter() passing specified node types | |||
| // it wants to action on them, overriding the ones defined in NodePass. | |||
| // it wants to action on them, overriding the ones defined in IRNodePass. | |||
| // If the derived class wants to perform the same action on all node types, | |||
| // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. | |||
| // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back | |||
| // to call the Visit()/VisitAfter() in this parent NodePass class. | |||
| class NodePass : public Pass { | |||
| // to call the Visit()/VisitAfter() in this parent IRNodePass class. | |||
| class IRNodePass : public IRPass { | |||
| public: | |||
| // Tree traversal order | |||
| enum Order { DFS, BFS }; | |||
| // Constructor | |||
| // Default DFS traversal | |||
| explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } | |||
| explicit IRNodePass(Order order = Order::DFS) { traversalOrder_ = order; } | |||
| ~NodePass() = default; | |||
| ~IRNodePass() = default; | |||
| /// \brief Run the transformation pass against the IR tree | |||
| /// \param[inout] root_ir Pointer to the IR tree to be transformed | |||
| @@ -251,12 +228,70 @@ class NodePass : public Pass { | |||
| virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified); | |||
| #endif | |||
| // Leaf IR node | |||
| virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified); | |||
| virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified); | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| private: | |||
| // Helper function to perform DFS visit | |||
| Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified); | |||
| // Helper function to perform BFS visit | |||
| Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified); | |||
| // Tree traversal order of the NodePass | |||
| Order traversalOrder_; | |||
| }; | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| // The base class Pass is the basic unit of tree transformation. | |||
| // The actual implementation of the passes will be derived from here. | |||
| class Pass : public std::enable_shared_from_this<Pass> { | |||
| public: | |||
| // Run the transformation pass against the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| virtual Status Run(ExecutionTree *tree, bool *modified) = 0; | |||
| virtual ~Pass() = default; | |||
| }; | |||
| // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | |||
| class TreePass : public Pass { | |||
| public: | |||
| /// \brief Run the transformation pass against the execution tree. | |||
| /// \param[inout] tree Pointer to the execution tree to be transformed. | |||
| /// \param[inout] modified Indicate if the tree was modified | |||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||
| }; | |||
| // NodePass is a base Pass class which performs transformation on node visiting. | |||
| // NodePass implements Visitor design pattern. | |||
| // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, | |||
| // and the other when all the descending nodes are visited. | |||
| // Actual transformation is done by implementing a new derived class of NodePass. | |||
| // The derived class will implement the method Visit()/VisitAfter() passing specified node types | |||
| // it wants to action on them, overriding the ones defined in NodePass. | |||
| // If the derived class wants to perform the same action on all node types, | |||
| // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. | |||
| // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back | |||
| // to call the Visit()/VisitAfter() in this parent NodePass class. | |||
| class NodePass : public Pass { | |||
| public: | |||
| // Tree traversal order | |||
| enum Order { DFS, BFS }; | |||
| // Constructor | |||
| // Default DFS traversal | |||
| explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } | |||
| ~NodePass() = default; | |||
| /// \brief Run the transformation pass against the execution tree | |||
| /// \param[inout] tree Pointer to the execution tree to be transformed | |||
| /// \param[inout] modified Indicator if the tree was changed | |||
| @@ -326,27 +361,18 @@ class NodePass : public Pass { | |||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||
| #endif | |||
| ////////////////////////////////// | |||
| private: | |||
| // Helper function to perform DFS visit | |||
| Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified); | |||
| // Helper function to perform BFS visit | |||
| Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified); | |||
| ////////////////////////////////// | |||
| // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. | |||
| // Helper function to perform DFS visit | |||
| Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); | |||
| // Helper function to perform BFS visit | |||
| Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified); | |||
| ////////////////////////////////// | |||
| // Tree traversal order of the NodePass | |||
| Order traversalOrder_; | |||
| }; | |||
| ////////////////////////////////// | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,163 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheValidationPass::CacheValidationPass() : is_cached_(false), is_mappable_(false) {} | |||
| // Returns an error if BatchNode exists under a cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<BatchNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<BatchNode>): visiting " << node->Name() << "."; | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("BatchNode is not supported as a descendant operator under a cache."); | |||
| } | |||
| if (node->IsCached()) { | |||
| RETURN_STATUS_UNEXPECTED("BatchNode cannot be cached."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if ConcatNode exists under a cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ConcatNode>): visiting " << node->Name() << "."; | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("ConcatNode is not supported as a descendant operator under a cache."); | |||
| } | |||
| if (node->IsCached()) { | |||
| RETURN_STATUS_UNEXPECTED("ConcatNode cannot be cached."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if FilterNode exists under a cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<FilterNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<FilterNode>): visiting " << node->Name() << "."; | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("FilterNode is not supported as a descendant operator under a cache."); | |||
| } | |||
| if (node->IsCached()) { | |||
| RETURN_STATUS_UNEXPECTED("FilterNode cannot be cached."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if SkipNode exists under a cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<SkipNode>): visiting " << node->Name() << "."; | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("SkipNode is not supported as a descendant operator under a cache."); | |||
| } | |||
| if (node->IsCached()) { | |||
| RETURN_STATUS_UNEXPECTED("SkipNode cannot be cached."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if TakeNode exists under a cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<TakeNode>): visiting " << node->Name() << "."; | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("TakeNode (possibly from Split) is not supported as a descendant operator under a cache."); | |||
| } | |||
| if (node->IsCached()) { | |||
| RETURN_STATUS_UNEXPECTED("TakeNode cannot be cached."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if ZipNode exists under a cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<ZipNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<ZipNode>): visiting " << node->Name() << "."; | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("ZipNode is not supported as a descendant operator under a cache."); | |||
| } | |||
| if (node->IsCached()) { | |||
| RETURN_STATUS_UNEXPECTED("ZipNode cannot be cached."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if MapNode with non-deterministic tensor operations exists under a cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<MapNode>): visiting " << node->Name() << "."; | |||
| if (node->IsCached()) { | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("Nested cache operations over MapNode is not supported."); | |||
| } | |||
| // If Map is created to be cached, set the flag indicating we found an operation with a cache. | |||
| is_cached_ = true; | |||
| auto tfuncs = node->TensorOperations(); | |||
| for (size_t i = 0; i < tfuncs.size(); i++) { | |||
| if (tfuncs[i]->IsRandomOp()) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "MapNode with non-deterministic operations is not supported as a descendant of cache."); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Flag an error if we have a cache over another cache | |||
| Status CacheValidationPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::Visit(<DatasetNode>): visiting " << node->Name() << "."; | |||
| if (node->IsCached()) { | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("Nested cache operations over " + node->Name() + " is not supported."); | |||
| } | |||
| // If this node is created to be cached, set the flag. | |||
| is_cached_ = true; | |||
| } | |||
| if (node->IsLeaf() && node->IsMappable()) { | |||
| is_mappable_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if MappableSource <- Repeat <- Node with a cache | |||
| // Because there is no operator in the cache hit stream to consume EoEs, caching above repeat causes problem. | |||
| Status CacheValidationPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<RepeatNode>): visiting " << node->Name() << "."; | |||
| if (is_cached_ && is_mappable_) { | |||
| RETURN_STATUS_UNEXPECTED("A cache over a RepeatNode of a mappable dataset is not supported."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheValidationPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { | |||
| MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(<DatasetNode>): visiting " << node->Name() << "."; | |||
| // Reset the flag when all descendants are visited | |||
| if (node->IsCached()) { | |||
| is_cached_ = false; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,105 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_ | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class CacheValidationPass cache_validation_pass.h | |||
| /// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures. | |||
| class CacheValidationPass : public IRNodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| CacheValidationPass(); | |||
| /// \brief Destructor | |||
| ~CacheValidationPass() = default; | |||
| /// \brief Returns an error if BatchNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override; | |||
| /// \brief Returns an error if ConcatNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override; | |||
| /// \brief Returns an error if FilterNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override; | |||
| /// \brief Returns an error if SkipNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override; | |||
| /// \brief Returns an error if TakeNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override; | |||
| /// \brief Returns an error if ZipNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override; | |||
| /// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; | |||
| /// \brief Returns an error if there is a cache over another cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| /// \brief Identifies and block repeat under cache scenarios | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override; | |||
| /// \brief Identifies the subtree above this node as not being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| private: | |||
| bool is_cached_; | |||
| bool is_mappable_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_ | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // constructor | |||
| EpochCtrlPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetNode> node) | |||
| : injection_point_(nullptr), num_epochs_(-1) {} | |||
| // Performs finder work for BuildVocabOp that has special rules about epoch control injection | |||
| Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<RootNode> node, bool *modified) { | |||
| // The injection is at the child of the root node | |||
| injection_point_ = node; | |||
| num_epochs_ = node->num_epochs(); | |||
| return Status::OK(); | |||
| } | |||
| // Performs finder work for BuildVocabOp that has special rules about epoch control injection | |||
| Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // Performs finder work for BuildSentencePieceVocabNode that has special rules about epoch control injection | |||
| Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status EpochCtrlPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) { | |||
| // Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here. | |||
| // Move the injection point to the child of this node. | |||
| injection_point_ = node; | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| EpochCtrlPass::EpochCtrlPass() {} | |||
| // Runs an injection pass to inject in operators needed at the pre pass stage | |||
| Status EpochCtrlPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: Injection pass started."; | |||
| // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. | |||
| // The finder can make updates to the EpochInjectionPass object. | |||
| EpochCtrlPass::InjectionFinder finder(root_ir); | |||
| RETURN_IF_NOT_OK(finder.Run(root_ir, modified)); | |||
| // The first injection logic is to check if we should inject the epoch control op as the root node. | |||
| // Do not inject the op if the number of epochs is 1. | |||
| std::shared_ptr<DatasetNode> parent = finder.injection_point(); | |||
| int32_t num_epochs = finder.num_epochs(); | |||
| if (num_epochs != 1 && parent != nullptr) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(parent->Children().size() == 1, "EpochCtrl must be injected on only one child."); | |||
| auto epoch_ctrl_node = std::make_shared<EpochCtrlNode>(num_epochs); | |||
| RETURN_IF_NOT_OK(parent->InsertBelow(epoch_ctrl_node)); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: Injection pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DatasetOp; | |||
| /// \class EpochInjectionPass epoch_ctrl_pass.h | |||
| /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api | |||
| /// parsing. | |||
| class EpochCtrlPass : public IRTreePass { | |||
| /// \class InjectionFinder | |||
| /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for | |||
| /// operators that need to be injected. It is run first by the main injection pass to find out what operators | |||
| /// it may need to inject. | |||
| class InjectionFinder : public IRNodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit InjectionFinder(std::shared_ptr<DatasetNode> node); | |||
| /// \brief Destructor | |||
| ~InjectionFinder() = default; | |||
| /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<RootNode> node, bool *modified) override; | |||
| /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Register the TransferNode for further action. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override; | |||
| /// \brief Getter | |||
| std::shared_ptr<DatasetNode> injection_point() { return injection_point_; } | |||
| /// \brief Getter | |||
| int32_t num_epochs() { return num_epochs_; } | |||
| private: | |||
| std::shared_ptr<DatasetNode> injection_point_; | |||
| int32_t num_epochs_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| EpochCtrlPass(); | |||
| /// \brief Destructor | |||
| ~EpochCtrlPass() = default; | |||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ | |||
| @@ -26,7 +26,7 @@ namespace dataset { | |||
| /// \class InputValidationPass | |||
| /// \brief This is a parse pass that validates input parameters of the IR tree. | |||
| class InputValidationPass : public NodePass { | |||
| class InputValidationPass : public IRNodePass { | |||
| /// \brief Runs a validatation pass to check input parameters | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] *modified indicates whether the node has been visited | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| NodeRemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<DatasetNode> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Node removal pass: Operation with cache found, identified descendant tree."; | |||
| if (node->IsCached()) { | |||
| is_caching_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree | |||
| Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: Descendant walk is complete."; | |||
| if (is_caching_ && node->IsLeaf()) { | |||
| // Mark this leaf node to indicate it is a descendant of an operator with cache. | |||
| // This is currently used in non-mappable dataset (leaf) nodes to not add a ShuffleOp in DatasetNode::Build(). | |||
| node->HasCacheAbove(); | |||
| } | |||
| is_caching_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // Perform ShuffleOp removal check. | |||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| *modified = false; | |||
| #if 0 | |||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||
| if (is_caching_) { | |||
| MS_LOG(INFO) << "Shuffle under an operation with cache is identified for removal."; | |||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node)); | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| NodeRemovalPass::NodeRemovalPass() {} | |||
| // Walk the tree to collect the nodes to remove, then removes them. | |||
| Status NodeRemovalPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: node removal pass started."; | |||
| // Create the removal node pass which can identify which nodes need to be removed. | |||
| std::unique_ptr<NodeRemovalPass::RemovalNodes> removal_nodes = std::make_unique<NodeRemovalPass::RemovalNodes>(); | |||
| RETURN_IF_NOT_OK(removal_nodes->Run(root_ir, modified)); | |||
| // Then, execute the removal of any nodes that were set up for removal | |||
| for (auto node : removal_nodes->nodes_to_remove()) { | |||
| RETURN_IF_NOT_OK(node->Remove()); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: node removal pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DatasetOp; | |||
| /// \class RemovalPass removal_pass.h | |||
| /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which | |||
| /// nodes should be removed, and then removes them. | |||
| class NodeRemovalPass : public IRTreePass { | |||
| /// \class RemovalNodes | |||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||
| /// It works in conjunction with the removal_pass. | |||
| class RemovalNodes : public IRNodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||
| RemovalNodes(); | |||
| /// \brief Destructor | |||
| ~RemovalNodes() = default; | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| /// \brief Perform ShuffleNode removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override; | |||
| /// \brief Getter | |||
| /// \return All the nodes to be removed | |||
| std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; } | |||
| private: | |||
| bool is_caching_; | |||
| std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove_; | |||
| }; | |||
| public: | |||
| /// \brief Constructor | |||
| NodeRemovalPass(); | |||
| /// \brief Destructor | |||
| ~NodeRemovalPass() = default; | |||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ | |||
| @@ -17,34 +17,25 @@ | |||
| #include "minddata/dataset/engine/tree_adapter.h" | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) { | |||
| // Vector of actions in validation pass | |||
| std::vector<std::unique_ptr<NodePass>> validations; | |||
| MS_LOG(INFO) << "Running pre pass loops."; | |||
| validations.push_back(std::make_unique<InputValidationPass>()); | |||
| // Vector of flags for each action | |||
| // Apply validation actions | |||
| for (auto i = 0; i < validations.size(); i++) { | |||
| auto modified = false; | |||
| // InputValidationPass does not change the IR tree. We don't need to capture the "modified" value. | |||
| RETURN_IF_NOT_OK(validations[i]->Run(ir, &modified)); | |||
| } | |||
| // Vector of actions in pre-pass phase | |||
| std::vector<std::unique_ptr<Pass>> actions; | |||
| std::vector<std::unique_ptr<IRPass>> actions; | |||
| // We will gradually move CacheErrorPass, EpochInjectionPass, CacheTransformPass | |||
| // from ExecutionTree::PrepareTreePreAction to here. | |||
| MS_LOG(INFO) << "Running pre pass loops."; | |||
| actions.push_back(std::make_unique<InputValidationPass>()); | |||
| actions.push_back(std::make_unique<CacheValidationPass>()); | |||
| actions.push_back(std::make_unique<NodeRemovalPass>()); | |||
| actions.push_back(std::make_unique<EpochCtrlPass>()); | |||
| // Vector of flags for each action | |||
| std::vector<bool> modified(actions.size(), false); | |||
| @@ -60,7 +51,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) { | |||
| Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) { | |||
| // Vector of optimizations | |||
| std::vector<std::unique_ptr<NodePass>> optimizations; | |||
| std::vector<std::unique_ptr<IRNodePass>> optimizations; | |||
| MS_LOG(INFO) << "Running optimization pass loops"; | |||
| // We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here. | |||
| @@ -79,7 +70,7 @@ Status TreeAdapter::Optimize(std::shared_ptr<DatasetNode> ir) { | |||
| Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||
| // Vector of actions in post-pass phase | |||
| std::vector<std::unique_ptr<Pass>> actions; | |||
| std::vector<std::unique_ptr<IRPass>> actions; | |||
| MS_LOG(INFO) << "Running post pass loops."; | |||
| // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. | |||
| @@ -96,10 +87,6 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||
| } | |||
| Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { | |||
| // Check if pipeline is valid or not | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ir->Parent().size() <= 1, | |||
| "The data pipeline is not a tree (i.e. one node has two consumers)"); | |||
| // Build the DatasetOp ExecutionTree from the optimized IR tree | |||
| std::vector<std::shared_ptr<DatasetOp>> ops; | |||
| RETURN_IF_NOT_OK(ir->Build(&ops)); | |||
| @@ -130,8 +117,12 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e | |||
| RETURN_UNEXPECTED_IF_NULL(input_ir); | |||
| MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; | |||
| // We will first walk the input tree to sanity check this is not a graph | |||
| // Flag an error when it is not a tree | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input_ir->IsTree(), "The data pipeline is not a tree (i.e. one node has two consumers)"); | |||
| // Copy the input IR tree and insert under the root node | |||
| // Create a root node to host the input IR tree, the deepcopied tree will be passed to optimization pass | |||
| // Create a root node to host the new copy of the input IR tree to pass to the optimizer | |||
| auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs); | |||
| MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n'; | |||
| @@ -151,11 +142,9 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e | |||
| // This will evolve in the long run | |||
| tree_ = std::make_unique<ExecutionTree>(); | |||
| // Build the Execution tree from the child of the root node | |||
| // Build the Execution tree from the child of the IR root node, which represent the root of the input IR tree | |||
| std::shared_ptr<DatasetOp> root_op; | |||
| // input_ir is the ir node before the deepcopy. | |||
| // We will replace input_ir with root_ir->Children()[0] once IR optimizer is in | |||
| RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op)); | |||
| RETURN_IF_NOT_OK(BuildExecutionTree(root_ir->Children()[0], &root_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | |||
| if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); | |||
| @@ -163,7 +152,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e | |||
| // Note: We will gradually move the pre pass, optimizer pass, and post pass | |||
| // on ExecutionTree to perform on IR tree. | |||
| // Prepare the tree | |||
| RETURN_IF_NOT_OK(tree_->Prepare(num_epochs)); | |||
| RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true)); | |||
| // After the tree is prepared, the col_name_id_map can safely be obtained | |||
| column_name_map_ = tree_->root()->column_name_id_map(); | |||
| @@ -44,7 +44,7 @@ Status NormalizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt | |||
| } | |||
| void NormalizeOp::Print(std::ostream &out) const { | |||
| out << "NormalizeOp, mean: " << mean_ << std::endl << "std: " << std_ << std::endl; | |||
| out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -83,7 +83,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) { | |||
| }; | |||
| exe_tree->SetPrePassOverride(pass); | |||
| ASSERT_OK(exe_tree->PrepareTreePreAction()); | |||
| ASSERT_OK(exe_tree->PreAction()); | |||
| std::stringstream ss; | |||
| // print the tree in std::string as a way to verify that nodes are indeed removed | |||
| @@ -124,7 +124,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { | |||
| }; | |||
| exe_tree->SetPrePassOverride(pass); | |||
| ASSERT_OK(exe_tree->PrepareTreePreAction()); | |||
| ASSERT_OK(exe_tree->PreAction()); | |||
| std::stringstream ss; | |||
| // print the tree in std::string as a way to verify that nodes are indeed removed | |||
| exe_tree->Print(ss); | |||
| @@ -237,7 +237,7 @@ def test_cache_map_failure1(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert "Nested cache operations is not supported!" in str(e.value) | |||
| assert "Nested cache operations" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure1 Ended.\n') | |||
| @@ -279,7 +279,7 @@ def test_cache_map_failure2(): | |||
| num_iter = 0 | |||
| for _ in dsz.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "ZipOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure2 Ended.\n') | |||
| @@ -319,7 +319,7 @@ def test_cache_map_failure3(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "BatchOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure3 Ended.\n') | |||
| @@ -361,7 +361,7 @@ def test_cache_map_failure4(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "FilterOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure4 Ended.\n') | |||
| @@ -402,7 +402,7 @@ def test_cache_map_failure5(): | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache" in str(e.value) | |||
| assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure5 Ended.\n') | |||
| @@ -522,7 +522,7 @@ def test_cache_map_failure8(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value) | |||
| assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure8 Ended.\n') | |||
| @@ -564,7 +564,7 @@ def test_cache_map_failure9(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure9 Ended.\n') | |||
| @@ -606,7 +606,7 @@ def test_cache_map_failure10(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "SkipOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure10 Ended.\n') | |||
| @@ -655,13 +655,13 @@ def test_cache_map_split1(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) | |||
| logger.info('test_cache_split1 Ended.\n') | |||