Merge pull request !1383 from amongo/KeepPrimAttrInCNodetags/v0.5.0-beta
| @@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { | |||
| auto ctx = node_cfg_->context(); | |||
| auto engine = node_cfg_->engine(); | |||
| auto cfg = engine->MakeConfig(node, ctx); | |||
| auto abs = engine->cache().GetValue(cfg); | |||
| if (abs == nullptr) { | |||
| auto eval_result = engine->cache().GetValue(cfg); | |||
| if (eval_result == nullptr || eval_result->abstract() == nullptr) { | |||
| return "Undefined"; | |||
| } | |||
| auto abs = eval_result->abstract(); | |||
| auto dtype = abs->BuildType(); | |||
| auto shape = abs->BuildShape(); | |||
| std::ostringstream oss; | |||
| @@ -42,7 +42,11 @@ enum PrimType { | |||
| class Primitive : public Named { | |||
| public: | |||
| explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) | |||
| : Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {} | |||
| : Named(name), | |||
| is_base_(is_base), | |||
| has_signature_(false), | |||
| prim_type_(prim_type), | |||
| record_evaluate_add_attr_(false) {} | |||
| Primitive(const Primitive &prim) | |||
| : Named(prim), | |||
| @@ -50,14 +54,23 @@ class Primitive : public Named { | |||
| instance_name_(prim.instance_name_), | |||
| is_base_(prim.is_base_), | |||
| has_signature_(prim.has_signature_), | |||
| prim_type_(prim.prim_type_) {} | |||
| prim_type_(prim.prim_type_), | |||
| record_evaluate_add_attr_(false) {} | |||
| MS_DECLARE_PARENT(Primitive, Named); | |||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||
| std::string ToString() const override { return name(); } | |||
| void BeginRecordAddAttr() { | |||
| evaluate_added_attrs_.clear(); | |||
| record_evaluate_add_attr_ = true; | |||
| } | |||
| void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } | |||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | |||
| attrs_[name] = attr; | |||
| if (record_evaluate_add_attr_) { | |||
| evaluate_added_attrs_[name] = attr; | |||
| } | |||
| return *this; | |||
| } | |||
| @@ -80,6 +93,7 @@ class Primitive : public Named { | |||
| py::function hook() const { return hook_; } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; } | |||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | |||
| bool HasAttr() const { return !attrs_.empty(); } | |||
| @@ -106,6 +120,7 @@ class Primitive : public Named { | |||
| protected: | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_; | |||
| private: | |||
| std::string instance_name_; | |||
| @@ -113,6 +128,7 @@ class Primitive : public Named { | |||
| bool is_base_; | |||
| bool has_signature_; | |||
| PrimType prim_type_; | |||
| bool record_evaluate_add_attr_; | |||
| }; | |||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | |||
| @@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv | |||
| } | |||
| subargs.push_back(AbstractJoin(l_ptr->elements())); | |||
| } | |||
| AbstractBasePtr engin_exc = engine->Execute(fn, subargs); | |||
| EvalResultPtr engin_exc = engine->Execute(fn, subargs); | |||
| AbstractBasePtrList result; | |||
| for (std::size_t i = 1; i < args_spec_list.size(); i++) { | |||
| result.push_back(engin_exc); | |||
| result.push_back(engin_exc->abstract()); | |||
| } | |||
| return std::make_shared<AbstractList>(result); | |||
| } | |||
| @@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi | |||
| AbstractBasePtr list_type = AbstractJoin(lst->elements()); | |||
| auto result1 = engine->Execute(fn, lst->elements()); | |||
| auto result2 = engine->Execute(fn, {dflt, list_type}); | |||
| MS_EXCEPTION_IF_NULL(result1); | |||
| return result1->Join(result2); | |||
| MS_EXCEPTION_IF_NULL(result1->abstract()); | |||
| MS_EXCEPTION_IF_NULL(result2->abstract()); | |||
| return result1->abstract()->Join(result2->abstract()); | |||
| } | |||
| AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) { | |||
| return sorted_nodes; | |||
| } | |||
| AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | |||
| EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | |||
| FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| std::size_t nargs = fg->parameters().size(); | |||
| @@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs | |||
| const auto &arg = args_spec_list[i]; | |||
| const auto &node = parameters[i]; | |||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); | |||
| engine->cache().set_value(conf, arg); | |||
| engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr)); | |||
| } | |||
| const AnfNodePtr &func_node = fg->get_return(); | |||
| @@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs | |||
| const auto &node = *it; | |||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); | |||
| MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); | |||
| ret_base = engine->GetEvaluatedValue(node_conf); | |||
| ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); | |||
| MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() | |||
| << ", abstract: " << ret_base->ToString(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(ret_base); | |||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString(); | |||
| return ret_base; | |||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString(); | |||
| return std::make_shared<EvalResult>(ret_base, nullptr); | |||
| } | |||
| AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||
| @@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons | |||
| return cloned_func_graph; | |||
| } | |||
| AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { | |||
| const std::string &evaluator_name = ToString(); | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| args_spec_list = NormalizeArgs(args_spec_list); | |||
| args_spec_list = BroadenUndeterminedArgs(args_spec_list); | |||
| @@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar | |||
| auto iter = cache_->find(args_spec_list); | |||
| if (iter == cache_->end()) { | |||
| MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; | |||
| AbstractBasePtr ret = Eval(engine, args_spec_list); | |||
| if (ret == nullptr) { | |||
| EvalResultPtr ret = Eval(engine, args_spec_list); | |||
| if (ret->abstract() == nullptr) { | |||
| EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | |||
| MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(ret); | |||
| MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << "."; | |||
| MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; | |||
| (*cache_)[args_spec_list] = ret; | |||
| trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | |||
| return ret; | |||
| } else { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; | |||
| MS_EXCEPTION_IF_NULL(iter->second->abstract()); | |||
| MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; | |||
| trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); | |||
| return iter->second; | |||
| } | |||
| } | |||
| AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr) { | |||
| EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr) { | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| AbstractBasePtr ret = EvalPrim(engine, args_spec_list); | |||
| EvalResultPtr ret = EvalPrim(engine, args_spec_list); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| if (args_conf_list.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "Size should greater than 0"; | |||
| } | |||
| AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); | |||
| EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); | |||
| // No need to cache. | |||
| return ret; | |||
| } | |||
| AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||
| AbstractBasePtr ret = EvalPrim(args_conf_list); | |||
| EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||
| EvalResultPtr ret = EvalPrim(args_conf_list); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); | |||
| EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); | |||
| // Don't lookup from cache, as different out_conf with same node but different context | |||
| // may add different entry to anfnode_config_map_, like getattr primitive. | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| } | |||
| AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| MS_EXCEPTION_IF_NULL(cache_); | |||
| auto iter = cache_->find(args_spec_list); | |||
| @@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), | |||
| [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); | |||
| AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); | |||
| EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| } | |||
| AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||
| EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| MS_EXCEPTION_IF_NULL(cache_); | |||
| auto iter = cache_->find(args_spec_list); | |||
| @@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a | |||
| } | |||
| // Call the original evaluator, get the result: y = f(x) | |||
| AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr); | |||
| EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); | |||
| // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input | |||
| // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) | |||
| AbstractBasePtrList bparams; | |||
| @@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a | |||
| args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), | |||
| [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); | |||
| AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams); | |||
| AbstractFunctionPtr bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result), bparams_final); | |||
| AbstractFunctionPtr bprop = | |||
| std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final); | |||
| // J(f)(J(x)) return a tuple (y, bprop_f) | |||
| AbstractBasePtrList jargs = {result, bprop}; | |||
| AbstractBasePtrList jargs = {result->abstract(), bprop}; | |||
| AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs); | |||
| (*cache_)[args_spec_list] = jtuple; | |||
| return jtuple; | |||
| auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>()); | |||
| (*cache_)[args_spec_list] = infer_reuslt; | |||
| return infer_reuslt; | |||
| } | |||
| AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { | |||
| EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { | |||
| if (args_spec_list.size() != args_spec_list_.size()) { | |||
| MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() | |||
| << ", arguments no: " << args_spec_list.size(); | |||
| @@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[i]); | |||
| (void)args_spec_list[i]->Join(args_spec_list_[i]); | |||
| } | |||
| return output_; | |||
| return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>()); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -29,21 +29,28 @@ | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| using EvaluatorCacheMap = | |||
| std::unordered_map<AbstractBasePtrList, AbstractBasePtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; | |||
| std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; | |||
| using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>; | |||
| using EvaluatorAttrMap = | |||
| std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; | |||
| using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>; | |||
| class Evaluator : public Base { | |||
| public: | |||
| explicit Evaluator(const std::string &id) : cache_(std::make_shared<EvaluatorCacheMap>()), identifier_(id) {} | |||
| explicit Evaluator(const std::string &id) | |||
| : cache_(std::make_shared<EvaluatorCacheMap>()), | |||
| attr_cache_(std::make_shared<EvaluatorAttrMap>()), | |||
| identifier_(id) {} | |||
| ~Evaluator() override = default; | |||
| MS_DECLARE_PARENT(Evaluator, Base); | |||
| // difference between Run() and Eval(): | |||
| // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. | |||
| // Run() will modify cache_ member, so it cannot marked as const; | |||
| virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); | |||
| virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); | |||
| virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||
| virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||
| virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | |||
| @@ -58,9 +65,10 @@ class Evaluator : public Base { | |||
| virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } | |||
| EvaluatorCacheMapPtr &cache() { return cache_; } | |||
| EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } | |||
| EvaluatorCacheMapPtr cache_; | |||
| EvaluatorAttrMapPtr attr_cache_; | |||
| std::string identifier_; | |||
| AnfNodeWeakPtr bound_node_; | |||
| @@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator { | |||
| explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} | |||
| ~PrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(PrimEvaluator, Evaluator); | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { | |||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||
| } | |||
| }; | |||
| @@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator { | |||
| explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | |||
| ~TrivialPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||
| virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||
| virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; | |||
| }; | |||
| class TransitionPrimEvaluator : public PrimEvaluator { | |||
| @@ -90,10 +98,10 @@ class TransitionPrimEvaluator : public PrimEvaluator { | |||
| explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | |||
| ~TransitionPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||
| // Parameter in_conf0 : the first element in args_conf_list; | |||
| virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; | |||
| virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; | |||
| }; | |||
| class SymbolicPrimEvaluator : public PrimEvaluator { | |||
| @@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator { | |||
| explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} | |||
| ~SymbolicPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||
| virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; | |||
| virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; | |||
| }; | |||
| // Evaluator will be stored in AnalysisEngine.constructors_ | |||
| @@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator { | |||
| DummyEvaluator() : Evaluator("dummy") {} | |||
| ~DummyEvaluator() override = default; | |||
| MS_DECLARE_PARENT(DummyEvaluator, Evaluator); | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } | |||
| }; | |||
| // Wrap another evaluator to track a subset of uses. | |||
| @@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator { | |||
| bound_node_ = AnfNodeWeakPtr(node); | |||
| } | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||
| } | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) override; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; | |||
| std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } | |||
| private: | |||
| @@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator { | |||
| ~BaseFuncGraphEvaluator() override = default; | |||
| MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); | |||
| AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||
| EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||
| virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; | |||
| @@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator { | |||
| } | |||
| bound_node_ = AnfNodeWeakPtr(node); | |||
| } | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | |||
| } | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) override; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; | |||
| std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } | |||
| private: | |||
| @@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator { | |||
| ~VirtualEvaluator() override = default; | |||
| MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); | |||
| AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||
| EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; | |||
| std::string ToString() const override { return identifier_; } | |||
| private: | |||
| @@ -285,11 +292,11 @@ class JEvaluator : public Evaluator { | |||
| } | |||
| bound_node_ = AnfNodeWeakPtr(node); | |||
| } | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | |||
| } | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) override; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; | |||
| std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } | |||
| private: | |||
| @@ -135,13 +135,17 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| using mindspore::parse::PyObjectWrapper; | |||
| AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||
| return abs_base; | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| return infer_result; | |||
| } | |||
| AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| if (!prim_->isa<prim::DoSignaturePrimitive>()) { | |||
| MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString(); | |||
| @@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config | |||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||
| ScopePtr scope = kDefaultScope; | |||
| if (out_conf != nullptr) { | |||
| @@ -212,8 +216,8 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s | |||
| return graph_specialize_args; | |||
| } | |||
| AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||
| } | |||
| @@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config | |||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||
| // get the forward graph | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>(); | |||
| @@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| } | |||
| } // end anonymous namespace | |||
| AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | |||
| const auto &iter = cache_->find(args); | |||
| @@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A | |||
| MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; | |||
| } | |||
| auto infer_fuc = pyobj.attr("__infer__"); | |||
| prim_py_->BeginRecordAddAttr(); | |||
| py::dict output = infer_fuc(*py_args); | |||
| prim_py_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_py_->evaluate_added_attrs(); | |||
| MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); | |||
| auto res_spec = PyInferRes2Abstract(prim_py_, output); | |||
| MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | |||
| (*cache_)[args] = res_spec; | |||
| return res_spec; | |||
| auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); | |||
| (*cache_)[args] = infer_result; | |||
| return infer_result; | |||
| } | |||
| AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. | |||
| if (nargs_ != args.size()) { | |||
| MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; | |||
| @@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const | |||
| } | |||
| AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type); | |||
| return abs_base; | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>()); | |||
| } | |||
| ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { | |||
| @@ -553,8 +560,8 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun | |||
| manager->AddFuncGraph(func_graph); | |||
| } | |||
| AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, | |||
| const AnfNodeConfigPtr &old_conf) { | |||
| EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, | |||
| const AnfNodeConfigPtr &old_conf) { | |||
| MS_EXCEPTION_IF_NULL(old_conf); | |||
| AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); | |||
| @@ -585,9 +592,9 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat | |||
| return eng->ForwardConfig(old_conf, fn_conf); | |||
| } | |||
| AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, | |||
| const AbstractBasePtrList &args_spec_list, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, | |||
| const AbstractBasePtrList &args_spec_list, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| // args_spec_list: same as StaticGetter | |||
| if (args_spec_list.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; | |||
| @@ -627,9 +634,9 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng | |||
| return eng->ForwardConfig(out_conf, fn_conf); | |||
| } | |||
| AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, | |||
| const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, | |||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||
| EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, | |||
| const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, | |||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(EXCEPTION) << "args_spec_list is empty"; | |||
| } | |||
| @@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e | |||
| AbstractBasePtr attr = cls->GetAttribute(item_name); | |||
| if (attr != nullptr) { | |||
| return attr; | |||
| return std::make_shared<EvalResult>(attr, nullptr); | |||
| } | |||
| ValuePtr method = cls->GetMethod(item_name); | |||
| @@ -660,9 +667,9 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e | |||
| return StaticGetterInferred(converted_v, data_conf, out_conf); | |||
| } | |||
| AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, | |||
| const TypePtr &data_type, const ConfigPtr &data_conf, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, | |||
| const TypePtr &data_type, const ConfigPtr &data_conf, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| MS_EXCEPTION_IF_NULL(item_v); | |||
| MS_EXCEPTION_IF_NULL(data_type); | |||
| // The method maybe a Primitive or Composite | |||
| @@ -689,8 +696,8 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e | |||
| return StaticGetterInferred(converted_v, data_conf, out_conf); | |||
| } | |||
| AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||
| EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { | |||
| // Inputs: namespace and its static function; or class and its member function | |||
| CheckArgsSize("StaticGetter", args_spec_list, 2); | |||
| @@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator { | |||
| EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} | |||
| ~EmbedEvaluator() override = default; | |||
| MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||
| EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||
| // arg: free variable to be embedded | |||
| if (args_conf_list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); | |||
| @@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator { | |||
| AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]); | |||
| MS_EXCEPTION_IF_NULL(node_conf); | |||
| AbstractBasePtr x = node_conf->GetEvaluatedValue(); | |||
| AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); | |||
| x = SensitivityTransform(x); | |||
| SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x); | |||
| AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>()); | |||
| return abs_scalar; | |||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | |||
| } | |||
| }; | |||
| @@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} | |||
| ~RefToEmbedEvaluator() override = default; | |||
| MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||
| EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { | |||
| if (args_conf_list.size() != 1) { | |||
| MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); | |||
| return nullptr; | |||
| @@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; | |||
| return nullptr; | |||
| } | |||
| AbstractBasePtr abs = node_conf->GetEvaluatedValue(); | |||
| AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); | |||
| AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>(); | |||
| if (ref_abs == nullptr) { | |||
| MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; | |||
| @@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| } | |||
| auto refkey = key_value->cast<RefKeyPtr>(); | |||
| if (refkey == nullptr) { | |||
| return std::make_shared<AbstractScalar>(type); | |||
| return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>()); | |||
| } | |||
| std::string name = refkey->tag(); | |||
| @@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| x = SensitivityTransform(x); | |||
| std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | |||
| std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | |||
| return abs_scalar; | |||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | |||
| } | |||
| }; | |||
| @@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { | |||
| GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} | |||
| ~GetAttrEvaluator() override = default; | |||
| MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||
| // Inputs: data, item | |||
| if (args_spec_list.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | |||
| } | |||
| AbstractBasePtr ret = nullptr; | |||
| EvalResultPtr ret = nullptr; | |||
| if (bound_node() != nullptr) { | |||
| TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); | |||
| ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); | |||
| @@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator { | |||
| ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} | |||
| ~ResolveEvaluator() override = default; | |||
| MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||
| // Inputs: namespace, symbol | |||
| if (args_spec_list.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | |||
| } | |||
| AbstractBasePtr ret = nullptr; | |||
| EvalResultPtr ret = nullptr; | |||
| if (bound_node() != nullptr) { | |||
| TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); | |||
| ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); | |||
| @@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { | |||
| CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} | |||
| ~CreateInstanceEvaluator() override = default; | |||
| MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override { | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, | |||
| const AnfNodeConfigPtr &out_conf) override { | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; | |||
| } | |||
| @@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { | |||
| } | |||
| AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| auto infer_result = std::make_shared<EvalResult>(ret, nullptr); | |||
| (*cache_)[args_spec_list] = infer_result; | |||
| return infer_result; | |||
| } | |||
| pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { | |||
| @@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator { | |||
| public: | |||
| PartialEvaluator() : Evaluator("PartialEvaluator") {} | |||
| ~PartialEvaluator() override = default; | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf = nullptr) override { | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf = nullptr) override { | |||
| if (args_conf_list.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "Args size should be greater than 0"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | |||
| auto arg0_value = args_conf_list[0]->GetEvaluatedValue(); | |||
| auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); | |||
| AbstractBasePtrList args_spec_list{arg0_value}; | |||
| // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. | |||
| if (arg0_value->isa<AbstractError>()) { | |||
| auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node()); | |||
| MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() | |||
| << " as func is: " << arg0_value->ToString(); | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | |||
| (*cache_)[args_spec_list] = eval_result; | |||
| return eval_result; | |||
| } | |||
| auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); | |||
| // Sometimes, node[0] in out_conf becomes phi0; | |||
| @@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator { | |||
| } | |||
| } | |||
| (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); }); | |||
| (void)std::transform( | |||
| args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); | |||
| AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); | |||
| auto cnode = out_conf->node()->cast<CNodePtr>(); | |||
| @@ -989,16 +999,17 @@ class PartialEvaluator : public Evaluator { | |||
| func->Visit(build_partial); | |||
| auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | |||
| (*cache_)[args_spec_list] = infer_result; | |||
| return infer_result; | |||
| } | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||
| } | |||
| AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, | |||
| const AnfNodeConfigPtr &out_conf = nullptr) const { | |||
| EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, | |||
| const AnfNodeConfigPtr &out_conf = nullptr) const { | |||
| MS_EXCEPTION_IF_NULL(out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | |||
| auto cnode = out_conf->node()->cast<CNodePtr>(); | |||
| @@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { | |||
| : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} | |||
| ~StandardPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| PrimitivePtr prim() { return prim_; } | |||
| std::string ToString() const override { return identifier_ + prim_->name(); } | |||
| @@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator { | |||
| : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} | |||
| ~PythonPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); } | |||
| std::string ToString() const override { return identifier_ + prim_py_->name(); } | |||
| @@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator { | |||
| public: | |||
| explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} | |||
| ~DoSignatureEvaluator() override = default; | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||
| AnfNodeConfigPtr out_config = nullptr) override; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||
| AnfNodeConfigPtr out_config = nullptr) override; | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||
| } | |||
| @@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator { | |||
| public: | |||
| explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} | |||
| ~UnpackGraphEvaluator() override = default; | |||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||
| AnfNodeConfigPtr out_config = nullptr) override; | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||
| AnfNodeConfigPtr out_config = nullptr) override; | |||
| AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||
| MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; | |||
| } | |||
| @@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator { | |||
| ~UniformPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); | |||
| AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| ValuePtr RunImpl(const ValuePtrList &args) const; | |||
| // If eval_value_ is False, return broadened arguments. | |||
| @@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { | |||
| if (conf->node()->intermediate_abstract()) { | |||
| return conf->node()->intermediate_abstract(); | |||
| } | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| } | |||
| AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { | |||
| @@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() { | |||
| // Specialize CNode in func graphs | |||
| void FuncGraphSpecializer::SecondPass() { | |||
| for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) { | |||
| for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { | |||
| if (node->isa<CNode>()) { | |||
| ProcessCNode(node->cast<CNodePtr>()); | |||
| } | |||
| @@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||
| AnfNodeConfigPtr conf = MakeConfig(node); | |||
| AnfNodePtr new_node = GetReplicatedNode(node); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| if (new_node->func_graph() != specialized_func_graph_) { | |||
| MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() | |||
| << ", new_node: " << new_node->DebugString() | |||
| @@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||
| MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); | |||
| if (node->isa<CNode>()) { | |||
| auto attrs = conf->GetEvaluatedValue()->attribute(); | |||
| auto c_old = node->cast<CNodePtr>(); | |||
| auto c_new = new_node->cast<CNodePtr>(); | |||
| auto new_inputs = c_new->inputs(); | |||
| @@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||
| AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); | |||
| // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if | |||
| // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. | |||
| AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival); | |||
| AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); | |||
| if (replace_node == nullptr) { | |||
| replace_node = BuildReplacedNode(iconf); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| @@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n | |||
| MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() | |||
| << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); | |||
| } | |||
| auto attrs = std::make_shared<AttrValueMap>(); | |||
| for (size_t i = 0; i < partial_closure->args().size(); i++) { | |||
| auto old_node = cnode->input(i + 2); | |||
| auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]); | |||
| auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); | |||
| if (possibile_value_node != nullptr) { | |||
| partial_node_list.push_back(possibile_value_node); | |||
| } else { | |||
| @@ -455,7 +456,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB | |||
| const EvaluatorPtr &eval) { | |||
| MS_EXCEPTION_IF_NULL(eval); | |||
| std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices; | |||
| AbstractBasePtr ret = nullptr; | |||
| EvalResultPtr ret = nullptr; | |||
| AbstractBasePtrList broaded_argvals; | |||
| for (auto &argvals_map : *evalcaches_[eval]) { | |||
| auto argvals = argvals_map.first; | |||
| @@ -478,7 +479,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB | |||
| (*real)[broaded_argvals] = ret; | |||
| evalcaches_[eval] = real; | |||
| return std::make_pair(broaded_argvals, ret); | |||
| return std::make_pair(broaded_argvals, ret->abstract()); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Choices.size: " << choices.size(); | |||
| return std::make_pair(AbstractBasePtrList(), nullptr); | |||
| @@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||
| return; | |||
| } | |||
| specializer_->AddSeen(new_node); | |||
| auto new_inputs = new_node->inputs(); | |||
| if (new_inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; | |||
| @@ -530,7 +530,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||
| } | |||
| if (CanSpecializeNode(func)) { | |||
| new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); | |||
| // for primitive node , we build the primitive node with infered attributes in the first pass | |||
| // so we do not build replaced node again here in second pass | |||
| if (IsValueNode<Primitive>(func)) { | |||
| new_inputs[0] = func; | |||
| } else { | |||
| new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < argvals.size();) { | |||
| @@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||
| } | |||
| i = next; | |||
| } | |||
| new_node->set_inputs(new_inputs); | |||
| } | |||
| @@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct | |||
| EvaluatorCacheMap evaluator_cache_map = *eval->cache(); | |||
| if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { | |||
| *result = std::make_pair(argvals, evaluator_cache_map[argvals]); | |||
| *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); | |||
| return kSpecializeSuccess; | |||
| } | |||
| DumpEvaluatorCache(evaluator_cache_map, argvals); | |||
| @@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct | |||
| MS_EXCEPTION_IF_NULL(choices); | |||
| if (choices->count(argvals)) { | |||
| *result = std::make_pair(argvals, (*choices)[argvals]); | |||
| *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); | |||
| return kSpecializeSuccess; | |||
| } else if (choices->size() == 1) { | |||
| MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; | |||
| *result = std::make_pair(choices->begin()->first, choices->begin()->second); | |||
| *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); | |||
| return kSpecializeSuccess; | |||
| } else if (choices->empty()) { | |||
| MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; | |||
| @@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct | |||
| return kSpecializeFindUniqueArgvalPoly; | |||
| } | |||
| } | |||
| static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { | |||
| auto &prim_attrs = prim->attrs(); | |||
| bool is_attr_same = true; | |||
| for (auto &item : *attrs) { | |||
| auto itr = prim_attrs.find(item.first); | |||
| if (itr != prim_attrs.end()) { | |||
| if (!(*(itr->second) == *(item.second))) { | |||
| is_attr_same = false; | |||
| break; | |||
| } | |||
| } else { | |||
| is_attr_same = false; | |||
| break; | |||
| } | |||
| } | |||
| if (!is_attr_same) { | |||
| if (prim->isa<PrimitivePy>()) { | |||
| PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>(); | |||
| auto clone_fn = prim_py->GetPyObj().attr("_clone"); | |||
| py::object new_obj = clone_fn(); | |||
| auto cloned_prim = new_obj.cast<PrimitivePyPtr>(); | |||
| for (auto &item : *attrs) { | |||
| cloned_prim->AddAttr(item.first, item.second); | |||
| } | |||
| return cloned_prim; | |||
| } | |||
| auto cloned_prim = std::make_shared<Primitive>(*prim); | |||
| for (auto &item : *attrs) { | |||
| cloned_prim->AddAttr(item.first, item.second); | |||
| } | |||
| return cloned_prim; | |||
| } | |||
| return prim; | |||
| } | |||
| AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) { | |||
| AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, | |||
| const AttrValueMapPtr &attrs) { | |||
| MS_EXCEPTION_IF_NULL(origin_node); | |||
| MS_EXCEPTION_IF_NULL(ival); | |||
| @@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin | |||
| ValuePtr value = nullptr; | |||
| if (abs->isa<PrimitiveAbstractClosure>()) { | |||
| auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs); | |||
| value = real_fn->prim(); | |||
| // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one | |||
| if (attrs != nullptr) { | |||
| value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); | |||
| } else { | |||
| value = real_fn->prim(); | |||
| } | |||
| } else if (abs->isa<MetaFuncGraphAbstractClosure>()) { | |||
| auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs); | |||
| value = real_fn->meta_func_graph(); | |||
| @@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia | |||
| AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); | |||
| // Build a value node if ival is constant and not any-value | |||
| AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival); | |||
| AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, | |||
| const AttrValueMapPtr &attrs); | |||
| // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a | |||
| // replicated node. | |||
| AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); | |||
| @@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase | |||
| return nullptr; | |||
| } | |||
| void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) { | |||
| void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { | |||
| MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() | |||
| << ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString() | |||
| << ", Pointer: " << arg.get(); | |||
| cache_[conf] = arg; | |||
| << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() | |||
| << ", Pointer: " << result->abstract().get(); | |||
| cache_[conf] = result; | |||
| // Set intermediate abstract value. | |||
| if (IsIntermediateAbstract(arg)) { | |||
| if (IsIntermediateAbstract(result->abstract())) { | |||
| if (conf->node()->intermediate_abstract() == nullptr) { | |||
| conf->node()->set_intermediate_abstract(arg); | |||
| MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString(); | |||
| conf->node()->set_intermediate_abstract(result->abstract()); | |||
| MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); | |||
| } else { | |||
| auto old_spec = conf->node()->intermediate_abstract(); | |||
| auto joined_spec = IntermediateJoin(arg, old_spec); | |||
| auto joined_spec = IntermediateJoin(result->abstract(), old_spec); | |||
| conf->node()->set_intermediate_abstract(joined_spec); | |||
| MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" | |||
| << arg->ToString() << "\njoined_spec:\t" | |||
| << result->abstract()->ToString() << "\njoined_spec:\t" | |||
| << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); | |||
| } | |||
| } | |||
| } | |||
| AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { | |||
| EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { | |||
| auto value = cache_.find(conf); | |||
| if (value == cache_.end()) { | |||
| return nullptr; | |||
| @@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana | |||
| return eval->graph_context(); | |||
| } | |||
| AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { | |||
| EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| auto value = cache_.GetValue(conf); | |||
| if (value != nullptr) { | |||
| MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", " | |||
| << value->ToString(); | |||
| MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() | |||
| << ", " << value->abstract()->ToString(); | |||
| return value; | |||
| } | |||
| @@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) | |||
| return value; | |||
| } | |||
| AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||
| EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| AnfNodePtr node = conf->node(); | |||
| AbstractBasePtr ret_abstract = nullptr; | |||
| EvalResultPtr eval_result = nullptr; | |||
| #ifdef DEBUG | |||
| compute_conf_stack_.push_back(node); | |||
| std::ostringstream buffer; | |||
| @@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->abstract() != nullptr) { | |||
| MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); | |||
| ret_abstract = node->abstract(); | |||
| eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>()); | |||
| } else if (node->isa<ValueNode>()) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| ret_abstract = EvalValueNode(value_node, conf); | |||
| eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr); | |||
| } else if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| trace::TraceEvalCNodeEnter(conf); | |||
| ret_abstract = EvalCNode(cnode, conf); | |||
| eval_result = EvalCNode(cnode, conf); | |||
| trace::TraceEvalCNodeLeave(); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() | |||
| @@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||
| #ifdef DEBUG | |||
| compute_conf_stack_.pop_back(); | |||
| if (ret_abstract == nullptr) { | |||
| if (eval_result == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() | |||
| << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | |||
| } | |||
| #endif | |||
| MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString(); | |||
| return ret_abstract; | |||
| MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); | |||
| return eval_result; | |||
| } | |||
| AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { | |||
| @@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co | |||
| return ToAbstract(value_node->value(), conf->context(), conf); | |||
| } | |||
| AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||
| EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| @@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo | |||
| AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); | |||
| MS_EXCEPTION_IF_NULL(func_conf); | |||
| // Keep it in a local variable, otherwise smart pointer will free it. | |||
| AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue(); | |||
| AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); | |||
| if (maybe_func == nullptr) { | |||
| MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() | |||
| << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); | |||
| @@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo | |||
| return ExecuteEvaluators(infs, conf, args_conf_list); | |||
| } | |||
| AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { | |||
| EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { | |||
| ConfigPtrList args_conf_list; | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), | |||
| [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); | |||
| @@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| return tracked_eval; | |||
| } | |||
| AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||
| const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list) { | |||
| EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||
| const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { | |||
| if (evaluators.size() == 1) { | |||
| EvaluatorPtr eval = evaluators[0]; | |||
| MS_EXCEPTION_IF_NULL(eval); | |||
| @@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr | |||
| return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); | |||
| } | |||
| AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||
| const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list) { | |||
| EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||
| const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list) { | |||
| AbstractBasePtrList out_specs; | |||
| if (!multi_poss_.count(evaluators[0])) { | |||
| multi_poss_[evaluators[0]] = evaluators[1]; | |||
| @@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->GetEvaluatedValue(); | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| for (auto eval : evaluators) { | |||
| auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | |||
| @@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||
| eval_trace_.push_back(current_inf); | |||
| MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); | |||
| MS_EXCEPTION_IF_NULL(eval); | |||
| auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_spec); | |||
| MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); | |||
| out_specs.push_back(out_spec); | |||
| MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString(); | |||
| auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(eval_result->abstract()); | |||
| MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); | |||
| out_specs.push_back(eval_result->abstract()); | |||
| eval_trace_.pop_back(); | |||
| if (eval_trace_.empty()) { | |||
| multi_poss_.clear(); | |||
| @@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||
| // Try to travel the latest undetermined. | |||
| if (latest_entry != eval_trace_.rbegin()->first) { | |||
| MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); | |||
| auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_spec); | |||
| MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString(); | |||
| return out_spec; | |||
| auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(eval_result->abstract()); | |||
| MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() | |||
| << " return out_spec: " << eval_result->abstract()->ToString(); | |||
| return eval_result; | |||
| } | |||
| } | |||
| } | |||
| @@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||
| if (out_specs.size() == 1) { | |||
| MS_EXCEPTION_IF_NULL(out_specs[0]); | |||
| // If only one result derived, then broaden it to avoid wrong constant propagation. | |||
| return out_specs[0]->Broaden(); | |||
| return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>()); | |||
| } | |||
| auto joined_spec = AbstractJoin(out_specs); | |||
| MS_EXCEPTION_IF_NULL(joined_spec); | |||
| MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); | |||
| return joined_spec; | |||
| return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>()); | |||
| } | |||
| AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() { | |||
| EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { | |||
| AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>(); | |||
| return engine_.lock()->GetEvaluatedValue(self); | |||
| } | |||
| @@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { | |||
| return a; | |||
| } | |||
| AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { | |||
| EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { | |||
| auto evaluator = GetPrimEvaluator(primitive, nullptr); | |||
| MS_EXCEPTION_IF_NULL(evaluator); | |||
| if (!evaluator->isa<TrivialPrimEvaluator>()) { | |||
| @@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr | |||
| << evaluator->ToString(); | |||
| } | |||
| auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator); | |||
| auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs); | |||
| return res_spec; | |||
| auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); | |||
| return eval_result; | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -40,13 +40,33 @@ | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| // define attribute value map | |||
| using AttrValueMap = std::unordered_map<std::string, ValuePtr>; | |||
| using AttrValueMapPtr = std::shared_ptr<AttrValueMap>; | |||
| // the class to save evaluated result: abstract value and modified attribute | |||
| class EvalResult : public Base { | |||
| public: | |||
| EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} | |||
| ~EvalResult() override = default; | |||
| MS_DECLARE_PARENT(EvalResult, Base); | |||
| AbstractBasePtr abstract() { return abstract_; } | |||
| AttrValueMapPtr attribute() { return attribute_; } | |||
| private: | |||
| AbstractBasePtr abstract_; | |||
| AttrValueMapPtr attribute_; | |||
| }; | |||
| using EvalResultPtr = std::shared_ptr<EvalResult>; | |||
| // Superclass for AnfNodeConfig and VirtualConfig. | |||
| class Config : public Base { | |||
| public: | |||
| Config() = default; | |||
| ~Config() override = default; | |||
| MS_DECLARE_PARENT(Config, Base); | |||
| virtual AbstractBasePtr GetEvaluatedValue() = 0; | |||
| virtual EvalResultPtr GetEvaluatedValue() = 0; | |||
| }; | |||
| // Config will be stored in AnalysisCache | |||
| @@ -74,7 +94,7 @@ class AnfNodeConfig : public Config { | |||
| ~AnfNodeConfig() override = default; | |||
| MS_DECLARE_PARENT(AnfNodeConfig, Config); | |||
| AbstractBasePtr GetEvaluatedValue() override; | |||
| EvalResultPtr GetEvaluatedValue() override; | |||
| AnalysisContextPtr context() const { return context_; } | |||
| @@ -123,7 +143,9 @@ class VirtualConfig : public Config { | |||
| ~VirtualConfig() override = default; | |||
| MS_DECLARE_PARENT(VirtualConfig, Config); | |||
| AbstractBasePtr GetEvaluatedValue() override { return abstract_; } | |||
| EvalResultPtr GetEvaluatedValue() override { | |||
| return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>()); | |||
| } | |||
| private: | |||
| AbstractBasePtr abstract_; | |||
| @@ -135,11 +157,11 @@ class AnalysisCache { | |||
| AnalysisCache() = default; | |||
| ~AnalysisCache() = default; | |||
| void Clear() { cache_.clear(); } | |||
| void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg); | |||
| AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf); | |||
| void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); | |||
| EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); | |||
| private: | |||
| std::unordered_map<AnfNodeConfigPtr, AbstractBasePtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_; | |||
| std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_; | |||
| }; | |||
| using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>; | |||
| @@ -147,7 +169,7 @@ using AnfNodeConfigMap = | |||
| std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; | |||
| struct AnalysisResult { | |||
| AbstractBasePtr inferred; | |||
| EvalResultPtr inferred; | |||
| AnalysisContextPtr context; | |||
| }; | |||
| @@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| // func_graph: The func_graph to analyze. | |||
| // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. | |||
| AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); | |||
| EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); | |||
| // Return the Evaluator for the given function. | |||
| EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); | |||
| AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); | |||
| AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); | |||
| EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); | |||
| // Infer the result of fn(args). | |||
| AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); | |||
| EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); | |||
| void Clear(); | |||
| void ClearEvaluatorCache(); | |||
| AnalysisCache &cache() { return cache_; } | |||
| @@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| // Set the analysis result for orig to the result for new. | |||
| // This sets an entry in anfnode_config_map from orig to new. | |||
| AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { | |||
| EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { | |||
| // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. | |||
| (void)anfnode_config_map_.emplace(orig_conf, new_conf); | |||
| MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() | |||
| @@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | |||
| const ConfigPtrList &args_conf_list); | |||
| AbstractBasePtr Eval(const AnfNodeConfigPtr &conf); | |||
| EvalResultPtr Eval(const AnfNodeConfigPtr &conf); | |||
| EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); | |||
| AbstractBasePtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list); | |||
| AbstractBasePtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||
| const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list); | |||
| EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list); | |||
| EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list); | |||
| #ifdef DEBUG | |||
| std::vector<AnfNodePtr> compute_conf_stack_; | |||
| @@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { | |||
| return FromValueInside(MakeValue(value), broaden); | |||
| } | |||
| AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | |||
| EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI | |||
| args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); | |||
| } | |||
| } | |||
| AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list); | |||
| AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); | |||
| op_exec_info->abstract = infer_res; | |||
| } | |||
| @@ -26,6 +26,8 @@ | |||
| #include <list> | |||
| #include <string> | |||
| #include <fstream> | |||
| #include <queue> | |||
| #include <set> | |||
| #include "ir/visitor.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -223,6 +225,31 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||
| return res; | |||
| } | |||
| // search the cnodes inside this graph only | |||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) { | |||
| std::queue<CNodePtr> todo; | |||
| todo.push(ret); | |||
| std::vector<CNodePtr> sorted_nodes; | |||
| auto seen = NewSeenGeneration(); | |||
| while (!todo.empty()) { | |||
| CNodePtr top = todo.front(); | |||
| todo.pop(); | |||
| sorted_nodes.push_back(top); | |||
| auto inputs = top->inputs(); | |||
| for (auto &item : inputs) { | |||
| if (item->seen_ == seen) { | |||
| continue; | |||
| } | |||
| if (item->isa<CNode>()) { | |||
| todo.push(item->cast<CNodePtr>()); | |||
| } | |||
| item->seen_ = seen; | |||
| } | |||
| } | |||
| return sorted_nodes; | |||
| } | |||
| std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { | |||
| std::vector<AnfNodePtr> vecs; | |||
| if (node == nullptr) { | |||
| @@ -57,6 +57,7 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl | |||
| std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, | |||
| const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret); | |||
| class FuncGraphIndex { | |||
| public: | |||
| explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, | |||
| @@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init ExpandDims""" | |||
| self.__setattr_flag__ = True | |||
| self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) | |||
| def __infer__(self, x, axis): | |||
| @@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer): | |||
| # if primitive need setattr in __infer__ need add this flag | |||
| """init Cast""" | |||
| self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) | |||
| self.__setattr_flag__ = True | |||
| def __infer__(self, x, t): | |||
| src_type = x['dtype'] | |||
| @@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer): | |||
| def __init__(self): | |||
| """init Reshape""" | |||
| self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) | |||
| self.__setattr_flag__ = True | |||
| def __infer__(self, x, shape): | |||
| shape_v = shape['value'] | |||
| @@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init Transpose""" | |||
| self.__setattr_flag__ = True | |||
| self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) | |||
| def __infer__(self, x, perm): | |||
| @@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init index_select""" | |||
| self.__setattr_flag__ = True | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||
| def __infer__(self, params, indices, axis): | |||
| @@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| """init Tile""" | |||
| self.__setattr_flag__ = True | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| def __infer__(self, input_x): | |||
| @@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| """init Pack""" | |||
| self.__setattr_flag__ = True | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| self.axis = axis | |||
| @@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| """init Unpack""" | |||
| self.__setattr_flag__ = True | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| self.axis = axis | |||
| @@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init""" | |||
| self.__setattr_flag__ = True | |||
| def infer_shape(self, cond_shape, x_shape, y_shape): | |||
| if cond_shape != x_shape or x_shape != y_shape: | |||
| @@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, transpose_a=False, transpose_b=False): | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | |||
| self.__setattr_flag__ = True | |||
| cls_name = self.name | |||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | |||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||
| @@ -596,7 +595,6 @@ class BatchMatMul(MatMul): | |||
| @prim_attr_register | |||
| def __init__(self, transpose_a=False, transpose_b=False): | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | |||
| self.__setattr_flag__ = True | |||
| cls_name = self.name | |||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | |||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||
| @@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.__setattr_flag__ = True | |||
| self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | |||
| def infer_shape(self, inputs): | |||
| @@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer): | |||
| """init Conv2D""" | |||
| self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) | |||
| self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) | |||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name) | |||
| self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) | |||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) | |||
| self.add_prim_attr('stride', self.stride) | |||
| self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | |||
| self.add_prim_attr('dilation', self.dilation) | |||
| validator.check_value_type('pad', pad, (int,), self.name) | |||
| @@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer): | |||
| self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] | |||
| self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) | |||
| out_channel = self.out_channel | |||
| out_shape = [x_shape[0], out_channel, h_out, w_out] | |||
| return out_shape | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test nn ops """ | |||
| import functools | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| def test_cast_op_attr(): | |||
| class CastNet(nn.Cell): | |||
| def __init__(self): | |||
| super(CastNet, self).__init__() | |||
| self.cast = P.Cast() | |||
| def construct(self, x, t): | |||
| return self.cast(x, t) | |||
| class CastTypeTest(nn.Cell): | |||
| def __init__(self, net): | |||
| super(CastTypeTest, self).__init__() | |||
| self.net = net | |||
| self.cast = P.Cast() | |||
| def construct(self, x, y, z): | |||
| cast_op = self.cast | |||
| t1 = cast_op(x, mstype.float32) | |||
| t2 = cast_op(y, mstype.int32) | |||
| cast_net = self.net | |||
| t3 = cast_net(x, mstype.float16) | |||
| t4 = cast_net(y, mstype.int32) | |||
| t5 = cast_net(z, mstype.float16) | |||
| return (t1, t2, t3, t4, t5) | |||
| net = CastTypeTest(CastNet()) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.int32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| t3 = Tensor(np.ones([1,16,1,1918]).astype(np.int32)) | |||
| out = net(t1, t2, t3) | |||
| assert out[0].asnumpy().dtype == np.float32 | |||
| assert out[1].asnumpy().dtype == np.int32 | |||
| assert out[2].asnumpy().dtype == np.float16 | |||
| assert out[3].asnumpy().dtype == np.int32 | |||
| assert out[4].asnumpy().dtype == np.float16 | |||
| @@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) { | |||
| AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tensor, slice}; | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred); | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract array failed."; | |||
| } | |||
| @@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) { | |||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract array failed."; | |||
| } | |||
| @@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) { | |||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract array failed."; | |||
| } | |||
| @@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) { | |||
| AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2); | |||
| AbstractBasePtrList args_spec_list = {tensor, start_index}; | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract array failed."; | |||
| } | |||
| @@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) { | |||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract array failed."; | |||
| } | |||
| @@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) { | |||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract array failed."; | |||
| } | |||
| @@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) { | |||
| abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | |||
| AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) { | |||
| abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | |||
| AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) { | |||
| auto tuple = std::make_shared<AbstractTuple>(eles); | |||
| AbstractBasePtrList args_spec_list = {tuple}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) { | |||
| AbstractBasePtr abstract_v2 = FromValue(2, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2}; | |||
| AbstractBasePtr abstract_val = FromValue(10, false); | |||
| cache[args_spec_list] = abstract_val; | |||
| cache[args_spec_list] = std::make_shared<EvalResult>(abstract_val, std::make_shared<AttrValueMap>()); | |||
| auto iter = cache.find(args_spec_list); | |||
| ASSERT_TRUE(iter != cache.end()); | |||
| ASSERT_TRUE(iter->second == abstract_val); | |||
| ASSERT_TRUE(iter->second->abstract() == abstract_val); | |||
| AbstractBasePtr abstract_v1_variant1 = FromValue(1, false); | |||
| AbstractBasePtr abstract_v2_variant1 = FromValue(2, false); | |||
| @@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) { | |||
| iter = cache.find(args_spec_list_variant1); | |||
| ASSERT_TRUE(iter != cache.end()); | |||
| ASSERT_TRUE(iter->second == abstract_val); | |||
| ASSERT_TRUE(iter->second->abstract() == abstract_val); | |||
| AbstractBasePtr abstract_v1_variant2 = FromValue(1, false); | |||
| AbstractBasePtr abstract_v2_variant2 = FromValue(3, false); | |||
| @@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) { | |||
| std::vector<int> shape = {2, 2, 6, 6}; | |||
| expected->set_shape(std::make_shared<Shape>(shape)); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) { | |||
| AbstractBasePtr abstract_x = FromValue(x, false); | |||
| args_spec_list.push_back(abstract_x); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | |||
| } | |||
| @@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) { | |||
| AbstractBasePtr abstract_x = FromValue(x, false); | |||
| args_spec_list.push_back(abstract_x); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | |||
| } | |||
| @@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) { | |||
| auto prim_typeof = std::make_shared<Primitive>("typeof"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| res->dump(); | |||
| TypePtr res_value = res->GetValueTrack()->cast<TypePtr>(); | |||
| res_value->dump(); | |||
| @@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) { | |||
| auto prim_list_map = std::make_shared<Primitive>("list_map"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)})); | |||
| res->dump(); | |||
| MS_LOG(INFO) << "result res: " << res->ToString(); | |||
| @@ -188,7 +188,7 @@ TEST_F(TestPrim, test_list_reduce) { | |||
| auto prim_list_reduce = std::make_shared<Primitive>("list_reduce"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| res->dump(); | |||
| TypePtr res_type = res->GetTypeTrack(); | |||
| res_type->dump(); | |||
| @@ -205,7 +205,7 @@ TEST_F(TestPrim, test_scalar_to_array) { | |||
| auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| res->dump(); | |||
| TypePtr res_type = res->BuildType(); | |||
| res_type->dump(); | |||
| @@ -223,7 +223,7 @@ TEST_F(TestPrim, test_array_to_scalar) { | |||
| auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| res->dump(); | |||
| TypePtr res_type = res->BuildType(); | |||
| res_type->dump(); | |||
| @@ -239,7 +239,7 @@ TEST_F(TestPrim, test_J_1) { | |||
| auto prim_J = std::make_shared<Primitive>("J"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res); | |||
| ASSERT_TRUE(res_J != nullptr); | |||
| ASSERT_TRUE(*(res_J->element()) == *abstract_v1); | |||
| @@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) { | |||
| int v1 = 1; | |||
| AbstractBasePtr abstract_v1 = FromValue(v1, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_v1}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| res->dump(); | |||
| AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res); | |||
| ASSERT_TRUE(res_J != nullptr); | |||
| @@ -302,7 +302,7 @@ TEST_F(TestPrim, test_dot) { | |||
| AbstractBasePtrList args_spec_list = {a1, a2}; | |||
| AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred); | |||
| AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||
| ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack()))); | |||
| } | |||
| @@ -317,7 +317,7 @@ TEST_F(TestPrim, test_switch1) { | |||
| AbstractBasePtr arg2 = FromValue(2, false); | |||
| AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *arg1); | |||
| } | |||
| @@ -330,7 +330,7 @@ TEST_F(TestPrim, test_switch2) { | |||
| AbstractBasePtr arg2 = FromValue(2, false); | |||
| AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "make result res: " << res->ToString(); | |||
| MS_LOG(INFO) << "make result arg2: " << arg2->ToString(); | |||
| ASSERT_TRUE(*res == *arg2); | |||
| @@ -343,7 +343,7 @@ TEST_F(TestPrim, test_identity) { | |||
| AbstractBasePtr abstract_v1 = FromValue(1, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_v1}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *abstract_v1); | |||
| } | |||
| @@ -357,7 +357,7 @@ TEST_F(TestPrim, test_broadcast_shape) { | |||
| AbstractBasePtrList args_spec_list = {a, b}; | |||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred); | |||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||
| auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | |||
| std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)}; | |||
| @@ -377,7 +377,7 @@ TEST_F(TestPrim, test_partial) { | |||
| AbstractBasePtr abstract_v2 = FromValue(1, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2}; | |||
| auto expected = std::make_shared<PartialAbstractClosure>( | |||
| std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list); | |||
| @@ -392,7 +392,7 @@ TEST_F(TestPrim, test_env_setitem) { | |||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | |||
| AbstractBasePtr abstract_x = FromValue(1, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_x}; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | |||
| @@ -400,7 +400,7 @@ TEST_F(TestPrim, test_env_setitem) { | |||
| AbstractBasePtr abstract_y = FromValue(2, false); | |||
| args_spec_list = {abstract_env, embed_x, abstract_y}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | |||
| ASSERT_TRUE(*res == *exp); | |||
| } | |||
| @@ -412,7 +412,7 @@ TEST_F(TestPrim, test_env_getitem) { | |||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | |||
| AbstractBasePtr abstract_x = FromValue(1, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_x}; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||
| FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | |||
| @@ -420,7 +420,7 @@ TEST_F(TestPrim, test_env_getitem) { | |||
| AbstractBasePtr abstract_y = FromValue(2, false); | |||
| args_spec_list = {abstract_env, embed_x, abstract_y}; | |||
| AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | |||
| ASSERT_TRUE(*res == *exp); | |||
| @@ -429,7 +429,7 @@ TEST_F(TestPrim, test_env_getitem) { | |||
| AbstractBasePtr abstract_z = FromValue(3, false); | |||
| args_spec_list = {res, embed_x, abstract_z}; | |||
| res = engine_->Run(graph_getitem, args_spec_list).inferred; | |||
| res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *abstract_x); | |||
| } | |||
| @@ -442,7 +442,7 @@ TEST_F(TestPrim, test_env_add) { | |||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | |||
| AbstractBasePtr abstract_x = FromValue(1, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_x}; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||
| FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); | |||
| @@ -450,19 +450,19 @@ TEST_F(TestPrim, test_env_add) { | |||
| AbstractBasePtr abstract_y = FromValue(2, false); | |||
| args_spec_list = {abstract_env, embed_x, abstract_y}; | |||
| AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred; | |||
| AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | |||
| ASSERT_TRUE(*abstract_e1 == *exp); | |||
| AbstractBasePtr abstract_z = FromValue(3, false); | |||
| args_spec_list = {abstract_env, embed_x, abstract_z}; | |||
| AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred; | |||
| AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*abstract_e2 == *exp); | |||
| FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2); | |||
| args_spec_list = {abstract_e1, abstract_e2}; | |||
| AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *exp); | |||
| } | |||
| @@ -475,7 +475,7 @@ TEST_F(TestPrim, test_shape) { | |||
| AbstractBasePtrList args_spec_list = {a}; | |||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred); | |||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||
| auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | |||
| std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)}; | |||
| @@ -493,7 +493,7 @@ TEST_F(TestPrim, test_relu) { | |||
| AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW | |||
| AbstractBasePtrList args_spec_list = {expected}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -507,7 +507,7 @@ TEST_F(TestPrim, test_relu2) { | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5}); | |||
| AbstractBasePtrList args_spec_list = {arr}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack())); | |||
| } | |||
| @@ -540,7 +540,7 @@ TEST_F(TestPrim, test_conv2d1) { | |||
| std::vector<int> shape = {2, 64, 14, 14}; | |||
| expected->set_shape(std::make_shared<Shape>(shape)); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -558,7 +558,7 @@ TEST_F(TestPrim, test_conv2d) { | |||
| auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3}); | |||
| AbstractBasePtrList args_spec_list = {input, weight}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16}); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| @@ -574,7 +574,7 @@ TEST_F(TestPrim, test_conv2d_native) { | |||
| auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3}); | |||
| AbstractBasePtrList args_spec_list = {input, weight}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16}); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| @@ -590,7 +590,7 @@ TEST_F(TestPrim, test_biasAdd) { | |||
| auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32}); | |||
| AbstractBasePtrList args_spec_list = {value, bias}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32}); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| @@ -606,7 +606,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) { | |||
| auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10}); | |||
| AbstractBasePtrList args_spec_list = {logits, labels}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_NE(ret, nullptr); | |||
| auto res = dyn_cast<AbstractTuple>(ret); | |||
| auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64}); | |||
| @@ -636,7 +636,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) { | |||
| auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10}); | |||
| AbstractBasePtrList args_spec_list = {logits, labels}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto res = dyn_cast<AbstractScalar>(ret); | |||
| AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64); | |||
| expected->set_type(UTPrimUtils::kF64); | |||
| @@ -690,7 +690,7 @@ TEST_F(TestPrim, test_fused_batch_norm) { | |||
| AbstractBasePtr expected0 = abstract_inputs->Clone(); | |||
| AbstractBasePtr expected1 = abstract_scale->Clone(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected0: " << expected0->ToString(); | |||
| MS_LOG(INFO) << "expected1: " << expected1->ToString(); | |||
| @@ -722,7 +722,7 @@ TEST_F(TestPrim, test_pooling) { | |||
| inputs->set_shape(inputs_dims); | |||
| AbstractBasePtr abstract_input = FromValue(inputs, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_input}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expected = abstract_input->Clone()->Broaden(); | |||
| std::vector<int> expected_dims = {8, 64, 2, 2}; | |||
| @@ -747,7 +747,7 @@ TEST_F(TestPrim, test_hastype) { | |||
| auto prim = std::make_shared<Primitive>("hastype"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -761,7 +761,7 @@ TEST_F(TestPrim, test_array_len) { | |||
| auto prim = std::make_shared<Primitive>("array_len"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -775,7 +775,7 @@ TEST_F(TestPrim, test_list_len) { | |||
| auto prim = std::make_shared<Primitive>("list_len"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -789,7 +789,7 @@ TEST_F(TestPrim, test_tuple_len) { | |||
| auto prim = std::make_shared<Primitive>("tuple_len"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -803,7 +803,7 @@ TEST_F(TestPrim, test_tuple_reversed) { | |||
| auto prim = std::make_shared<Primitive>("tuple_reversed"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "expect=" << expected->ToString(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -825,7 +825,7 @@ TEST_F(TestPrim, test_list_getitem) { | |||
| auto prim = std::make_shared<Primitive>("list_getitem"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *elem); | |||
| } | |||
| @@ -844,7 +844,7 @@ TEST_F(TestPrim, test_list_setitem) { | |||
| auto prim = std::make_shared<Primitive>("list_setitem"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| AbstractBasePtrList elems_exp = {elem1, elem2}; | |||
| auto expected = std::make_shared<AbstractList>(elems_exp); | |||
| @@ -866,7 +866,7 @@ TEST_F(TestPrim, test_list_append) { | |||
| auto prim = std::make_shared<Primitive>("list_append"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2})); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -890,7 +890,7 @@ TEST_F(TestPrim, test_tuple_setitem) { | |||
| auto prim = std::make_shared<Primitive>("tuple_setitem"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| AbstractBasePtrList elems_exp = {elem1, elem2}; | |||
| auto expected = std::make_shared<AbstractTuple>(elems_exp); | |||
| @@ -916,7 +916,7 @@ TEST_F(TestPrim, test_make_list) { | |||
| auto prim = std::make_shared<Primitive>("make_list"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -939,7 +939,7 @@ TEST_F(TestPrim, test_make_range) { | |||
| AbstractBasePtrList elem_list({ele1, ele2, ele3}); | |||
| AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "res=" << res->ToString(); | |||
| MS_LOG(INFO) << "expected=" << expected->ToString(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| @@ -982,7 +982,7 @@ TEST_F(TestPrim, test_layernorm) { | |||
| AbstractBasePtr expected1 = abstract_mean_var->Clone(); | |||
| AbstractBasePtr expected2 = abstract_mean_var->Clone(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected0: " << expected0->ToString(); | |||
| MS_LOG(INFO) << "expected1: " << expected1->ToString(); | |||
| @@ -1028,7 +1028,7 @@ TEST_F(TestPrim, test_DropoutGenMask) { | |||
| AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | |||
| std::make_shared<Shape>(std::vector<int>{79})); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "res=" << res->ToString(); | |||
| MS_LOG(INFO) << "expected=" << expected->ToString(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| @@ -1058,7 +1058,7 @@ TEST_F(TestPrim, test_dropout) { | |||
| std::vector<int> shape = {2, 20, 32, 32}; | |||
| expected->set_shape(std::make_shared<Shape>(shape)); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -1079,7 +1079,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) { | |||
| auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | |||
| auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | |||
| AbstractBasePtrList args_spec_list = {x_input, y_input}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto res = dyn_cast<AbstractTuple>(ret); | |||
| AbstractBasePtrList x_idx_list; | |||
| auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | |||
| @@ -1103,7 +1103,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) { | |||
| auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | |||
| auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | |||
| AbstractBasePtrList args_spec_list = {x_input, y_input}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| auto res = dyn_cast<AbstractTuple>(ret); | |||
| AbstractBasePtrList x_idx_list({abstract::FromValue(1)}); | |||
| auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | |||
| @@ -1128,7 +1128,7 @@ TEST_F(TestPrim, test_DictGetItem) { | |||
| AbstractBasePtr key = abstract::FromValue("x"); | |||
| AbstractBasePtrList args_spec_list = {array_dict, key}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | |||
| AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second)); | |||
| @@ -1147,7 +1147,7 @@ TEST_F(TestPrim, test_DictGetItem2) { | |||
| AbstractBasePtr key = abstract::FromValue("x"); | |||
| AbstractBasePtrList args_spec_list = {array_dict, key}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | |||
| AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x); | |||
| @@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) { | |||
| auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| } | |||
| @@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) { | |||
| MS_LOG(INFO) << "" << graph_f_->get_return()->ToString(); | |||
| AbstractBasePtr abstract_v1 = FromValue(1, false); | |||
| args_spec_list.push_back(abstract_v1); | |||
| AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| // now this test case failed randomly, have to debug. | |||
| @@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) { | |||
| args_spec_list.clear(); | |||
| args_spec_list.push_back(abstract_v1); | |||
| args_spec_list.push_back(abstract_v2); | |||
| abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred; | |||
| abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| } | |||
| @@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) { | |||
| AbstractBasePtr abstract_v2 = FromValue(v1, false); | |||
| args_spec_list.push_back(abstract_v1); | |||
| args_spec_list.push_back(abstract_v2); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| } | |||
| @@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { | |||
| auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred; | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32); | |||
| } | |||
| @@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) { | |||
| AbstractBasePtr base1 = FromValue(x1, false); | |||
| AbstractBasePtr base2 = FromValue(x2, false); | |||
| AbstractBasePtrList base_list = {base1, base2}; | |||
| auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list); | |||
| auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract(); | |||
| MS_LOG(INFO) << "result spec: " << res->ToString(); | |||
| AbstractBasePtr exp = FromValue(x3, false); | |||
| MS_LOG(INFO) << "result exp: " << exp->ToString(); | |||
| @@ -446,7 +446,7 @@ void TestGraphEval::TearDown() { | |||
| TEST_F(TestGraphInfer, test_graph_infer_defaults) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(50), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(1), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(9), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(48), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_kwarg) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(7), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(46), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(57), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | |||
| from ....mindspore_test_framework.pipeline.forward.verify_exception \ | |||
| import pipeline_for_verify_exception_for_case_by_case_config | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| def conv3x3(in_channels, out_channels, stride=1, padding=1): | |||
| """3x3 convolution """ | |||
| @@ -377,6 +378,21 @@ class StateNet(nn.Cell): | |||
| return x | |||
| def test_conv2d_same_primitive(): | |||
| class Conv2DSameNet(nn.Cell): | |||
| def __init__(self): | |||
| super(Conv2DSameNet, self).__init__() | |||
| self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||
| def construct(self, x, y): | |||
| r1 = self.conv1(x) | |||
| r2 = self.conv2(y) | |||
| return (r1, r2) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| net = Conv2DSameNet() | |||
| out = net(t1, t2) | |||
| class ComparisonNet(nn.Cell): | |||
| def __init__(self): | |||
| """ ComparisonNet definition """ | |||
| @@ -0,0 +1,276 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test nn ops """ | |||
| import functools | |||
| import numpy as np | |||
| import mindspore | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | |||
| from mindspore.ops.primitive import constexpr | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | |||
| from ....mindspore_test_framework.pipeline.forward.verify_exception \ | |||
| import pipeline_for_verify_exception_for_case_by_case_config | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| class FakeOp(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """""" | |||
| def infer_shape(self, x, y): | |||
| self.second_shape = y | |||
| self.add_prim_attr("second_shape", y) | |||
| return x | |||
| def infer_dtype(self, x, y): | |||
| return x | |||
| # test the normal case that should generate independent primitive because of different | |||
| # generated attributes after inference | |||
| def test_conv2d_same_primitive(): | |||
| class Conv2DSameNet(nn.Cell): | |||
| def __init__(self): | |||
| super(Conv2DSameNet, self).__init__() | |||
| self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) | |||
| def construct(self, x, y): | |||
| r1 = self.conv1(x) | |||
| r2 = self.conv2(y) | |||
| return (r1, r2) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| net = Conv2DSameNet() | |||
| out = net(t1, t2) | |||
| # test cell as high order argument | |||
| # The graph with free variables used as argument is not supported yet | |||
| # because of the limit of inference specialize system | |||
| def Xtest_conv2d_op_with_arg(): | |||
| class Conv2dNet(nn.Cell): | |||
| def __init__(self): | |||
| super(Conv2dNet, self).__init__() | |||
| def construct(self, op, x): | |||
| return op(x) | |||
| class OpsNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(OpsNet, self).__init__() | |||
| self.opnet = net | |||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | |||
| def construct(self, x, y): | |||
| conv_op = self.conv2 | |||
| a = self.opnet(conv_op, x) | |||
| b = self.opnet(conv_op, y) | |||
| return (a, b) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| net = OpsNet(Conv2dNet()) | |||
| out = net(t1, t2) | |||
| def test_conv2d_op_with_arg(): | |||
| class FackOpNet(nn.Cell): | |||
| def __init__(self): | |||
| super(FackOpNet, self).__init__() | |||
| self.op = FakeOp() | |||
| def construct(self, x, y): | |||
| return self.op(x, y) | |||
| class OpNet(nn.Cell): | |||
| def __init__(self): | |||
| super(OpNet, self).__init__() | |||
| def construct(self, op, x, y): | |||
| return op(x, y) | |||
| class OpsNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(OpsNet, self).__init__() | |||
| self.opnet = net | |||
| self.op = FackOpNet() | |||
| def construct(self, x, y): | |||
| op = self.op | |||
| a = self.opnet(op, x, y) | |||
| b = self.opnet(op, y, x) | |||
| return (a, b) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| net = OpsNet(OpNet()) | |||
| out = net(t1, t2) | |||
| def test_conv2d_op_with_arg_same_input(): | |||
| class FackOpNet(nn.Cell): | |||
| def __init__(self): | |||
| super(FackOpNet, self).__init__() | |||
| self.op = FakeOp() | |||
| def construct(self, x, y): | |||
| return self.op(x, y) | |||
| class OpNet(nn.Cell): | |||
| def __init__(self): | |||
| super(OpNet, self).__init__() | |||
| def construct(self, op, x, y): | |||
| return op(x, y) | |||
| class OpsNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(OpsNet, self).__init__() | |||
| self.opnet = net | |||
| self.op = FackOpNet() | |||
| def construct(self, x, y): | |||
| op = self.op | |||
| a = self.opnet(op, x, x) | |||
| b = self.opnet(op, y, x) | |||
| return (a, b) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| net = OpsNet(OpNet()) | |||
| out = net(t1, t2) | |||
| # test op with partial | |||
| def test_op_as_partial(): | |||
| class OpAsPartial(nn.Cell): | |||
| def __init__(self): | |||
| super(OpAsPartial, self).__init__() | |||
| self.op = FakeOp() | |||
| def construct(self, x, y, z): | |||
| partial_op = F.partial(self.op, x) | |||
| a = partial_op(y) | |||
| b = partial_op(z) | |||
| return a, b | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||
| net = OpAsPartial() | |||
| out = net(t1, t2, t3) | |||
| # test op with partial | |||
| def test_op_as_partial_inside(): | |||
| class OpAsPartial(nn.Cell): | |||
| def __init__(self): | |||
| super(OpAsPartial, self).__init__() | |||
| self.op = FakeOp() | |||
| def construct(self, x, y, z): | |||
| partial_op = F.partial(self.op, x) | |||
| a = partial_op(y) | |||
| b = partial_op(z) | |||
| return a, b | |||
| class OuterNet(nn.Cell): | |||
| def __init__(self): | |||
| super(OuterNet, self).__init__() | |||
| self.net = OpAsPartial() | |||
| def construct(self, x, y, z): | |||
| a,b = self.net(x, y, z) | |||
| return a, b | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||
| net = OuterNet() | |||
| out = net(t1, t2, t3) | |||
| # test op with partial case 2 | |||
| def test_op_as_partial_independent(): | |||
| class OpAsPartial(nn.Cell): | |||
| def __init__(self): | |||
| super(OpAsPartial, self).__init__() | |||
| self.op = FakeOp() | |||
| def construct(self, x, y, z): | |||
| partial_op1 = F.partial(self.op, x) | |||
| a = partial_op1(y) | |||
| partial_op2 = F.partial(self.op, x) | |||
| b = partial_op2(z) | |||
| return a, b | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||
| net = OpAsPartial() | |||
| out = net(t1, t2, t3) | |||
| def test_nest_partial(): | |||
| class NestPartial(nn.Cell): | |||
| def __init__(self): | |||
| super(NestPartial, self).__init__() | |||
| self.op = FakeOp() | |||
| def construct(self, x, y, z): | |||
| partial_op1 = F.partial(self.op) | |||
| partial_op2 = F.partial(partial_op1, x) | |||
| a = partial_op2(y) | |||
| partial_op3 = F.partial(self.op) | |||
| partial_op4 = F.partial(partial_op3, x) | |||
| b = partial_op4(z) | |||
| return a, b | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||
| net = NestPartial() | |||
| out = net(t1, t2, t3) | |||
| # high order argument | |||
| # op and op args as network arguments | |||
| def test_op_with_arg_as_input(): | |||
| class WithOpArgNet(nn.Cell): | |||
| def __init__(self): | |||
| super(WithOpArgNet, self).__init__() | |||
| def construct(self, op, x, y): | |||
| return op(x, y) | |||
| class OpsNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(OpsNet, self).__init__() | |||
| self.opnet = net | |||
| self.op = FakeOp() | |||
| def construct(self, x, y, z): | |||
| op = self.op | |||
| a = self.opnet(op, x, z) | |||
| b = self.opnet(op, x, y) | |||
| return (a, b) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||
| net = OpsNet(WithOpArgNet()) | |||
| out = net(t1, t2, t3) | |||
| # The partial application used as argument is not supported yet | |||
| # because of the limit of inference specialize system | |||
| def Xtest_partial_as_arg(): | |||
| class PartialArgNet(nn.Cell): | |||
| def __init__(self): | |||
| super(PartialArgNet, self).__init__() | |||
| def construct(self, partial_op, y): | |||
| return partial_op(y) | |||
| class OpsNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(OpsNet, self).__init__() | |||
| self.partial_net = net | |||
| self.op = FakeOp() | |||
| def construct(self, x, y, z): | |||
| partial_op = F.partial(self.op, x) | |||
| a = self.partial_net(partial_op, z) | |||
| b = self.partial_net(partial_op, y) | |||
| return (a, b) | |||
| t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) | |||
| t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) | |||
| t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) | |||
| net = OpsNet(PartialArgNet()) | |||
| out = net(t1, t2, t3) | |||