/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ #include #include "optimizer/optimizer.h" #include "optimizer/irpass.h" #include "ir/visitor.h" #include "operator/ops.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimMakeRef, X, Y, Z} -> Y class MakeRefEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { y_ = nullptr; auto gety = [this](const AnfNodePtr &node) -> bool { this->y_ = node; return true; }; AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); return y_; } void Visit(const AnfNodePtr &) override {} private: AnfNodePtr y_{nullptr}; }; // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y class GetMakeRefEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { auto cnode = node->cast(); if (cnode == nullptr || cnode->size() != 2) { return nullptr; } // {prim::kPrimGetRefKey/Value, {...}} auto ref = cnode->input(1)->cast(); if (ref == nullptr || !ref->IsApply(prim::kPrimMakeRef) || ref->size() != 4) { return nullptr; } // {prim::kPrimMakeRef, X, Y, Z} if (cnode->IsApply(prim::kPrimGetRefKey)) { return ref->input(1); } if (cnode->IsApply(prim::kPrimGetRefValue)) { return ref->input(2); } return nullptr; } }; // IsValueNode class ReplaceRefkeyByParam : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { if (!IsValueNode(node)) { return nullptr; } auto refkey = GetValueNode(node); auto resource = std::dynamic_pointer_cast(optimizer->resource()); MS_EXCEPTION_IF_NULL(resource); auto top_graph = resource->func_graph(); MS_EXCEPTION_IF_NULL(top_graph); for (const auto &tnode : top_graph->parameters()) { auto para = tnode->cast(); if (para != nullptr && para->name() == refkey->tag()) { return para; } } return nullptr; } }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_