/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/static_analysis/program_specialize.h" #include #include #include "./common.h" #include "operator/ops.h" #include "operator/composite/do_signature.h" #include "pipeline/static_analysis/abstract_function.h" #include "utils/graph_utils.h" #include "utils/log_adapter.h" #include "utils/profile.h" #include "debug/trace.h" namespace mindspore { namespace abstract { namespace { inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { if (conf->node()->intermediate_abstract()) { return conf->node()->intermediate_abstract(); } return conf->GetEvaluatedValue()->abstract(); } AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { AnfNodePtr value_node = NewValueNode(v); value_node->set_abstract(abs_base); MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString(); return value_node; } bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) { while (fg != nullptr && fg != parent) { fg = fg->parent(); } return fg == parent; } } // namespace FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(context); MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString(); return SpecializeFuncGraph(fg, context); } FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(context); auto iter = specializations_.find(context->SpecializeKey()); if (iter != specializations_.end()) { return iter->second->specialized_func_graph(); } std::shared_ptr fg_spec = std::make_shared(this, fg, context); FuncGraphPtr fg2 = fg_spec->specialized_func_graph(); specializations_[context->SpecializeKey()] = fg_spec; fg_spec->Run(); return fg2; } std::shared_ptr ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) { MS_EXCEPTION_IF_NULL(context); auto iter = specializations_.find(context->SpecializeKey()); if (iter != specializations_.end()) { return iter->second; } return nullptr; } std::string GetNextCounter() { static int g_CloneCounter = 1; std::string str_count = std::to_string(g_CloneCounter); g_CloneCounter++; return str_count; } FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context) : specializer_(s), func_graph_(fg), context_(context) { parent_ = s->GetFuncGraphSpecializer(context->parent()); engine_ = s->engine(); cloner_ = SpecializerClone(fg, std::make_shared(GetNextCounter())); repl_node_ = cloner_->cloned_node(); specialized_func_graph_ = cloner_->cloned_func_graph()[fg]; todo_.push_back(fg->get_return()); auto ps = fg->parameters(); (void)todo_.insert(todo_.end(), ps.begin(), ps.end()); } AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr fg = node->func_graph(); if (node->isa()) { return node; } std::shared_ptr specializer = shared_from_this(); while (fg != nullptr && fg != specializer->func_graph_) { specializer = specializer->parent_; } // If had replicated, just return that. auto iter = specializer->repl_node_->find(node); if (iter != specializer->repl_node_->end()) { return iter->second; } auto new_node = specializer->cloner_->CloneDisconnected(node); if (node->isa()) { if (!new_node->isa()) { MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; } auto c_node = node->cast(); MS_EXCEPTION_IF_NULL(c_node); auto inputs = c_node->inputs(); std::vector new_inputs; (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr { if (inp->isa()) { return inp; } return ReplicateDisconnectedNode(inp); }); auto c_new_node = new_node->cast(); MS_EXCEPTION_IF_NULL(c_new_node); c_new_node->set_inputs(new_inputs); } iter = specializer->repl_node_->find(node); if (iter != specializer->repl_node_->end()) { if (iter->second == node) { MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString(); } } else { MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString(); } return new_node; } AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr fg = node->func_graph(); std::shared_ptr specializer = shared_from_this(); while (fg != nullptr && fg != specializer->func_graph_) { specializer = specializer->parent_; } MS_EXCEPTION_IF_NULL(specializer->repl_node_); auto iter = specializer->repl_node_->find(node); if (iter != specializer->repl_node_->end()) { return iter->second; } return node; } void FuncGraphSpecializer::Run() { MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString() << ", cloned func graph name: " << specialized_func_graph_->ToString() << ", func graph: " << func_graph_->get_return()->DebugString(); FirstPass(); SecondPass(); MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString() << ", cloned func graph name: " << specialized_func_graph_->ToString() << ", new func graph: " << specialized_func_graph_->get_return()->DebugString(); } void FuncGraphSpecializer::FirstPass() { while (todo_.size()) { AnfNodePtr node = todo_.back(); todo_.pop_back(); if (node->func_graph() == nullptr) { // do nothing for ValueNode continue; } if (node->func_graph() != func_graph_) { if (parent_ == nullptr) { MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } parent_->AddTodoItem(node); parent_->FirstPass(); AnfNodePtr new_node = parent_->GetReplicatedNode(node); if (node->isa()) { parent_->ProcessCNode(new_node->cast()); } continue; } if (marked_.count(node) > 0) { continue; } (void)marked_.insert(node); ProcessNode(node); } } // Specialize CNode in func graphs void FuncGraphSpecializer::SecondPass() { for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { if (node->isa()) { ProcessCNode(node->cast()); } } } void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); ScopeGuard scope_guard(node->scope()); AnfNodeConfigPtr conf = MakeConfig(node); AnfNodePtr new_node = GetReplicatedNode(node); MS_EXCEPTION_IF_NULL(new_node); if (new_node->func_graph() != specialized_func_graph_) { MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() << ", new_node: " << new_node->DebugString() << ", new_node->func_graph(): " << new_node->func_graph()->ToString() << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); return; } new_node->set_abstract(GetEvaluatedValueWrap(conf)); if (new_node->isa() && new_node->abstract()->isa()) { auto partial_abstract = dyn_cast(new_node->abstract()); if (partial_abstract->node() == node) { partial_abstract->set_node(new_node); } } MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); if (node->isa()) { auto attrs = conf->GetEvaluatedValue()->attribute(); auto c_old = node->cast(); auto c_new = new_node->cast(); auto new_inputs = c_new->inputs(); auto old_inputs = c_old->inputs(); for (size_t i = 0; i < old_inputs.size(); ++i) { auto node_input = old_inputs[i]; AnfNodeConfigPtr iconf = MakeConfig(node_input); AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); if (replace_node == nullptr) { replace_node = BuildReplacedNode(iconf); MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_abstract(ival); MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString(); } else { MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString(); } if (new_inputs[i] != replace_node) { new_inputs[i] = replace_node; MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); } } c_new->set_inputs(new_inputs); } } AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); auto conf_iter = engine_->anfnode_config_map().find(conf); AnfNodeConfigPtr new_conf = conf; while (conf_iter != engine_->anfnode_config_map().end()) { MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node(" << new_conf->node()->DebugString() << ")"; new_conf = conf_iter->second; MS_EXCEPTION_IF_NULL(new_conf); MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node(" << conf->node()->DebugString() << ")"; (void)ReplicateDisconnectedNode(new_conf->node()); conf_iter = engine_->anfnode_config_map().find(new_conf); } todo_.push_back(new_conf->node()); auto repl = GetReplicatedNode(new_conf->node()); if (repl->func_graph()) { MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString() << ") to replace origin:" << new_conf->node()->DebugString(); } else { MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString() << ") to replace origin: " << new_conf->node()->DebugString(); } return repl; } namespace { const StringImmPtr kDeadNode = std::make_shared("Dead Node"); const StringImmPtr kPolyNode = std::make_shared("Poly Node"); inline bool CanSpecializeNode(const AnfNodePtr &node) { if (IsValueNode(node) || IsValueNode(node) || IsValueNode(node)) { return true; } return false; } } // namespace AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, const AbstractBasePtrList &argvals) { MS_EXCEPTION_IF_NULL(abs); AbstractFunctionPtr real_a = dyn_cast(abs); MS_EXCEPTION_IF_NULL(real_a); AbstractFunctionPtr func = real_a->GetUnique(); SpecializeStatusCode errcode; ScopeGuard scope_guard(node->scope()); AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode); if (repl == nullptr) { if (errcode == kSpecializeFindUniqueArgvalDead) { const auto error_dead_node = std::make_shared(kDeadNode, node); repl = BuildValueNode(kDeadNode, error_dead_node); MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString(); } else if (errcode == kSpecializeFindUniqueArgvalPoly) { const auto error_poly_node = std::make_shared(kPolyNode, node); repl = BuildValueNode(kPolyNode, error_poly_node); MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString(); } else { MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString() << ", abstract: " << abs->ToString(); } } return repl; } AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, const AbstractBasePtrList &args, SpecializeStatusCode *errcode) { MS_EXCEPTION_IF_NULL(abs); MS_EXCEPTION_IF_NULL(func); MS_EXCEPTION_IF_NULL(errcode); *errcode = kSpecializeSuccess; auto real_func = dyn_cast(func); if (real_func != nullptr) { return BuildValueNode(real_func->prim(), abs); } EvaluatorPtr eval; eval = engine_->GetEvaluatorFor(func); MS_EXCEPTION_IF_NULL(eval); AbstractBasePtrList argvals = eval->NormalizeArgs(args); std::pair result; SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result); if (status != kSpecializeSuccess) { *errcode = status; return nullptr; } argvals = result.first; AbstractBasePtr unique_output = result.second; auto prim_func = dyn_cast(func); if (prim_func != nullptr) { auto type_func = std::make_shared(prim_func->prim(), argvals, unique_output); return BuildValueNode(prim_func->prim(), type_func); } if (!eval->isa()) { MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString(); } auto real_eval = dyn_cast(eval); if (func->context() == nullptr) { MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); } AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString(); FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); return BuildValueNode(v, abs); } AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { auto new_inputs = new_node->inputs(); AnfNodePtr func = new_inputs[0]; AbstractBasePtr fnval = new_inputs[0]->abstract(); AbstractBasePtrList args; auto backed_fnval = fnval; if (fnval->isa()) { auto partial_closure = dyn_cast(fnval); backed_fnval = partial_closure->fn(); args = partial_closure->args(); } std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), [](const AnfNodePtr &inp) { return inp->abstract(); }); ScopeGuard scope_guard(new_node->scope()); auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); auto wrapped_node = specialized_node; if (fnval->isa()) { auto partial_closure = dyn_cast(fnval); AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), specialized_node}; auto anf_node = partial_closure->node(); if (!anf_node->isa()) { MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); } auto cnode = anf_node->cast(); if (cnode->size() != partial_closure->args().size() + 2) { MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); } auto attrs = std::make_shared(); for (size_t i = 0; i < partial_closure->args().size(); i++) { auto old_node = cnode->input(i + 2); auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); if (possibile_value_node != nullptr) { partial_node_list.push_back(possibile_value_node); } else { if (!(old_node->isa() || old_node->isa())) { MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); } partial_node_list.push_back(old_node); } } wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); wrapped_node->set_abstract(partial_closure); } return wrapped_node; } const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { auto cache_iter = evalcaches_.find(eval); if (cache_iter == evalcaches_.end()) { evalcaches_[eval] = eval->cache(); return eval->cache(); } return cache_iter->second; } std::pair FuncGraphSpecializer::BuildFromBroadedArgsVal( const EvaluatorPtr &eval) { MS_EXCEPTION_IF_NULL(eval); std::unordered_set choices; EvalResultPtr ret = nullptr; AbstractBasePtrList broaded_argvals; for (auto &argvals_map : *evalcaches_[eval]) { auto argvals = argvals_map.first; broaded_argvals.clear(); (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals), [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); (void)choices.insert(broaded_argvals); MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals); } if (1 == choices.size()) { ConfigPtrList args_conf_list; (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list), [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared(v); }); // if broaden return null ret = eval->Run(engine_, args_conf_list, nullptr); EvaluatorCacheMapPtr real = std::make_shared(); (*real)[broaded_argvals] = ret; evalcaches_[eval] = real; return std::make_pair(broaded_argvals, ret->abstract()); } else { MS_LOG(DEBUG) << "Choices.size: " << choices.size(); return std::make_pair(AbstractBasePtrList(), nullptr); } } void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { MS_EXCEPTION_IF_NULL(new_node); if (specializer_->seen().count(new_node) > 0) { return; } specializer_->AddSeen(new_node); auto new_inputs = new_node->inputs(); if (new_inputs.empty()) { MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; } AnfNodePtr func = new_inputs[0]; MS_EXCEPTION_IF_NULL(func); // First element is func so arg start from 1 std::vector args(new_inputs.begin() + 1, new_inputs.end()); // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...) while (IsPrimitiveCNode(func, prim::kPrimPartial)) { std::vector inputs = func->cast()->inputs(); // First element is partial, second is func so arg is start from 2 (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); func = inputs[1]; } new_inputs = args; (void)new_inputs.insert(new_inputs.begin(), func); AbstractBasePtrList argvals; MS_EXCEPTION_IF_NULL(new_inputs[0]); AbstractBasePtr fnval = new_inputs[0]->abstract(); MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", " << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString(); // First element is func so function arguments start from 1 for (size_t i = 1; i < new_inputs.size(); ++i) { argvals.push_back(new_inputs[i]->abstract()); MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", " << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); } if (!func->isa()) { MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString(); if (func->abstract()->isa() && !func->abstract()->isa()) { auto func_abs = func->abstract()->cast(); EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs); std::pair result; AbstractBasePtrList empty_args; auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result); MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early if (status == kSpecializeFindUniqueArgvalPoly || (func->isa() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || func->abstract()->isa()))) { auto wrapped_node = BuildSpecializedParameterNode(new_node); new_inputs[0] = wrapped_node; } } } if (CanSpecializeNode(func)) { // for primitive node , we build the primitive node with infered attributes in the first pass // so we do not build replaced node again here in second pass if (IsValueNode(func)) { new_inputs[0] = func; } else { new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); } } for (size_t i = 0; i < argvals.size();) { size_t next = i + 1; if (CanSpecializeNode(args[i])) { new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector{}); } i = next; } new_node->set_inputs(new_inputs); } namespace { void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) { MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items."; int i = 0; for (const auto &item : evaluator_cache_map) { MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first; } } bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) { if (func->isa() && argvals.empty()) { MS_LOG(DEBUG) << "High order primitive return POLY."; return true; } if (func->isa() && argvals.empty()) { auto meta_func_graph_wrapper = dyn_cast(func); auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph(); if (meta_func_graph != nullptr && meta_func_graph->isa()) { auto do_signature = dyn_cast(meta_func_graph); if (do_signature != nullptr && do_signature->function()->isa()) { MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY."; return true; } } } return false; } } // end anonymous namespace SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval, const AbstractBasePtrList &argvals, std::pair *result) { MS_EXCEPTION_IF_NULL(func); MS_EXCEPTION_IF_NULL(eval); MS_EXCEPTION_IF_NULL(result); EvaluatorCacheMap evaluator_cache_map = *eval->cache(); if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); return kSpecializeSuccess; } DumpEvaluatorCache(evaluator_cache_map, argvals); const EvaluatorCacheMapPtr &choices = GetEvalCache(eval); MS_EXCEPTION_IF_NULL(choices); if (choices->count(argvals)) { *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); return kSpecializeSuccess; } else if (choices->size() == 1) { MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); return kSpecializeSuccess; } else if (choices->empty()) { MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; return kSpecializeFindUniqueArgvalDead; } else { if (IsPolyFunc(func, argvals)) { return kSpecializeFindUniqueArgvalPoly; } MS_LOG(DEBUG) << "Try to find generalized argvals."; *result = BuildFromBroadedArgsVal(eval); if (!result->first.empty()) { return kSpecializeSuccess; } MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism."; return kSpecializeFindUniqueArgvalPoly; } } static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { auto &prim_attrs = prim->attrs(); bool is_attr_same = true; for (auto &item : *attrs) { auto itr = prim_attrs.find(item.first); if (itr != prim_attrs.end()) { if (!(*(itr->second) == *(item.second))) { is_attr_same = false; break; } } else { is_attr_same = false; break; } } if (!is_attr_same) { if (prim->isa()) { PrimitivePyPtr prim_py = prim->cast(); auto clone_fn = prim_py->GetPyObj().attr("_clone"); py::object new_obj = clone_fn(); auto cloned_prim = new_obj.cast(); for (auto &item : *attrs) { cloned_prim->AddAttr(item.first, item.second); } return cloned_prim; } auto cloned_prim = std::make_shared(*prim); for (auto &item : *attrs) { cloned_prim->AddAttr(item.first, item.second); } return cloned_prim; } return prim; } AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, const AttrValueMapPtr &attrs) { MS_EXCEPTION_IF_NULL(origin_node); MS_EXCEPTION_IF_NULL(ival); AbstractFunctionPtr abs = dyn_cast(ival); if (abs != nullptr) { // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction. if (abs->isa()) { return nullptr; } ValuePtr value = nullptr; if (abs->isa()) { auto real_fn = dyn_cast(abs); // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one if (attrs != nullptr) { value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); } else { value = real_fn->prim(); } } else if (abs->isa()) { auto real_fn = dyn_cast(abs); value = real_fn->meta_func_graph(); } else if (abs->isa()) { auto real_fn = dyn_cast(abs); value = real_fn->func_graph(); } else { return nullptr; } if (!value->isa() || value->cast()->parent() == nullptr || (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast()->parent()))) { return BuildValueNode(value, ival); } else { return nullptr; } } else { ValuePtr val = ival->BuildValue(); if (val->isa()) { return nullptr; } // keep primitive 'depend' not to be optimized if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) { return nullptr; } return BuildValueNode(val, ival); } } AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) { return engine_->MakeConfig(node, context_); } } // namespace abstract } // namespace mindspore