/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ #include #include "ir/pattern_matcher.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { namespace internal { class GetRefValueTransform { public: GetRefValueTransform() {} ~GetRefValueTransform() = default; AnfNodePtr operator()(const AnfNodePtr &node) { CNodePtr cnode = node->cast(); auto inputs = cnode->inputs(); auto fg = GetValueNode(inputs[0])->cast(); if (fg->recursive()) { MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString(); return node; } auto new_fg = TransformableClone(fg, std::make_shared("GetRefValue")); auto output = new_fg->output(); new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output})); inputs[0] = NewValueNode(new_fg); auto ret_node = cnode->func_graph()->NewCNode(inputs); return ret_node; } }; } // namespace internal // {prim::kPrimMakeRef, X, Y, Z} -> Y class MakeRefEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x, y, z; MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y); return nullptr; } }; // {prim::kPrimGetRefValue, Parameter} -> Parameter class GetRefParamEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x; MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); return nullptr; } }; // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y // {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f} class GetMakeRefEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x, y, z; MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node)); internal::GetRefValueTransform trans; auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr { auto rep = trans(x.GetNode(node)); if (rep != nullptr) { return rep; } return nullptr; }; MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node)); return nullptr; } }; // IsValueNode class ReplaceRefkeyByParam : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr { auto refkey = GetValueNode(node); auto resource = std::dynamic_pointer_cast(optimizer->resource()); MS_EXCEPTION_IF_NULL(resource); auto top_graph = resource->func_graph(); MS_EXCEPTION_IF_NULL(top_graph); for (const auto &tnode : top_graph->parameters()) { auto para = tnode->cast(); if (para != nullptr && para->name() == refkey->tag()) { return para; } } return nullptr; }; PatternNode x; MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode, node)); return nullptr; } }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REF_ELIMINATE_H_