/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/jit/parse/resolve.h" #include #include #include #include "ir/param_info.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/python_adapter.h" #include "utils/any.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/opt.h" #include "frontend/optimizer/irpass.h" namespace mindspore { namespace parse { abstract::AbstractBasePtr ClassObject::ToAbstract() { ClassPtr cls_ptr = ParseDataClass(obj()); auto abs_scalar = std::make_shared(); abs_scalar->set_type(std::make_shared()); abs_scalar->set_value(cls_ptr); AbstractBasePtrList args_spec_list = {abs_scalar}; auto func_ptr = std::make_shared(prim::kPrimMakeRecord); return std::make_shared(func_ptr, args_spec_list); } abstract::AbstractBasePtr ClassType::ToAbstract() { auto abs_scalar = std::make_shared(shared_from_base(), std::make_shared()); AbstractBasePtrList args_spec_list = {abs_scalar}; auto func_ptr = std::make_shared(prim::kPrimCreateInstance); auto ret_val = std::make_shared(func_ptr, args_spec_list); ret_val->set_value_desc(ToString()); return ret_val; } // call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace bool SymbolResolver::Resolve() { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object obj = namespace_->obj(); std::string symbol = symbol_->symbol(); if (py::isinstance(obj)) { MS_LOG(ERROR) << "Unresolved symbol: " << symbol; return false; } result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol)); return true; } namespace { // if any mixed precision flag add a cast node after the parameter node. // argument obj should be python Parameter object // it will be converted to Parameter node here AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { MS_EXCEPTION_IF_NULL(func_graph); // parameter object should not be none if (py::isinstance(obj)) { MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null."; } if (!py::hasattr(obj, "name")) { MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj"; } // get the parameter name from parameter object auto name_attr = python_adapter::GetPyObjAttr(obj, "name"); if (py::isinstance(name_attr)) { MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } auto param_name = py::cast(name_attr); auto top_graph = Parser::GetTopFuncGraph(); // if the parameter node has been created , return it AnfNodePtr para_node = nullptr; for (auto const ¶m : top_graph->parameters()) { auto param_node = dyn_cast(param); if (param_node != nullptr && param_node->name() == param_name) { para_node = param; break; } } if (para_node == nullptr) { auto node = top_graph->AddWeightParameter(param_name); auto value = py::cast(obj); node->set_default_param(value); // set_abstract for parameter auto abs = value->ToAbstract(); node->set_abstract(abs); para_node = node; } func_graph->add_parameter_obj_node(para_node); return para_node; } bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { AnfNodePtr output = nullptr; if (py::hasattr(obj, "__parameter__") && py::isinstance(obj)) { auto param = ResolveParameterObj(func_graph, obj); if (param == nullptr) { MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr"; return false; } MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString(); output = param; } else if (py::hasattr(obj, "__parameter_tuple__")) { auto tuple = obj.cast(); std::vector args; args.push_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t it = 0; it < tuple.size(); ++it) { AnfNodePtr out = nullptr; bool success = ResolveObjectToNode(func_graph, tuple[it], &out); if (!success) { MS_LOG(ERROR) << "Resolve object to node failed"; return false; } args.push_back(out); } output = NewCNode(args, func_graph); } else { ValuePtr convert_result = nullptr; bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve()); if (!converted) { MS_LOG(ERROR) << "Convert data failed"; return false; } MS_EXCEPTION_IF_NULL(convert_result); output = NewValueNode(convert_result); if (convert_result->isa()) { output = GetMixedPrecisionCastHelp(func_graph, output); } } *node = output; return true; } bool IsAllFuncInValueSequence(const std::vector &value_vec) { for (auto &elem : value_vec) { if (elem->isa() || elem->isa()) { const auto &vec = GetValue>(elem); auto is_graph = IsAllFuncInValueSequence(vec); if (!is_graph) { return false; } } else if (!elem->isa() && !elem->isa()) { return false; } } return true; } AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, const std::vector &value_vec) { std::vector nodes; nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); for (auto &elem : value_vec) { AnfNodePtr node = nullptr; if (elem->isa() || elem->isa()) { const auto &vec = GetValue>(elem); node = TransformToMakeTupleNodes(manager, func_graph, vec); } else if (elem->isa()) { FuncGraphPtr new_fg = elem->cast(); manager->AddFuncGraph(new_fg); node = NewValueNode(new_fg); } else if (elem->isa()) { node = NewValueNode(elem); } else { MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); } nodes.emplace_back(node); } auto cnode = func_graph->NewCNode(nodes); return cnode; } // transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, const ValueNodePtr &value_node, AnfNodePtr *const transformed) { MS_EXCEPTION_IF_NULL(value_node); const auto &value_vec = GetValue>(value_node->value()); if (!IsAllFuncInValueSequence(value_vec)) { return false; } // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it, // So if has graph in list, try to replace the node with make tuple of graph value node. // we do this because the graphmanger won't investigate the graph inside valuetuple, // change the vector of graph to be make_tuple of graph value node. // (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all // independent nodes. auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); // replace the ret ptr to be make tuple of graph value node *transformed = node_tuple_graphs; return true; } // resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &node) { ScopeGuard scope_guard(node->scope()); AnfNodePtr resolved_node = nullptr; TraceManager::DebugTrace(std::make_shared(node->debug_info())); bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node); if (!success) { MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } if (IsValueNode(resolved_node)) { auto new_fg = GetValueNode(resolved_node); manager->AddFuncGraph(new_fg); } // if the constant node is constant of vector of graph ,add graph to manager if (IsValueNode(resolved_node) || IsValueNode(resolved_node)) { (void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast(), &resolved_node); } TraceManager::EndTrace(); return resolved_node; } } // namespace AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) { if (node->func_graph() == nullptr || manager == nullptr) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; } SymbolResolver symbol_resolver(name_space, symbol, node); if (!symbol_resolver.Resolve()) { MS_EXCEPTION(TypeError) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } py::object obj = symbol_resolver.result(); AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node); return resolved_node; } AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) { if (node->func_graph() == nullptr || manager == nullptr) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; } SymbolResolver symbol_resolver(name_space, symbol, node); if (!symbol_resolver.Resolve()) { MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } py::object obj = symbol_resolver.result(); if (!data_converter::IsCellInstance(obj)) { return nullptr; } py::object obj_attr = obj.attr(attr.c_str()); AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node); return resolved_node; } namespace { opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap map({ {"resolve_attr", { // for resolve primitive; irpass.resolver_resolve_attr_, }}, {"resolve", { // for resolve and getattr primitive; irpass.resolver_resolve_, irpass.resolver_getattr_, }}, }); return map; } } // namespace bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { if (func_graph == nullptr || res == nullptr) { MS_LOG(ERROR) << "func_graph or resource is null"; return false; } opt::irpass::ResolveIRPassLib irpass; opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); (void)parse::python_adapter::set_python_scoped(); MS_EXCEPTION_IF_NULL(opt_resolve); (void)opt_resolve->step(func_graph, use_profile); return true; } bool ResolveAll(const FuncGraphManagerPtr &manager) { if (manager == nullptr) { MS_LOG(ERROR) << "func graph manager is null"; return false; } if (manager->roots().size() > 1) { MS_LOG(WARNING) << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs" "called from root graph, so it's not necessary to pass all graphs as roots. " "Please ensure your usage."; } // should not use pipeline::Resource as Resource::Clean will clean some // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop // fail as valid scope has been cleaned. auto res = std::make_shared(); res->set_manager(manager); auto roots = manager->roots(); for (auto &fg : roots) { bool ret = ResolveFuncGraph(fg, res, false); if (!ret) { MS_EXCEPTION_IF_NULL(fg); MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed"; } } return true; } } // namespace parse } // namespace mindspore