diff --git a/mindspore/ccsrc/ir/optimizer_caller.h b/mindspore/ccsrc/ir/optimizer_caller.h new file mode 100644 index 0000000000..bd30454147 --- /dev/null +++ b/mindspore/ccsrc/ir/optimizer_caller.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ +#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ + +#include "ir/anf.h" +#include "optimizer/opt.h" + +namespace mindspore { +class OptimizerCaller { + public: + virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ diff --git a/mindspore/ccsrc/ir/visitor.cc b/mindspore/ccsrc/ir/visitor.cc index efebe3124a..9e63f4f9c1 100644 --- a/mindspore/ccsrc/ir/visitor.cc +++ b/mindspore/ccsrc/ir/visitor.cc @@ -14,11 +14,10 @@ * limitations under the License. */ -#include "ir/visitor.h" #include "ir/func_graph.h" +#include "ir/visitor.h" namespace mindspore { -AnfNodePtr AnfVisitor::operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } void AnfVisitor::Visit(const AnfNodePtr &node) { node->accept(this); } void AnfVisitor::Visit(const CNodePtr &cnode) { diff --git a/mindspore/ccsrc/ir/visitor.h b/mindspore/ccsrc/ir/visitor.h index e771f7ad28..6dcf28249a 100644 --- a/mindspore/ccsrc/ir/visitor.h +++ b/mindspore/ccsrc/ir/visitor.h @@ -18,14 +18,12 @@ #define MINDSPORE_CCSRC_IR_VISITOR_H_ #include -#include "ir/anf.h" -#include "optimizer/opt.h" +#include "ir/optimizer_caller.h" namespace mindspore { using VisitFuncType = std::function; -class AnfVisitor { +class AnfVisitor : public OptimizerCaller { public: - virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &); virtual void Visit(const AnfNodePtr &); virtual void Visit(const CNodePtr &); virtual void Visit(const ValueNodePtr &); diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/optimizer/irpass/branch_culling.h index 736f67b5dd..bb5e021886 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.h @@ -20,22 +20,21 @@ #include #include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" -#include "operator/ops.h" +#include "ir/optimizer_caller.h" #include "ir/pattern_matcher.h" +#include "operator/ops.h" +#include "optimizer/irpass.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimSwitch, true, X, Y} // {prim::kPrimSwitch, false, X, Y} -class SwitchSimplify { +class SwitchSimplify : public OptimizerCaller { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br; auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { auto cond_value_ = GetValue(GetValueNode(cond.GetNode(node))); @@ -54,9 +53,9 @@ class SwitchSimplify { // {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => // {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} -class FloatTupleGetItemSwitch { +class FloatTupleGetItemSwitch : public OptimizerCaller { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br, x; MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), @@ -69,9 +68,9 @@ class FloatTupleGetItemSwitch { // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} -class FloatEnvGetItemSwitch { +class FloatEnvGetItemSwitch : public OptimizerCaller { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br, x, x2; MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), @@ -93,9 +92,9 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN } // namespace internal // {{prim::kPrimSwitch, X, G1, G2}, Xs} -class ConvertSwitchReplacement { +class ConvertSwitchReplacement : public OptimizerCaller { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; }