/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ #include #include #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "frontend/optimizer/optimizer_caller.h" #include "ir/pattern_matcher.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimSwitch, true, X, Y} // {prim::kPrimSwitch, false, X, Y} class SwitchSimplify : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br; auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { auto cond_value_ = GetValue(GetValueNode(cond.GetNode(node))); if (cond_value_) { return true_br.GetNode(node); } return false_br.GetNode(node); }; MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, cond.CheckFunc(IsValueNode, node)); return nullptr; } }; // {prim::kPrimTupleGetItem, {prim::kPrimSwitch, X0, X1, X2}, C} => // {prim::kPrimSwitch, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} class FloatTupleGetItemSwitch : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br, x; MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), PPrimitive(prim::kPrimTupleGetItem, false_br, x)), x.CheckFunc(IsVNode, node)); return nullptr; } }; // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} class FloatEnvGetItemSwitch : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br, x, x2; MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); return nullptr; } }; namespace internal { FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, const AbstractBasePtr &true_graph_output_abs, const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, const FuncGraphPtr &func_graph); } // namespace internal // {{prim::kPrimSwitch, X, G1, G2}, Xs} class ConvertSwitchReplacement : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } PatternNode cond, true_br, false_br; auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { auto g1_ = GetValueNode(true_br.GetNode(node)); auto g2_ = GetValueNode(false_br.GetNode(node)); auto x_ = cond.GetNode(node); // for switch replace method, only graphs without graph inside can be replaced for (auto &item : g1_->value_nodes()) { auto value_node = item.first; if (IsValueNode(value_node)) { return nullptr; } } for (auto &item : g2_->value_nodes()) { auto value_node = item.first; if (IsValueNode(value_node)) { return nullptr; } } auto true_output = g1_->output()->abstract(); auto false_output = g2_->output()->abstract(); auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); std::vector params; auto fg = node->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); return nnode; }; MATCH_REPLACE_LAMBDA_IF( node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda, true_br.CheckFunc(IsValueNode, node) && false_br.CheckFunc(IsValueNode, node)); return nullptr; } }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BRANCH_CULLING_H_