/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ #include #include #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer_caller.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "ir/pattern_matcher.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/parse_base.h" namespace mindspore { namespace opt { namespace irpass { const char PARSE_SUPER_NAME[] = "namespace"; // {prim::kPrimResolve, Ns, Sym} class ResolverResolve : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimResolve, {IsVNode, IsVNode})(node); if (sym_ != nullptr) { return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); } return nullptr; } void Visit(const ValueNodePtr &vnode) override { if (IsValueNode(vnode)) { ns_ = GetValueNode(vnode); } else if (ns_ != nullptr && IsValueNode(vnode)) { sym_ = GetValueNode(vnode); } } void Reset() { ns_ = nullptr; sym_ = nullptr; } private: parse::NameSpacePtr ns_{nullptr}; parse::SymbolPtr sym_{nullptr}; }; // {prim::kPrimGetAttr, Ns, Str} class ResolverGetattr : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimGetAttr, {IsVNode, IsVNode})(node); if (sym_ != nullptr) { return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); } return nullptr; } void Visit(const AnfNodePtr &node) override { if (IsValueNode(node)) { ns_ = GetValueNode(node); } else if (ns_ != nullptr && IsValueNode(node)) { auto str = GetValue(GetValueNode(node)); sym_ = std::make_shared(str); } } void Reset() { ns_ = nullptr; sym_ = nullptr; } private: parse::NameSpacePtr ns_{nullptr}; parse::SymbolPtr sym_{nullptr}; }; // {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} class ResolveAttr : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { PatternNode ns_node, sym_node, attr_node; auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr { auto node_to_getattr = node->cast()->input(1); std::string attr_as_string = GetValueNode(attr_node.GetNode(node))->value(); auto ns_ = GetValueNode(ns_node.GetNode(node)); auto sym_ = GetValueNode(sym_node.GetNode(node)); if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) { // deal with the case of getting attr from a class member // and avoid the case of getting attr from self (the result of ParseSuper) auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string); return result; } return nullptr; }; MATCH_REPLACE_LAMBDA_IF( node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node), ResolveAttrLambda, attr_node.CheckFunc(IsValueNode, node)); return nullptr; } }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_