|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607 |
- /**
- * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
- *
- * Copyright 2019 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "pipeline/static_analysis/program_specialize.h"
-
- #include <algorithm>
- #include <exception>
- #include "./common.h"
- #include "operator/ops.h"
- #include "operator/composite/do_signature.h"
- #include "utils/graph_utils.h"
- #include "utils/profile.h"
- #include "debug/trace.h"
-
- namespace mindspore {
- namespace abstract {
- namespace {
- inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
- if (conf->node()->intermediate_abstract()) {
- return conf->node()->intermediate_abstract();
- }
- return conf->GetEvaluatedValue();
- }
-
- AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
- AnfNodePtr value_node = NewValueNode(v);
- value_node->set_abstract(abs_base);
- MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
- return value_node;
- }
-
- bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
- while (fg != nullptr && fg != parent) {
- fg = fg->parent();
- }
- return fg == parent;
- }
- } // namespace
-
- FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
- MS_EXCEPTION_IF_NULL(fg);
- MS_EXCEPTION_IF_NULL(context);
- MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
- return SpecializeFuncGraph(fg, context);
- }
-
- FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
- MS_EXCEPTION_IF_NULL(fg);
- MS_EXCEPTION_IF_NULL(context);
- auto iter = specializations_.find(context->SpecializeKey());
- if (iter != specializations_.end()) {
- return iter->second->specialized_func_graph();
- }
-
- std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
- FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
- specializations_[context->SpecializeKey()] = fg_spec;
- fg_spec->Run();
- return fg2;
- }
-
- std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
- MS_EXCEPTION_IF_NULL(context);
- auto iter = specializations_.find(context->SpecializeKey());
- if (iter != specializations_.end()) {
- return iter->second;
- }
- return nullptr;
- }
-
- std::string GetNextCounter() {
- static int g_CloneCounter = 1;
- std::string str_count = std::to_string(g_CloneCounter);
- g_CloneCounter++;
- return str_count;
- }
-
- FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
- const AnalysisContextPtr &context)
- : specializer_(s), func_graph_(fg), context_(context) {
- parent_ = s->GetFuncGraphSpecializer(context->parent());
- engine_ = s->engine();
- cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
- repl_node_ = cloner_->cloned_node();
- specialized_func_graph_ = cloner_->cloned_func_graph()[fg];
- todo_.push_back(fg->get_return());
- auto ps = fg->parameters();
- (void)todo_.insert(todo_.end(), ps.begin(), ps.end());
- }
-
- AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
- MS_EXCEPTION_IF_NULL(node);
- FuncGraphPtr fg = node->func_graph();
-
- if (node->isa<ValueNode>()) {
- return node;
- }
- std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
- while (fg != nullptr && fg != specializer->func_graph_) {
- specializer = specializer->parent_;
- }
- // If had replicated, just return that.
- auto iter = specializer->repl_node_->find(node);
- if (iter != specializer->repl_node_->end()) {
- return iter->second;
- }
-
- auto new_node = specializer->cloner_->CloneDisconnected(node);
- if (node->isa<CNode>()) {
- if (!new_node->isa<CNode>()) {
- MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
- }
- auto c_node = node->cast<CNodePtr>();
- MS_EXCEPTION_IF_NULL(c_node);
- auto inputs = c_node->inputs();
- std::vector<AnfNodePtr> new_inputs;
- (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
- [this](const AnfNodePtr &inp) -> AnfNodePtr {
- if (inp->isa<ValueNode>()) {
- return inp;
- }
- return ReplicateDisconnectedNode(inp);
- });
- auto c_new_node = new_node->cast<CNodePtr>();
- MS_EXCEPTION_IF_NULL(c_new_node);
- c_new_node->set_inputs(new_inputs);
- }
-
- iter = specializer->repl_node_->find(node);
- if (iter != specializer->repl_node_->end()) {
- if (iter->second == node) {
- MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
- }
- } else {
- MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString();
- }
- return new_node;
- }
-
- AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
- MS_EXCEPTION_IF_NULL(node);
- FuncGraphPtr fg = node->func_graph();
-
- std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
- while (fg != nullptr && fg != specializer->func_graph_) {
- specializer = specializer->parent_;
- }
-
- MS_EXCEPTION_IF_NULL(specializer->repl_node_);
- auto iter = specializer->repl_node_->find(node);
- if (iter != specializer->repl_node_->end()) {
- return iter->second;
- }
- return node;
- }
-
- void FuncGraphSpecializer::Run() {
- MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
- << ", cloned func graph name: " << specialized_func_graph_->ToString()
- << ", func graph: " << func_graph_->get_return()->DebugString();
- FirstPass();
- SecondPass();
- MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString()
- << ", cloned func graph name: " << specialized_func_graph_->ToString()
- << ", new func graph: " << specialized_func_graph_->get_return()->DebugString();
- }
-
- void FuncGraphSpecializer::FirstPass() {
- while (todo_.size()) {
- AnfNodePtr node = todo_.back();
- todo_.pop_back();
- if (node->func_graph() == nullptr) {
- // do nothing for ValueNode
- continue;
- }
- if (node->func_graph() != func_graph_) {
- if (parent_ == nullptr) {
- MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info());
- }
- parent_->AddTodoItem(node);
- parent_->FirstPass();
- AnfNodePtr new_node = parent_->GetReplicatedNode(node);
- if (node->isa<CNode>()) {
- parent_->ProcessCNode(new_node->cast<CNodePtr>());
- }
- continue;
- }
- if (marked_.count(node) > 0) {
- continue;
- }
- (void)marked_.insert(node);
- ProcessNode(node);
- }
- }
-
- // Specialize CNode in func graphs
- void FuncGraphSpecializer::SecondPass() {
- for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) {
- if (node->isa<CNode>()) {
- ProcessCNode(node->cast<CNodePtr>());
- }
- }
- }
-
- void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
- MS_EXCEPTION_IF_NULL(node);
- ScopeGuard scope_guard(node->scope());
- AnfNodeConfigPtr conf = MakeConfig(node);
- AnfNodePtr new_node = GetReplicatedNode(node);
- MS_EXCEPTION_IF_NULL(new_node);
-
- if (new_node->func_graph() != specialized_func_graph_) {
- MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
- << ", new_node: " << new_node->DebugString()
- << ", new_node->func_graph(): " << new_node->func_graph()->ToString()
- << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
- return;
- }
- new_node->set_abstract(GetEvaluatedValueWrap(conf));
- MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
-
- if (node->isa<CNode>()) {
- auto c_old = node->cast<CNodePtr>();
- auto c_new = new_node->cast<CNodePtr>();
- auto new_inputs = c_new->inputs();
- auto old_inputs = c_old->inputs();
- for (size_t i = 0; i < old_inputs.size(); ++i) {
- auto node_input = old_inputs[i];
- AnfNodeConfigPtr iconf = MakeConfig(node_input);
- AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
- // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
- // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
- AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival);
- if (replace_node == nullptr) {
- replace_node = BuildReplacedNode(iconf);
- MS_EXCEPTION_IF_NULL(replace_node);
- replace_node->set_abstract(ival);
- MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
- } else {
- MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
- << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
- }
- if (new_inputs[i] != replace_node) {
- new_inputs[i] = replace_node;
- MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
- }
- }
- c_new->set_inputs(new_inputs);
- }
- }
-
- AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
- MS_EXCEPTION_IF_NULL(conf);
-
- auto conf_iter = engine_->anfnode_config_map().find(conf);
- AnfNodeConfigPtr new_conf = conf;
- while (conf_iter != engine_->anfnode_config_map().end()) {
- MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node("
- << new_conf->node()->DebugString() << ")";
- new_conf = conf_iter->second;
- MS_EXCEPTION_IF_NULL(new_conf);
- MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node("
- << conf->node()->DebugString() << ")";
- (void)ReplicateDisconnectedNode(new_conf->node());
- conf_iter = engine_->anfnode_config_map().find(new_conf);
- }
- todo_.push_back(new_conf->node());
- auto repl = GetReplicatedNode(new_conf->node());
- if (repl->func_graph()) {
- MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString()
- << ") to replace origin:" << new_conf->node()->DebugString();
- } else {
- MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
- << ") to replace origin: " << new_conf->node()->DebugString();
- }
- return repl;
- }
-
- namespace {
- const StringImmPtr kDeadNode = std::make_shared<StringImm>("Dead Node");
- const StringImmPtr kPolyNode = std::make_shared<StringImm>("Poly Node");
-
- inline bool CanSpecializeNode(const AnfNodePtr &node) {
- if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
- return true;
- }
- return false;
- }
- } // namespace
-
- AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
- const AbstractBasePtrList &argvals) {
- MS_EXCEPTION_IF_NULL(abs);
- AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
- MS_EXCEPTION_IF_NULL(real_a);
-
- AbstractFunctionPtr func = real_a->GetUnique();
- SpecializeStatusCode errcode;
- ScopeGuard scope_guard(node->scope());
- AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode);
- if (repl == nullptr) {
- if (errcode == kSpecializeFindUniqueArgvalDead) {
- const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
- repl = BuildValueNode(kDeadNode, error_dead_node);
- MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString();
- } else if (errcode == kSpecializeFindUniqueArgvalPoly) {
- const auto error_poly_node = std::make_shared<AbstractError>(kPolyNode, node);
- repl = BuildValueNode(kPolyNode, error_poly_node);
- MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString();
- } else {
- MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString()
- << ", abstract: " << abs->ToString();
- }
- }
-
- return repl;
- }
-
- AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
- const AbstractBasePtrList &args,
- SpecializeStatusCode *errcode) {
- MS_EXCEPTION_IF_NULL(abs);
- MS_EXCEPTION_IF_NULL(func);
- MS_EXCEPTION_IF_NULL(errcode);
- *errcode = kSpecializeSuccess;
-
- auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func);
- if (real_func != nullptr) {
- return BuildValueNode(real_func->prim(), abs);
- }
-
- EvaluatorPtr eval;
- eval = engine_->GetEvaluatorFor(func);
- MS_EXCEPTION_IF_NULL(eval);
- AbstractBasePtrList argvals = eval->NormalizeArgs(args);
-
- std::pair<AbstractBasePtrList, AbstractBasePtr> result;
- SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result);
- if (status != kSpecializeSuccess) {
- *errcode = status;
- return nullptr;
- }
- argvals = result.first;
- AbstractBasePtr unique_output = result.second;
-
- auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
- if (prim_func != nullptr) {
- auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
- return BuildValueNode(prim_func->prim(), type_func);
- }
-
- if (!eval->isa<BaseFuncGraphEvaluator>()) {
- MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString();
- }
- auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
-
- if (func->context() != nullptr) {
- if (!IsVisible(func_graph_, func->context()->func_graph())) {
- MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
- }
- } else {
- MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
- }
- AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
- MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
- << ", graph: " << context->func_graph()->get_return()->DebugString();
- FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
- return BuildValueNode(v, abs);
- }
-
- const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
- auto cache_iter = evalcaches_.find(eval);
- if (cache_iter == evalcaches_.end()) {
- evalcaches_[eval] = eval->cache();
- return eval->cache();
- }
- return cache_iter->second;
- }
-
- std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgsVal(
- const EvaluatorPtr &eval) {
- MS_EXCEPTION_IF_NULL(eval);
- std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
- AbstractBasePtr ret = nullptr;
- AbstractBasePtrList broaded_argvals;
- for (auto &argvals_map : *evalcaches_[eval]) {
- auto argvals = argvals_map.first;
- broaded_argvals.clear();
-
- (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals),
- [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
- (void)choices.insert(broaded_argvals);
- MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
- }
-
- if (1 == choices.size()) {
- ConfigPtrList args_conf_list;
- (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
- [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
-
- // if broaden return null
- ret = eval->Run(engine_, args_conf_list, nullptr);
- EvaluatorCacheMapPtr real = std::make_shared<EvaluatorCacheMap>();
-
- (*real)[broaded_argvals] = ret;
- evalcaches_[eval] = real;
- return std::make_pair(broaded_argvals, ret);
- } else {
- MS_LOG(DEBUG) << "Choices.size: " << choices.size();
- return std::make_pair(AbstractBasePtrList(), nullptr);
- }
- }
-
- void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
- MS_EXCEPTION_IF_NULL(new_node);
- if (specializer_->seen().count(new_node) > 0) {
- return;
- }
- specializer_->AddSeen(new_node);
-
- auto new_inputs = new_node->inputs();
- if (new_inputs.empty()) {
- MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
- }
- AnfNodePtr func = new_inputs[0];
- MS_EXCEPTION_IF_NULL(func);
-
- // First element is func so arg start from 1
- std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
- // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
- while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
- std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
- // First element is partial, second is func so arg is start from 2
- (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
- func = inputs[1];
- new_inputs = args;
- (void)new_inputs.insert(new_inputs.begin(), func);
- }
-
- AbstractBasePtrList argvals;
- MS_EXCEPTION_IF_NULL(new_inputs[0]);
- AbstractBasePtr fnval = new_inputs[0]->abstract();
- MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", "
- << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString();
-
- // First element is func so function arguments start from 1
- for (size_t i = 1; i < new_inputs.size(); ++i) {
- argvals.push_back(new_inputs[i]->abstract());
- MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", "
- << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
- }
-
- if (CanSpecializeNode(func)) {
- new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
- }
-
- for (size_t i = 0; i < argvals.size();) {
- size_t next = i + 1;
- if (CanSpecializeNode(args[i])) {
- new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
- }
- // support for partial(Multitype) which Multitype should not be inferred to POLY.
- // after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only,
- // so even with partial parameter, it will specialize to that graph.
- // Maybe a better idea should inline graph with partial node first, then it will have full
- // parameter list to infer and specialize.
- MS_EXCEPTION_IF_NULL(new_inputs[next]);
- if (new_inputs[next]->isa<ValueNode>() && (GetValueNode(new_inputs[next]) == kPolyNode) &&
- IsPrimitive(func, prim::kPrimPartial)) {
- new_inputs[next] = args[i];
- }
- i = next;
- }
-
- new_node->set_inputs(new_inputs);
- }
-
- namespace {
- void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) {
- MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
- int i = 0;
- for (const auto &item : evaluator_cache_map) {
- MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first;
- }
- }
-
- bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
- if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
- MS_LOG(DEBUG) << "High order primitive return POLY.";
- return true;
- }
- if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
- auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
- auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
- if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
- auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
- if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
- MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
- return true;
- }
- }
- }
- return false;
- }
- } // end anonymous namespace
-
- SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
- const AbstractBasePtrList &argvals,
- std::pair<AbstractBasePtrList, AbstractBasePtr> *result) {
- MS_EXCEPTION_IF_NULL(func);
- MS_EXCEPTION_IF_NULL(eval);
- MS_EXCEPTION_IF_NULL(result);
-
- EvaluatorCacheMap evaluator_cache_map = *eval->cache();
- if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
- *result = std::make_pair(argvals, evaluator_cache_map[argvals]);
- return kSpecializeSuccess;
- }
- DumpEvaluatorCache(evaluator_cache_map, argvals);
-
- const EvaluatorCacheMapPtr &choices = GetEvalCache(eval);
- MS_EXCEPTION_IF_NULL(choices);
-
- if (choices->count(argvals)) {
- *result = std::make_pair(argvals, (*choices)[argvals]);
- return kSpecializeSuccess;
- } else if (choices->size() == 1) {
- MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
- *result = std::make_pair(choices->begin()->first, choices->begin()->second);
- return kSpecializeSuccess;
- } else if (choices->empty()) {
- MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
- return kSpecializeFindUniqueArgvalDead;
- } else {
- if (IsPolyFunc(func, argvals)) {
- return kSpecializeFindUniqueArgvalPoly;
- }
-
- MS_LOG(DEBUG) << "Try to find generalized argvals.";
- *result = BuildFromBroadedArgsVal(eval);
- if (!result->first.empty()) {
- return kSpecializeSuccess;
- }
- MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism.";
- return kSpecializeFindUniqueArgvalPoly;
- }
- }
-
- AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) {
- MS_EXCEPTION_IF_NULL(origin_node);
- MS_EXCEPTION_IF_NULL(ival);
-
- AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
- if (abs != nullptr) {
- // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction.
- if (abs->isa<AbstractFuncUnion>()) {
- return nullptr;
- }
- ValuePtr value = nullptr;
- if (abs->isa<PrimitiveAbstractClosure>()) {
- auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
- value = real_fn->prim();
- } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
- auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
- value = real_fn->meta_func_graph();
- } else if (abs->isa<FuncGraphAbstractClosure>()) {
- auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
- value = real_fn->func_graph();
- } else {
- return nullptr;
- }
- if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
- (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
- return BuildValueNode(value, ival);
- } else {
- return nullptr;
- }
- } else {
- ValuePtr val = ival->BuildValue();
- if (val->isa<AnyValue>()) {
- return nullptr;
- } else {
- return BuildValueNode(val, ival);
- }
- }
- }
-
- AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
- return engine_->MakeConfig(node, context_);
- }
- } // namespace abstract
- } // namespace mindspore
|