/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/static_analysis/evaluator.h" #include #include #include "ir/func_graph_cloner.h" #include "pipeline/static_analysis/utils.h" #include "debug/trace.h" namespace mindspore { namespace abstract { namespace { void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(evaluator); if (out_conf != nullptr) { MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); } for (size_t i = 0; i < arg_spec_list.size(); i++) { MS_LOG(DEBUG) << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString(); } } void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(evaluator); if (out_conf != nullptr) { auto node = out_conf->node(); if (IsValueNode(node)) { MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope() << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); } else { MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString() << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); } } } } // namespace AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) { AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); MS_EXCEPTION_IF_NULL(parent_context_); AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); return context; } static std::vector FastShadowSort(const AnfNodePtr &ret_node) { auto ori_func_graph = ret_node->func_graph(); MS_EXCEPTION_IF_NULL(ori_func_graph); std::vector sorted_nodes; std::unordered_set checked_cnodes; std::size_t index = 0; sorted_nodes.emplace_back(ret_node); while (index < sorted_nodes.size()) { auto current = sorted_nodes[index]; index++; MS_EXCEPTION_IF_NULL(current); if (current->isa()) { auto &inputs = current->cast()->inputs(); for (auto it = inputs.begin(); it != inputs.end(); it++) { AnfNodePtr input = *it; if (input != nullptr && input->isa() && checked_cnodes.find(input) == checked_cnodes.end() && input->func_graph() == ori_func_graph) { sorted_nodes.emplace_back(input); (void)checked_cnodes.insert(input); } } } } return sorted_nodes; } AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); MS_EXCEPTION_IF_NULL(fg); std::size_t nargs = fg->parameters().size(); if (args_spec_list.size() != nargs) { MS_EXCEPTION(ValueError) << "Function " << fg->ToString() << ", The number of parameters of this function is " << fg->parameters().size() << ", but the number of provided arguments is " << args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); } MS_EXCEPTION_IF_NULL(parent_context_); MS_EXCEPTION_IF_NULL(engine); graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list); const auto ¶meters = fg->parameters(); for (size_t i = 0; i < nargs; i++) { const auto &arg = args_spec_list[i]; const auto &node = parameters[i]; AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); engine->cache().set_value(conf, arg); } const AnfNodePtr &func_node = fg->get_return(); MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); AbstractBasePtr ret_base = nullptr; std::vector nodes = FastShadowSort(func_node); for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { const auto &node = *it; AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); ret_base = engine->GetEvaluatedValue(node_conf); MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); } MS_EXCEPTION_IF_NULL(ret_base); MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString(); return ret_base; } AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { MS_EXCEPTION_IF_NULL(func_graph_); if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { AbstractBasePtrList broaded_list; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), [](const AbstractBasePtr &arg) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(arg); return arg->Broaden(); }); MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) << ", broaded: " << mindspore::ToString(broaded_list); return broaded_list; } if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { if (parent_context_) { MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() << ", context: " << parent_context_->ToString(); auto last_context = parent_context_->Filter(func_graph_); if (last_context && last_context->func_graph() == func_graph_) { MS_LOG(DEBUG) << "Find last infer context: " << last_context->ToString(); MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); // Join the last eval arguments and current arguments to check if there are loop variant. auto joined_args_spec_list = AbstractJoin(args_spec_list, last_context->args_spec_list()); MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. if (!(joined_args_spec_list == args_spec_list)) { func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); } return joined_args_spec_list; } } } return args_spec_list; } FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { auto iter = func_graph_cache_.find(args_spec_list); FuncGraphPtr ret = nullptr; if (iter == func_graph_cache_.end()) { auto fg = func_graph(); MS_EXCEPTION_IF_NULL(fg); TraceManager::DebugTrace(std::make_shared(fg->debug_info())); FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list); TraceManager::EndTrace(); func_graph_cache_[args_spec_list] = generated_graph; MS_EXCEPTION_IF_NULL(engine); engine->func_graph_manager()->AddFuncGraph(generated_graph); ret = generated_graph; } else { ret = iter->second; } // For the top graph, if it is replaced by generated graph, update the top graph to the new one. if (parse::Parser::GetTopFuncGraph() == func_graph()) { if (ret != func_graph()) { parse::Parser::UpdateTopFuncGraph(ret); } } return ret; } FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { auto iter = func_graph_cache_.find(args_spec_list); if (iter != func_graph_cache_.end()) { return iter->second; } MS_EXCEPTION_IF_NULL(meta_func_graph_); FuncGraphPtr generated_func_graph = nullptr; if (this->bound_node() != nullptr) { TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); TraceManager::EndTrace(); } else { generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); } FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph); func_graph_cache_[args_spec_list] = cloned_func_graph; MS_EXCEPTION_IF_NULL(engine); engine->func_graph_manager()->AddFuncGraph(cloned_func_graph); return cloned_func_graph; } AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { const std::string &evaluator_name = ToString(); AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->GetEvaluatedValue(); }); args_spec_list = NormalizeArgs(args_spec_list); trace::TraceGraphInferEnter(shared_from_base(), out_conf); InferEntryLogging(shared_from_base(), args_spec_list, out_conf); MS_EXCEPTION_IF_NULL(cache_); auto iter = cache_->find(args_spec_list); if (iter == cache_->end()) { MS_LOG(DEBUG) << evaluator_name << " cache miss, call Infer()."; AbstractBasePtr ret = Infer(engine, args_spec_list); if (ret == nullptr) { InferFailLogging(shared_from_base(), args_spec_list, out_conf); MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; } MS_EXCEPTION_IF_NULL(ret); MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << "."; (*cache_)[args_spec_list] = ret; trace::TraceGraphInferLeave(shared_from_base()); return ret; } else { MS_EXCEPTION_IF_NULL(iter->second); MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; trace::TraceGraphInferLeave(shared_from_base()); return iter->second; } } AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->GetEvaluatedValue(); }); AbstractBasePtr ret = EvalPrim(engine, args_spec_list); return ret; } AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->GetEvaluatedValue(); }); if (args_conf_list.size() == 0) { MS_LOG(EXCEPTION) << "Size should greater than 0"; } AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); // No need to cache. return ret; } AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { AbstractBasePtr ret = EvalPrim(args_conf_list); return ret; } AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->GetEvaluatedValue(); }); AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); // Don't lookup from cache, as different out_conf with same node but different context // may add different entry to anfnode_config_map_, like getattr primitive. (*cache_)[args_spec_list] = ret; return ret; } AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->GetEvaluatedValue(); }); MS_EXCEPTION_IF_NULL(cache_); auto iter = cache_->find(args_spec_list); if (iter != cache_->end()) { return iter->second; } ConfigPtrList partial_args_conf_list; // Join arguments in partial and the rest arguments from args_conf_list. (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list), [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); (*cache_)[args_spec_list] = ret; return ret; } AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->GetEvaluatedValue(); }); MS_EXCEPTION_IF_NULL(cache_); auto iter = cache_->find(args_spec_list); if (iter != cache_->end()) { return iter->second; } // Call the original evaluator, get the result: y = f(x) AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr); // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) AbstractBasePtrList bparams; bparams.push_back(SensitivityTransform(orig_func_)); (void)std::transform( args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); AbstractBasePtr bparams_final = std::make_shared(bparams); AbstractFunctionPtr bprop = std::make_shared(SensitivityTransform(result), bparams_final); // J(f)(J(x)) return a tuple (y, bprop_f) AbstractBasePtrList jargs = {result, bprop}; AbstractBasePtr jtuple = std::make_shared(jargs); (*cache_)[args_spec_list] = jtuple; return jtuple; } AbstractBasePtr VirtualEvaluator::Infer(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() != args_spec_list_.size()) { MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() << ", arguments no: " << args_spec_list.size(); } // Check each parameter and argument match; for (std::size_t i = 0; i < args_spec_list.size(); i++) { MS_EXCEPTION_IF_NULL(args_spec_list[i]); (void)args_spec_list[i]->Join(args_spec_list_[i]); } return output_; } } // namespace abstract } // namespace mindspore