/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ #include #include #include "optimizer/optimizer.h" #include "optimizer/irpass.h" #include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "operator/ops.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimSwitch, true, X, Y} // {prim::kPrimSwitch, false, X, Y} class SwitchSimplify : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); auto getx = [this](const AnfNodePtr &node) -> bool { this->x_ = node; return true; }; auto gety = [this](const AnfNodePtr &node) -> bool { this->y_ = node; return true; }; AnfVisitor::Match(prim::kPrimSwitch, {IsValueNode, getx, gety})(node); // simplify the switch if (is_match_) { if (cond_) { return x_; } return y_; } return nullptr; } void Visit(const AnfNodePtr &node) override { if (!is_match_ && IsValueNode(node)) { cond_ = GetValue(GetValueNode(node)); is_match_ = true; } } void Reset() { x_ = nullptr; y_ = nullptr; cond_ = false; is_match_ = false; } private: bool is_match_{false}, cond_{false}; AnfNodePtr x_{nullptr}, y_{nullptr}; }; // {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => // {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} class FloatTupleGetItemSwitch : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); auto fg = node->func_graph(); if (Xs_.empty() || c_ == nullptr || fg == nullptr) { return nullptr; } auto true_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[1], c_}); auto false_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[2], c_}); return fg->NewCNode({NewValueNode(prim::kPrimSwitch), Xs_[0], true_node, false_node}); } void Visit(const CNodePtr &cnode) override { // {prim::kPrimSwith, X1, X2, X3} if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch) || cnode->size() != 4) { return; } // copy X1, X2, X3 auto &inputs = cnode->inputs(); (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); } void Visit(const ValueNodePtr &vnode) override { c_ = vnode; } void Reset() { Xs_.clear(); c_ = nullptr; } private: AnfNodePtr c_{nullptr}; std::vector Xs_{}; }; // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} class FloatEnvGetItemSwitch : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsNode, IsNode})(node); if (!is_match_) { return nullptr; } // {prim::kPrimEnvGetItem, {...}, X4, X5} auto cnode = node->cast(); auto sw_node = cnode->input(1)->cast(); auto x4 = cnode->input(2); auto x5 = cnode->input(3); is_match_ = false; AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsNode, IsNode})(sw_node); if (!is_match_) { return nullptr; } // {prim::kPrimSwitch, X1, X2, X3} auto x1 = sw_node->input(1); auto x2 = sw_node->input(2); auto x3 = sw_node->input(3); auto fg = node->func_graph(); if (fg == nullptr) { return nullptr; } auto true_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x2, x4, x5}); auto false_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x3, x4, x5}); return fg->NewCNode({NewValueNode(prim::kPrimSwitch), x1, true_node, false_node}); } void Visit(const AnfNodePtr &) override { is_match_ = true; } private: bool is_match_{false}; }; namespace internal { FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, const AbstractBasePtr &true_graph_output_abs, const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond); } // namespace internal // {{prim::kPrimSwitch, X, G1, G2}, Xs} class ConvertSwitchReplacement : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } Reset(); auto cnode = node->cast(); if (cnode->size() < 1) { return nullptr; } // {prim::kPrimSwitch, X, G1, G2} AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(cnode->input(0)); if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) { return nullptr; } auto true_output = g1_->output()->abstract(); auto false_output = g2_->output()->abstract(); auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); std::vector params; auto fg = node->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); return internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_); } void Visit(const AnfNodePtr &node) override { if (x_ == nullptr) { x_ = node; return; } AnfVisitor::Visit(node); } void Visit(const ValueNodePtr &vnode) override { auto g = GetValueNode(vnode); if (g1_ == nullptr) { g1_ = g; } else { g2_ = g; } } void Reset() { x_ = nullptr; g1_ = nullptr; g2_ = nullptr; } private: AnfNodePtr x_{nullptr}; FuncGraphPtr g1_{nullptr}, g2_{nullptr}; }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_