| @@ -114,7 +114,7 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer) | |||
| return; | |||
| } | |||
| buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl; | |||
| buffer << "#IR entry : @" << graph->ToString() << std::endl; | |||
| buffer << "#attrs :" << std::endl; | |||
| for (const auto &attr : graph->attrs()) { | |||
| buffer << attr.first << " : "; | |||
| @@ -216,7 +216,7 @@ void DumpOperator(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g | |||
| if (IsValueNode<FuncGraph>(op)) { | |||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op); | |||
| if (fg != nullptr) { | |||
| gsub->buffer << "call @" << fg->ToString() << "." << fg->debug_info()->get_id(); | |||
| gsub->buffer << "call @" << fg->ToString(); | |||
| } | |||
| } else if (op->isa<CNode>()) { | |||
| if (gsub->local_var_map.find(op) != gsub->local_var_map.end()) { | |||
| @@ -224,7 +224,7 @@ void DumpOperator(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g | |||
| } else { | |||
| auto node = op->cast<CNodePtr>(); | |||
| auto fg = node->func_graph(); | |||
| gsub->buffer << "$(" << fg->ToString() << "." << fg->debug_info()->get_id() << ":" << node->ToString() << ")"; | |||
| gsub->buffer << "$(" << fg->ToString() << ":" << node->ToString() << ")"; | |||
| } | |||
| } else if (op->isa<ValueNode>()) { | |||
| gsub->buffer << GetValueNode(op)->ToString(); | |||
| @@ -262,14 +262,14 @@ void DumpOperands(const AnfNodePtr &nd, OrderedMap<AnfNodePtr, int32_t> *para_ma | |||
| } else { | |||
| auto node = in->cast<CNodePtr>(); | |||
| auto fg = node->func_graph(); | |||
| gsub->buffer << "$(" << fg->ToString() << "." << fg->debug_info()->get_id() << ":" << node->ToString() << ")"; | |||
| gsub->buffer << "$(" << fg->ToString() << ":" << node->ToString() << ")"; | |||
| } | |||
| } else if (in->isa<ValueNode>() && !IsValueNode<FuncGraph>(in)) { | |||
| // non Primitive valuenode | |||
| gsub->buffer << GetValueNode(in)->ToString(); | |||
| } else if (IsValueNode<FuncGraph>(in)) { | |||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(in); | |||
| gsub->buffer << "@" << fg->ToString() << "." << fg->debug_info()->get_id(); | |||
| gsub->buffer << "@" << fg->ToString(); | |||
| } else { | |||
| gsub->buffer << in->ToString(); | |||
| } | |||
| @@ -501,8 +501,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo> | |||
| } | |||
| fout << std::endl; | |||
| } | |||
| fout << "subgraph @" << sg.first->ToString() << "."; | |||
| fout << sg.first->debug_info()->get_id() << "("; | |||
| fout << "subgraph @" << sg.first->ToString() << "("; | |||
| if (sg.first != graph) { | |||
| std::vector<AnfNodePtr> parameters = sg.first->parameters(); | |||
| if (parameters.size() == 1) { | |||
| @@ -73,7 +73,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| if (!check_integrity_) { | |||
| break; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'"; | |||
| MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "'"; | |||
| } | |||
| auto param_map = exported[fg]; | |||
| if (param_map.find(param) != param_map.end()) { | |||
| @@ -83,7 +83,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| } | |||
| if (throw_excp) { | |||
| MS_LOG(EXCEPTION) << "Can not find index for param '" << param->DumpText() << "' for func graph '" | |||
| << func_graph->DumpText() << "." << func_graph->debug_info()->get_id() << "'"; | |||
| << func_graph->DumpText() << "'"; | |||
| } | |||
| return -1; | |||
| } | |||
| @@ -457,7 +457,7 @@ void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &nod | |||
| } | |||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg); | |||
| std::string func_graph_id = fg->debug_info()->get_id(); | |||
| comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id; | |||
| comment << " fg_" << func_graph_id << "=" << fg->ToString(); | |||
| } | |||
| if (has_comment) { | |||
| ofs << comment.str(); | |||
| @@ -552,8 +552,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun | |||
| if (*(func_graph->switch_layer_input())) { | |||
| ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n"; | |||
| } | |||
| ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "." | |||
| << func_graph->debug_info()->get_id() << "\n"; | |||
| ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "\n"; | |||
| if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { | |||
| ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "#" | |||
| << label_manage::Label(func_graph->debug_info()) << "\n"; | |||
| @@ -93,7 +93,7 @@ void DumpInferStack(std::ostringstream &oss) { | |||
| infer_vec.clear(); | |||
| break; | |||
| } | |||
| auto graph_context = graph_infer->graph_context(); | |||
| auto graph_context = graph_infer->context(); | |||
| if (graph_context == nullptr) { | |||
| MS_LOG(INFO) << "Null context continue"; | |||
| continue; | |||
| @@ -253,7 +253,7 @@ std::vector<AnalysisContextPtr> AnalyzedFuncGraphExporter::ProcessFuncGraphCall( | |||
| } | |||
| auto base_fg_evaluator = dyn_cast<abstract::BaseFuncGraphEvaluator>(evaluator); | |||
| auto ctx = base_fg_evaluator->graph_context(); | |||
| auto ctx = base_fg_evaluator->context(); | |||
| if (ctx != nullptr && context_map_.insert({ctx, false}).second) { | |||
| MS_LOG(DEBUG) << "Add new context, ctx.addr = " << ctx.get() << "ctx = " << ctx->ToString(); | |||
| context_vec_.push_back(ctx); | |||
| @@ -298,7 +298,7 @@ void AnalyzedFuncGraphExporter::OutputStatementComment(std::ofstream &ofs, const | |||
| } | |||
| FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg); | |||
| std::string func_graph_id = fg->debug_info()->get_id(); | |||
| comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id; | |||
| comment << " fg_" << func_graph_id << "=" << fg->ToString(); | |||
| if (ctxs.size() > i && ctxs[i] != nullptr) { | |||
| comment << "(@ctx.addr=" << ctxs[i].get() << ")"; | |||
| } | |||
| @@ -392,8 +392,7 @@ void AnalyzedFuncGraphExporter::ExportOneFuncGraph(std::ofstream &ofs, const Fun | |||
| std::vector<AnfNodePtr> parameters = func_graph->parameters(); | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map; | |||
| ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "." | |||
| << func_graph->debug_info()->get_id(); | |||
| ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText(); | |||
| if (cur_ctx_ != nullptr) { | |||
| ofs << " @ctx.addr=" << cur_ctx_.get(); | |||
| } | |||
| @@ -114,6 +114,8 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph | |||
| } | |||
| } | |||
| auto ret = engine->Run(func_graph, args_spec); | |||
| MS_LOG(INFO) << "function call max depth: " << engine->function_call_max_depth() | |||
| << ", simulate call max depth: " << engine->stack_frame_max_depth(); | |||
| MS_LOG(DEBUG) << "AbstractAnalyze end"; | |||
| return ret; | |||
| } | |||
| @@ -861,7 +861,7 @@ class SideEffectFinder { | |||
| const SccPtr &GetScc(const FuncGraphPtr &func_graph) const { | |||
| auto found = scc_map_.find(func_graph); | |||
| if (found == scc_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString() << "." << func_graph->debug_info()->get_id(); | |||
| MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString(); | |||
| } | |||
| return found->second; | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include "abstract/utils.h" | |||
| #include "debug/trace.h" | |||
| #include "utils/ms_context.h" | |||
| #include "pipeline/jit/static_analysis/stack_frame.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| @@ -66,62 +67,145 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr & | |||
| return context; | |||
| } | |||
| EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | |||
| FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | |||
| void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, | |||
| const StackFramePtr &new_stack_frame) { | |||
| // Enter new func graph. | |||
| auto ¤t_node = current_stack_frame->CurrentNode(); | |||
| auto current_context = current_stack_frame->current_context(); | |||
| AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context); | |||
| auto evaluator = new_stack_frame->evaluator(); | |||
| MS_EXCEPTION_IF_NULL(evaluator); | |||
| trace::TraceGraphEvalEnter(evaluator, call_conf); | |||
| // Increase & Check the func graph call depth. | |||
| engine->IncreaseFunctionCallDepth(); | |||
| engine->IncreaseStackFrameDepth(); | |||
| if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) { | |||
| MS_LOG(EXCEPTION) << "Exceed function call depth limit " | |||
| << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) | |||
| << ", (function call depth: " << engine->function_call_depth() | |||
| << ", simulate call depth: " << engine->stack_frame_depth() | |||
| << "), please call 'context.set_context(max_call_depth=value)' to adjust this value."; | |||
| } | |||
| MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString() | |||
| << "), enter, function call depth: " << engine->function_call_depth() << " - " | |||
| << engine->stack_frame_depth(); | |||
| } | |||
| void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine, | |||
| const StackFramePtr ¤t_stack_frame) { | |||
| // Leave current func graph. | |||
| auto evaluator = current_stack_frame->evaluator(); | |||
| MS_EXCEPTION_IF_NULL(evaluator); | |||
| trace::TraceGraphEvalLeave(evaluator); | |||
| // Decrease the func graph call depth. | |||
| engine->DecreaseFunctionCallDepth(); | |||
| engine->DecreaseStackFrameDepth(); | |||
| MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString() | |||
| << "), leave, function call depth: " << engine->function_call_depth() << " - " | |||
| << engine->stack_frame_depth(); | |||
| } | |||
| // Start running stack frames in a Evaluator. | |||
| AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) { | |||
| EvalResultPtr eval_result = nullptr; | |||
| AbstractBasePtr res_base = nullptr; | |||
| std::stack<StackFramePtr> stack_frames; | |||
| auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_); | |||
| MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame; | |||
| stack_frames.push(current_stack_frame); | |||
| while (1) { | |||
| current_stack_frame = stack_frames.top(); | |||
| if (current_stack_frame->Done()) { | |||
| MS_EXCEPTION_IF_NULL(res_base); | |||
| MS_LOG(DEBUG) << "[" << this << "/StackFrame] Leave from func graph, " << current_stack_frame; | |||
| stack_frames.pop(); | |||
| if (stack_frames.empty()) { | |||
| MS_LOG(DEBUG) << "[" << this << "/StackFrame] Finish at func graph, " << current_stack_frame | |||
| << ", res_base: " << res_base->ToString(); | |||
| break; | |||
| } | |||
| // Save func graph eval result for specialize. | |||
| auto evaluator = current_stack_frame->evaluator(); | |||
| MS_EXCEPTION_IF_NULL(evaluator); | |||
| (*evaluator->evaluator_cache_map())[current_stack_frame->args_abs_list()] = eval_result; | |||
| // Leave current func graph. | |||
| LeaveStackFrame(engine, current_stack_frame); | |||
| // Switch the stack frame. | |||
| current_stack_frame = stack_frames.top(); | |||
| MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame; | |||
| current_stack_frame->Back(engine, eval_result); | |||
| continue; | |||
| } | |||
| auto new_stack_frame = current_stack_frame->Jump(engine); | |||
| if (new_stack_frame != nullptr) { | |||
| // Enter new func graph. | |||
| EnterStackFrame(engine, current_stack_frame, new_stack_frame); | |||
| // Update current stack frame. | |||
| stack_frames.push(new_stack_frame); | |||
| current_stack_frame = new_stack_frame; | |||
| MS_LOG(DEBUG) << "[" << this << "/StackFrame] Jump to new func graph, " << new_stack_frame; | |||
| } | |||
| eval_result = current_stack_frame->Step(engine); | |||
| MS_EXCEPTION_IF_NULL(eval_result); | |||
| res_base = eval_result->abstract(); | |||
| } | |||
| return res_base; | |||
| } | |||
| EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| engine->IncreaseFunctionCallDepth(); | |||
| if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) { | |||
| MS_LOG(EXCEPTION) << "Exceed function call depth limit " | |||
| << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) | |||
| << ", (function call depth: " << engine->function_call_depth() | |||
| << ", simulate call depth: " << engine->stack_frame_depth() | |||
| << "), please call 'context.set_context(max_call_depth=value)' to adjust this value."; | |||
| } | |||
| MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString() | |||
| << "), enter, function call depth: " << engine->function_call_depth() << " - " | |||
| << engine->stack_frame_depth(); | |||
| // Initialize evaluator starter with args_abs_list. | |||
| FuncGraphPtr fg = GetFuncGraph(engine, args_abs_list); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| std::size_t nargs = fg->parameters().size(); | |||
| if (args_spec_list.size() != nargs) { | |||
| if (args_abs_list.size() != nargs) { | |||
| MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is " | |||
| << fg->parameters().size() << ", but the number of provided arguments is " | |||
| << args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); | |||
| << args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(parent_context_); | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list); | |||
| context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list); | |||
| const auto ¶meters = fg->parameters(); | |||
| for (size_t i = 0; i < nargs; i++) { | |||
| const auto &arg = args_spec_list[i]; | |||
| const auto &arg = args_abs_list[i]; | |||
| const auto &node = parameters[i]; | |||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); | |||
| engine->analysis_cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr)); | |||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, context_); | |||
| engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr)); | |||
| } | |||
| const AnfNodePtr &func_node = fg->get_return(); | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << "/" << fg->ToString() | |||
| << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString() | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString() | |||
| << ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString() | |||
| << ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL") | |||
| << ", current function call depth: " << engine->function_call_depth(); | |||
| AbstractBasePtr ret_base = nullptr; | |||
| engine->IncreaseFunctionCallDepth(); | |||
| if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) { | |||
| MS_LOG(EXCEPTION) << "Exceed function call depth limit " | |||
| << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) | |||
| << ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; | |||
| } | |||
| const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType { | |||
| if (node->func_graph() != fg || node->isa<ValueNode>()) { | |||
| return EXCLUDE; | |||
| } | |||
| return FOLLOW; | |||
| }); | |||
| for (const auto &node : all_nodes) { | |||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); | |||
| MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << "/" << fg->ToString() | |||
| << ", node_conf: " << node_conf->ToString(); | |||
| auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf); | |||
| ret_base = node_eval_result->abstract(); | |||
| MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << "/" << fg->ToString() | |||
| << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); | |||
| } | |||
| engine->DecreaseFunctionCallDepth(); | |||
| MS_EXCEPTION_IF_NULL(ret_base); | |||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() | |||
| << ", is stub: " << fg->stub(); | |||
| auto res_base = LaunchStackFrame(engine, fg); | |||
| MS_EXCEPTION_IF_NULL(res_base); | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString() | |||
| << ", evaluated abstract: " << res_base->ToString() << ", is stub: " << fg->stub(); | |||
| if (fg->stub()) { | |||
| ret_base = std::make_shared<AbstractUndetermined>(); | |||
| res_base = std::make_shared<AbstractUndetermined>(); | |||
| } | |||
| return std::make_shared<EvalResult>(ret_base, nullptr); | |||
| engine->DecreaseFunctionCallDepth(); | |||
| MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString() | |||
| << "), leave, function call depth: " << engine->function_call_depth() << " - " | |||
| << engine->stack_frame_depth(); | |||
| return std::make_shared<EvalResult>(res_base, nullptr); | |||
| } | |||
| AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||
| @@ -158,7 +242,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||
| if (parent_context_) { | |||
| MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() | |||
| << ", context: " << parent_context_->ToString(); | |||
| auto last_context = parent_context_->Filter(func_graph_); | |||
| auto last_context = parent_context_->FindParentContext(func_graph_); | |||
| if (last_context && last_context->func_graph() == func_graph_) { | |||
| MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString(); | |||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | |||
| @@ -23,6 +23,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <stack> | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -135,7 +136,7 @@ class SymbolicPrimEvaluator : public PrimEvaluator { | |||
| virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; | |||
| }; | |||
| // Evaluator will be stored in AnalysisEngine.constructors_ | |||
| // Evaluator will be stored in AnalysisEngine.evaluators_ | |||
| using EvaluatorPtrList = std::vector<EvaluatorPtr>; | |||
| class DummyEvaluator : public Evaluator { | |||
| @@ -179,6 +180,11 @@ class TrackedEvaluator : public Evaluator { | |||
| EvaluatorPtr sub_evaluator_; | |||
| }; | |||
| using FuncGraphCacheMap = | |||
| std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; | |||
| class StackFrame; | |||
| using StackFramePtr = std::shared_ptr<StackFrame>; | |||
| class BaseFuncGraphEvaluator : public Evaluator { | |||
| public: | |||
| explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context) | |||
| @@ -192,19 +198,26 @@ class BaseFuncGraphEvaluator : public Evaluator { | |||
| virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||
| AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list); | |||
| AnalysisContextPtr graph_context() const { return graph_context_; } | |||
| AnalysisContextPtr context() const { return context_; } | |||
| void set_context(const AnalysisContextPtr &context) { context_ = context; } | |||
| protected: | |||
| AnalysisContextPtr parent_context_; | |||
| private: | |||
| AnalysisContextPtr graph_context_; | |||
| // Add functions for stack frame routine. | |||
| AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg); | |||
| void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, | |||
| const StackFramePtr &new_stack_frame); | |||
| void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame); | |||
| AnalysisContextPtr context_; | |||
| }; | |||
| class FuncGraphEvaluator : public BaseFuncGraphEvaluator { | |||
| public: | |||
| FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) | |||
| : BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {} | |||
| : BaseFuncGraphEvaluator(context->FindParentContext(func_graph)), func_graph_(func_graph) {} | |||
| ~FuncGraphEvaluator() override = default; | |||
| MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); | |||
| @@ -219,8 +232,7 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator { | |||
| private: | |||
| FuncGraphPtr func_graph_; | |||
| std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual> | |||
| func_graph_cache_; | |||
| FuncGraphCacheMap func_graph_cache_; | |||
| std::vector<AbstractBasePtrList> trace_; | |||
| }; | |||
| using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>; | |||
| @@ -243,8 +255,7 @@ class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator { | |||
| private: | |||
| MetaFuncGraphPtr meta_func_graph_; | |||
| std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual> | |||
| func_graph_cache_; | |||
| FuncGraphCacheMap func_graph_cache_; | |||
| ScopePtr scope_; | |||
| }; | |||
| @@ -0,0 +1,114 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pipeline/jit/static_analysis/stack_frame.h" | |||
| #include "debug/trace.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator, | |||
| const CNodePtr current_cnode) { | |||
| AbstractBasePtrList args_abs_list; | |||
| auto &inputs = current_cnode->inputs(); | |||
| for (std::size_t i = 1; i < inputs.size(); i++) { | |||
| auto config = engine->MakeConfig(inputs[i], current_context_); | |||
| auto abs = config->ObtainEvalResult()->abstract(); | |||
| args_abs_list.push_back(abs); | |||
| } | |||
| args_abs_list = evaluator->NormalizeArgs(args_abs_list); | |||
| args_abs_list = evaluator->BroadenUndeterminedArgs(args_abs_list); | |||
| return args_abs_list; | |||
| } | |||
| StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode, | |||
| const FuncGraphAbstractClosurePtr &graph_func) { | |||
| // Get the evaluator for func graph. | |||
| auto evaluator = engine->GetEvaluatorFor(graph_func); | |||
| auto fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator); | |||
| if (fg_evaluator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Evaluator should be a BaseGraphEvaluator, but got " << evaluator->ToString(); | |||
| } | |||
| fg_evaluator->set_context(current_context_); | |||
| // Evaluate the inputs firstly. Build arguments for the func graph. | |||
| AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode); | |||
| // Generate func graph with arguments. | |||
| auto fg = fg_evaluator->GetFuncGraph(engine, args_abs_list); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| std::size_t nargs = fg->parameters().size(); | |||
| if (args_abs_list.size() != nargs) { | |||
| MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is " | |||
| << fg->parameters().size() << ", but the number of provided arguments is " | |||
| << args_abs_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); | |||
| } | |||
| MS_LOG(DEBUG) << "current_node: " << current_cnode->DebugString() << ", fg: " << fg->ToString() | |||
| << ", current_context_: " << current_context_->ToString(); | |||
| // Find parent context and create new context. | |||
| auto branch_fg = graph_func->func_graph(); | |||
| auto parent_context = graph_func->context()->FindParentContext(branch_fg); | |||
| auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list); | |||
| // Evaluate the parameters with new context. | |||
| for (size_t i = 0; i < nargs; i++) { | |||
| const auto &arg_abs = args_abs_list[i]; | |||
| const auto &node = fg->parameters()[i]; | |||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context); | |||
| engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg_abs, nullptr)); | |||
| } | |||
| // Create a new stack frame and set arguments for it. | |||
| auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, current_context()); | |||
| new_stack_frame->set_args_abs_list(std::move(args_abs_list)); | |||
| return new_stack_frame; | |||
| } | |||
| // Check if we need branch to another func graph. | |||
| StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) { | |||
| auto ¤t_node = CurrentNode(); | |||
| if (!current_node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = current_node->cast<CNodePtr>(); | |||
| auto func = engine->GetCNodeOperatorAbstract(cnode, current_context_); | |||
| auto graph_func = dyn_cast<FuncGraphAbstractClosure>(func); // Not handle MetaFuncGraphAbstractClosure by now. | |||
| if (graph_func == nullptr) { | |||
| return nullptr; // Not call FuncGraph. | |||
| } | |||
| // It's FuncGraph Call. | |||
| return DoJump(engine, cnode, graph_func); | |||
| } | |||
| EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) { | |||
| auto ¤t_node = NextNode(); | |||
| MS_LOG(DEBUG) << "current_node: " << current_node->DebugString() | |||
| << ", current_context_: " << current_context_->ToString(); | |||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_); | |||
| auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf); | |||
| return node_eval_result; | |||
| } | |||
| void StackFrame::Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result) { | |||
| auto ¤t_node = NextNode(); | |||
| MS_LOG(DEBUG) << "current_node: " << current_node->DebugString() | |||
| << ", current_context_: " << current_context_->ToString(); | |||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_); | |||
| engine->SaveEvalResultInCache(node_conf, result); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,138 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_ | |||
| #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "pipeline/jit/static_analysis/evaluator.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| class StackFrame; | |||
| using StackFramePtr = std::shared_ptr<StackFrame>; | |||
| using EvaluatorWeakPtr = std::weak_ptr<Evaluator>; | |||
| class StackFrame : public Base { | |||
| public: | |||
| StackFrame(const EvaluatorPtr &evaluator, const FuncGraphPtr &func_graph, const AnalysisContextPtr ¤t_context, | |||
| const AnalysisContextPtr &parent_context) | |||
| : evaluator_(EvaluatorWeakPtr(evaluator)), | |||
| func_graph_(func_graph), | |||
| current_context_(current_context), | |||
| parent_context_(parent_context), | |||
| slot_index_(0), | |||
| done_(false) { | |||
| Load(); | |||
| } | |||
| virtual ~StackFrame() = default; | |||
| void Load() { | |||
| node_slots = TopoSort(func_graph_->get_return(), SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType { | |||
| if (node->func_graph() != func_graph_ || node->isa<ValueNode>()) { | |||
| return EXCLUDE; | |||
| } | |||
| return FOLLOW; | |||
| }); | |||
| slot_index_ = 0; | |||
| args_abs_list_.clear(); | |||
| } | |||
| // Check if we need branch to another func graph. | |||
| StackFramePtr Jump(const AnalysisEnginePtr &engine); | |||
| // Run one step in current func graph. | |||
| EvalResultPtr Step(const AnalysisEnginePtr &engine); | |||
| // Return back from branch func graph. | |||
| void Back(const AnalysisEnginePtr &engine, const EvalResultPtr &result); | |||
| bool Done() { return done_; } | |||
| AnfNodePtr &CurrentNode() { | |||
| if (slot_index_ >= node_slots.size()) { | |||
| MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToAbstract() | |||
| << " is invalid. Try to access frame sequence by index " << slot_index_ | |||
| << ", while the size is " << node_slots.size() << "."; | |||
| } | |||
| return node_slots[slot_index_]; | |||
| } | |||
| AnfNodePtr &NextNode() { | |||
| auto ¤t_node = CurrentNode(); | |||
| // Set `done_` true, if the stack frames is being exhausted. | |||
| if (current_node == func_graph_->get_return()) { | |||
| done_ = true; | |||
| } | |||
| // Move cursor to next node. | |||
| slot_index_++; | |||
| return current_node; | |||
| } | |||
| EvaluatorPtr evaluator() const { return evaluator_.lock(); } | |||
| FuncGraphPtr func_graph() const { return func_graph_; } | |||
| AnalysisContextPtr current_context() const { return current_context_; } | |||
| AnalysisContextPtr parent_context() const { return parent_context_; } | |||
| AbstractBasePtrList &args_abs_list() { return args_abs_list_; } | |||
| void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; } | |||
| std::string ToString() const override { | |||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||
| std::ostringstream buffer; | |||
| buffer << "StackFrame: " << this << ", " << func_graph_->ToString(); | |||
| if (slot_index_ < node_slots.size()) { | |||
| auto current_node = node_slots[slot_index_]; | |||
| buffer << "(#" << slot_index_ << " / Running " << current_node->DebugString() << ")"; | |||
| } else { | |||
| buffer << "(Exhausted..)"; | |||
| } | |||
| buffer << ", parent: "; | |||
| auto parent_graph = parent_context_->func_graph(); | |||
| if (parent_graph != nullptr) { | |||
| buffer << parent_graph << "/" << parent_graph->ToString(); | |||
| } else { | |||
| buffer << "NULL"; | |||
| } | |||
| return buffer.str(); | |||
| } | |||
| friend std::ostream &operator<<(std::ostream &os, const StackFramePtr &frame) { | |||
| MS_EXCEPTION_IF_NULL(frame); | |||
| os << frame->ToString(); | |||
| return os; | |||
| } | |||
| private: | |||
| AbstractBasePtrList GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator, | |||
| const CNodePtr current_cnode); | |||
| StackFramePtr DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode, | |||
| const FuncGraphAbstractClosurePtr &graph_func); | |||
| EvaluatorWeakPtr evaluator_; | |||
| FuncGraphPtr func_graph_; | |||
| AnalysisContextPtr current_context_; | |||
| AnalysisContextPtr parent_context_; | |||
| AbstractBasePtrList args_abs_list_; | |||
| std::vector<AnfNodePtr> node_slots; | |||
| size_t slot_index_; | |||
| bool done_; | |||
| }; | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_STACK_FRAME_H_ | |||
| @@ -115,6 +115,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac | |||
| // Running the analyzer. | |||
| ResetFunctionCallDepth(); | |||
| ResetStackFrameDepth(); | |||
| AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); | |||
| MS_EXCEPTION_IF_NULL(root_context); | |||
| MS_EXCEPTION_IF_NULL(root_context->func_graph()); | |||
| @@ -133,7 +134,7 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana | |||
| const ConfigPtrList &args_conf_list) { | |||
| std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context); | |||
| (void)eval->Run(shared_from_this(), args_conf_list, nullptr); | |||
| return eval->graph_context(); | |||
| return eval->context(); | |||
| } | |||
| EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) { | |||
| @@ -152,7 +153,7 @@ EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr & | |||
| } | |||
| MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString() | |||
| << ", result: " << result->abstract().get() << ", " << result->abstract()->ToString(); | |||
| analysis_cache_.set_value(conf, result); | |||
| SaveEvalResultInCache(conf, result); | |||
| return result; | |||
| } | |||
| @@ -179,7 +180,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||
| auto abstract = EvalValueNode(value_node, conf); | |||
| eval_result = std::make_shared<EvalResult>(abstract, std::make_shared<AttrValueMap>()); | |||
| } else if (node->isa<CNode>()) { | |||
| CheckNoStackInSameFuncGraph(conf); | |||
| // CheckNoStackInSameFuncGraph(conf); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| trace::TraceEvalCNodeEnter(conf); | |||
| eval_result = EvalCNode(cnode, conf); | |||
| @@ -224,7 +225,7 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) { | |||
| MS_LOG(EXCEPTION) << "Top evaluator is " << top_evaluator->ToString(); | |||
| } | |||
| auto top_fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(top_evaluator); | |||
| auto top_context_fg = top_fg_evaluator->graph_context()->func_graph(); | |||
| auto top_context_fg = top_fg_evaluator->context()->func_graph(); | |||
| if (current_cnode_fg != top_context_fg) { // Ignore FV call. | |||
| return; | |||
| } | |||
| @@ -249,18 +250,15 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co | |||
| return out; | |||
| } | |||
| EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| AbstractFunctionPtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString(); | |||
| } | |||
| AnfNodePtr func_node = inputs[0]; | |||
| MS_EXCEPTION_IF_NULL(func_node); | |||
| MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString(); | |||
| AnalysisContextPtr context = conf->context(); | |||
| AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); | |||
| MS_EXCEPTION_IF_NULL(func_conf); | |||
| // Keep it in a local variable, otherwise smart pointer will free it. | |||
| @@ -270,32 +268,40 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf | |||
| MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString() | |||
| << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); | |||
| } | |||
| if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||
| MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; | |||
| return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>()); | |||
| } | |||
| AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func); | |||
| if (func == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Not AbstractFunction: " << maybe_func->ToString() << ", func_conf: " << func_conf->ToString() | |||
| << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); | |||
| } | |||
| return func; | |||
| } | |||
| EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| AbstractFunctionPtr func = GetCNodeOperatorAbstract(cnode, conf->context()); | |||
| if (func->BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||
| MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; | |||
| return std::make_shared<EvalResult>(func->Clone(), std::make_shared<AttrValueMap>()); | |||
| } | |||
| ConfigPtrList args_conf_list; | |||
| // ignore the first node which is function name | |||
| // Ignore the first node which is function name | |||
| auto &inputs = cnode->inputs(); | |||
| for (std::size_t i = 1; i < inputs.size(); i++) { | |||
| const AnfNodePtr &node = inputs[i]; | |||
| args_conf_list.push_back(MakeConfig(node, context)); | |||
| args_conf_list.push_back(MakeConfig(node, conf->context())); | |||
| } | |||
| std::vector<EvaluatorPtr> infs; | |||
| std::vector<EvaluatorPtr> evaluators; | |||
| auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) { | |||
| auto build_evaluator = [this, &evaluators, &cnode](const AbstractFuncAtomPtr &poss) { | |||
| auto evaluator = this->GetEvaluatorFor(poss); | |||
| evaluator->set_bound_node(cnode); | |||
| infs.push_back(evaluator); | |||
| evaluators.push_back(evaluator); | |||
| }; | |||
| func->Visit(build_evaluator); | |||
| auto eval_result = ExecuteEvaluators(infs, conf, args_conf_list); | |||
| auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list); | |||
| return eval_result; | |||
| } | |||
| @@ -314,7 +320,7 @@ EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const Abs | |||
| } | |||
| void AnalysisEngine::ClearEvaluatorCache() { | |||
| for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) { | |||
| for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : evaluators_) { | |||
| EvaluatorPtr evaluator = element.second; | |||
| MS_EXCEPTION_IF_NULL(evaluator); | |||
| MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map()); | |||
| @@ -338,7 +344,7 @@ void AnalysisEngine::Clear() { | |||
| analysis_cache_.Clear(); | |||
| anfnode_config_map_.clear(); | |||
| eval_trace_.clear(); | |||
| constructors_.clear(); | |||
| evaluators_.clear(); | |||
| constructors_app_.clear(); | |||
| continued_evals_.clear(); | |||
| } | |||
| @@ -407,38 +413,38 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| } // namespace | |||
| EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) { | |||
| auto inf_pair = constructors_.find(func); | |||
| if (inf_pair != constructors_.end()) { | |||
| auto inf_pair = evaluators_.find(func); | |||
| if (inf_pair != evaluators_.end()) { | |||
| return inf_pair->second; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| auto primitive = func->prim(); | |||
| auto evaluator = GetPrimEvaluator(primitive, shared_from_this()); | |||
| constructors_[func] = evaluator; | |||
| evaluators_[func] = evaluator; | |||
| return evaluator; | |||
| } | |||
| EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) { | |||
| auto inf_pair = constructors_.find(func); | |||
| if (inf_pair != constructors_.end()) { | |||
| auto inf_pair = evaluators_.find(func); | |||
| if (inf_pair != evaluators_.end()) { | |||
| return inf_pair->second; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator = | |||
| std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context()); | |||
| constructors_[func] = func_graph_evaluator; | |||
| evaluators_[func] = func_graph_evaluator; | |||
| return func_graph_evaluator; | |||
| } | |||
| EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) { | |||
| auto inf_pair = constructors_.find(func); | |||
| if (inf_pair != constructors_.end()) { | |||
| auto inf_pair = evaluators_.find(func); | |||
| if (inf_pair != evaluators_.end()) { | |||
| return inf_pair->second; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| std::shared_ptr<MetaFuncGraphEvaluator> evaluator = | |||
| std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->context(), func->GetScope()); | |||
| constructors_[func] = evaluator; | |||
| evaluators_[func] = evaluator; | |||
| return evaluator; | |||
| } | |||
| @@ -505,18 +511,18 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| } | |||
| EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| MS_LOG(DEBUG) << "The func value: " << func->ToString(); | |||
| if (func->tracking_id() != nullptr) { | |||
| MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() || | |||
| func->isa<abstract::FuncGraphAbstractClosure>()) { | |||
| EvaluatorPtr evaluator = _GetEvaluatorFor(func); | |||
| return evaluator; | |||
| } | |||
| auto inf_pair = constructors_.find(func); | |||
| if (inf_pair != constructors_.end()) { | |||
| auto inf_pair = evaluators_.find(func); | |||
| if (inf_pair != evaluators_.end()) { | |||
| return inf_pair->second; | |||
| } | |||
| @@ -524,7 +530,7 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| func_generic->set_tracking_id(nullptr); | |||
| EvaluatorPtr eval = _GetEvaluatorFor(func_generic); | |||
| auto tracked_eval = std::make_shared<TrackedEvaluator>(eval); | |||
| constructors_[func] = tracked_eval; | |||
| evaluators_[func] = tracked_eval; | |||
| return tracked_eval; | |||
| } | |||
| @@ -89,7 +89,7 @@ class AnfNodeConfig : public Config { | |||
| } | |||
| context_ = nullptr; | |||
| if (context != nullptr) { | |||
| context_ = context->Filter(fg); | |||
| context_ = context->FindParentContext(fg); | |||
| } | |||
| } | |||
| @@ -116,8 +116,8 @@ class AnfNodeConfig : public Config { | |||
| std::string ToString() const override { | |||
| std::ostringstream buffer; | |||
| buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId() | |||
| << "), Context: " << context_->ToString(); | |||
| buffer << "Node: " << node_ << "/" << node_->DebugString() << "-uid(" << node_->UniqueId() | |||
| << "), Context: " << context_ << "/" << context_->ToString(); | |||
| return buffer.str(); | |||
| } | |||
| @@ -190,6 +190,9 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| prim_constructors_(prim_evaluator_map), | |||
| func_graph_manager_(func_graph_manager) { | |||
| function_call_depth_ = 0; | |||
| function_call_max_depth_ = 0; | |||
| stack_frame_depth_ = 0; | |||
| stack_frame_max_depth_ = 0; | |||
| forward_count_ = 0; | |||
| } | |||
| ~AnalysisEngine() = default; | |||
| @@ -197,10 +200,16 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| // func_graph: The func_graph to analyze. | |||
| // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. | |||
| AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); | |||
| void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| MS_EXCEPTION_IF_NULL(result); | |||
| analysis_cache_.set_value(conf, result); | |||
| } | |||
| EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf); | |||
| // Return the Evaluator for the given function. | |||
| EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); | |||
| AbstractFunctionPtr GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context); | |||
| AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); | |||
| EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); | |||
| // Infer the result of fn(args). | |||
| @@ -231,18 +240,43 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| AnalysisCache analysis_cache_; | |||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | |||
| void ResetFunctionCallDepth() { function_call_depth_ = 0; } | |||
| void IncreaseFunctionCallDepth() { function_call_depth_++; } | |||
| void ResetFunctionCallDepth() { | |||
| function_call_depth_ = 0; | |||
| function_call_max_depth_ = 0; | |||
| } | |||
| void IncreaseFunctionCallDepth() { | |||
| function_call_depth_++; | |||
| if (function_call_max_depth_ < function_call_depth_) { | |||
| function_call_max_depth_ = function_call_depth_; | |||
| } | |||
| } | |||
| void DecreaseFunctionCallDepth() { | |||
| if (function_call_depth_ == 0) { | |||
| MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it."; | |||
| } | |||
| function_call_depth_--; | |||
| } | |||
| size_t function_call_depth() { return function_call_depth_; } | |||
| size_t function_call_max_depth() { return function_call_max_depth_; } | |||
| uint64_t function_call_depth() { return function_call_depth_; } | |||
| void ResetStackFrameDepth() { | |||
| stack_frame_depth_ = 0; | |||
| stack_frame_max_depth_ = 0; | |||
| } | |||
| void IncreaseStackFrameDepth() { | |||
| stack_frame_depth_++; | |||
| if (stack_frame_max_depth_ < stack_frame_depth_) { | |||
| stack_frame_max_depth_ = stack_frame_depth_; | |||
| } | |||
| } | |||
| void DecreaseStackFrameDepth() { | |||
| if (stack_frame_depth_ == 0) { | |||
| MS_LOG(EXCEPTION) << "Current stack frame depth is already 0, can not decrease it."; | |||
| } | |||
| stack_frame_depth_--; | |||
| } | |||
| size_t stack_frame_depth() { return stack_frame_depth_; } | |||
| size_t stack_frame_max_depth() { return stack_frame_max_depth_; } | |||
| void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf); | |||
| @@ -282,7 +316,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| const PrimEvaluatorMap &prim_constructors_; | |||
| FuncGraphManagerPtr func_graph_manager_; | |||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_; | |||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_; | |||
| std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher> | |||
| constructors_app_; | |||
| AnfNodeConfigMap anfnode_config_map_; | |||
| @@ -299,10 +333,15 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| const ConfigPtrList &args_conf_list); | |||
| EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list); | |||
| // record current depth of function call statck | |||
| uint64_t function_call_depth_; | |||
| // Record current depth of function call stack, including `stack_frame_depth_`. | |||
| size_t function_call_depth_; | |||
| size_t function_call_max_depth_; | |||
| // Record current depth of stack frames call. | |||
| size_t stack_frame_depth_; | |||
| size_t stack_frame_max_depth_; | |||
| uint64_t forward_count_; | |||
| size_t forward_count_; | |||
| #ifdef DEBUG | |||
| std::vector<AnfNodePtr> compute_conf_stack_; | |||
| @@ -133,7 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| // so different tracking_id will produce different FuncGraphAbstractClosure, | |||
| // different FuncGraphEvaluator. | |||
| // Espcecially useful for recursive func graph call, so it will not mess up | |||
| // the graph_context_ in FuncGraphEvaluator. | |||
| // the `context_` in FuncGraphEvaluator. | |||
| // Notes: Be careful to use nullptr for this variable. | |||
| // store it as weak_ptr to break reference cycle. | |||
| AnfNodeWeakPtr tracking_id_; | |||
| @@ -23,39 +23,39 @@ | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, | |||
| AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent_context, FuncGraphPtr fg, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| auto children_context_map_iter = parent->children_cache_.find(fg); | |||
| if (children_context_map_iter != parent->children_cache_.end()) { | |||
| auto children_context_map_iter = parent_context->children_cache_.find(fg); | |||
| if (children_context_map_iter != parent_context->children_cache_.end()) { | |||
| auto children_context_map = children_context_map_iter->second; | |||
| auto children_context_iter = children_context_map.find(args_spec_list); | |||
| if (children_context_iter != children_context_map.end()) { | |||
| return children_context_iter->second.lock(); | |||
| } | |||
| } | |||
| AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(parent, fg, args_spec_list); | |||
| AnalysisContextPtr new_context = std::make_shared<AnalysisContext>(parent_context, fg, args_spec_list); | |||
| // Reference to myself, so use weak_ptr to break reference cycle. | |||
| auto weak_context = std::weak_ptr<AnalysisContext>(context_new); | |||
| context_new->parent_cache_[fg] = weak_context; | |||
| parent->children_cache_[fg][args_spec_list] = weak_context; | |||
| return context_new; | |||
| auto weak_context = std::weak_ptr<AnalysisContext>(new_context); | |||
| new_context->parent_cache_[fg] = weak_context; | |||
| parent_context->children_cache_[fg][args_spec_list] = weak_context; | |||
| return new_context; | |||
| } | |||
| AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| FuncGraphPtr graph_parent = func_graph->parent(); | |||
| auto iter = parent_cache_.find(graph_parent); | |||
| FuncGraphPtr parent_graph = func_graph->parent(); | |||
| AnalysisContextPtr parent_context = nullptr; | |||
| auto iter = parent_cache_.find(parent_graph); | |||
| if (iter != parent_cache_.end()) { | |||
| parent_context = iter->second.lock(); | |||
| } | |||
| // if this happen, it will be bug in code. but we raise exception to keep the scene. | |||
| // If this happen, it will be a bug in code. But we raise exception to keep the scene. | |||
| if (parent_context == nullptr) { | |||
| std::ostringstream oss; | |||
| oss << "BUG: cannot found parent_context in current context: " << this->ToString() | |||
| << ", func_graph: " << func_graph->ToString() << ", graph_parent: "; | |||
| if (graph_parent != nullptr) { | |||
| oss << graph_parent->ToString(); | |||
| oss << "BUG: Failed to find parent context in current context: " << this->ToString() | |||
| << ", func_graph: " << func_graph->ToString() << ", parent_graph: "; | |||
| if (parent_graph != nullptr) { | |||
| oss << parent_graph->ToString(); | |||
| } else { | |||
| oss << "nullptr"; | |||
| } | |||
| @@ -64,7 +64,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func | |||
| return NewContext(parent_context, func_graph, args_spec_list); | |||
| } | |||
| AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) { | |||
| AnalysisContextPtr AnalysisContext::FindParentContext(const FuncGraphPtr &func_graph) { | |||
| auto p_iter = parent_cache_.find(func_graph); | |||
| AnalysisContextPtr parent_context = nullptr; | |||
| if (p_iter != parent_cache_.end()) { | |||
| @@ -75,10 +75,10 @@ AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) { | |||
| parent_context = iter_parent->second.lock(); | |||
| } | |||
| } | |||
| // if this happen, it will be bug in code. but we raise exception to keep the scene. | |||
| // If this happen, it would be a bug in code. But we raise exception to keep the scene. | |||
| if (parent_context == nullptr) { | |||
| std::ostringstream oss; | |||
| oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: "; | |||
| oss << "BUG: Failed to find parent context for: " << func_graph->ToString() << ", parent_graph: "; | |||
| if (func_graph->parent() != nullptr) { | |||
| oss << func_graph->parent()->ToString(); | |||
| } else { | |||
| @@ -52,7 +52,7 @@ class AnalysisContext { | |||
| AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); | |||
| // Return a context restricted to a graph's dependencies. | |||
| AnalysisContextPtr Filter(const FuncGraphPtr &graph); | |||
| AnalysisContextPtr FindParentContext(const FuncGraphPtr &graph); | |||
| bool operator==(const AnalysisContext &other) const; | |||
| std::size_t hash(); | |||
| static AnalysisContextPtr DummyContext(); | |||
| @@ -176,7 +176,11 @@ void FuncGraph::DumpCNodeList() { | |||
| } | |||
| std::string FuncGraph::ToString() const { | |||
| return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info()); | |||
| std::ostringstream buffer; | |||
| auto debug_info = const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info(); | |||
| buffer << mindspore::label_manage::Label(debug_info); | |||
| buffer << "." << debug_info->get_id(); | |||
| return buffer.str(); | |||
| } | |||
| GraphDebugInfoPtr FuncGraph::debug_info() { | |||
| @@ -295,22 +295,22 @@ TEST_F(TestInferGraph, test_context) { | |||
| AnalysisContextPtr dummy_context = AnalysisContext::DummyContext(); | |||
| AnalysisContextPtr f_context = dummy_context->NewFuncGraphContext(graph_f_, AbstractBasePtrList()); | |||
| ASSERT_TRUE(f_context->Filter(graph_f_) = f_context); | |||
| ASSERT_TRUE(f_context->Filter(nullptr) = dummy_context); | |||
| ASSERT_TRUE(f_context->FindParentContext(graph_f_) = f_context); | |||
| ASSERT_TRUE(f_context->FindParentContext(nullptr) = dummy_context); | |||
| AnalysisContextPtr g_context = f_context->NewFuncGraphContext(graph_g_, AbstractBasePtrList()); | |||
| ASSERT_TRUE(g_context->Filter(graph_g_) = g_context); | |||
| ASSERT_TRUE(g_context->Filter(graph_f_) = dummy_context); | |||
| ASSERT_TRUE(g_context->Filter(nullptr) = dummy_context); | |||
| ASSERT_TRUE(g_context->FindParentContext(graph_g_) = g_context); | |||
| ASSERT_TRUE(g_context->FindParentContext(graph_f_) = dummy_context); | |||
| ASSERT_TRUE(g_context->FindParentContext(nullptr) = dummy_context); | |||
| AnalysisContextPtr alpha_context = dummy_context->NewFuncGraphContext(graph_alpha_, AbstractBasePtrList()); | |||
| ASSERT_TRUE(alpha_context->Filter(graph_alpha_) = alpha_context); | |||
| ASSERT_TRUE(alpha_context->Filter(nullptr) = dummy_context); | |||
| ASSERT_TRUE(alpha_context->FindParentContext(graph_alpha_) = alpha_context); | |||
| ASSERT_TRUE(alpha_context->FindParentContext(nullptr) = dummy_context); | |||
| AnalysisContextPtr beta_context = alpha_context->NewFuncGraphContext(graph_beta_, AbstractBasePtrList()); | |||
| ASSERT_TRUE(beta_context->Filter(graph_beta_) = beta_context); | |||
| ASSERT_TRUE(beta_context->Filter(graph_alpha_) = alpha_context); | |||
| ASSERT_TRUE(beta_context->Filter(nullptr) = dummy_context); | |||
| ASSERT_TRUE(beta_context->FindParentContext(graph_beta_) = beta_context); | |||
| ASSERT_TRUE(beta_context->FindParentContext(graph_alpha_) = alpha_context); | |||
| ASSERT_TRUE(beta_context->FindParentContext(nullptr) = dummy_context); | |||
| } | |||
| class TestInferMetaGraph : public UT::Common { | |||