Merge pull request !1383 from amongo/KeepPrimAttrInCNodetags/v0.5.0-beta
| @@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { | |||||
| auto ctx = node_cfg_->context(); | auto ctx = node_cfg_->context(); | ||||
| auto engine = node_cfg_->engine(); | auto engine = node_cfg_->engine(); | ||||
| auto cfg = engine->MakeConfig(node, ctx); | auto cfg = engine->MakeConfig(node, ctx); | ||||
| auto abs = engine->cache().GetValue(cfg); | |||||
| if (abs == nullptr) { | |||||
| auto eval_result = engine->cache().GetValue(cfg); | |||||
| if (eval_result == nullptr || eval_result->abstract() == nullptr) { | |||||
| return "Undefined"; | return "Undefined"; | ||||
| } | } | ||||
| auto abs = eval_result->abstract(); | |||||
| auto dtype = abs->BuildType(); | auto dtype = abs->BuildType(); | ||||
| auto shape = abs->BuildShape(); | auto shape = abs->BuildShape(); | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| @@ -42,7 +42,11 @@ enum PrimType { | |||||
| class Primitive : public Named { | class Primitive : public Named { | ||||
| public: | public: | ||||
| explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) | explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) | ||||
| : Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {} | |||||
| : Named(name), | |||||
| is_base_(is_base), | |||||
| has_signature_(false), | |||||
| prim_type_(prim_type), | |||||
| record_evaluate_add_attr_(false) {} | |||||
| Primitive(const Primitive &prim) | Primitive(const Primitive &prim) | ||||
| : Named(prim), | : Named(prim), | ||||
| @@ -50,14 +54,23 @@ class Primitive : public Named { | |||||
| instance_name_(prim.instance_name_), | instance_name_(prim.instance_name_), | ||||
| is_base_(prim.is_base_), | is_base_(prim.is_base_), | ||||
| has_signature_(prim.has_signature_), | has_signature_(prim.has_signature_), | ||||
| prim_type_(prim.prim_type_) {} | |||||
| prim_type_(prim.prim_type_), | |||||
| record_evaluate_add_attr_(false) {} | |||||
| MS_DECLARE_PARENT(Primitive, Named); | MS_DECLARE_PARENT(Primitive, Named); | ||||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | ||||
| std::string ToString() const override { return name(); } | std::string ToString() const override { return name(); } | ||||
| void BeginRecordAddAttr() { | |||||
| evaluate_added_attrs_.clear(); | |||||
| record_evaluate_add_attr_ = true; | |||||
| } | |||||
| void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } | |||||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | ||||
| attrs_[name] = attr; | attrs_[name] = attr; | ||||
| if (record_evaluate_add_attr_) { | |||||
| evaluate_added_attrs_[name] = attr; | |||||
| } | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -80,6 +93,7 @@ class Primitive : public Named { | |||||
| py::function hook() const { return hook_; } | py::function hook() const { return hook_; } | ||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | ||||
| std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; } | |||||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | ||||
| bool HasAttr() const { return !attrs_.empty(); } | bool HasAttr() const { return !attrs_.empty(); } | ||||
| @@ -106,6 +120,7 @@ class Primitive : public Named { | |||||
| protected: | protected: | ||||
| std::unordered_map<std::string, ValuePtr> attrs_; | std::unordered_map<std::string, ValuePtr> attrs_; | ||||
| std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_; | |||||
| private: | private: | ||||
| std::string instance_name_; | std::string instance_name_; | ||||
| @@ -113,6 +128,7 @@ class Primitive : public Named { | |||||
| bool is_base_; | bool is_base_; | ||||
| bool has_signature_; | bool has_signature_; | ||||
| PrimType prim_type_; | PrimType prim_type_; | ||||
| bool record_evaluate_add_attr_; | |||||
| }; | }; | ||||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | ||||
| @@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv | |||||
| } | } | ||||
| subargs.push_back(AbstractJoin(l_ptr->elements())); | subargs.push_back(AbstractJoin(l_ptr->elements())); | ||||
| } | } | ||||
| AbstractBasePtr engin_exc = engine->Execute(fn, subargs); | |||||
| EvalResultPtr engin_exc = engine->Execute(fn, subargs); | |||||
| AbstractBasePtrList result; | AbstractBasePtrList result; | ||||
| for (std::size_t i = 1; i < args_spec_list.size(); i++) { | for (std::size_t i = 1; i < args_spec_list.size(); i++) { | ||||
| result.push_back(engin_exc); | |||||
| result.push_back(engin_exc->abstract()); | |||||
| } | } | ||||
| return std::make_shared<AbstractList>(result); | return std::make_shared<AbstractList>(result); | ||||
| } | } | ||||
| @@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi | |||||
| AbstractBasePtr list_type = AbstractJoin(lst->elements()); | AbstractBasePtr list_type = AbstractJoin(lst->elements()); | ||||
| auto result1 = engine->Execute(fn, lst->elements()); | auto result1 = engine->Execute(fn, lst->elements()); | ||||
| auto result2 = engine->Execute(fn, {dflt, list_type}); | auto result2 = engine->Execute(fn, {dflt, list_type}); | ||||
| MS_EXCEPTION_IF_NULL(result1); | |||||
| return result1->Join(result2); | |||||
| MS_EXCEPTION_IF_NULL(result1->abstract()); | |||||
| MS_EXCEPTION_IF_NULL(result2->abstract()); | |||||
| return result1->abstract()->Join(result2->abstract()); | |||||
| } | } | ||||
| AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) { | |||||
| return sorted_nodes; | return sorted_nodes; | ||||
| } | } | ||||
| AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | |||||
| EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | |||||
| FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | ||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| std::size_t nargs = fg->parameters().size(); | std::size_t nargs = fg->parameters().size(); | ||||
| @@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs | |||||
| const auto &arg = args_spec_list[i]; | const auto &arg = args_spec_list[i]; | ||||
| const auto &node = parameters[i]; | const auto &node = parameters[i]; | ||||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); | AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); | ||||
| engine->cache().set_value(conf, arg); | |||||
| engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr)); | |||||
| } | } | ||||
| const AnfNodePtr &func_node = fg->get_return(); | const AnfNodePtr &func_node = fg->get_return(); | ||||
| @@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs | |||||
| const auto &node = *it; | const auto &node = *it; | ||||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); | AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); | ||||
| MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); | MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); | ||||
| ret_base = engine->GetEvaluatedValue(node_conf); | |||||
| ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); | |||||
| MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() | MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() | ||||
| << ", abstract: " << ret_base->ToString(); | << ", abstract: " << ret_base->ToString(); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(ret_base); | MS_EXCEPTION_IF_NULL(ret_base); | ||||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString(); | |||||
| return ret_base; | |||||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString(); | |||||
| return std::make_shared<EvalResult>(ret_base, nullptr); | |||||
| } | } | ||||
| AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | ||||
| @@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons | |||||
| return cloned_func_graph; | return cloned_func_graph; | ||||
| } | } | ||||
| AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { | |||||
| const std::string &evaluator_name = ToString(); | const std::string &evaluator_name = ToString(); | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| }); | }); | ||||
| args_spec_list = NormalizeArgs(args_spec_list); | args_spec_list = NormalizeArgs(args_spec_list); | ||||
| args_spec_list = BroadenUndeterminedArgs(args_spec_list); | args_spec_list = BroadenUndeterminedArgs(args_spec_list); | ||||
| @@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar | |||||
| auto iter = cache_->find(args_spec_list); | auto iter = cache_->find(args_spec_list); | ||||
| if (iter == cache_->end()) { | if (iter == cache_->end()) { | ||||
| MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; | MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; | ||||
| AbstractBasePtr ret = Eval(engine, args_spec_list); | |||||
| if (ret == nullptr) { | |||||
| EvalResultPtr ret = Eval(engine, args_spec_list); | |||||
| if (ret->abstract() == nullptr) { | |||||
| EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | ||||
| MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; | MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(ret); | |||||
| MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << "."; | |||||
| MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; | |||||
| (*cache_)[args_spec_list] = ret; | (*cache_)[args_spec_list] = ret; | ||||
| trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | ||||
| return ret; | return ret; | ||||
| } else { | } else { | ||||
| MS_EXCEPTION_IF_NULL(iter->second); | MS_EXCEPTION_IF_NULL(iter->second); | ||||
| MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; | |||||
| MS_EXCEPTION_IF_NULL(iter->second->abstract()); | |||||
| MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; | |||||
| trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | ||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| } | } | ||||
| AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr) { | |||||
| EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr) { | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| }); | }); | ||||
| AbstractBasePtr ret = EvalPrim(engine, args_spec_list); | |||||
| EvalResultPtr ret = EvalPrim(engine, args_spec_list); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| }); | }); | ||||
| if (args_conf_list.size() == 0) { | if (args_conf_list.size() == 0) { | ||||
| MS_LOG(EXCEPTION) << "Size should greater than 0"; | MS_LOG(EXCEPTION) << "Size should greater than 0"; | ||||
| } | } | ||||
| AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); | |||||
| EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); | |||||
| // No need to cache. | // No need to cache. | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||||
| AbstractBasePtr ret = EvalPrim(args_conf_list); | |||||
| EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||||
| EvalResultPtr ret = EvalPrim(args_conf_list); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| }); | }); | ||||
| AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); | |||||
| EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); | |||||
| // Don't lookup from cache, as different out_conf with same node but different context | // Don't lookup from cache, as different out_conf with same node but different context | ||||
| // may add different entry to anfnode_config_map_, like getattr primitive. | // may add different entry to anfnode_config_map_, like getattr primitive. | ||||
| (*cache_)[args_spec_list] = ret; | (*cache_)[args_spec_list] = ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| }); | }); | ||||
| MS_EXCEPTION_IF_NULL(cache_); | MS_EXCEPTION_IF_NULL(cache_); | ||||
| auto iter = cache_->find(args_spec_list); | auto iter = cache_->find(args_spec_list); | ||||
| @@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP | |||||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), | (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), | ||||
| [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); | [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); | ||||
| AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); | |||||
| EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); | |||||
| (*cache_)[args_spec_list] = ret; | (*cache_)[args_spec_list] = ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||||
| EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| }); | }); | ||||
| MS_EXCEPTION_IF_NULL(cache_); | MS_EXCEPTION_IF_NULL(cache_); | ||||
| auto iter = cache_->find(args_spec_list); | auto iter = cache_->find(args_spec_list); | ||||
| @@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a | |||||
| } | } | ||||
| // Call the original evaluator, get the result: y = f(x) | // Call the original evaluator, get the result: y = f(x) | ||||
| AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr); | |||||
| EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); | |||||
| // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input | // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input | ||||
| // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) | // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) | ||||
| AbstractBasePtrList bparams; | AbstractBasePtrList bparams; | ||||
| @@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a | |||||
| args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), | args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), | ||||
| [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); | [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); | ||||
| AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams); | AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams); | ||||
| AbstractFunctionPtr bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result), bparams_final); | |||||
| AbstractFunctionPtr bprop = | |||||
| std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final); | |||||
| // J(f)(J(x)) return a tuple (y, bprop_f) | // J(f)(J(x)) return a tuple (y, bprop_f) | ||||
| AbstractBasePtrList jargs = {result, bprop}; | |||||
| AbstractBasePtrList jargs = {result->abstract(), bprop}; | |||||
| AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs); | AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs); | ||||
| (*cache_)[args_spec_list] = jtuple; | |||||
| return jtuple; | |||||
| auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>()); | |||||
| (*cache_)[args_spec_list] = infer_reuslt; | |||||
| return infer_reuslt; | |||||
| } | } | ||||
| AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { | |||||
| EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { | |||||
| if (args_spec_list.size() != args_spec_list_.size()) { | if (args_spec_list.size() != args_spec_list_.size()) { | ||||
| MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() | MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() | ||||
| << ", arguments no: " << args_spec_list.size(); | << ", arguments no: " << args_spec_list.size(); | ||||
| @@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[i]); | MS_EXCEPTION_IF_NULL(args_spec_list[i]); | ||||
| (void)args_spec_list[i]->Join(args_spec_list_[i]); | (void)args_spec_list[i]->Join(args_spec_list_[i]); | ||||
| } | } | ||||
| return output_; | |||||
| return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,21 +29,28 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| using EvaluatorCacheMap = | using EvaluatorCacheMap = | ||||
| std::unordered_map<AbstractBasePtrList, AbstractBasePtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; | |||||
| std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; | |||||
| using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>; | using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>; | ||||
| using EvaluatorAttrMap = | |||||
| std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; | |||||
| using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>; | |||||
| class Evaluator : public Base { | class Evaluator : public Base { | ||||
| public: | public: | ||||
| explicit Evaluator(const std::string &id) : cache_(std::make_shared<EvaluatorCacheMap>()), identifier_(id) {} | |||||
| explicit Evaluator(const std::string &id) | |||||
| : cache_(std::make_shared<EvaluatorCacheMap>()), | |||||
| attr_cache_(std::make_shared<EvaluatorAttrMap>()), | |||||
| identifier_(id) {} | |||||
| ~Evaluator() override = default; | ~Evaluator() override = default; | ||||
| MS_DECLARE_PARENT(Evaluator, Base); | MS_DECLARE_PARENT(Evaluator, Base); | ||||
| // difference between Run() and Eval(): | // difference between Run() and Eval(): | ||||
| // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. | // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. | ||||
| // Run() will modify cache_ member, so it cannot marked as const; | // Run() will modify cache_ member, so it cannot marked as const; | ||||
| virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); | |||||
| virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); | |||||
| virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||||
| virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||||
| virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | ||||
| @@ -58,9 +65,10 @@ class Evaluator : public Base { | |||||
| virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } | virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } | ||||
| EvaluatorCacheMapPtr &cache() { return cache_; } | EvaluatorCacheMapPtr &cache() { return cache_; } | ||||
| EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } | |||||
| EvaluatorCacheMapPtr cache_; | EvaluatorCacheMapPtr cache_; | ||||
| EvaluatorAttrMapPtr attr_cache_; | |||||
| std::string identifier_; | std::string identifier_; | ||||
| AnfNodeWeakPtr bound_node_; | AnfNodeWeakPtr bound_node_; | ||||
| @@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator { | |||||
| explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} | explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} | ||||
| ~PrimEvaluator() override = default; | ~PrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(PrimEvaluator, Evaluator); | MS_DECLARE_PARENT(PrimEvaluator, Evaluator); | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator { | |||||
| explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | ||||
| ~TrivialPrimEvaluator() override = default; | ~TrivialPrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); | MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||||
| virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||||
| virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; | |||||
| }; | }; | ||||
| class TransitionPrimEvaluator : public PrimEvaluator { | class TransitionPrimEvaluator : public PrimEvaluator { | ||||
| @@ -90,10 +98,10 @@ class TransitionPrimEvaluator : public PrimEvaluator { | |||||
| explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | ||||
| ~TransitionPrimEvaluator() override = default; | ~TransitionPrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); | MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||||
| // Parameter in_conf0 : the first element in args_conf_list; | // Parameter in_conf0 : the first element in args_conf_list; | ||||
| virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; | |||||
| virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; | |||||
| }; | }; | ||||
| class SymbolicPrimEvaluator : public PrimEvaluator { | class SymbolicPrimEvaluator : public PrimEvaluator { | ||||
| @@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator { | |||||
| explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | ||||
| ~SymbolicPrimEvaluator() override = default; | ~SymbolicPrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); | MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||||
| virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||||
| virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; | |||||
| }; | }; | ||||
| // Evaluator will be stored in AnalysisEngine.constructors_ | // Evaluator will be stored in AnalysisEngine.constructors_ | ||||
| @@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator { | |||||
| DummyEvaluator() : Evaluator("dummy") {} | DummyEvaluator() : Evaluator("dummy") {} | ||||
| ~DummyEvaluator() override = default; | ~DummyEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(DummyEvaluator, Evaluator); | MS_DECLARE_PARENT(DummyEvaluator, Evaluator); | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } | |||||
| }; | }; | ||||
| // Wrap another evaluator to track a subset of uses. | // Wrap another evaluator to track a subset of uses. | ||||
| @@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator { | |||||
| bound_node_ = AnfNodeWeakPtr(node); | bound_node_ = AnfNodeWeakPtr(node); | ||||
| } | } | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | ||||
| } | } | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) override; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; | |||||
| std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } | std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } | ||||
| private: | private: | ||||
| @@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator { | |||||
| ~BaseFuncGraphEvaluator() override = default; | ~BaseFuncGraphEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); | MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | ||||
| @@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator { | |||||
| } | } | ||||
| bound_node_ = AnfNodeWeakPtr(node); | bound_node_ = AnfNodeWeakPtr(node); | ||||
| } | } | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | ||||
| } | } | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) override; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; | |||||
| std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } | std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } | ||||
| private: | private: | ||||
| @@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator { | |||||
| ~VirtualEvaluator() override = default; | ~VirtualEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); | MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||||
| std::string ToString() const override { return identifier_; } | std::string ToString() const override { return identifier_; } | ||||
| private: | private: | ||||
| @@ -285,11 +292,11 @@ class JEvaluator : public Evaluator { | |||||
| } | } | ||||
| bound_node_ = AnfNodeWeakPtr(node); | bound_node_ = AnfNodeWeakPtr(node); | ||||
| } | } | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | ||||
| } | } | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) override; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; | |||||
| std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } | std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } | ||||
| private: | private: | ||||
| @@ -135,13 +135,17 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| using mindspore::parse::PyObjectWrapper; | using mindspore::parse::PyObjectWrapper; | ||||
| AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||||
| prim_->BeginRecordAddAttr(); | |||||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | ||||
| return abs_base; | |||||
| prim_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||||
| auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| return infer_result; | |||||
| } | } | ||||
| AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| if (!prim_->isa<prim::DoSignaturePrimitive>()) { | if (!prim_->isa<prim::DoSignaturePrimitive>()) { | ||||
| MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString(); | MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString(); | ||||
| @@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config | |||||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||||
| ScopePtr scope = kDefaultScope; | ScopePtr scope = kDefaultScope; | ||||
| if (out_conf != nullptr) { | if (out_conf != nullptr) { | ||||
| @@ -212,8 +216,8 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s | |||||
| return graph_specialize_args; | return graph_specialize_args; | ||||
| } | } | ||||
| AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | ||||
| } | } | ||||
| @@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config | |||||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||||
| // get the forward graph | // get the forward graph | ||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | MS_EXCEPTION_IF_NULL(args_spec_list[0]); | ||||
| AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>(); | AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>(); | ||||
| @@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||||
| } | } | ||||
| } // end anonymous namespace | } // end anonymous namespace | ||||
| AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||||
| EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||||
| MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | ||||
| const auto &iter = cache_->find(args); | const auto &iter = cache_->find(args); | ||||
| @@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A | |||||
| MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; | MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; | ||||
| } | } | ||||
| auto infer_fuc = pyobj.attr("__infer__"); | auto infer_fuc = pyobj.attr("__infer__"); | ||||
| prim_py_->BeginRecordAddAttr(); | |||||
| py::dict output = infer_fuc(*py_args); | py::dict output = infer_fuc(*py_args); | ||||
| prim_py_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_py_->evaluate_added_attrs(); | |||||
| MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); | MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); | ||||
| auto res_spec = PyInferRes2Abstract(prim_py_, output); | auto res_spec = PyInferRes2Abstract(prim_py_, output); | ||||
| MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | ||||
| (*cache_)[args] = res_spec; | |||||
| return res_spec; | |||||
| auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| (*cache_)[args] = infer_result; | |||||
| return infer_result; | |||||
| } | } | ||||
| AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||||
| EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||||
| // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. | // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. | ||||
| if (nargs_ != args.size()) { | if (nargs_ != args.size()) { | ||||
| MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; | MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; | ||||
| @@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const | |||||
| } | } | ||||
| AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type); | AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type); | ||||
| return abs_base; | |||||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { | ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { | ||||
| @@ -553,8 +560,8 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun | |||||
| manager->AddFuncGraph(func_graph); | manager->AddFuncGraph(func_graph); | ||||
| } | } | ||||
| AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, | |||||
| const AnfNodeConfigPtr &old_conf) { | |||||
| EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, | |||||
| const AnfNodeConfigPtr &old_conf) { | |||||
| MS_EXCEPTION_IF_NULL(old_conf); | MS_EXCEPTION_IF_NULL(old_conf); | ||||
| AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); | AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); | ||||
| @@ -585,9 +592,9 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat | |||||
| return eng->ForwardConfig(old_conf, fn_conf); | return eng->ForwardConfig(old_conf, fn_conf); | ||||
| } | } | ||||
| AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, | |||||
| const AbstractBasePtrList &args_spec_list, | |||||
| const AnfNodeConfigPtr &out_conf) { | |||||
| EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, | |||||
| const AbstractBasePtrList &args_spec_list, | |||||
| const AnfNodeConfigPtr &out_conf) { | |||||
| // args_spec_list: same as StaticGetter | // args_spec_list: same as StaticGetter | ||||
| if (args_spec_list.size() < 2) { | if (args_spec_list.size() < 2) { | ||||
| MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; | MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; | ||||
| @@ -627,9 +634,9 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng | |||||
| return eng->ForwardConfig(out_conf, fn_conf); | return eng->ForwardConfig(out_conf, fn_conf); | ||||
| } | } | ||||
| AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, | |||||
| const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, | |||||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||||
| EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, | |||||
| const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, | |||||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||||
| if (args_spec_list.empty()) { | if (args_spec_list.empty()) { | ||||
| MS_LOG(EXCEPTION) << "args_spec_list is empty"; | MS_LOG(EXCEPTION) << "args_spec_list is empty"; | ||||
| } | } | ||||
| @@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e | |||||
| AbstractBasePtr attr = cls->GetAttribute(item_name); | AbstractBasePtr attr = cls->GetAttribute(item_name); | ||||
| if (attr != nullptr) { | if (attr != nullptr) { | ||||
| return attr; | |||||
| return std::make_shared<EvalResult>(attr, nullptr); | |||||
| } | } | ||||
| ValuePtr method = cls->GetMethod(item_name); | ValuePtr method = cls->GetMethod(item_name); | ||||
| @@ -660,9 +667,9 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e | |||||
| return StaticGetterInferred(converted_v, data_conf, out_conf); | return StaticGetterInferred(converted_v, data_conf, out_conf); | ||||
| } | } | ||||
| AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, | |||||
| const TypePtr &data_type, const ConfigPtr &data_conf, | |||||
| const AnfNodeConfigPtr &out_conf) { | |||||
| EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, | |||||
| const TypePtr &data_type, const ConfigPtr &data_conf, | |||||
| const AnfNodeConfigPtr &out_conf) { | |||||
| MS_EXCEPTION_IF_NULL(item_v); | MS_EXCEPTION_IF_NULL(item_v); | ||||
| MS_EXCEPTION_IF_NULL(data_type); | MS_EXCEPTION_IF_NULL(data_type); | ||||
| // The method maybe a Primitive or Composite | // The method maybe a Primitive or Composite | ||||
| @@ -689,8 +696,8 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e | |||||
| return StaticGetterInferred(converted_v, data_conf, out_conf); | return StaticGetterInferred(converted_v, data_conf, out_conf); | ||||
| } | } | ||||
| AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||||
| EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||||
| // Inputs: namespace and its static function; or class and its member function | // Inputs: namespace and its static function; or class and its member function | ||||
| CheckArgsSize("StaticGetter", args_spec_list, 2); | CheckArgsSize("StaticGetter", args_spec_list, 2); | ||||
| @@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} | EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} | ||||
| ~EmbedEvaluator() override = default; | ~EmbedEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); | MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||||
| EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||||
| // arg: free variable to be embedded | // arg: free variable to be embedded | ||||
| if (args_conf_list.size() != 1) { | if (args_conf_list.size() != 1) { | ||||
| MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); | MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); | ||||
| @@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]); | AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]); | ||||
| MS_EXCEPTION_IF_NULL(node_conf); | MS_EXCEPTION_IF_NULL(node_conf); | ||||
| AbstractBasePtr x = node_conf->GetEvaluatedValue(); | |||||
| AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); | |||||
| x = SensitivityTransform(x); | x = SensitivityTransform(x); | ||||
| SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x); | SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x); | ||||
| AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>()); | AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>()); | ||||
| return abs_scalar; | |||||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} | RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} | ||||
| ~RefToEmbedEvaluator() override = default; | ~RefToEmbedEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); | MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||||
| EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||||
| if (args_conf_list.size() != 1) { | if (args_conf_list.size() != 1) { | ||||
| MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); | MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; | MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AbstractBasePtr abs = node_conf->GetEvaluatedValue(); | |||||
| AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); | |||||
| AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>(); | AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>(); | ||||
| if (ref_abs == nullptr) { | if (ref_abs == nullptr) { | ||||
| MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; | MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; | ||||
| @@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| } | } | ||||
| auto refkey = key_value->cast<RefKeyPtr>(); | auto refkey = key_value->cast<RefKeyPtr>(); | ||||
| if (refkey == nullptr) { | if (refkey == nullptr) { | ||||
| return std::make_shared<AbstractScalar>(type); | |||||
| return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| std::string name = refkey->tag(); | std::string name = refkey->tag(); | ||||
| @@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| x = SensitivityTransform(x); | x = SensitivityTransform(x); | ||||
| std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | ||||
| std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | ||||
| return abs_scalar; | |||||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { | |||||
| GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} | GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} | ||||
| ~GetAttrEvaluator() override = default; | ~GetAttrEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); | MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||||
| // Inputs: data, item | // Inputs: data, item | ||||
| if (args_spec_list.size() != 2) { | if (args_spec_list.size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | ||||
| } | } | ||||
| AbstractBasePtr ret = nullptr; | |||||
| EvalResultPtr ret = nullptr; | |||||
| if (bound_node() != nullptr) { | if (bound_node() != nullptr) { | ||||
| TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); | ||||
| ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); | ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); | ||||
| @@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator { | |||||
| ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} | ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} | ||||
| ~ResolveEvaluator() override = default; | ~ResolveEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); | MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||||
| // Inputs: namespace, symbol | // Inputs: namespace, symbol | ||||
| if (args_spec_list.size() != 2) { | if (args_spec_list.size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | ||||
| } | } | ||||
| AbstractBasePtr ret = nullptr; | |||||
| EvalResultPtr ret = nullptr; | |||||
| if (bound_node() != nullptr) { | if (bound_node() != nullptr) { | ||||
| TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); | ||||
| ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); | ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); | ||||
| @@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { | |||||
| CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} | CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} | ||||
| ~CreateInstanceEvaluator() override = default; | ~CreateInstanceEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); | MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||||
| const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override { | |||||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, | |||||
| const AnfNodeConfigPtr &out_conf) override { | |||||
| if (args_spec_list.empty()) { | if (args_spec_list.empty()) { | ||||
| MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; | MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; | ||||
| } | } | ||||
| @@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { | |||||
| } | } | ||||
| AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); | AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); | ||||
| (*cache_)[args_spec_list] = ret; | |||||
| return ret; | |||||
| auto infer_result = std::make_shared<EvalResult>(ret, nullptr); | |||||
| (*cache_)[args_spec_list] = infer_result; | |||||
| return infer_result; | |||||
| } | } | ||||
| pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { | pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { | ||||
| @@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator { | |||||
| public: | public: | ||||
| PartialEvaluator() : Evaluator("PartialEvaluator") {} | PartialEvaluator() : Evaluator("PartialEvaluator") {} | ||||
| ~PartialEvaluator() override = default; | ~PartialEvaluator() override = default; | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf = nullptr) override { | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf = nullptr) override { | |||||
| if (args_conf_list.size() == 0) { | if (args_conf_list.size() == 0) { | ||||
| MS_LOG(EXCEPTION) << "Args size should be greater than 0"; | MS_LOG(EXCEPTION) << "Args size should be greater than 0"; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(out_conf); | MS_EXCEPTION_IF_NULL(out_conf); | ||||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | MS_EXCEPTION_IF_NULL(out_conf->node()); | ||||
| auto arg0_value = args_conf_list[0]->GetEvaluatedValue(); | |||||
| auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); | |||||
| AbstractBasePtrList args_spec_list{arg0_value}; | AbstractBasePtrList args_spec_list{arg0_value}; | ||||
| // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. | // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. | ||||
| if (arg0_value->isa<AbstractError>()) { | if (arg0_value->isa<AbstractError>()) { | ||||
| auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node()); | auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node()); | ||||
| MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() | MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() | ||||
| << " as func is: " << arg0_value->ToString(); | << " as func is: " << arg0_value->ToString(); | ||||
| (*cache_)[args_spec_list] = ret; | |||||
| return ret; | |||||
| auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | |||||
| (*cache_)[args_spec_list] = eval_result; | |||||
| return eval_result; | |||||
| } | } | ||||
| auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); | auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); | ||||
| // Sometimes, node[0] in out_conf becomes phi0; | // Sometimes, node[0] in out_conf becomes phi0; | ||||
| @@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator { | |||||
| } | } | ||||
| } | } | ||||
| (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||||
| [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); }); | |||||
| (void)std::transform( | |||||
| args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||||
| [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); | |||||
| AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); | AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); | ||||
| auto cnode = out_conf->node()->cast<CNodePtr>(); | auto cnode = out_conf->node()->cast<CNodePtr>(); | ||||
| @@ -989,16 +999,17 @@ class PartialEvaluator : public Evaluator { | |||||
| func->Visit(build_partial); | func->Visit(build_partial); | ||||
| auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); | auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); | ||||
| (*cache_)[args_spec_list] = ret; | |||||
| return ret; | |||||
| auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | |||||
| (*cache_)[args_spec_list] = infer_result; | |||||
| return infer_result; | |||||
| } | } | ||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | ||||
| } | } | ||||
| AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, | |||||
| const AnfNodeConfigPtr &out_conf = nullptr) const { | |||||
| EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, | |||||
| const AnfNodeConfigPtr &out_conf = nullptr) const { | |||||
| MS_EXCEPTION_IF_NULL(out_conf); | MS_EXCEPTION_IF_NULL(out_conf); | ||||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | MS_EXCEPTION_IF_NULL(out_conf->node()); | ||||
| auto cnode = out_conf->node()->cast<CNodePtr>(); | auto cnode = out_conf->node()->cast<CNodePtr>(); | ||||
| @@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { | |||||
| : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} | : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} | ||||
| ~StandardPrimEvaluator() override = default; | ~StandardPrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); | MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||||
| PrimitivePtr prim() { return prim_; } | PrimitivePtr prim() { return prim_; } | ||||
| std::string ToString() const override { return identifier_ + prim_->name(); } | std::string ToString() const override { return identifier_ + prim_->name(); } | ||||
| @@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator { | |||||
| : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} | : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} | ||||
| ~PythonPrimEvaluator() override = default; | ~PythonPrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); | MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||||
| PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); } | PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); } | ||||
| std::string ToString() const override { return identifier_ + prim_py_->name(); } | std::string ToString() const override { return identifier_ + prim_py_->name(); } | ||||
| @@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator { | |||||
| public: | public: | ||||
| explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} | explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} | ||||
| ~DoSignatureEvaluator() override = default; | ~DoSignatureEvaluator() override = default; | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||||
| AnfNodeConfigPtr out_config = nullptr) override; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||||
| AnfNodeConfigPtr out_config = nullptr) override; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | ||||
| } | } | ||||
| @@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator { | |||||
| public: | public: | ||||
| explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} | explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} | ||||
| ~UnpackGraphEvaluator() override = default; | ~UnpackGraphEvaluator() override = default; | ||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||||
| AnfNodeConfigPtr out_config = nullptr) override; | |||||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||||
| AnfNodeConfigPtr out_config = nullptr) override; | |||||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | ||||
| } | } | ||||
| @@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator { | |||||
| ~UniformPrimEvaluator() override = default; | ~UniformPrimEvaluator() override = default; | ||||
| MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); | MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); | ||||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||||
| ValuePtr RunImpl(const ValuePtrList &args) const; | ValuePtr RunImpl(const ValuePtrList &args) const; | ||||
| // If eval_value_ is False, return broadened arguments. | // If eval_value_ is False, return broadened arguments. | ||||
| @@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { | |||||
| if (conf->node()->intermediate_abstract()) { | if (conf->node()->intermediate_abstract()) { | ||||
| return conf->node()->intermediate_abstract(); | return conf->node()->intermediate_abstract(); | ||||
| } | } | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| } | } | ||||
| AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { | AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { | ||||
| @@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() { | |||||
| // Specialize CNode in func graphs | // Specialize CNode in func graphs | ||||
| void FuncGraphSpecializer::SecondPass() { | void FuncGraphSpecializer::SecondPass() { | ||||
| for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) { | |||||
| for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { | |||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| ProcessCNode(node->cast<CNodePtr>()); | ProcessCNode(node->cast<CNodePtr>()); | ||||
| } | } | ||||
| @@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||||
| AnfNodeConfigPtr conf = MakeConfig(node); | AnfNodeConfigPtr conf = MakeConfig(node); | ||||
| AnfNodePtr new_node = GetReplicatedNode(node); | AnfNodePtr new_node = GetReplicatedNode(node); | ||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| if (new_node->func_graph() != specialized_func_graph_) { | if (new_node->func_graph() != specialized_func_graph_) { | ||||
| MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() | MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() | ||||
| << ", new_node: " << new_node->DebugString() | << ", new_node: " << new_node->DebugString() | ||||
| @@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||||
| MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); | MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto attrs = conf->GetEvaluatedValue()->attribute(); | |||||
| auto c_old = node->cast<CNodePtr>(); | auto c_old = node->cast<CNodePtr>(); | ||||
| auto c_new = new_node->cast<CNodePtr>(); | auto c_new = new_node->cast<CNodePtr>(); | ||||
| auto new_inputs = c_new->inputs(); | auto new_inputs = c_new->inputs(); | ||||
| @@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||||
| AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); | AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); | ||||
| // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if | // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if | ||||
| // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. | // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. | ||||
| AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival); | |||||
| AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); | |||||
| if (replace_node == nullptr) { | if (replace_node == nullptr) { | ||||
| replace_node = BuildReplacedNode(iconf); | replace_node = BuildReplacedNode(iconf); | ||||
| MS_EXCEPTION_IF_NULL(replace_node); | MS_EXCEPTION_IF_NULL(replace_node); | ||||
| @@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n | |||||
| MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() | MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() | ||||
| << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); | << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); | ||||
| } | } | ||||
| auto attrs = std::make_shared<AttrValueMap>(); | |||||
| for (size_t i = 0; i < partial_closure->args().size(); i++) { | for (size_t i = 0; i < partial_closure->args().size(); i++) { | ||||
| auto old_node = cnode->input(i + 2); | auto old_node = cnode->input(i + 2); | ||||
| auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]); | |||||
| auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); | |||||
| if (possibile_value_node != nullptr) { | if (possibile_value_node != nullptr) { | ||||
| partial_node_list.push_back(possibile_value_node); | partial_node_list.push_back(possibile_value_node); | ||||
| } else { | } else { | ||||
| @@ -455,7 +456,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB | |||||
| const EvaluatorPtr &eval) { | const EvaluatorPtr &eval) { | ||||
| MS_EXCEPTION_IF_NULL(eval); | MS_EXCEPTION_IF_NULL(eval); | ||||
| std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices; | std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices; | ||||
| AbstractBasePtr ret = nullptr; | |||||
| EvalResultPtr ret = nullptr; | |||||
| AbstractBasePtrList broaded_argvals; | AbstractBasePtrList broaded_argvals; | ||||
| for (auto &argvals_map : *evalcaches_[eval]) { | for (auto &argvals_map : *evalcaches_[eval]) { | ||||
| auto argvals = argvals_map.first; | auto argvals = argvals_map.first; | ||||
| @@ -478,7 +479,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB | |||||
| (*real)[broaded_argvals] = ret; | (*real)[broaded_argvals] = ret; | ||||
| evalcaches_[eval] = real; | evalcaches_[eval] = real; | ||||
| return std::make_pair(broaded_argvals, ret); | |||||
| return std::make_pair(broaded_argvals, ret->abstract()); | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Choices.size: " << choices.size(); | MS_LOG(DEBUG) << "Choices.size: " << choices.size(); | ||||
| return std::make_pair(AbstractBasePtrList(), nullptr); | return std::make_pair(AbstractBasePtrList(), nullptr); | ||||
| @@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||||
| return; | return; | ||||
| } | } | ||||
| specializer_->AddSeen(new_node); | specializer_->AddSeen(new_node); | ||||
| auto new_inputs = new_node->inputs(); | auto new_inputs = new_node->inputs(); | ||||
| if (new_inputs.empty()) { | if (new_inputs.empty()) { | ||||
| MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; | MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; | ||||
| @@ -530,7 +530,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||||
| } | } | ||||
| if (CanSpecializeNode(func)) { | if (CanSpecializeNode(func)) { | ||||
| new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); | |||||
| // for primitive node , we build the primitive node with infered attributes in the first pass | |||||
| // so we do not build replaced node again here in second pass | |||||
| if (IsValueNode<Primitive>(func)) { | |||||
| new_inputs[0] = func; | |||||
| } else { | |||||
| new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); | |||||
| } | |||||
| } | } | ||||
| for (size_t i = 0; i < argvals.size();) { | for (size_t i = 0; i < argvals.size();) { | ||||
| @@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||||
| } | } | ||||
| i = next; | i = next; | ||||
| } | } | ||||
| new_node->set_inputs(new_inputs); | new_node->set_inputs(new_inputs); | ||||
| } | } | ||||
| @@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct | |||||
| EvaluatorCacheMap evaluator_cache_map = *eval->cache(); | EvaluatorCacheMap evaluator_cache_map = *eval->cache(); | ||||
| if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { | if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { | ||||
| *result = std::make_pair(argvals, evaluator_cache_map[argvals]); | |||||
| *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); | |||||
| return kSpecializeSuccess; | return kSpecializeSuccess; | ||||
| } | } | ||||
| DumpEvaluatorCache(evaluator_cache_map, argvals); | DumpEvaluatorCache(evaluator_cache_map, argvals); | ||||
| @@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct | |||||
| MS_EXCEPTION_IF_NULL(choices); | MS_EXCEPTION_IF_NULL(choices); | ||||
| if (choices->count(argvals)) { | if (choices->count(argvals)) { | ||||
| *result = std::make_pair(argvals, (*choices)[argvals]); | |||||
| *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); | |||||
| return kSpecializeSuccess; | return kSpecializeSuccess; | ||||
| } else if (choices->size() == 1) { | } else if (choices->size() == 1) { | ||||
| MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; | MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; | ||||
| *result = std::make_pair(choices->begin()->first, choices->begin()->second); | |||||
| *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); | |||||
| return kSpecializeSuccess; | return kSpecializeSuccess; | ||||
| } else if (choices->empty()) { | } else if (choices->empty()) { | ||||
| MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; | MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; | ||||
| @@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct | |||||
| return kSpecializeFindUniqueArgvalPoly; | return kSpecializeFindUniqueArgvalPoly; | ||||
| } | } | ||||
| } | } | ||||
| static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { | |||||
| auto &prim_attrs = prim->attrs(); | |||||
| bool is_attr_same = true; | |||||
| for (auto &item : *attrs) { | |||||
| auto itr = prim_attrs.find(item.first); | |||||
| if (itr != prim_attrs.end()) { | |||||
| if (!(*(itr->second) == *(item.second))) { | |||||
| is_attr_same = false; | |||||
| break; | |||||
| } | |||||
| } else { | |||||
| is_attr_same = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!is_attr_same) { | |||||
| if (prim->isa<PrimitivePy>()) { | |||||
| PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>(); | |||||
| auto clone_fn = prim_py->GetPyObj().attr("_clone"); | |||||
| py::object new_obj = clone_fn(); | |||||
| auto cloned_prim = new_obj.cast<PrimitivePyPtr>(); | |||||
| for (auto &item : *attrs) { | |||||
| cloned_prim->AddAttr(item.first, item.second); | |||||
| } | |||||
| return cloned_prim; | |||||
| } | |||||
| auto cloned_prim = std::make_shared<Primitive>(*prim); | |||||
| for (auto &item : *attrs) { | |||||
| cloned_prim->AddAttr(item.first, item.second); | |||||
| } | |||||
| return cloned_prim; | |||||
| } | |||||
| return prim; | |||||
| } | |||||
| AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) { | |||||
| AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, | |||||
| const AttrValueMapPtr &attrs) { | |||||
| MS_EXCEPTION_IF_NULL(origin_node); | MS_EXCEPTION_IF_NULL(origin_node); | ||||
| MS_EXCEPTION_IF_NULL(ival); | MS_EXCEPTION_IF_NULL(ival); | ||||
| @@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin | |||||
| ValuePtr value = nullptr; | ValuePtr value = nullptr; | ||||
| if (abs->isa<PrimitiveAbstractClosure>()) { | if (abs->isa<PrimitiveAbstractClosure>()) { | ||||
| auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs); | auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs); | ||||
| value = real_fn->prim(); | |||||
| // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one | |||||
| if (attrs != nullptr) { | |||||
| value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); | |||||
| } else { | |||||
| value = real_fn->prim(); | |||||
| } | |||||
| } else if (abs->isa<MetaFuncGraphAbstractClosure>()) { | } else if (abs->isa<MetaFuncGraphAbstractClosure>()) { | ||||
| auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs); | auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs); | ||||
| value = real_fn->meta_func_graph(); | value = real_fn->meta_func_graph(); | ||||
| @@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia | |||||
| AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); | AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); | ||||
| // Build a value node if ival is constant and not any-value | // Build a value node if ival is constant and not any-value | ||||
| AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival); | |||||
| AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, | |||||
| const AttrValueMapPtr &attrs); | |||||
| // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a | // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a | ||||
| // replicated node. | // replicated node. | ||||
| AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); | AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); | ||||
| @@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) { | |||||
| void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { | |||||
| MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() | MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() | ||||
| << ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString() | |||||
| << ", Pointer: " << arg.get(); | |||||
| cache_[conf] = arg; | |||||
| << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() | |||||
| << ", Pointer: " << result->abstract().get(); | |||||
| cache_[conf] = result; | |||||
| // Set intermediate abstract value. | // Set intermediate abstract value. | ||||
| if (IsIntermediateAbstract(arg)) { | |||||
| if (IsIntermediateAbstract(result->abstract())) { | |||||
| if (conf->node()->intermediate_abstract() == nullptr) { | if (conf->node()->intermediate_abstract() == nullptr) { | ||||
| conf->node()->set_intermediate_abstract(arg); | |||||
| MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString(); | |||||
| conf->node()->set_intermediate_abstract(result->abstract()); | |||||
| MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); | |||||
| } else { | } else { | ||||
| auto old_spec = conf->node()->intermediate_abstract(); | auto old_spec = conf->node()->intermediate_abstract(); | ||||
| auto joined_spec = IntermediateJoin(arg, old_spec); | |||||
| auto joined_spec = IntermediateJoin(result->abstract(), old_spec); | |||||
| conf->node()->set_intermediate_abstract(joined_spec); | conf->node()->set_intermediate_abstract(joined_spec); | ||||
| MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" | MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" | ||||
| << arg->ToString() << "\njoined_spec:\t" | |||||
| << result->abstract()->ToString() << "\njoined_spec:\t" | |||||
| << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); | << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { | |||||
| EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { | |||||
| auto value = cache_.find(conf); | auto value = cache_.find(conf); | ||||
| if (value == cache_.end()) { | if (value == cache_.end()) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana | |||||
| return eval->graph_context(); | return eval->graph_context(); | ||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { | |||||
| EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { | |||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| auto value = cache_.GetValue(conf); | auto value = cache_.GetValue(conf); | ||||
| if (value != nullptr) { | if (value != nullptr) { | ||||
| MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", " | |||||
| << value->ToString(); | |||||
| MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() | |||||
| << ", " << value->abstract()->ToString(); | |||||
| return value; | return value; | ||||
| } | } | ||||
| @@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) | |||||
| return value; | return value; | ||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||||
| EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| AnfNodePtr node = conf->node(); | AnfNodePtr node = conf->node(); | ||||
| AbstractBasePtr ret_abstract = nullptr; | |||||
| EvalResultPtr eval_result = nullptr; | |||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| compute_conf_stack_.push_back(node); | compute_conf_stack_.push_back(node); | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| @@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (node->abstract() != nullptr) { | if (node->abstract() != nullptr) { | ||||
| MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); | MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); | ||||
| ret_abstract = node->abstract(); | |||||
| eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>()); | |||||
| } else if (node->isa<ValueNode>()) { | } else if (node->isa<ValueNode>()) { | ||||
| auto value_node = node->cast<ValueNodePtr>(); | auto value_node = node->cast<ValueNodePtr>(); | ||||
| ret_abstract = EvalValueNode(value_node, conf); | |||||
| eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr); | |||||
| } else if (node->isa<CNode>()) { | } else if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| trace::TraceEvalCNodeEnter(conf); | trace::TraceEvalCNodeEnter(conf); | ||||
| ret_abstract = EvalCNode(cnode, conf); | |||||
| eval_result = EvalCNode(cnode, conf); | |||||
| trace::TraceEvalCNodeLeave(); | trace::TraceEvalCNodeLeave(); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() | MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() | ||||
| @@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| compute_conf_stack_.pop_back(); | compute_conf_stack_.pop_back(); | ||||
| if (ret_abstract == nullptr) { | |||||
| if (eval_result == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() | MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() | ||||
| << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | ||||
| } | } | ||||
| #endif | #endif | ||||
| MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString(); | |||||
| return ret_abstract; | |||||
| MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); | |||||
| return eval_result; | |||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { | AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { | ||||
| @@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co | |||||
| return ToAbstract(value_node->value(), conf->context(), conf); | return ToAbstract(value_node->value(), conf->context(), conf); | ||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||||
| EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto &inputs = cnode->inputs(); | auto &inputs = cnode->inputs(); | ||||
| @@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo | |||||
| AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); | AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); | ||||
| MS_EXCEPTION_IF_NULL(func_conf); | MS_EXCEPTION_IF_NULL(func_conf); | ||||
| // Keep it in a local variable, otherwise smart pointer will free it. | // Keep it in a local variable, otherwise smart pointer will free it. | ||||
| AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue(); | |||||
| AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); | |||||
| if (maybe_func == nullptr) { | if (maybe_func == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() | MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() | ||||
| << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); | << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); | ||||
| @@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo | |||||
| return ExecuteEvaluators(infs, conf, args_conf_list); | return ExecuteEvaluators(infs, conf, args_conf_list); | ||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { | |||||
| EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { | |||||
| ConfigPtrList args_conf_list; | ConfigPtrList args_conf_list; | ||||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), | (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), | ||||
| [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); | [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); | ||||
| @@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||||
| return tracked_eval; | return tracked_eval; | ||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||||
| const AnfNodeConfigPtr &out_conf, | |||||
| const ConfigPtrList &args_conf_list) { | |||||
| EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||||
| const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { | |||||
| if (evaluators.size() == 1) { | if (evaluators.size() == 1) { | ||||
| EvaluatorPtr eval = evaluators[0]; | EvaluatorPtr eval = evaluators[0]; | ||||
| MS_EXCEPTION_IF_NULL(eval); | MS_EXCEPTION_IF_NULL(eval); | ||||
| @@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr | |||||
| return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); | return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); | ||||
| } | } | ||||
| AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||||
| const AnfNodeConfigPtr &out_conf, | |||||
| const ConfigPtrList &args_conf_list) { | |||||
| EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||||
| const AnfNodeConfigPtr &out_conf, | |||||
| const ConfigPtrList &args_conf_list) { | |||||
| AbstractBasePtrList out_specs; | AbstractBasePtrList out_specs; | ||||
| if (!multi_poss_.count(evaluators[0])) { | if (!multi_poss_.count(evaluators[0])) { | ||||
| multi_poss_[evaluators[0]] = evaluators[1]; | multi_poss_[evaluators[0]] = evaluators[1]; | ||||
| @@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| MS_EXCEPTION_IF_NULL(conf); | MS_EXCEPTION_IF_NULL(conf); | ||||
| return conf->GetEvaluatedValue(); | |||||
| return conf->GetEvaluatedValue()->abstract(); | |||||
| }); | }); | ||||
| for (auto eval : evaluators) { | for (auto eval : evaluators) { | ||||
| auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | ||||
| @@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||||
| eval_trace_.push_back(current_inf); | eval_trace_.push_back(current_inf); | ||||
| MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); | MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); | ||||
| MS_EXCEPTION_IF_NULL(eval); | MS_EXCEPTION_IF_NULL(eval); | ||||
| auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); | |||||
| MS_EXCEPTION_IF_NULL(out_spec); | |||||
| MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); | |||||
| out_specs.push_back(out_spec); | |||||
| MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString(); | |||||
| auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); | |||||
| MS_EXCEPTION_IF_NULL(eval_result->abstract()); | |||||
| MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); | |||||
| out_specs.push_back(eval_result->abstract()); | |||||
| eval_trace_.pop_back(); | eval_trace_.pop_back(); | ||||
| if (eval_trace_.empty()) { | if (eval_trace_.empty()) { | ||||
| multi_poss_.clear(); | multi_poss_.clear(); | ||||
| @@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||||
| // Try to travel the latest undetermined. | // Try to travel the latest undetermined. | ||||
| if (latest_entry != eval_trace_.rbegin()->first) { | if (latest_entry != eval_trace_.rbegin()->first) { | ||||
| MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); | MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); | ||||
| auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); | |||||
| MS_EXCEPTION_IF_NULL(out_spec); | |||||
| MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString(); | |||||
| return out_spec; | |||||
| auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); | |||||
| MS_EXCEPTION_IF_NULL(eval_result->abstract()); | |||||
| MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() | |||||
| << " return out_spec: " << eval_result->abstract()->ToString(); | |||||
| return eval_result; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||||
| if (out_specs.size() == 1) { | if (out_specs.size() == 1) { | ||||
| MS_EXCEPTION_IF_NULL(out_specs[0]); | MS_EXCEPTION_IF_NULL(out_specs[0]); | ||||
| // If only one result derived, then broaden it to avoid wrong constant propagation. | // If only one result derived, then broaden it to avoid wrong constant propagation. | ||||
| return out_specs[0]->Broaden(); | |||||
| return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| auto joined_spec = AbstractJoin(out_specs); | auto joined_spec = AbstractJoin(out_specs); | ||||
| MS_EXCEPTION_IF_NULL(joined_spec); | MS_EXCEPTION_IF_NULL(joined_spec); | ||||
| MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); | MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); | ||||
| return joined_spec; | |||||
| return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() { | |||||
| EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { | |||||
| AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>(); | AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>(); | ||||
| return engine_.lock()->GetEvaluatedValue(self); | return engine_.lock()->GetEvaluatedValue(self); | ||||
| } | } | ||||
| @@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { | |||||
| return a; | return a; | ||||
| } | } | ||||
| AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { | |||||
| EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { | |||||
| auto evaluator = GetPrimEvaluator(primitive, nullptr); | auto evaluator = GetPrimEvaluator(primitive, nullptr); | ||||
| MS_EXCEPTION_IF_NULL(evaluator); | MS_EXCEPTION_IF_NULL(evaluator); | ||||
| if (!evaluator->isa<TrivialPrimEvaluator>()) { | if (!evaluator->isa<TrivialPrimEvaluator>()) { | ||||
| @@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr | |||||
| << evaluator->ToString(); | << evaluator->ToString(); | ||||
| } | } | ||||
| auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator); | auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator); | ||||
| auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs); | |||||
| return res_spec; | |||||
| auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); | |||||
| return eval_result; | |||||
| } | } | ||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -40,13 +40,33 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| // define attribute value map | |||||
| using AttrValueMap = std::unordered_map<std::string, ValuePtr>; | |||||
| using AttrValueMapPtr = std::shared_ptr<AttrValueMap>; | |||||
| // the class to save evaluated result: abstract value and modified attribute | |||||
| class EvalResult : public Base { | |||||
| public: | |||||
| EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} | |||||
| ~EvalResult() override = default; | |||||
| MS_DECLARE_PARENT(EvalResult, Base); | |||||
| AbstractBasePtr abstract() { return abstract_; } | |||||
| AttrValueMapPtr attribute() { return attribute_; } | |||||
| private: | |||||
| AbstractBasePtr abstract_; | |||||
| AttrValueMapPtr attribute_; | |||||
| }; | |||||
| using EvalResultPtr = std::shared_ptr<EvalResult>; | |||||
| // Superclass for AnfNodeConfig and VirtualConfig. | // Superclass for AnfNodeConfig and VirtualConfig. | ||||
| class Config : public Base { | class Config : public Base { | ||||
| public: | public: | ||||
| Config() = default; | Config() = default; | ||||
| ~Config() override = default; | ~Config() override = default; | ||||
| MS_DECLARE_PARENT(Config, Base); | MS_DECLARE_PARENT(Config, Base); | ||||
| virtual AbstractBasePtr GetEvaluatedValue() = 0; | |||||
| virtual EvalResultPtr GetEvaluatedValue() = 0; | |||||
| }; | }; | ||||
| // Config will be stored in AnalysisCache | // Config will be stored in AnalysisCache | ||||
| @@ -74,7 +94,7 @@ class AnfNodeConfig : public Config { | |||||
| ~AnfNodeConfig() override = default; | ~AnfNodeConfig() override = default; | ||||
| MS_DECLARE_PARENT(AnfNodeConfig, Config); | MS_DECLARE_PARENT(AnfNodeConfig, Config); | ||||
| AbstractBasePtr GetEvaluatedValue() override; | |||||
| EvalResultPtr GetEvaluatedValue() override; | |||||
| AnalysisContextPtr context() const { return context_; } | AnalysisContextPtr context() const { return context_; } | ||||
| @@ -123,7 +143,9 @@ class VirtualConfig : public Config { | |||||
| ~VirtualConfig() override = default; | ~VirtualConfig() override = default; | ||||
| MS_DECLARE_PARENT(VirtualConfig, Config); | MS_DECLARE_PARENT(VirtualConfig, Config); | ||||
| AbstractBasePtr GetEvaluatedValue() override { return abstract_; } | |||||
| EvalResultPtr GetEvaluatedValue() override { | |||||
| return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>()); | |||||
| } | |||||
| private: | private: | ||||
| AbstractBasePtr abstract_; | AbstractBasePtr abstract_; | ||||
| @@ -135,11 +157,11 @@ class AnalysisCache { | |||||
| AnalysisCache() = default; | AnalysisCache() = default; | ||||
| ~AnalysisCache() = default; | ~AnalysisCache() = default; | ||||
| void Clear() { cache_.clear(); } | void Clear() { cache_.clear(); } | ||||
| void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg); | |||||
| AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf); | |||||
| void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); | |||||
| EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); | |||||
| private: | private: | ||||
| std::unordered_map<AnfNodeConfigPtr, AbstractBasePtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_; | |||||
| std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_; | |||||
| }; | }; | ||||
| using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>; | using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>; | ||||
| @@ -147,7 +169,7 @@ using AnfNodeConfigMap = | |||||
| std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; | std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; | ||||
| struct AnalysisResult { | struct AnalysisResult { | ||||
| AbstractBasePtr inferred; | |||||
| EvalResultPtr inferred; | |||||
| AnalysisContextPtr context; | AnalysisContextPtr context; | ||||
| }; | }; | ||||
| @@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| // func_graph: The func_graph to analyze. | // func_graph: The func_graph to analyze. | ||||
| // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. | // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. | ||||
| AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); | AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); | |||||
| EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); | |||||
| // Return the Evaluator for the given function. | // Return the Evaluator for the given function. | ||||
| EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); | EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); | ||||
| AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); | AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); | ||||
| AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); | |||||
| EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); | |||||
| // Infer the result of fn(args). | // Infer the result of fn(args). | ||||
| AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); | |||||
| EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); | |||||
| void Clear(); | void Clear(); | ||||
| void ClearEvaluatorCache(); | void ClearEvaluatorCache(); | ||||
| AnalysisCache &cache() { return cache_; } | AnalysisCache &cache() { return cache_; } | ||||
| @@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| // Set the analysis result for orig to the result for new. | // Set the analysis result for orig to the result for new. | ||||
| // This sets an entry in anfnode_config_map from orig to new. | // This sets an entry in anfnode_config_map from orig to new. | ||||
| AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { | |||||
| EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { | |||||
| // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. | // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. | ||||
| (void)anfnode_config_map_.emplace(orig_conf, new_conf); | (void)anfnode_config_map_.emplace(orig_conf, new_conf); | ||||
| MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() | MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() | ||||
| @@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | ||||
| const ConfigPtrList &args_conf_list); | const ConfigPtrList &args_conf_list); | ||||
| AbstractBasePtr Eval(const AnfNodeConfigPtr &conf); | |||||
| EvalResultPtr Eval(const AnfNodeConfigPtr &conf); | |||||
| EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); | EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); | ||||
| AbstractBasePtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||||
| const ConfigPtrList &args_conf_list); | |||||
| AbstractBasePtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||||
| const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list); | |||||
| EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||||
| const ConfigPtrList &args_conf_list); | |||||
| EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||||
| const ConfigPtrList &args_conf_list); | |||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| std::vector<AnfNodePtr> compute_conf_stack_; | std::vector<AnfNodePtr> compute_conf_stack_; | ||||
| @@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { | |||||
| return FromValueInside(MakeValue(value), broaden); | return FromValueInside(MakeValue(value), broaden); | ||||
| } | } | ||||
| AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | |||||
| EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI | |||||
| args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); | args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); | ||||
| } | } | ||||
| } | } | ||||
| AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list); | |||||
| AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); | |||||
| op_exec_info->abstract = infer_res; | op_exec_info->abstract = infer_res; | ||||
| } | } | ||||
| @@ -26,6 +26,8 @@ | |||||
| #include <list> | #include <list> | ||||
| #include <string> | #include <string> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <queue> | |||||
| #include <set> | |||||
| #include "ir/visitor.h" | #include "ir/visitor.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -223,6 +225,31 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||||
| return res; | return res; | ||||
| } | } | ||||
| // search the cnodes inside this graph only | |||||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) { | |||||
| std::queue<CNodePtr> todo; | |||||
| todo.push(ret); | |||||
| std::vector<CNodePtr> sorted_nodes; | |||||
| auto seen = NewSeenGeneration(); | |||||
| while (!todo.empty()) { | |||||
| CNodePtr top = todo.front(); | |||||
| todo.pop(); | |||||
| sorted_nodes.push_back(top); | |||||
| auto inputs = top->inputs(); | |||||
| for (auto &item : inputs) { | |||||
| if (item->seen_ == seen) { | |||||
| continue; | |||||
| } | |||||
| if (item->isa<CNode>()) { | |||||
| todo.push(item->cast<CNodePtr>()); | |||||
| } | |||||
| item->seen_ = seen; | |||||
| } | |||||
| } | |||||
| return sorted_nodes; | |||||
| } | |||||
| std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | ||||
| std::vector<AnfNodePtr> vecs; | std::vector<AnfNodePtr> vecs; | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| @@ -57,6 +57,7 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl | |||||
| std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, | std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, | ||||
| const IncludeFunc &include = AlwaysInclude); | const IncludeFunc &include = AlwaysInclude); | ||||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret); | |||||
| class FuncGraphIndex { | class FuncGraphIndex { | ||||
| public: | public: | ||||
| explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | ||||
| @@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init ExpandDims""" | """init ExpandDims""" | ||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) | ||||
| def __infer__(self, x, axis): | def __infer__(self, x, axis): | ||||
| @@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer): | |||||
| # if primitive need setattr in __infer__ need add this flag | # if primitive need setattr in __infer__ need add this flag | ||||
| """init Cast""" | """init Cast""" | ||||
| self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) | ||||
| self.__setattr_flag__ = True | |||||
| def __infer__(self, x, t): | def __infer__(self, x, t): | ||||
| src_type = x['dtype'] | src_type = x['dtype'] | ||||
| @@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| """init Reshape""" | """init Reshape""" | ||||
| self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) | self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) | ||||
| self.__setattr_flag__ = True | |||||
| def __infer__(self, x, shape): | def __infer__(self, x, shape): | ||||
| shape_v = shape['value'] | shape_v = shape['value'] | ||||
| @@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init Transpose""" | """init Transpose""" | ||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) | ||||
| def __infer__(self, x, perm): | def __infer__(self, x, perm): | ||||
| @@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init index_select""" | """init index_select""" | ||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | ||||
| def __infer__(self, params, indices, axis): | def __infer__(self, params, indices, axis): | ||||
| @@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=0): | def __init__(self, axis=0): | ||||
| """init Tile""" | """init Tile""" | ||||
| self.__setattr_flag__ = True | |||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| def __infer__(self, input_x): | def __infer__(self, input_x): | ||||
| @@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=0): | def __init__(self, axis=0): | ||||
| """init Pack""" | """init Pack""" | ||||
| self.__setattr_flag__ = True | |||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| self.axis = axis | self.axis = axis | ||||
| @@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=0): | def __init__(self, axis=0): | ||||
| """init Unpack""" | """init Unpack""" | ||||
| self.__setattr_flag__ = True | |||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| self.axis = axis | self.axis = axis | ||||
| @@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init""" | """init""" | ||||
| self.__setattr_flag__ = True | |||||
| def infer_shape(self, cond_shape, x_shape, y_shape): | def infer_shape(self, cond_shape, x_shape, y_shape): | ||||
| if cond_shape != x_shape or x_shape != y_shape: | if cond_shape != x_shape or x_shape != y_shape: | ||||
| @@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, transpose_a=False, transpose_b=False): | def __init__(self, transpose_a=False, transpose_b=False): | ||||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | ||||
| self.__setattr_flag__ = True | |||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | ||||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | ||||
| @@ -596,7 +595,6 @@ class BatchMatMul(MatMul): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, transpose_a=False, transpose_b=False): | def __init__(self, transpose_a=False, transpose_b=False): | ||||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | ||||
| self.__setattr_flag__ = True | |||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | ||||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | ||||
| @@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | ||||
| def infer_shape(self, inputs): | def infer_shape(self, inputs): | ||||
| @@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer): | |||||
| """init Conv2D""" | """init Conv2D""" | ||||
| self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) | ||||
| self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) | self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) | ||||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name) | |||||
| self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) | |||||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) | |||||
| self.add_prim_attr('stride', self.stride) | |||||
| self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | ||||
| self.add_prim_attr('dilation', self.dilation) | self.add_prim_attr('dilation', self.dilation) | ||||
| validator.check_value_type('pad', pad, (int,), self.name) | validator.check_value_type('pad', pad, (int,), self.name) | ||||
| @@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer): | |||||
| self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] | self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] | ||||
| self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) | self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) | ||||
| out_channel = self.out_channel | out_channel = self.out_channel | ||||
| out_shape = [x_shape[0], out_channel, h_out, w_out] | out_shape = [x_shape[0], out_channel, h_out, w_out] | ||||
| return out_shape | return out_shape | ||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test nn ops """ | |||||
| import functools | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.ops import Primitive | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | |||||
| from mindspore.ops.primitive import constexpr | |||||
| from mindspore import context | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| def test_cast_op_attr(): | |||||
| class CastNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(CastNet, self).__init__() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, x, t): | |||||
| return self.cast(x, t) | |||||
| class CastTypeTest(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(CastTypeTest, self).__init__() | |||||
| self.net = net | |||||
| self.cast = P.Cast() | |||||
| def construct(self, x, y, z): | |||||
| cast_op = self.cast | |||||
| t1 = cast_op(x, mstype.float32) | |||||
| t2 = cast_op(y, mstype.int32) | |||||
| cast_net = self.net | |||||
| t3 = cast_net(x, mstype.float16) | |||||
| t4 = cast_net(y, mstype.int32) | |||||
| t5 = cast_net(z, mstype.float16) | |||||
| return (t1, t2, t3, t4, t5) | |||||
| net = CastTypeTest(CastNet()) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.int32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| t3 = Tensor(np.ones([1,16,1,1918]).astype(np.int32)) | |||||
| out = net(t1, t2, t3) | |||||
| assert out[0].asnumpy().dtype == np.float32 | |||||
| assert out[1].asnumpy().dtype == np.int32 | |||||
| assert out[2].asnumpy().dtype == np.float16 | |||||
| assert out[3].asnumpy().dtype == np.int32 | |||||
| assert out[4].asnumpy().dtype == np.float16 | |||||
| @@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) { | |||||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | ||||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | ||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract tuple failed."; | FAIL() << "Cast ret to abstract tuple failed."; | ||||
| } | } | ||||
| @@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) { | |||||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | ||||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | ||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract tuple failed."; | FAIL() << "Cast ret to abstract tuple failed."; | ||||
| } | } | ||||
| @@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) { | |||||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | ||||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | ||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract tuple failed."; | FAIL() << "Cast ret to abstract tuple failed."; | ||||
| } | } | ||||
| @@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) { | |||||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | ||||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | ||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract tuple failed."; | FAIL() << "Cast ret to abstract tuple failed."; | ||||
| } | } | ||||
| @@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) { | |||||
| AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | ||||
| AbstractBasePtrList args_spec_list = {tensor, slice}; | AbstractBasePtrList args_spec_list = {tensor, slice}; | ||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred); | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract array failed."; | FAIL() << "Cast ret to abstract array failed."; | ||||
| } | } | ||||
| @@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) { | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | ||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | ||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract array failed."; | FAIL() << "Cast ret to abstract array failed."; | ||||
| } | } | ||||
| @@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) { | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | ||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | ||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract array failed."; | FAIL() << "Cast ret to abstract array failed."; | ||||
| } | } | ||||
| @@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) { | |||||
| AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2); | AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2); | ||||
| AbstractBasePtrList args_spec_list = {tensor, start_index}; | AbstractBasePtrList args_spec_list = {tensor, start_index}; | ||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract array failed."; | FAIL() << "Cast ret to abstract array failed."; | ||||
| } | } | ||||
| @@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) { | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | ||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | ||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract array failed."; | FAIL() << "Cast ret to abstract array failed."; | ||||
| } | } | ||||
| @@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) { | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | ||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | ||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract array failed."; | FAIL() << "Cast ret to abstract array failed."; | ||||
| } | } | ||||
| @@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) { | |||||
| abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | ||||
| AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict}; | AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict}; | ||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); | |||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract tuple failed."; | FAIL() << "Cast ret to abstract tuple failed."; | ||||
| } | } | ||||
| @@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) { | |||||
| abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | ||||
| AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple}; | AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple}; | ||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); | |||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract tuple failed."; | FAIL() << "Cast ret to abstract tuple failed."; | ||||
| } | } | ||||
| @@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) { | |||||
| auto tuple = std::make_shared<AbstractTuple>(eles); | auto tuple = std::make_shared<AbstractTuple>(eles); | ||||
| AbstractBasePtrList args_spec_list = {tuple}; | AbstractBasePtrList args_spec_list = {tuple}; | ||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred); | |||||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| FAIL() << "Cast ret to abstract tuple failed."; | FAIL() << "Cast ret to abstract tuple failed."; | ||||
| } | } | ||||
| @@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) { | |||||
| AbstractBasePtr abstract_v2 = FromValue(2, false); | AbstractBasePtr abstract_v2 = FromValue(2, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2}; | AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2}; | ||||
| AbstractBasePtr abstract_val = FromValue(10, false); | AbstractBasePtr abstract_val = FromValue(10, false); | ||||
| cache[args_spec_list] = abstract_val; | |||||
| cache[args_spec_list] = std::make_shared<EvalResult>(abstract_val, std::make_shared<AttrValueMap>()); | |||||
| auto iter = cache.find(args_spec_list); | auto iter = cache.find(args_spec_list); | ||||
| ASSERT_TRUE(iter != cache.end()); | ASSERT_TRUE(iter != cache.end()); | ||||
| ASSERT_TRUE(iter->second == abstract_val); | |||||
| ASSERT_TRUE(iter->second->abstract() == abstract_val); | |||||
| AbstractBasePtr abstract_v1_variant1 = FromValue(1, false); | AbstractBasePtr abstract_v1_variant1 = FromValue(1, false); | ||||
| AbstractBasePtr abstract_v2_variant1 = FromValue(2, false); | AbstractBasePtr abstract_v2_variant1 = FromValue(2, false); | ||||
| @@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) { | |||||
| iter = cache.find(args_spec_list_variant1); | iter = cache.find(args_spec_list_variant1); | ||||
| ASSERT_TRUE(iter != cache.end()); | ASSERT_TRUE(iter != cache.end()); | ||||
| ASSERT_TRUE(iter->second == abstract_val); | |||||
| ASSERT_TRUE(iter->second->abstract() == abstract_val); | |||||
| AbstractBasePtr abstract_v1_variant2 = FromValue(1, false); | AbstractBasePtr abstract_v1_variant2 = FromValue(1, false); | ||||
| AbstractBasePtr abstract_v2_variant2 = FromValue(3, false); | AbstractBasePtr abstract_v2_variant2 = FromValue(3, false); | ||||
| @@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) { | |||||
| std::vector<int> shape = {2, 2, 6, 6}; | std::vector<int> shape = {2, 2, 6, 6}; | ||||
| expected->set_shape(std::make_shared<Shape>(shape)); | expected->set_shape(std::make_shared<Shape>(shape)); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | MS_LOG(INFO) << "expected: " << expected->ToString(); | ||||
| @@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) { | |||||
| AbstractBasePtr abstract_x = FromValue(x, false); | AbstractBasePtr abstract_x = FromValue(x, false); | ||||
| args_spec_list.push_back(abstract_x); | args_spec_list.push_back(abstract_x); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ||||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | ||||
| } | } | ||||
| @@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) { | |||||
| AbstractBasePtr abstract_x = FromValue(x, false); | AbstractBasePtr abstract_x = FromValue(x, false); | ||||
| args_spec_list.push_back(abstract_x); | args_spec_list.push_back(abstract_x); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ||||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | ||||
| } | } | ||||
| @@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) { | |||||
| args_spec_list.push_back(abstract_x); | args_spec_list.push_back(abstract_x); | ||||
| args_spec_list.push_back(abstract_y); | args_spec_list.push_back(abstract_y); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ||||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ||||
| } | } | ||||
| @@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) { | |||||
| args_spec_list.push_back(abstract_x); | args_spec_list.push_back(abstract_x); | ||||
| args_spec_list.push_back(abstract_y); | args_spec_list.push_back(abstract_y); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ||||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ||||
| } | } | ||||
| @@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) { | |||||
| args_spec_list.push_back(abstract_x); | args_spec_list.push_back(abstract_x); | ||||
| args_spec_list.push_back(abstract_y); | args_spec_list.push_back(abstract_y); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ||||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ||||
| } | } | ||||
| @@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) { | |||||
| args_spec_list.push_back(abstract_x); | args_spec_list.push_back(abstract_x); | ||||
| args_spec_list.push_back(abstract_y); | args_spec_list.push_back(abstract_y); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | ||||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | ||||
| } | } | ||||
| @@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) { | |||||
| auto prim_typeof = std::make_shared<Primitive>("typeof"); | auto prim_typeof = std::make_shared<Primitive>("typeof"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| res->dump(); | res->dump(); | ||||
| TypePtr res_value = res->GetValueTrack()->cast<TypePtr>(); | TypePtr res_value = res->GetValueTrack()->cast<TypePtr>(); | ||||
| res_value->dump(); | res_value->dump(); | ||||
| @@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) { | |||||
| auto prim_list_map = std::make_shared<Primitive>("list_map"); | auto prim_list_map = std::make_shared<Primitive>("list_map"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3); | FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)})); | auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)})); | ||||
| res->dump(); | res->dump(); | ||||
| MS_LOG(INFO) << "result res: " << res->ToString(); | MS_LOG(INFO) << "result res: " << res->ToString(); | ||||
| @@ -188,7 +188,7 @@ TEST_F(TestPrim, test_list_reduce) { | |||||
| auto prim_list_reduce = std::make_shared<Primitive>("list_reduce"); | auto prim_list_reduce = std::make_shared<Primitive>("list_reduce"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3); | FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| res->dump(); | res->dump(); | ||||
| TypePtr res_type = res->GetTypeTrack(); | TypePtr res_type = res->GetTypeTrack(); | ||||
| res_type->dump(); | res_type->dump(); | ||||
| @@ -205,7 +205,7 @@ TEST_F(TestPrim, test_scalar_to_array) { | |||||
| auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array"); | auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| res->dump(); | res->dump(); | ||||
| TypePtr res_type = res->BuildType(); | TypePtr res_type = res->BuildType(); | ||||
| res_type->dump(); | res_type->dump(); | ||||
| @@ -223,7 +223,7 @@ TEST_F(TestPrim, test_array_to_scalar) { | |||||
| auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar"); | auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| res->dump(); | res->dump(); | ||||
| TypePtr res_type = res->BuildType(); | TypePtr res_type = res->BuildType(); | ||||
| res_type->dump(); | res_type->dump(); | ||||
| @@ -239,7 +239,7 @@ TEST_F(TestPrim, test_J_1) { | |||||
| auto prim_J = std::make_shared<Primitive>("J"); | auto prim_J = std::make_shared<Primitive>("J"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res); | AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res); | ||||
| ASSERT_TRUE(res_J != nullptr); | ASSERT_TRUE(res_J != nullptr); | ||||
| ASSERT_TRUE(*(res_J->element()) == *abstract_v1); | ASSERT_TRUE(*(res_J->element()) == *abstract_v1); | ||||
| @@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) { | |||||
| int v1 = 1; | int v1 = 1; | ||||
| AbstractBasePtr abstract_v1 = FromValue(v1, false); | AbstractBasePtr abstract_v1 = FromValue(v1, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_v1}; | AbstractBasePtrList args_spec_list = {abstract_v1}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| res->dump(); | res->dump(); | ||||
| AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res); | AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res); | ||||
| ASSERT_TRUE(res_J != nullptr); | ASSERT_TRUE(res_J != nullptr); | ||||
| @@ -302,7 +302,7 @@ TEST_F(TestPrim, test_dot) { | |||||
| AbstractBasePtrList args_spec_list = {a1, a2}; | AbstractBasePtrList args_spec_list = {a1, a2}; | ||||
| AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred); | |||||
| AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||||
| ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack()))); | ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack()))); | ||||
| } | } | ||||
| @@ -317,7 +317,7 @@ TEST_F(TestPrim, test_switch1) { | |||||
| AbstractBasePtr arg2 = FromValue(2, false); | AbstractBasePtr arg2 = FromValue(2, false); | ||||
| AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *arg1); | ASSERT_TRUE(*res == *arg1); | ||||
| } | } | ||||
| @@ -330,7 +330,7 @@ TEST_F(TestPrim, test_switch2) { | |||||
| AbstractBasePtr arg2 = FromValue(2, false); | AbstractBasePtr arg2 = FromValue(2, false); | ||||
| AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "make result res: " << res->ToString(); | MS_LOG(INFO) << "make result res: " << res->ToString(); | ||||
| MS_LOG(INFO) << "make result arg2: " << arg2->ToString(); | MS_LOG(INFO) << "make result arg2: " << arg2->ToString(); | ||||
| ASSERT_TRUE(*res == *arg2); | ASSERT_TRUE(*res == *arg2); | ||||
| @@ -343,7 +343,7 @@ TEST_F(TestPrim, test_identity) { | |||||
| AbstractBasePtr abstract_v1 = FromValue(1, false); | AbstractBasePtr abstract_v1 = FromValue(1, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_v1}; | AbstractBasePtrList args_spec_list = {abstract_v1}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *abstract_v1); | ASSERT_TRUE(*res == *abstract_v1); | ||||
| } | } | ||||
| @@ -357,7 +357,7 @@ TEST_F(TestPrim, test_broadcast_shape) { | |||||
| AbstractBasePtrList args_spec_list = {a, b}; | AbstractBasePtrList args_spec_list = {a, b}; | ||||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred); | |||||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||||
| auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | ||||
| std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)}; | std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)}; | ||||
| @@ -377,7 +377,7 @@ TEST_F(TestPrim, test_partial) { | |||||
| AbstractBasePtr abstract_v2 = FromValue(1, false); | AbstractBasePtr abstract_v2 = FromValue(1, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2}; | AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2}; | AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2}; | ||||
| auto expected = std::make_shared<PartialAbstractClosure>( | auto expected = std::make_shared<PartialAbstractClosure>( | ||||
| std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list); | std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list); | ||||
| @@ -392,7 +392,7 @@ TEST_F(TestPrim, test_env_setitem) { | |||||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | ||||
| AbstractBasePtr abstract_x = FromValue(1, false); | AbstractBasePtr abstract_x = FromValue(1, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_x}; | AbstractBasePtrList args_spec_list = {abstract_x}; | ||||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; | |||||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | ||||
| @@ -400,7 +400,7 @@ TEST_F(TestPrim, test_env_setitem) { | |||||
| AbstractBasePtr abstract_y = FromValue(2, false); | AbstractBasePtr abstract_y = FromValue(2, false); | ||||
| args_spec_list = {abstract_env, embed_x, abstract_y}; | args_spec_list = {abstract_env, embed_x, abstract_y}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | ||||
| ASSERT_TRUE(*res == *exp); | ASSERT_TRUE(*res == *exp); | ||||
| } | } | ||||
| @@ -412,7 +412,7 @@ TEST_F(TestPrim, test_env_getitem) { | |||||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | ||||
| AbstractBasePtr abstract_x = FromValue(1, false); | AbstractBasePtr abstract_x = FromValue(1, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_x}; | AbstractBasePtrList args_spec_list = {abstract_x}; | ||||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; | |||||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||||
| FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | ||||
| @@ -420,7 +420,7 @@ TEST_F(TestPrim, test_env_getitem) { | |||||
| AbstractBasePtr abstract_y = FromValue(2, false); | AbstractBasePtr abstract_y = FromValue(2, false); | ||||
| args_spec_list = {abstract_env, embed_x, abstract_y}; | args_spec_list = {abstract_env, embed_x, abstract_y}; | ||||
| AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | ||||
| ASSERT_TRUE(*res == *exp); | ASSERT_TRUE(*res == *exp); | ||||
| @@ -429,7 +429,7 @@ TEST_F(TestPrim, test_env_getitem) { | |||||
| AbstractBasePtr abstract_z = FromValue(3, false); | AbstractBasePtr abstract_z = FromValue(3, false); | ||||
| args_spec_list = {res, embed_x, abstract_z}; | args_spec_list = {res, embed_x, abstract_z}; | ||||
| res = engine_->Run(graph_getitem, args_spec_list).inferred; | |||||
| res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *abstract_x); | ASSERT_TRUE(*res == *abstract_x); | ||||
| } | } | ||||
| @@ -442,7 +442,7 @@ TEST_F(TestPrim, test_env_add) { | |||||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | ||||
| AbstractBasePtr abstract_x = FromValue(1, false); | AbstractBasePtr abstract_x = FromValue(1, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_x}; | AbstractBasePtrList args_spec_list = {abstract_x}; | ||||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; | |||||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||||
| FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | ||||
| @@ -450,19 +450,19 @@ TEST_F(TestPrim, test_env_add) { | |||||
| AbstractBasePtr abstract_y = FromValue(2, false); | AbstractBasePtr abstract_y = FromValue(2, false); | ||||
| args_spec_list = {abstract_env, embed_x, abstract_y}; | args_spec_list = {abstract_env, embed_x, abstract_y}; | ||||
| AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred; | |||||
| AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | ||||
| ASSERT_TRUE(*abstract_e1 == *exp); | ASSERT_TRUE(*abstract_e1 == *exp); | ||||
| AbstractBasePtr abstract_z = FromValue(3, false); | AbstractBasePtr abstract_z = FromValue(3, false); | ||||
| args_spec_list = {abstract_env, embed_x, abstract_z}; | args_spec_list = {abstract_env, embed_x, abstract_z}; | ||||
| AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred; | |||||
| AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*abstract_e2 == *exp); | ASSERT_TRUE(*abstract_e2 == *exp); | ||||
| FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2); | FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2); | ||||
| args_spec_list = {abstract_e1, abstract_e2}; | args_spec_list = {abstract_e1, abstract_e2}; | ||||
| AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *exp); | ASSERT_TRUE(*res == *exp); | ||||
| } | } | ||||
| @@ -475,7 +475,7 @@ TEST_F(TestPrim, test_shape) { | |||||
| AbstractBasePtrList args_spec_list = {a}; | AbstractBasePtrList args_spec_list = {a}; | ||||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred); | |||||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||||
| auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | ||||
| std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)}; | std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)}; | ||||
| @@ -493,7 +493,7 @@ TEST_F(TestPrim, test_relu) { | |||||
| AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW | AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW | ||||
| AbstractBasePtrList args_spec_list = {expected}; | AbstractBasePtrList args_spec_list = {expected}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| } | } | ||||
| @@ -507,7 +507,7 @@ TEST_F(TestPrim, test_relu2) { | |||||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5}); | auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5}); | ||||
| AbstractBasePtrList args_spec_list = {arr}; | AbstractBasePtrList args_spec_list = {arr}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto res = dyn_cast<AbstractTensor>(ret); | auto res = dyn_cast<AbstractTensor>(ret); | ||||
| ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack())); | ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack())); | ||||
| } | } | ||||
| @@ -540,7 +540,7 @@ TEST_F(TestPrim, test_conv2d1) { | |||||
| std::vector<int> shape = {2, 64, 14, 14}; | std::vector<int> shape = {2, 64, 14, 14}; | ||||
| expected->set_shape(std::make_shared<Shape>(shape)); | expected->set_shape(std::make_shared<Shape>(shape)); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | MS_LOG(INFO) << "expected: " << expected->ToString(); | ||||
| @@ -558,7 +558,7 @@ TEST_F(TestPrim, test_conv2d) { | |||||
| auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3}); | auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3}); | ||||
| AbstractBasePtrList args_spec_list = {input, weight}; | AbstractBasePtrList args_spec_list = {input, weight}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto res = dyn_cast<AbstractTensor>(ret); | auto res = dyn_cast<AbstractTensor>(ret); | ||||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16}); | auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16}); | ||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| @@ -574,7 +574,7 @@ TEST_F(TestPrim, test_conv2d_native) { | |||||
| auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3}); | auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3}); | ||||
| AbstractBasePtrList args_spec_list = {input, weight}; | AbstractBasePtrList args_spec_list = {input, weight}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto res = dyn_cast<AbstractTensor>(ret); | auto res = dyn_cast<AbstractTensor>(ret); | ||||
| auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16}); | auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16}); | ||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| @@ -590,7 +590,7 @@ TEST_F(TestPrim, test_biasAdd) { | |||||
| auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32}); | auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32}); | ||||
| AbstractBasePtrList args_spec_list = {value, bias}; | AbstractBasePtrList args_spec_list = {value, bias}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto res = dyn_cast<AbstractTensor>(ret); | auto res = dyn_cast<AbstractTensor>(ret); | ||||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32}); | auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32}); | ||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| @@ -606,7 +606,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) { | |||||
| auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10}); | auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10}); | ||||
| AbstractBasePtrList args_spec_list = {logits, labels}; | AbstractBasePtrList args_spec_list = {logits, labels}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_NE(ret, nullptr); | ASSERT_NE(ret, nullptr); | ||||
| auto res = dyn_cast<AbstractTuple>(ret); | auto res = dyn_cast<AbstractTuple>(ret); | ||||
| auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64}); | auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64}); | ||||
| @@ -636,7 +636,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) { | |||||
| auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10}); | auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10}); | ||||
| AbstractBasePtrList args_spec_list = {logits, labels}; | AbstractBasePtrList args_spec_list = {logits, labels}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto res = dyn_cast<AbstractScalar>(ret); | auto res = dyn_cast<AbstractScalar>(ret); | ||||
| AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64); | AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64); | ||||
| expected->set_type(UTPrimUtils::kF64); | expected->set_type(UTPrimUtils::kF64); | ||||
| @@ -690,7 +690,7 @@ TEST_F(TestPrim, test_fused_batch_norm) { | |||||
| AbstractBasePtr expected0 = abstract_inputs->Clone(); | AbstractBasePtr expected0 = abstract_inputs->Clone(); | ||||
| AbstractBasePtr expected1 = abstract_scale->Clone(); | AbstractBasePtr expected1 = abstract_scale->Clone(); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| MS_LOG(INFO) << "expected0: " << expected0->ToString(); | MS_LOG(INFO) << "expected0: " << expected0->ToString(); | ||||
| MS_LOG(INFO) << "expected1: " << expected1->ToString(); | MS_LOG(INFO) << "expected1: " << expected1->ToString(); | ||||
| @@ -722,7 +722,7 @@ TEST_F(TestPrim, test_pooling) { | |||||
| inputs->set_shape(inputs_dims); | inputs->set_shape(inputs_dims); | ||||
| AbstractBasePtr abstract_input = FromValue(inputs, false); | AbstractBasePtr abstract_input = FromValue(inputs, false); | ||||
| AbstractBasePtrList args_spec_list = {abstract_input}; | AbstractBasePtrList args_spec_list = {abstract_input}; | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expected = abstract_input->Clone()->Broaden(); | AbstractBasePtr expected = abstract_input->Clone()->Broaden(); | ||||
| std::vector<int> expected_dims = {8, 64, 2, 2}; | std::vector<int> expected_dims = {8, 64, 2, 2}; | ||||
| @@ -747,7 +747,7 @@ TEST_F(TestPrim, test_hastype) { | |||||
| auto prim = std::make_shared<Primitive>("hastype"); | auto prim = std::make_shared<Primitive>("hastype"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| } | } | ||||
| @@ -761,7 +761,7 @@ TEST_F(TestPrim, test_array_len) { | |||||
| auto prim = std::make_shared<Primitive>("array_len"); | auto prim = std::make_shared<Primitive>("array_len"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| } | } | ||||
| @@ -775,7 +775,7 @@ TEST_F(TestPrim, test_list_len) { | |||||
| auto prim = std::make_shared<Primitive>("list_len"); | auto prim = std::make_shared<Primitive>("list_len"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| } | } | ||||
| @@ -789,7 +789,7 @@ TEST_F(TestPrim, test_tuple_len) { | |||||
| auto prim = std::make_shared<Primitive>("tuple_len"); | auto prim = std::make_shared<Primitive>("tuple_len"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| } | } | ||||
| @@ -803,7 +803,7 @@ TEST_F(TestPrim, test_tuple_reversed) { | |||||
| auto prim = std::make_shared<Primitive>("tuple_reversed"); | auto prim = std::make_shared<Primitive>("tuple_reversed"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "expect=" << expected->ToString(); | MS_LOG(INFO) << "expect=" << expected->ToString(); | ||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| } | } | ||||
| @@ -825,7 +825,7 @@ TEST_F(TestPrim, test_list_getitem) { | |||||
| auto prim = std::make_shared<Primitive>("list_getitem"); | auto prim = std::make_shared<Primitive>("list_getitem"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *elem); | ASSERT_TRUE(*res == *elem); | ||||
| } | } | ||||
| @@ -844,7 +844,7 @@ TEST_F(TestPrim, test_list_setitem) { | |||||
| auto prim = std::make_shared<Primitive>("list_setitem"); | auto prim = std::make_shared<Primitive>("list_setitem"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| AbstractBasePtrList elems_exp = {elem1, elem2}; | AbstractBasePtrList elems_exp = {elem1, elem2}; | ||||
| auto expected = std::make_shared<AbstractList>(elems_exp); | auto expected = std::make_shared<AbstractList>(elems_exp); | ||||
| @@ -866,7 +866,7 @@ TEST_F(TestPrim, test_list_append) { | |||||
| auto prim = std::make_shared<Primitive>("list_append"); | auto prim = std::make_shared<Primitive>("list_append"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2})); | auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2})); | ||||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | MS_LOG(INFO) << "expected: " << expected->ToString(); | ||||
| @@ -890,7 +890,7 @@ TEST_F(TestPrim, test_tuple_setitem) { | |||||
| auto prim = std::make_shared<Primitive>("tuple_setitem"); | auto prim = std::make_shared<Primitive>("tuple_setitem"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| AbstractBasePtrList elems_exp = {elem1, elem2}; | AbstractBasePtrList elems_exp = {elem1, elem2}; | ||||
| auto expected = std::make_shared<AbstractTuple>(elems_exp); | auto expected = std::make_shared<AbstractTuple>(elems_exp); | ||||
| @@ -916,7 +916,7 @@ TEST_F(TestPrim, test_make_list) { | |||||
| auto prim = std::make_shared<Primitive>("make_list"); | auto prim = std::make_shared<Primitive>("make_list"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| } | } | ||||
| @@ -939,7 +939,7 @@ TEST_F(TestPrim, test_make_range) { | |||||
| AbstractBasePtrList elem_list({ele1, ele2, ele3}); | AbstractBasePtrList elem_list({ele1, ele2, ele3}); | ||||
| AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list); | AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "res=" << res->ToString(); | MS_LOG(INFO) << "res=" << res->ToString(); | ||||
| MS_LOG(INFO) << "expected=" << expected->ToString(); | MS_LOG(INFO) << "expected=" << expected->ToString(); | ||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| @@ -982,7 +982,7 @@ TEST_F(TestPrim, test_layernorm) { | |||||
| AbstractBasePtr expected1 = abstract_mean_var->Clone(); | AbstractBasePtr expected1 = abstract_mean_var->Clone(); | ||||
| AbstractBasePtr expected2 = abstract_mean_var->Clone(); | AbstractBasePtr expected2 = abstract_mean_var->Clone(); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| MS_LOG(INFO) << "expected0: " << expected0->ToString(); | MS_LOG(INFO) << "expected0: " << expected0->ToString(); | ||||
| MS_LOG(INFO) << "expected1: " << expected1->ToString(); | MS_LOG(INFO) << "expected1: " << expected1->ToString(); | ||||
| @@ -1028,7 +1028,7 @@ TEST_F(TestPrim, test_DropoutGenMask) { | |||||
| AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | ||||
| std::make_shared<Shape>(std::vector<int>{79})); | std::make_shared<Shape>(std::vector<int>{79})); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "res=" << res->ToString(); | MS_LOG(INFO) << "res=" << res->ToString(); | ||||
| MS_LOG(INFO) << "expected=" << expected->ToString(); | MS_LOG(INFO) << "expected=" << expected->ToString(); | ||||
| ASSERT_TRUE(*res == *expected); | ASSERT_TRUE(*res == *expected); | ||||
| @@ -1058,7 +1058,7 @@ TEST_F(TestPrim, test_dropout) { | |||||
| std::vector<int> shape = {2, 20, 32, 32}; | std::vector<int> shape = {2, 20, 32, 32}; | ||||
| expected->set_shape(std::make_shared<Shape>(shape)); | expected->set_shape(std::make_shared<Shape>(shape)); | ||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| MS_LOG(INFO) << "result: " << res->ToString(); | MS_LOG(INFO) << "result: " << res->ToString(); | ||||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | MS_LOG(INFO) << "expected: " << expected->ToString(); | ||||
| @@ -1079,7 +1079,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) { | |||||
| auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | ||||
| auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | ||||
| AbstractBasePtrList args_spec_list = {x_input, y_input}; | AbstractBasePtrList args_spec_list = {x_input, y_input}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto res = dyn_cast<AbstractTuple>(ret); | auto res = dyn_cast<AbstractTuple>(ret); | ||||
| AbstractBasePtrList x_idx_list; | AbstractBasePtrList x_idx_list; | ||||
| auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | ||||
| @@ -1103,7 +1103,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) { | |||||
| auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | ||||
| auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | ||||
| AbstractBasePtrList args_spec_list = {x_input, y_input}; | AbstractBasePtrList args_spec_list = {x_input, y_input}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| auto res = dyn_cast<AbstractTuple>(ret); | auto res = dyn_cast<AbstractTuple>(ret); | ||||
| AbstractBasePtrList x_idx_list({abstract::FromValue(1)}); | AbstractBasePtrList x_idx_list({abstract::FromValue(1)}); | ||||
| auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | ||||
| @@ -1128,7 +1128,7 @@ TEST_F(TestPrim, test_DictGetItem) { | |||||
| AbstractBasePtr key = abstract::FromValue("x"); | AbstractBasePtr key = abstract::FromValue("x"); | ||||
| AbstractBasePtrList args_spec_list = {array_dict, key}; | AbstractBasePtrList args_spec_list = {array_dict, key}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | ||||
| AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second)); | AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second)); | ||||
| @@ -1147,7 +1147,7 @@ TEST_F(TestPrim, test_DictGetItem2) { | |||||
| AbstractBasePtr key = abstract::FromValue("x"); | AbstractBasePtr key = abstract::FromValue("x"); | ||||
| AbstractBasePtrList args_spec_list = {array_dict, key}; | AbstractBasePtrList args_spec_list = {array_dict, key}; | ||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | ||||
| AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x); | AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x); | ||||
| @@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) { | |||||
| auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); | auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ||||
| } | } | ||||
| @@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) { | |||||
| MS_LOG(INFO) << "" << graph_f_->get_return()->ToString(); | MS_LOG(INFO) << "" << graph_f_->get_return()->ToString(); | ||||
| AbstractBasePtr abstract_v1 = FromValue(1, false); | AbstractBasePtr abstract_v1 = FromValue(1, false); | ||||
| args_spec_list.push_back(abstract_v1); | args_spec_list.push_back(abstract_v1); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ||||
| // now this test case failed randomly, have to debug. | // now this test case failed randomly, have to debug. | ||||
| @@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) { | |||||
| args_spec_list.clear(); | args_spec_list.clear(); | ||||
| args_spec_list.push_back(abstract_v1); | args_spec_list.push_back(abstract_v1); | ||||
| args_spec_list.push_back(abstract_v2); | args_spec_list.push_back(abstract_v2); | ||||
| abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred; | |||||
| abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ||||
| } | } | ||||
| @@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) { | |||||
| AbstractBasePtr abstract_v2 = FromValue(v1, false); | AbstractBasePtr abstract_v2 = FromValue(v1, false); | ||||
| args_spec_list.push_back(abstract_v1); | args_spec_list.push_back(abstract_v1); | ||||
| args_spec_list.push_back(abstract_v2); | args_spec_list.push_back(abstract_v2); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract(); | |||||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | ||||
| } | } | ||||
| @@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { | |||||
| auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); | auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); | ||||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | ||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred; | |||||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract(); | |||||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack())); | ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack())); | ||||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32); | ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32); | ||||
| } | } | ||||
| @@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) { | |||||
| AbstractBasePtr base1 = FromValue(x1, false); | AbstractBasePtr base1 = FromValue(x1, false); | ||||
| AbstractBasePtr base2 = FromValue(x2, false); | AbstractBasePtr base2 = FromValue(x2, false); | ||||
| AbstractBasePtrList base_list = {base1, base2}; | AbstractBasePtrList base_list = {base1, base2}; | ||||
| auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list); | |||||
| auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract(); | |||||
| MS_LOG(INFO) << "result spec: " << res->ToString(); | MS_LOG(INFO) << "result spec: " << res->ToString(); | ||||
| AbstractBasePtr exp = FromValue(x3, false); | AbstractBasePtr exp = FromValue(x3, false); | ||||
| MS_LOG(INFO) << "result exp: " << exp->ToString(); | MS_LOG(INFO) << "result exp: " << exp->ToString(); | ||||
| @@ -446,7 +446,7 @@ void TestGraphEval::TearDown() { | |||||
| TEST_F(TestGraphInfer, test_graph_infer_defaults) { | TEST_F(TestGraphInfer, test_graph_infer_defaults) { | ||||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults"); | FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults"); | ||||
| AbstractBasePtrList args_spec_list = {}; | AbstractBasePtrList args_spec_list = {}; | ||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expect = FromValue(MakeValue(50), false); | AbstractBasePtr expect = FromValue(MakeValue(50), false); | ||||
| ASSERT_EQ(*res, *expect); | ASSERT_EQ(*res, *expect); | ||||
| } | } | ||||
| @@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) { | |||||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { | TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { | ||||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0"); | FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0"); | ||||
| AbstractBasePtrList args_spec_list = {}; | AbstractBasePtrList args_spec_list = {}; | ||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expect = FromValue(MakeValue(1), false); | AbstractBasePtr expect = FromValue(MakeValue(1), false); | ||||
| ASSERT_EQ(*res, *expect); | ASSERT_EQ(*res, *expect); | ||||
| } | } | ||||
| @@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { | |||||
| TEST_F(TestGraphInfer, test_graph_infer_vararg) { | TEST_F(TestGraphInfer, test_graph_infer_vararg) { | ||||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg"); | FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg"); | ||||
| AbstractBasePtrList args_spec_list = {}; | AbstractBasePtrList args_spec_list = {}; | ||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expect = FromValue(MakeValue(9), false); | AbstractBasePtr expect = FromValue(MakeValue(9), false); | ||||
| ASSERT_EQ(*res, *expect); | ASSERT_EQ(*res, *expect); | ||||
| } | } | ||||
| @@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) { | |||||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { | TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { | ||||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs"); | FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs"); | ||||
| AbstractBasePtrList args_spec_list = {}; | AbstractBasePtrList args_spec_list = {}; | ||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expect = FromValue(MakeValue(48), false); | AbstractBasePtr expect = FromValue(MakeValue(48), false); | ||||
| ASSERT_EQ(*res, *expect); | ASSERT_EQ(*res, *expect); | ||||
| } | } | ||||
| @@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { | |||||
| TEST_F(TestGraphInfer, test_graph_infer_kwarg) { | TEST_F(TestGraphInfer, test_graph_infer_kwarg) { | ||||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg"); | FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg"); | ||||
| AbstractBasePtrList args_spec_list = {}; | AbstractBasePtrList args_spec_list = {}; | ||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expect = FromValue(MakeValue(7), false); | AbstractBasePtr expect = FromValue(MakeValue(7), false); | ||||
| ASSERT_EQ(*res, *expect); | ASSERT_EQ(*res, *expect); | ||||
| } | } | ||||
| @@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) { | |||||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { | TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { | ||||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg"); | FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg"); | ||||
| AbstractBasePtrList args_spec_list = {}; | AbstractBasePtrList args_spec_list = {}; | ||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expect = FromValue(MakeValue(46), false); | AbstractBasePtr expect = FromValue(MakeValue(46), false); | ||||
| ASSERT_EQ(*res, *expect); | ASSERT_EQ(*res, *expect); | ||||
| } | } | ||||
| @@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { | |||||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) { | TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) { | ||||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults"); | FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults"); | ||||
| AbstractBasePtrList args_spec_list = {}; | AbstractBasePtrList args_spec_list = {}; | ||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||||
| AbstractBasePtr expect = FromValue(MakeValue(57), false); | AbstractBasePtr expect = FromValue(MakeValue(57), false); | ||||
| ASSERT_EQ(*res, *expect); | ASSERT_EQ(*res, *expect); | ||||
| } | } | ||||
| @@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | ||||
| from ....mindspore_test_framework.pipeline.forward.verify_exception \ | from ....mindspore_test_framework.pipeline.forward.verify_exception \ | ||||
| import pipeline_for_verify_exception_for_case_by_case_config | import pipeline_for_verify_exception_for_case_by_case_config | ||||
| from mindspore import context | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| def conv3x3(in_channels, out_channels, stride=1, padding=1): | def conv3x3(in_channels, out_channels, stride=1, padding=1): | ||||
| """3x3 convolution """ | """3x3 convolution """ | ||||
| @@ -377,6 +378,21 @@ class StateNet(nn.Cell): | |||||
| return x | return x | ||||
| def test_conv2d_same_primitive(): | |||||
| class Conv2DSameNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Conv2DSameNet, self).__init__() | |||||
| self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||||
| def construct(self, x, y): | |||||
| r1 = self.conv1(x) | |||||
| r2 = self.conv2(y) | |||||
| return (r1, r2) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| net = Conv2DSameNet() | |||||
| out = net(t1, t2) | |||||
| class ComparisonNet(nn.Cell): | class ComparisonNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| """ ComparisonNet definition """ | """ ComparisonNet definition """ | ||||
| @@ -0,0 +1,276 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test nn ops """ | |||||
| import functools | |||||
| import numpy as np | |||||
| import mindspore | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.ops import Primitive | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | |||||
| from mindspore.ops.primitive import constexpr | |||||
| from ..ut_filter import non_graph_engine | |||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | |||||
| from ....mindspore_test_framework.pipeline.forward.verify_exception \ | |||||
| import pipeline_for_verify_exception_for_case_by_case_config | |||||
| from mindspore import context | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| class FakeOp(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """""" | |||||
| def infer_shape(self, x, y): | |||||
| self.second_shape = y | |||||
| self.add_prim_attr("second_shape", y) | |||||
| return x | |||||
| def infer_dtype(self, x, y): | |||||
| return x | |||||
| # test the normal case that should generate independent primitive because of different | |||||
| # generated attributes after inference | |||||
| def test_conv2d_same_primitive(): | |||||
| class Conv2DSameNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Conv2DSameNet, self).__init__() | |||||
| self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||||
| def construct(self, x, y): | |||||
| r1 = self.conv1(x) | |||||
| r2 = self.conv2(y) | |||||
| return (r1, r2) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| net = Conv2DSameNet() | |||||
| out = net(t1, t2) | |||||
| # test cell as high order argument | |||||
| # The graph with free variables used as argument is not supported yet | |||||
| # because of the limit of inference specialize system | |||||
| def Xtest_conv2d_op_with_arg(): | |||||
| class Conv2dNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Conv2dNet, self).__init__() | |||||
| def construct(self, op, x): | |||||
| return op(x) | |||||
| class OpsNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(OpsNet, self).__init__() | |||||
| self.opnet = net | |||||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | |||||
| def construct(self, x, y): | |||||
| conv_op = self.conv2 | |||||
| a = self.opnet(conv_op, x) | |||||
| b = self.opnet(conv_op, y) | |||||
| return (a, b) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| net = OpsNet(Conv2dNet()) | |||||
| out = net(t1, t2) | |||||
| def test_conv2d_op_with_arg(): | |||||
| class FackOpNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(FackOpNet, self).__init__() | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y): | |||||
| return self.op(x, y) | |||||
| class OpNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(OpNet, self).__init__() | |||||
| def construct(self, op, x, y): | |||||
| return op(x, y) | |||||
| class OpsNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(OpsNet, self).__init__() | |||||
| self.opnet = net | |||||
| self.op = FackOpNet() | |||||
| def construct(self, x, y): | |||||
| op = self.op | |||||
| a = self.opnet(op, x, y) | |||||
| b = self.opnet(op, y, x) | |||||
| return (a, b) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| net = OpsNet(OpNet()) | |||||
| out = net(t1, t2) | |||||
| def test_conv2d_op_with_arg_same_input(): | |||||
| class FackOpNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(FackOpNet, self).__init__() | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y): | |||||
| return self.op(x, y) | |||||
| class OpNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(OpNet, self).__init__() | |||||
| def construct(self, op, x, y): | |||||
| return op(x, y) | |||||
| class OpsNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(OpsNet, self).__init__() | |||||
| self.opnet = net | |||||
| self.op = FackOpNet() | |||||
| def construct(self, x, y): | |||||
| op = self.op | |||||
| a = self.opnet(op, x, x) | |||||
| b = self.opnet(op, y, x) | |||||
| return (a, b) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| net = OpsNet(OpNet()) | |||||
| out = net(t1, t2) | |||||
| # test op with partial | |||||
| def test_op_as_partial(): | |||||
| class OpAsPartial(nn.Cell): | |||||
| def __init__(self): | |||||
| super(OpAsPartial, self).__init__() | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y, z): | |||||
| partial_op = F.partial(self.op, x) | |||||
| a = partial_op(y) | |||||
| b = partial_op(z) | |||||
| return a, b | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||||
| net = OpAsPartial() | |||||
| out = net(t1, t2, t3) | |||||
| # test op with partial | |||||
| def test_op_as_partial_inside(): | |||||
| class OpAsPartial(nn.Cell): | |||||
| def __init__(self): | |||||
| super(OpAsPartial, self).__init__() | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y, z): | |||||
| partial_op = F.partial(self.op, x) | |||||
| a = partial_op(y) | |||||
| b = partial_op(z) | |||||
| return a, b | |||||
| class OuterNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(OuterNet, self).__init__() | |||||
| self.net = OpAsPartial() | |||||
| def construct(self, x, y, z): | |||||
| a,b = self.net(x, y, z) | |||||
| return a, b | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||||
| net = OuterNet() | |||||
| out = net(t1, t2, t3) | |||||
| # test op with partial case 2 | |||||
| def test_op_as_partial_independent(): | |||||
| class OpAsPartial(nn.Cell): | |||||
| def __init__(self): | |||||
| super(OpAsPartial, self).__init__() | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y, z): | |||||
| partial_op1 = F.partial(self.op, x) | |||||
| a = partial_op1(y) | |||||
| partial_op2 = F.partial(self.op, x) | |||||
| b = partial_op2(z) | |||||
| return a, b | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||||
| net = OpAsPartial() | |||||
| out = net(t1, t2, t3) | |||||
| def test_nest_partial(): | |||||
| class NestPartial(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NestPartial, self).__init__() | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y, z): | |||||
| partial_op1 = F.partial(self.op) | |||||
| partial_op2 = F.partial(partial_op1, x) | |||||
| a = partial_op2(y) | |||||
| partial_op3 = F.partial(self.op) | |||||
| partial_op4 = F.partial(partial_op3, x) | |||||
| b = partial_op4(z) | |||||
| return a, b | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||||
| net = NestPartial() | |||||
| out = net(t1, t2, t3) | |||||
| # high order argument | |||||
| # op and op args as network arguments | |||||
| def test_op_with_arg_as_input(): | |||||
| class WithOpArgNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(WithOpArgNet, self).__init__() | |||||
| def construct(self, op, x, y): | |||||
| return op(x, y) | |||||
| class OpsNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(OpsNet, self).__init__() | |||||
| self.opnet = net | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y, z): | |||||
| op = self.op | |||||
| a = self.opnet(op, x, z) | |||||
| b = self.opnet(op, x, y) | |||||
| return (a, b) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||||
| net = OpsNet(WithOpArgNet()) | |||||
| out = net(t1, t2, t3) | |||||
| # The partial application used as argument is not supported yet | |||||
| # because of the limit of inference specialize system | |||||
| def Xtest_partial_as_arg(): | |||||
| class PartialArgNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(PartialArgNet, self).__init__() | |||||
| def construct(self, partial_op, y): | |||||
| return partial_op(y) | |||||
| class OpsNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(OpsNet, self).__init__() | |||||
| self.partial_net = net | |||||
| self.op = FakeOp() | |||||
| def construct(self, x, y, z): | |||||
| partial_op = F.partial(self.op, x) | |||||
| a = self.partial_net(partial_op, z) | |||||
| b = self.partial_net(partial_op, y) | |||||
| return (a, b) | |||||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||||
| net = OpsNet(PartialArgNet()) | |||||
| out = net(t1, t2, t3) | |||||