Merge pull request !1120 from Kang/mastertags/v0.3.0-alpha
| @@ -182,7 +182,7 @@ void DumpInferStack(std::ostringstream &oss) { | |||||
| } | } | ||||
| } | } | ||||
| void TraceGraphInfer() { | |||||
| void TraceGraphEval() { | |||||
| auto &infer_stack = GetCurrenGraphInferStack(); | auto &infer_stack = GetCurrenGraphInferStack(); | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (infer_stack.empty()) { | if (infer_stack.empty()) { | ||||
| @@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, | |||||
| ofs.close(); | ofs.close(); | ||||
| } | } | ||||
| void GetInferStackInfo(std::ostringstream &oss) { | |||||
| void GetEvalStackInfo(std::ostringstream &oss) { | |||||
| MS_LOG(INFO) << "Get graph analysis information begin"; | MS_LOG(INFO) << "Get graph analysis information begin"; | ||||
| auto stack = GetCNodeDebugStack(); | auto stack = GetCNodeDebugStack(); | ||||
| if (stack.empty()) { | if (stack.empty()) { | ||||
| @@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream &oss) { | |||||
| static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; | static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; | ||||
| // trace the cnode infer debug info | // trace the cnode infer debug info | ||||
| static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; | static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; | ||||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { | |||||
| void TraceGraphEvalEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { | |||||
| if (eval == nullptr) { | if (eval == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | ||||
| } | } | ||||
| @@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::An | |||||
| } | } | ||||
| } | } | ||||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { | |||||
| void TraceGraphEvalLeave(const abstract::EvaluatorPtr &eval) { | |||||
| if (eval == nullptr) { | if (eval == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | ||||
| } | } | ||||
| @@ -354,9 +354,9 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { | |||||
| } | } | ||||
| } | } | ||||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||||
| void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||||
| void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } | |||||
| void TraceEvalCNodeLeave() { cnode_debug_stack.pop_back(); } | |||||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; } | std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; } | ||||
| @@ -35,12 +35,12 @@ std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLi | |||||
| std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, | std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, | ||||
| SourceLineTip tip = kSourceLineTipNextLine); | SourceLineTip tip = kSourceLineTipNextLine); | ||||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); | DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); | ||||
| void TraceGraphInfer(); | |||||
| void GetInferStackInfo(std::ostringstream &oss); | |||||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); | |||||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); | |||||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); | |||||
| void TraceInferCNodeLeave(); | |||||
| void TraceGraphEval(); | |||||
| void GetEvalStackInfo(std::ostringstream &oss); | |||||
| void TraceGraphEvalEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); | |||||
| void TraceGraphEvalLeave(const abstract::EvaluatorPtr &eval); | |||||
| void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); | |||||
| void TraceEvalCNodeLeave(); | |||||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack(); | std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack(); | ||||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack(); | std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack(); | ||||
| std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); | std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); | ||||
| @@ -430,8 +430,8 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py: | |||||
| } catch (const py::error_already_set &ex) { | } catch (const py::error_already_set &ex) { | ||||
| // print function call stack info before release | // print function call stack info before release | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| trace::TraceGraphInfer(); | |||||
| trace::GetInferStackInfo(oss); | |||||
| trace::TraceGraphEval(); | |||||
| trace::GetEvalStackInfo(oss); | |||||
| // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see | // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see | ||||
| // these info from screen, no need to open log file to find these info | // these info from screen, no need to open log file to find these info | ||||
| py::print(oss.str()); | py::print(oss.str()); | ||||
| @@ -38,7 +38,7 @@ namespace abstract { | |||||
| class AbstractBase; | class AbstractBase; | ||||
| using AbstractBasePtrList = std::vector<AbstractBasePtr>; | using AbstractBasePtrList = std::vector<AbstractBasePtr>; | ||||
| // The base class for abstract value. The abstract value is used in inferring | |||||
| // The base class for abstract value. The abstract value is used in evaluating | |||||
| // to express the type, shape, and value of the real value. | // to express the type, shape, and value of the real value. | ||||
| class AbstractBase : public Base { | class AbstractBase : public Base { | ||||
| public: | public: | ||||
| @@ -153,7 +153,7 @@ bool AnalysisContext::operator==(const AnalysisContext &other) const { | |||||
| // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what | // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what | ||||
| // graph can be reused. | // graph can be reused. | ||||
| // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined | // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined | ||||
| // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in infer, thus the reused | |||||
| // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused | |||||
| // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize. | // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize. | ||||
| // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies | // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies | ||||
| // on correct shape to specialize a tensor constant. | // on correct shape to specialize a tensor constant. | ||||
| @@ -26,8 +26,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| namespace { | namespace { | ||||
| void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, | |||||
| const AnfNodeConfigPtr &out_conf) { | |||||
| void EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, | |||||
| const AnfNodeConfigPtr &out_conf) { | |||||
| MS_EXCEPTION_IF_NULL(evaluator); | MS_EXCEPTION_IF_NULL(evaluator); | ||||
| if (out_conf != nullptr) { | if (out_conf != nullptr) { | ||||
| MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); | MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); | ||||
| @@ -37,7 +37,7 @@ void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList | |||||
| } | } | ||||
| } | } | ||||
| void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { | |||||
| void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { | |||||
| MS_EXCEPTION_IF_NULL(evaluator); | MS_EXCEPTION_IF_NULL(evaluator); | ||||
| if (out_conf != nullptr) { | if (out_conf != nullptr) { | ||||
| auto node = out_conf->node(); | auto node = out_conf->node(); | ||||
| @@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) { | |||||
| return sorted_nodes; | return sorted_nodes; | ||||
| } | } | ||||
| AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | |||||
| AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | |||||
| FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | ||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| std::size_t nargs = fg->parameters().size(); | std::size_t nargs = fg->parameters().size(); | ||||
| @@ -124,7 +124,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(ret_base); | MS_EXCEPTION_IF_NULL(ret_base); | ||||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString(); | |||||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString(); | |||||
| return ret_base; | return ret_base; | ||||
| } | } | ||||
| @@ -155,7 +155,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||||
| << ", context: " << parent_context_->ToString(); | << ", context: " << parent_context_->ToString(); | ||||
| auto last_context = parent_context_->Filter(func_graph_); | auto last_context = parent_context_->Filter(func_graph_); | ||||
| if (last_context && last_context->func_graph() == func_graph_) { | if (last_context && last_context->func_graph() == func_graph_) { | ||||
| MS_LOG(DEBUG) << "Find last infer context: " << last_context->ToString(); | |||||
| MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString(); | |||||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | ||||
| MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); | MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); | ||||
| // Join the last eval arguments and current arguments to check if there are loop variant. | // Join the last eval arguments and current arguments to check if there are loop variant. | ||||
| @@ -248,26 +248,26 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar | |||||
| }); | }); | ||||
| args_spec_list = NormalizeArgs(args_spec_list); | args_spec_list = NormalizeArgs(args_spec_list); | ||||
| args_spec_list = BroadenUndeterminedArgs(args_spec_list); | args_spec_list = BroadenUndeterminedArgs(args_spec_list); | ||||
| trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf); | |||||
| InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | |||||
| trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf); | |||||
| EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | |||||
| MS_EXCEPTION_IF_NULL(cache_); | MS_EXCEPTION_IF_NULL(cache_); | ||||
| auto iter = cache_->find(args_spec_list); | auto iter = cache_->find(args_spec_list); | ||||
| if (iter == cache_->end()) { | if (iter == cache_->end()) { | ||||
| MS_LOG(DEBUG) << evaluator_name << " cache miss, call Infer()."; | |||||
| AbstractBasePtr ret = Infer(engine, args_spec_list); | |||||
| MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; | |||||
| AbstractBasePtr ret = Eval(engine, args_spec_list); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| InferFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | |||||
| EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | |||||
| MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; | MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(ret); | MS_EXCEPTION_IF_NULL(ret); | ||||
| MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << "."; | MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << "."; | ||||
| (*cache_)[args_spec_list] = ret; | (*cache_)[args_spec_list] = ret; | ||||
| trace::TraceGraphInferLeave(shared_from_base<Evaluator>()); | |||||
| trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | |||||
| return ret; | return ret; | ||||
| } else { | } else { | ||||
| MS_EXCEPTION_IF_NULL(iter->second); | MS_EXCEPTION_IF_NULL(iter->second); | ||||
| MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; | MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; | ||||
| trace::TraceGraphInferLeave(shared_from_base<Evaluator>()); | |||||
| trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| } | } | ||||
| @@ -378,7 +378,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a | |||||
| return jtuple; | return jtuple; | ||||
| } | } | ||||
| AbstractBasePtr VirtualEvaluator::Infer(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { | |||||
| AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { | |||||
| if (args_spec_list.size() != args_spec_list_.size()) { | if (args_spec_list.size() != args_spec_list_.size()) { | ||||
| MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() | MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() | ||||
| << ", arguments no: " << args_spec_list.size(); | << ", arguments no: " << args_spec_list.size(); | ||||
| @@ -38,12 +38,12 @@ class Evaluator : public Base { | |||||
| ~Evaluator() override = default; | ~Evaluator() override = default; | ||||
| MS_DECLARE_PARENT(Evaluator, Base); | MS_DECLARE_PARENT(Evaluator, Base); | ||||
| // difference between Run() and Infer(): | |||||
| // Run() will be called with ConfigPtrList, but Infer() will be called with AbstractBasePtr. | |||||
| // difference between Run() and Eval(): | |||||
| // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. | |||||
| // Run() will modify cache_ member, so it cannot marked as const; | // Run() will modify cache_ member, so it cannot marked as const; | ||||
| virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); | virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); | ||||
| virtual AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||||
| virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||||
| virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | ||||
| @@ -71,8 +71,8 @@ class PrimEvaluator : public Evaluator { | |||||
| explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} | explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} | ||||
| ~PrimEvaluator() override = default; | ~PrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(PrimEvaluator, Evaluator); | MS_DECLARE_PARENT(PrimEvaluator, Evaluator); | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) final { | |||||
| MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -113,7 +113,7 @@ class DummyEvaluator : public Evaluator { | |||||
| DummyEvaluator() : Evaluator("dummy") {} | DummyEvaluator() : Evaluator("dummy") {} | ||||
| ~DummyEvaluator() override = default; | ~DummyEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(DummyEvaluator, Evaluator); | MS_DECLARE_PARENT(DummyEvaluator, Evaluator); | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } | |||||
| }; | }; | ||||
| // Wrap another evaluator to track a subset of uses. | // Wrap another evaluator to track a subset of uses. | ||||
| @@ -139,8 +139,8 @@ class TrackedEvaluator : public Evaluator { | |||||
| bound_node_ = AnfNodeWeakPtr(node); | bound_node_ = AnfNodeWeakPtr(node); | ||||
| } | } | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||||
| } | } | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | ||||
| AnfNodeConfigPtr out_conf) override; | AnfNodeConfigPtr out_conf) override; | ||||
| @@ -158,7 +158,7 @@ class BaseFuncGraphEvaluator : public Evaluator { | |||||
| ~BaseFuncGraphEvaluator() override = default; | ~BaseFuncGraphEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); | MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | ||||
| @@ -238,7 +238,7 @@ class PartialAppEvaluator : public Evaluator { | |||||
| } | } | ||||
| bound_node_ = AnfNodeWeakPtr(node); | bound_node_ = AnfNodeWeakPtr(node); | ||||
| } | } | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | ||||
| } | } | ||||
| @@ -258,7 +258,7 @@ class VirtualEvaluator : public Evaluator { | |||||
| ~VirtualEvaluator() override = default; | ~VirtualEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); | MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| std::string ToString() const override { return identifier_; } | std::string ToString() const override { return identifier_; } | ||||
| private: | private: | ||||
| @@ -285,7 +285,7 @@ class JEvaluator : public Evaluator { | |||||
| } | } | ||||
| bound_node_ = AnfNodeWeakPtr(node); | bound_node_ = AnfNodeWeakPtr(node); | ||||
| } | } | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | ||||
| } | } | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | ||||
| @@ -470,16 +470,16 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const | |||||
| } | } | ||||
| } | } | ||||
| ValuePtr inferred_value = RunImpl(value_list); | |||||
| if (!(*inferred_value == *kAnyValue)) { | |||||
| ret_value_type = inferred_value->type(); | |||||
| ValuePtr evaluated_value = RunImpl(value_list); | |||||
| if (!(*evaluated_value == *kAnyValue)) { | |||||
| ret_value_type = evaluated_value->type(); | |||||
| } | } | ||||
| // for comparison primitives , return type shall have be specified to be bool. | // for comparison primitives , return type shall have be specified to be bool. | ||||
| if (specify_out_type_ != nullptr) { | if (specify_out_type_ != nullptr) { | ||||
| ret_value_type = specify_out_type_; | ret_value_type = specify_out_type_; | ||||
| } | } | ||||
| AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(inferred_value, ret_value_type); | |||||
| AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type); | |||||
| return abs_base; | return abs_base; | ||||
| } | } | ||||
| @@ -997,8 +997,8 @@ class PartialEvaluator : public Evaluator { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||||
| } | } | ||||
| AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, | AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, | ||||
| @@ -79,8 +79,8 @@ class DoSignatureEvaluator : public Evaluator { | |||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | ||||
| AnfNodeConfigPtr out_config = nullptr) override; | AnfNodeConfigPtr out_config = nullptr) override; | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||||
| } | } | ||||
| private: | private: | ||||
| @@ -94,8 +94,8 @@ class UnpackGraphEvaluator : public Evaluator { | |||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | ||||
| AnfNodeConfigPtr out_config = nullptr) override; | AnfNodeConfigPtr out_config = nullptr) override; | ||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||||
| } | } | ||||
| private: | private: | ||||
| @@ -183,11 +183,11 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||||
| ret_abstract = EvalValueNode(value_node, conf); | ret_abstract = EvalValueNode(value_node, conf); | ||||
| } else if (node->isa<CNode>()) { | } else if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| trace::TraceInferCNodeEnter(conf); | |||||
| ret_abstract = InferCNode(cnode, conf); | |||||
| trace::TraceInferCNodeLeave(); | |||||
| trace::TraceEvalCNodeEnter(conf); | |||||
| ret_abstract = EvalCNode(cnode, conf); | |||||
| trace::TraceEvalCNodeLeave(); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Illegal AnfNode for inferring, " << node->DebugString() | |||||
| MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() | |||||
| << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | ||||
| } | } | ||||
| @@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co | |||||
| return ToAbstract(value_node->value(), conf->context(), conf); | return ToAbstract(value_node->value(), conf->context(), conf); | ||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||||
| AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto &inputs = cnode->inputs(); | auto &inputs = cnode->inputs(); | ||||
| @@ -496,7 +496,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||||
| auto current_inf = std::make_pair(eval, args_spec_list); | auto current_inf = std::make_pair(eval, args_spec_list); | ||||
| MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | ||||
| // If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring. | |||||
| // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. | |||||
| auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); | auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); | ||||
| if (it == eval_trace_.rend()) { | if (it == eval_trace_.rend()) { | ||||
| eval_trace_.push_back(current_inf); | eval_trace_.push_back(current_inf); | ||||
| @@ -607,7 +607,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { | |||||
| return a; | return a; | ||||
| } | } | ||||
| AbstractBasePtr InferOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { | |||||
| AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { | |||||
| auto evaluator = GetPrimEvaluator(primitive, nullptr); | auto evaluator = GetPrimEvaluator(primitive, nullptr); | ||||
| MS_EXCEPTION_IF_NULL(evaluator); | MS_EXCEPTION_IF_NULL(evaluator); | ||||
| if (!evaluator->isa<TrivialPrimEvaluator>()) { | if (!evaluator->isa<TrivialPrimEvaluator>()) { | ||||
| @@ -165,7 +165,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); | EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); | ||||
| AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); | AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); | ||||
| AbstractBasePtr InferCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); | |||||
| AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); | |||||
| // Infer the result of fn(args). | // Infer the result of fn(args). | ||||
| AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); | AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); | ||||
| void Clear(); | void Clear(); | ||||
| @@ -244,7 +244,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { | |||||
| return FromValueInside(MakeValue(value), broaden); | return FromValueInside(MakeValue(value), broaden); | ||||
| } | } | ||||
| AbstractBasePtr InferOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | |||||
| AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI | |||||
| args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); | args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); | ||||
| } | } | ||||
| } | } | ||||
| AbstractBasePtr infer_res = InferOnePrim(prim, args_spec_list); | |||||
| AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list); | |||||
| op_exec_info->abstract = infer_res; | op_exec_info->abstract = infer_res; | ||||
| } | } | ||||
| @@ -216,8 +216,8 @@ void LogWriter::operator^(const LogStream &stream) const { | |||||
| } | } | ||||
| oss << msg.str(); | oss << msg.str(); | ||||
| trace::TraceGraphInfer(); | |||||
| trace::GetInferStackInfo(oss); | |||||
| trace::TraceGraphEval(); | |||||
| trace::GetEvalStackInfo(oss); | |||||
| if (exception_type_ == IndexError) { | if (exception_type_ == IndexError) { | ||||
| throw pybind11::index_error(oss.str()); | throw pybind11::index_error(oss.str()); | ||||
| @@ -396,9 +396,9 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { | |||||
| } | } | ||||
| class TestInferOnePrim : public UT::Common { | |||||
| class TestEvalOnePrim : public UT::Common { | |||||
| public: | public: | ||||
| TestInferOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {} | |||||
| TestEvalOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {} | |||||
| void SetUp(); | void SetUp(); | ||||
| void TearDown(); | void TearDown(); | ||||
| @@ -406,37 +406,37 @@ class TestInferOnePrim : public UT::Common { | |||||
| AnalysisEnginePtr engine_; | AnalysisEnginePtr engine_; | ||||
| }; | }; | ||||
| void TestInferOnePrim::SetUp() { engine_ = SetupAnalysisEngineStub(); } | |||||
| void TestEvalOnePrim::SetUp() { engine_ = SetupAnalysisEngineStub(); } | |||||
| void TestInferOnePrim::TearDown() { | |||||
| void TestEvalOnePrim::TearDown() { | |||||
| // destroy resource | // destroy resource | ||||
| } | } | ||||
| TEST_F(TestInferOnePrim, test_scalar_add) { | |||||
| TEST_F(TestEvalOnePrim, test_scalar_add) { | |||||
| double x1 = 1.1; | double x1 = 1.1; | ||||
| double x2 = 1.1; | double x2 = 1.1; | ||||
| double x3 = 2.2; | double x3 = 2.2; | ||||
| AbstractBasePtr base1 = FromValue(x1, false); | AbstractBasePtr base1 = FromValue(x1, false); | ||||
| AbstractBasePtr base2 = FromValue(x2, false); | AbstractBasePtr base2 = FromValue(x2, false); | ||||
| AbstractBasePtrList base_list = {base1, base2}; | AbstractBasePtrList base_list = {base1, base2}; | ||||
| auto res = InferOnePrim(std::make_shared<Primitive>("scalar_add"), base_list); | |||||
| auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list); | |||||
| MS_LOG(INFO) << "result spec: " << res->ToString(); | MS_LOG(INFO) << "result spec: " << res->ToString(); | ||||
| AbstractBasePtr exp = FromValue(x3, false); | AbstractBasePtr exp = FromValue(x3, false); | ||||
| MS_LOG(INFO) << "result exp: " << exp->ToString(); | MS_LOG(INFO) << "result exp: " << exp->ToString(); | ||||
| ASSERT_EQ(*res, *exp); | ASSERT_EQ(*res, *exp); | ||||
| } | } | ||||
| class TestGraphInfer : public UT::Common { | |||||
| class TestGraphEval : public UT::Common { | |||||
| public: | public: | ||||
| TestGraphInfer() : getPyFun("gtest_input.pipeline.infer.infer_test", true){}; | |||||
| TestGraphEval() : getPyFun("gtest_input.pipeline.infer.infer_test", true){}; | |||||
| void SetUp(); | void SetUp(); | ||||
| void TearDown(); | void TearDown(); | ||||
| AnalysisEnginePtr engine_; | AnalysisEnginePtr engine_; | ||||
| UT::PyFuncGraphFetcher getPyFun; | UT::PyFuncGraphFetcher getPyFun; | ||||
| }; | }; | ||||
| void TestGraphInfer::SetUp() { engine_ = SetupAnalysisEngine(); } | |||||
| void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); } | |||||
| void TestGraphInfer::TearDown() { | |||||
| void TestGraphEval::TearDown() { | |||||
| // destroy resource | // destroy resource | ||||
| engine_->ClearEvaluatorCache(); | engine_->ClearEvaluatorCache(); | ||||
| parse::data_converter::ClearObjectCache(); | parse::data_converter::ClearObjectCache(); | ||||