Merge pull request !2772 from Jamie/removalpasstags/v0.6.0-beta
| @@ -409,7 +409,7 @@ Status BatchOp::UnpackPadInfo(const PadInfo &pad_info, | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status BatchOp::Accept(NodePass *p, bool *modified) { | Status BatchOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<BatchOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -111,6 +111,51 @@ void DatasetOp::RemoveParent(const DatasetOp *parent) { | |||||
| parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); | parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); | ||||
| } | } | ||||
| // Removes this node from the tree and connects it's parent/child together | |||||
| Status DatasetOp::Remove() { | |||||
| if (parent_.size() > 1) { | |||||
| std::string err_msg("No support for op removal if the operator has more than one parent"); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| if (child_.size() > 1) { | |||||
| std::string err_msg("No support for op removal if the operator has more than one child"); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| // Scenario's when removing node B: | |||||
| // A -> B -> C | |||||
| // A -> B | |||||
| // B -> C | |||||
| // | |||||
| // If we remove B, then first take our child A and update it's parent to be C | |||||
| // It's possible the parent is null if we are the root node being removed. | |||||
| if (!child_.empty()) { | |||||
| // If we have a parent, then assign chlid's parent to point to our parent. | |||||
| if (!parent_.empty()) { | |||||
| child_[0]->parent_[0] = parent_[0]; | |||||
| } else { | |||||
| // We don't have a parent, so we are the root node being removed. | |||||
| // clear the parent list of our child so that it becomes the new root. | |||||
| child_[0]->parent_.clear(); | |||||
| tree_->AssignRoot(child_[0]); | |||||
| } | |||||
| } | |||||
| // Next, if we had a parent, then set it's child to be our child. | |||||
| if (!parent_.empty()) { | |||||
| // if we have a child, then set our parent to point to it | |||||
| if (!child_.empty()) { | |||||
| parent_[0]->child_[0] = child_[0]; | |||||
| } else { | |||||
| // We don't have a child, so clear the child list of the current | |||||
| // parent because it will be empty once we are removed. | |||||
| parent_[0]->child_.clear(); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Getter function to get a shared pointer to our childAdds a operator to become our child. | // Getter function to get a shared pointer to our childAdds a operator to become our child. | ||||
| std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const { | std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const { | ||||
| MS_ASSERT(child_index < static_cast<int>(child_.size())); | MS_ASSERT(child_index < static_cast<int>(child_.size())); | ||||
| @@ -289,6 +334,12 @@ Status DatasetOp::ComputeColMap() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DatasetOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // DatasetOp is the base class of visitor target pre-visit. | |||||
| // This method will only be called if its derived class does not implement one. | |||||
| return p->PreRunOnNode(shared_from_this(), modified); | |||||
| } | |||||
| Status DatasetOp::Accept(NodePass *p, bool *modified) { | Status DatasetOp::Accept(NodePass *p, bool *modified) { | ||||
| // DatasetOp is the base class of visitor target. | // DatasetOp is the base class of visitor target. | ||||
| // This method will only be called if its derived class does not implement one. | // This method will only be called if its derived class does not implement one. | ||||
| @@ -71,6 +71,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // @param child - shared pointer to the child to remove. | // @param child - shared pointer to the child to remove. | ||||
| Status RemoveChild(std::shared_ptr<DatasetOp> child); | Status RemoveChild(std::shared_ptr<DatasetOp> child); | ||||
| /// \brief Removes this node from the tree and connects it's parent/child together. | |||||
| /// \return Status eerror code returned | |||||
| Status Remove(); | |||||
| // Getter function to get a shared pointer to our child | // Getter function to get a shared pointer to our child | ||||
| // @param child_index - An operator can have n children. Indicates choose which child to return. | // @param child_index - An operator can have n children. Indicates choose which child to return. | ||||
| std::shared_ptr<DatasetOp> child(int32_t child_index) const; | std::shared_ptr<DatasetOp> child(int32_t child_index) const; | ||||
| @@ -264,10 +268,20 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // @return Vector of Children | // @return Vector of Children | ||||
| std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } | std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } | ||||
| // Base method for NodePass visit. | |||||
| // Subclass needs to override this if it requires special node visit access. | |||||
| // Check "dataset/engine/opt/pass.h" for more details. | |||||
| // @return Statue of the node visit | |||||
| /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up | |||||
| /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main | |||||
| /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it | |||||
| /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details. | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| virtual Status PreAccept(NodePass *p, bool *modified); | |||||
| /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access. | |||||
| /// Check "dataset/engine/opt/pass.h" for more details. | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| virtual Status Accept(NodePass *p, bool *modified); | virtual Status Accept(NodePass *p, bool *modified); | ||||
| // Op name getter | // Op name getter | ||||
| @@ -285,6 +299,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // Computes a CRC value for the operator | // Computes a CRC value for the operator | ||||
| static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op); | static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op); | ||||
| /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> | |||||
| /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr | |||||
| /// \return A shared_ptr casted to the derived class | |||||
| template <typename Derived> | |||||
| std::shared_ptr<Derived> shared_from_base() { | |||||
| return std::static_pointer_cast<Derived>(shared_from_this()); | |||||
| } | |||||
| protected: | protected: | ||||
| // Adds a parent operator to this operator | // Adds a parent operator to this operator | ||||
| // @notes External callers do not have access to this function. | // @notes External callers do not have access to this function. | ||||
| @@ -313,7 +313,7 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { | Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<DeviceQueueOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -261,7 +261,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status FilterOp::Accept(NodePass *p, bool *modified) { | Status FilterOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<FilterOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -367,7 +367,7 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status MapOp::Accept(NodePass *p, bool *modified) { | Status MapOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<MapOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -131,7 +131,7 @@ Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status ProjectOp::Accept(NodePass *p, bool *modified) { | Status ProjectOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<ProjectOp>(), modified); | |||||
| } | } | ||||
| // Compute the column map and save it into our own column name map | // Compute the column map and save it into our own column name map | ||||
| @@ -176,7 +176,7 @@ Status RenameOp::EoeReceived(int32_t) { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status RenameOp::Accept(NodePass *p, bool *modified) { | Status RenameOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<RenameOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -190,7 +190,7 @@ int32_t RepeatOp::num_producers() const { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status RepeatOp::Accept(NodePass *p, bool *modified) { | Status RepeatOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<RepeatOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -298,7 +298,7 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status ShuffleOp::Accept(NodePass *p, bool *modified) { | Status ShuffleOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<ShuffleOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -130,7 +130,7 @@ Status SkipOp::EofReceived(int32_t worker_id) { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status SkipOp::Accept(NodePass *p, bool *modified) { | Status SkipOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<SkipOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -249,7 +249,7 @@ Status GeneratorOp::Reset() { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status GeneratorOp::Accept(NodePass *p, bool *modified) { | Status GeneratorOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<GeneratorOp>(), modified); | |||||
| } | } | ||||
| Status GeneratorOp::ComputeColMap() { | Status GeneratorOp::ComputeColMap() { | ||||
| @@ -411,7 +411,7 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::se | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status ImageFolderOp::Accept(NodePass *p, bool *modified) { | Status ImageFolderOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<ImageFolderOp>(), modified); | |||||
| } | } | ||||
| Status ImageFolderOp::ComputeColMap() { | Status ImageFolderOp::ComputeColMap() { | ||||
| @@ -496,7 +496,7 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status MindRecordOp::Accept(NodePass *p, bool *modified) { | Status MindRecordOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<MindRecordOp>(), modified); | |||||
| } | } | ||||
| Status MindRecordOp::ComputeColMap() { | Status MindRecordOp::ComputeColMap() { | ||||
| @@ -1004,7 +1004,7 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status TFReaderOp::Accept(NodePass *p, bool *modified) { | Status TFReaderOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<TFReaderOp>(), modified); | |||||
| } | } | ||||
| Status TFReaderOp::ComputeColMap() { | Status TFReaderOp::ComputeColMap() { | ||||
| @@ -136,7 +136,7 @@ Status TakeOp::PrepareNodePostAction() { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status TakeOp::Accept(NodePass *p, bool *modified) { | Status TakeOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<TakeOp>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -237,7 +237,7 @@ Status ZipOp::EoeReceived(int32_t) { | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status ZipOp::Accept(NodePass *p, bool *modified) { | Status ZipOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified); | |||||
| return p->RunOnNode(shared_from_base<ZipOp>(), modified); | |||||
| } | } | ||||
| Status ZipOp::ComputeColMap() { | Status ZipOp::ComputeColMap() { | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "dataset/engine/datasetops/shuffle_op.h" | #include "dataset/engine/datasetops/shuffle_op.h" | ||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| #include "dataset/engine/opt/pass.h" | #include "dataset/engine/opt/pass.h" | ||||
| #include "dataset/engine/opt/pre/removal_pass.h" | |||||
| #include "dataset/engine/perf/profiling.h" | #include "dataset/engine/perf/profiling.h" | ||||
| #include "dataset/engine/perf/monitor.h" | #include "dataset/engine/perf/monitor.h" | ||||
| @@ -214,7 +215,8 @@ Status ExecutionTree::PrepareTreePreAction() { | |||||
| bool modified = false; | bool modified = false; | ||||
| std::vector<std::unique_ptr<Pass>> pre_actions; | std::vector<std::unique_ptr<Pass>> pre_actions; | ||||
| // Construct pre actions | // Construct pre actions | ||||
| // example: pre_actions.push_back(new SomePass()); | |||||
| MS_LOG(INFO) << "Running pre pass"; | |||||
| pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass())); | |||||
| // Apply pre action passes | // Apply pre action passes | ||||
| for (auto &pass : pre_actions) { | for (auto &pass : pre_actions) { | ||||
| RETURN_IF_NOT_OK(pass->Run(this, &modified)); | RETURN_IF_NOT_OK(pass->Run(this, &modified)); | ||||
| @@ -1,6 +1,8 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(engine-opt OBJECT | add_library(engine-opt OBJECT | ||||
| pass.cc | |||||
| util/printer_pass.cc | |||||
| pass.cc | |||||
| pre/removal_nodes.cc | |||||
| pre/removal_pass.cc | |||||
| util/printer_pass.cc | |||||
| ) | ) | ||||
| @@ -61,6 +61,7 @@ Status NodePass::Run(ExecutionTree *tree, bool *modified) { | |||||
| // Helper function to perform DFS visit | // Helper function to perform DFS visit | ||||
| Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) { | Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) { | ||||
| RETURN_IF_NOT_OK(node->PreAccept(this, modified)); | |||||
| for (const auto &c : node->Children()) { | for (const auto &c : node->Children()) { | ||||
| RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); | RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); | ||||
| } | } | ||||
| @@ -159,6 +160,5 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -66,14 +66,16 @@ class Pass : public std::enable_shared_from_this<Pass> { | |||||
| // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | ||||
| class TreePass : public Pass { | class TreePass : public Pass { | ||||
| public: | public: | ||||
| // Run the transformation pass against the execution tree. | |||||
| // @param tree - Pointer to the execution tree to be transformed. | |||||
| // @param modified - Pointer to the modified flag, | |||||
| /// \brief Run the transformation pass against the execution tree. | |||||
| /// \param[inout] tree Pointer to the execution tree to be transformed. | |||||
| /// \param[inout] modified Indicate if the tree was modified | |||||
| Status Run(ExecutionTree *tree, bool *modified) final; | Status Run(ExecutionTree *tree, bool *modified) final; | ||||
| // Derived classes may implement the runOnTree function to implement tree transformation. | |||||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||||
| // @return Status - The error code return | |||||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||||
| /// \param[inout] tree The tree to operate on. | |||||
| /// \param[inout] Indicate of the tree was modified. | |||||
| /// \return Status The error code return | |||||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | ||||
| }; | }; | ||||
| @@ -90,14 +92,23 @@ class NodePass : public Pass { | |||||
| ~NodePass() = default; | ~NodePass() = default; | ||||
| // Run the transformation pass against the execution tree. | |||||
| // @param tree - Pointer to the execution tree to be transformed. | |||||
| // @param modified - Pointer to the modified flag, | |||||
| /// \brief Run the transformation pass against the execution tree | |||||
| /// \param[inout] tree Pointer to the execution tree to be transformed | |||||
| /// \param[inout] modified Indicator if the tree was changed | |||||
| Status Run(ExecutionTree *tree, bool *modified) final; | Status Run(ExecutionTree *tree, bool *modified) final; | ||||
| // Derived classes may implement the runOnNode function to implement node level tree transformation. | |||||
| // "modified" flag needs to be set to true if tree is modified during the pass execution. | |||||
| // @return Status - The error code return | |||||
| /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down | |||||
| /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[out] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | |||||
| /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation | |||||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[out] modified Indicator if the node was changed at all. | |||||
| /// \return Status The error code return | |||||
| virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | ||||
| // Visit methods to be overridden. | // Visit methods to be overridden. | ||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "dataset/engine/opt/pre/removal_nodes.h" | |||||
| #include "dataset/engine/opt/pre/removal_pass.h" | |||||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} | |||||
| // Perform ShuffleOp removal check. | |||||
| Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||||
| *modified = false; | |||||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||||
| if (is_caching_) { | |||||
| MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||||
| if (removal_pass_) { | |||||
| removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } else { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||||
| #define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ | |||||
| #include <memory> | |||||
| #include "dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class RemovalPass; | |||||
| /// \class RemovalNodes removal_nodes.h | |||||
| /// \brief This is a NodePass who's job is to identify which nodes should be removed. | |||||
| /// It works in conjunction with the removal_pass. | |||||
| class RemovalNodes : public NodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||||
| explicit RemovalNodes(RemovalPass *removal_pass); | |||||
| /// \brief Perform ShuffleOp removal check | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||||
| private: | |||||
| bool is_caching_; | |||||
| RemovalPass *removal_pass_; // Back pointer to the owning removal pass | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "dataset/engine/opt/pre/removal_nodes.h" | |||||
| #include "dataset/engine/opt/pre/removal_pass.h" | |||||
| #include "dataset/engine/execution_tree.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // constructor | |||||
| RemovalPass::RemovalPass() {} | |||||
| // Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||||
| Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||||
| // Create the removal node pass which can identify which nodes need to be removed. | |||||
| std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this); | |||||
| RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); | |||||
| // Then, execute the removal of any nodes that were set up for removal | |||||
| for (auto node : removal_nodes_) { | |||||
| node->Remove(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Adds an operator to the list of operators to be removed | |||||
| void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||||
| #define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class DatasetOp; | |||||
| /// \class RemovalPass removal_pass.h | |||||
| /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which | |||||
| /// nodes should be removed, and then removes them. | |||||
| class RemovalPass : public TreePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| RemovalPass(); | |||||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||||
| /// \param[inout] tree The tree to operate on. | |||||
| /// \param[inout] Indicate of the tree was modified. | |||||
| /// \return Status The error code return | |||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||||
| /// \brief Adds an operator to the list of operators to be removed | |||||
| /// \param[in] dataset_op The operator to add to the removal list | |||||
| void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op); | |||||
| private: | |||||
| std::vector<std::shared_ptr<DatasetOp>> removal_nodes_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ | |||||