Merge pull request !2772 from Jamie/removalpasstags/v0.6.0-beta
| @@ -409,7 +409,7 @@ Status BatchOp::UnpackPadInfo(const PadInfo &pad_info, | |||
| // Visitor accept method for NodePass | |||
| Status BatchOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<BatchOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| @@ -111,6 +111,51 @@ void DatasetOp::RemoveParent(const DatasetOp *parent) { | |||
| parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); | |||
| } | |||
| // Removes this node from the tree and connects it's parent/child together | |||
| Status DatasetOp::Remove() { | |||
| if (parent_.size() > 1) { | |||
| std::string err_msg("No support for op removal if the operator has more than one parent"); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| if (child_.size() > 1) { | |||
| std::string err_msg("No support for op removal if the operator has more than one child"); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| // Scenario's when removing node B: | |||
| // A -> B -> C | |||
| // A -> B | |||
| // B -> C | |||
| // | |||
| // If we remove B, then first take our child A and update it's parent to be C | |||
| // It's possible the parent is null if we are the root node being removed. | |||
| if (!child_.empty()) { | |||
| // If we have a parent, then assign chlid's parent to point to our parent. | |||
| if (!parent_.empty()) { | |||
| child_[0]->parent_[0] = parent_[0]; | |||
| } else { | |||
| // We don't have a parent, so we are the root node being removed. | |||
| // clear the parent list of our child so that it becomes the new root. | |||
| child_[0]->parent_.clear(); | |||
| tree_->AssignRoot(child_[0]); | |||
| } | |||
| } | |||
| // Next, if we had a parent, then set it's child to be our child. | |||
| if (!parent_.empty()) { | |||
| // if we have a child, then set our parent to point to it | |||
| if (!child_.empty()) { | |||
| parent_[0]->child_[0] = child_[0]; | |||
| } else { | |||
| // We don't have a child, so clear the child list of the current | |||
| // parent because it will be empty once we are removed. | |||
| parent_[0]->child_.clear(); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Getter function to get a shared pointer to our childAdds a operator to become our child. | |||
| std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const { | |||
| MS_ASSERT(child_index < static_cast<int>(child_.size())); | |||
| @@ -289,6 +334,12 @@ Status DatasetOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetOp::PreAccept(NodePass *p, bool *modified) { | |||
| // DatasetOp is the base class of visitor target pre-visit. | |||
| // This method will only be called if its derived class does not implement one. | |||
| return p->PreRunOnNode(shared_from_this(), modified); | |||
| } | |||
| Status DatasetOp::Accept(NodePass *p, bool *modified) { | |||
| // DatasetOp is the base class of visitor target. | |||
| // This method will only be called if its derived class does not implement one. | |||
| @@ -71,6 +71,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| // @param child - shared pointer to the child to remove. | |||
| Status RemoveChild(std::shared_ptr<DatasetOp> child); | |||
| /// \brief Removes this node from the tree and connects it's parent/child together. | |||
| /// \return Status eerror code returned | |||
| Status Remove(); | |||
| // Getter function to get a shared pointer to our child | |||
| // @param child_index - An operator can have n children. Indicates choose which child to return. | |||
| std::shared_ptr<DatasetOp> child(int32_t child_index) const; | |||
| @@ -264,10 +268,20 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| // @return Vector of Children | |||
| std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } | |||
| // Base method for NodePass visit. | |||
| // Subclass needs to override this if it requires special node visit access. | |||
| // Check "dataset/engine/opt/pass.h" for more details. | |||
| // @return Statue of the node visit | |||
| /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up | |||
| /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main | |||
| /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it | |||
| /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details. | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| virtual Status PreAccept(NodePass *p, bool *modified); | |||
| /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access. | |||
| /// Check "dataset/engine/opt/pass.h" for more details. | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| virtual Status Accept(NodePass *p, bool *modified); | |||
| // Op name getter | |||
| @@ -285,6 +299,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| // Computes a CRC value for the operator | |||
| static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op); | |||
| /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> | |||
| /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr | |||
| /// \return A shared_ptr casted to the derived class | |||
| template <typename Derived> | |||
| std::shared_ptr<Derived> shared_from_base() { | |||
| return std::static_pointer_cast<Derived>(shared_from_this()); | |||
| } | |||
| protected: | |||
| // Adds a parent operator to this operator | |||
| // @notes External callers do not have access to this function. | |||
| @@ -313,7 +313,7 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { | |||
| // Visitor accept method for NodePass | |||
| Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<DeviceQueueOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| @@ -261,7 +261,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate | |||
| // Visitor accept method for NodePass | |||
| Status FilterOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<FilterOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -367,7 +367,7 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name | |||
| // Visitor accept method for NodePass | |||
| Status MapOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<MapOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -131,7 +131,7 @@ Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } | |||
| // Visitor accept method for NodePass | |||
| Status ProjectOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<ProjectOp>(), modified); | |||
| } | |||
| // Compute the column map and save it into our own column name map | |||
| @@ -176,7 +176,7 @@ Status RenameOp::EoeReceived(int32_t) { | |||
| // Visitor accept method for NodePass | |||
| Status RenameOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<RenameOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -190,7 +190,7 @@ int32_t RepeatOp::num_producers() const { | |||
| // Visitor accept method for NodePass | |||
| Status RepeatOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<RepeatOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -298,7 +298,7 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) { | |||
| // Visitor accept method for NodePass | |||
| Status ShuffleOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<ShuffleOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -130,7 +130,7 @@ Status SkipOp::EofReceived(int32_t worker_id) { | |||
| // Visitor accept method for NodePass | |||
| Status SkipOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<SkipOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -249,7 +249,7 @@ Status GeneratorOp::Reset() { | |||
| // Visitor accept method for NodePass | |||
| Status GeneratorOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<GeneratorOp>(), modified); | |||
| } | |||
| Status GeneratorOp::ComputeColMap() { | |||
| @@ -411,7 +411,7 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::se | |||
| // Visitor accept method for NodePass | |||
| Status ImageFolderOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<ImageFolderOp>(), modified); | |||
| } | |||
| Status ImageFolderOp::ComputeColMap() { | |||
| @@ -496,7 +496,7 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, | |||
| // Visitor accept method for NodePass | |||
| Status MindRecordOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<MindRecordOp>(), modified); | |||
| } | |||
| Status MindRecordOp::ComputeColMap() { | |||
| @@ -1004,7 +1004,7 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file | |||
| // Visitor accept method for NodePass | |||
| Status TFReaderOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<TFReaderOp>(), modified); | |||
| } | |||
| Status TFReaderOp::ComputeColMap() { | |||
| @@ -136,7 +136,7 @@ Status TakeOp::PrepareNodePostAction() { | |||
| // Visitor accept method for NodePass | |||
| Status TakeOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<TakeOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -237,7 +237,7 @@ Status ZipOp::EoeReceived(int32_t) { | |||
| // Visitor accept method for NodePass | |||
| Status ZipOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified); | |||
| return p->RunOnNode(shared_from_base<ZipOp>(), modified); | |||
| } | |||
| Status ZipOp::ComputeColMap() { | |||
| @@ -20,6 +20,7 @@ | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/engine/opt/pre/removal_pass.h" | |||
| #include "dataset/engine/perf/profiling.h" | |||
| #include "dataset/engine/perf/monitor.h" | |||
| @@ -214,7 +215,8 @@ Status ExecutionTree::PrepareTreePreAction() { | |||
| bool modified = false; | |||
| std::vector<std::unique_ptr<Pass>> pre_actions; | |||
| // Construct pre actions | |||
| // example: pre_actions.push_back(new SomePass()); | |||
| MS_LOG(INFO) << "Running pre pass"; | |||
| pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass())); | |||
| // Apply pre action passes | |||
| for (auto &pass : pre_actions) { | |||
| RETURN_IF_NOT_OK(pass->Run(this, &modified)); | |||
| @@ -1,6 +1,8 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-opt OBJECT | |||
| pass.cc | |||
| util/printer_pass.cc | |||
| pass.cc | |||
| pre/removal_nodes.cc | |||
| pre/removal_pass.cc | |||
| util/printer_pass.cc | |||
| ) | |||
| @@ -61,6 +61,7 @@ Status NodePass::Run(ExecutionTree *tree, bool *modified) { | |||
| // Helper function to perform DFS visit | |||
| Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| RETURN_IF_NOT_OK(node->PreAccept(this, modified)); | |||
| for (const auto &c : node->Children()) { | |||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); | |||
| } | |||
| @@ -159,6 +160,5 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -66,14 +66,16 @@ class Pass : public std::enable_shared_from_this<Pass> { | |||
| // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | |||
| class TreePass : public Pass { | |||
| public: | |||
| // Run the transformation pass against the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| /// \brief Run the transformation pass against the execution tree. | |||
| /// \param[inout] tree Pointer to the execution tree to be transformed. | |||
| /// \param[inout] modified Indicate if the tree was modified | |||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||
| // Derived classes may implement the runOnTree function to implement tree transformation. | |||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| // @return Status - The error code return | |||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||
| }; | |||
| @@ -90,14 +92,23 @@ class NodePass : public Pass { | |||
| ~NodePass() = default; | |||
| // Run the transformation pass against the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| /// \brief Run the transformation pass against the execution tree | |||
| /// \param[inout] tree Pointer to the execution tree to be transformed | |||
| /// \param[inout] modified Indicator if the tree was changed | |||
| Status Run(ExecutionTree *tree, bool *modified) final; | |||
| // Derived classes may implement the runOnNode function to implement node level tree transformation. | |||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| // @return Status - The error code return | |||
| /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down | |||
| /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution | |||
| /// \param[in] node The node being visited | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | |||
| /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution | |||
| /// \param[in] node The node being visited | |||
| /// \param[out] modified Indicator if the node was changed at all. | |||
| /// \return Status The error code return | |||
| virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | |||
| // Visit methods to be overridden. | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "dataset/engine/opt/pre/removal_nodes.h" | |||
| #include "dataset/engine/opt/pre/removal_pass.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} | |||
| // Perform ShuffleOp removal check. | |||
| Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||
| *modified = false; | |||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||
| if (is_caching_) { | |||
| MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||
| if (removal_pass_) { | |||
| removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node)); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||
| #include <memory> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class RemovalPass; | |||
| /// \class RemovalNodes removal_nodes.h | |||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||
| /// It works in conjunction with the removal_pass. | |||
| class RemovalNodes : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||
| explicit RemovalNodes(RemovalPass *removal_pass); | |||
| /// \brief Perform ShuffleOp removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||
| private: | |||
| bool is_caching_; | |||
| RemovalPass *removal_pass_; // Back pointer to the owning removal pass | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "dataset/engine/opt/pre/removal_nodes.h" | |||
| #include "dataset/engine/opt/pre/removal_pass.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // constructor | |||
| RemovalPass::RemovalPass() {} | |||
| // Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| // Create the removal node pass which can identify which nodes need to be removed. | |||
| std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this); | |||
| RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); | |||
| // Then, execute the removal of any nodes that were set up for removal | |||
| for (auto node : removal_nodes_) { | |||
| node->Remove(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the list of operators to be removed | |||
| void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DatasetOp; | |||
| /// \class RemovalPass removal_pass.h | |||
| /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which | |||
| /// nodes should be removed, and then removes them. | |||
| class RemovalPass : public TreePass { | |||
| public: | |||
| /// \brief Constructor | |||
| RemovalPass(); | |||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| /// \brief Adds an operator to the list of operators to be removed | |||
| /// \param[in] dataset_op The operator to add to the removal list | |||
| void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op); | |||
| private: | |||
| std::vector<std::shared_ptr<DatasetOp>> removal_nodes_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||