Merge pull request !28715 from 张清华/eliminate_tuple_unused_item2tags/v1.6.0
| @@ -60,16 +60,46 @@ void PrintKernelFormatAndType(std::ostringstream &buffer, const std::string &fmt | |||
| buffer << ">"; | |||
| } | |||
| void PrintTupleNodeUsedFlags(std::ostringstream &buffer, const abstract::AbstractSequencePtr &sequence_abs) { | |||
| if (sequence_abs == nullptr || sequence_abs->sequence_nodes().empty()) { | |||
| return; | |||
| } | |||
| buffer << ", sequence_nodes={"; | |||
| for (size_t i = 0; i < sequence_abs->sequence_nodes().size(); ++i) { | |||
| auto node = sequence_abs->sequence_nodes()[i].lock(); | |||
| if (node == nullptr) { | |||
| MS_LOG(DEBUG) << "The node in sequence_nodes is free."; | |||
| buffer << "node={<freed node>}"; | |||
| } else { | |||
| buffer << "node={" << node->DebugString(); | |||
| auto flags = GetSequenceNodeElementsUseFlags(node); | |||
| if (flags != nullptr) { | |||
| buffer << ", elements_use_flags=" << (*flags) << "}"; | |||
| } | |||
| } | |||
| if (i != sequence_abs->sequence_nodes().size() - 1) { | |||
| buffer << ", "; | |||
| } | |||
| } | |||
| buffer << "}"; | |||
| } | |||
| void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| ValuePtr tensor_value = nullptr; | |||
| abstract::AbstractSequencePtr sequence_abs = nullptr; | |||
| auto abstract = node->abstract(); | |||
| if (abstract != nullptr && abstract->isa<abstract::AbstractTensor>()) { | |||
| tensor_value = abstract->BuildValue(); | |||
| if (abstract != nullptr) { | |||
| if (abstract->isa<abstract::AbstractTensor>()) { | |||
| tensor_value = abstract->BuildValue(); | |||
| } | |||
| sequence_abs = dyn_cast<abstract::AbstractSequence>(abstract); | |||
| } | |||
| abstract::ShapePtr shape = dyn_cast<abstract::Shape>(node->Shape()); | |||
| TypePtr type = dyn_cast<Type>(node->Type()); | |||
| if ((shape != nullptr) && (type != nullptr)) { | |||
| @@ -77,12 +107,14 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) { | |||
| if (tensor_value != nullptr && tensor_value != kAnyValue) { | |||
| buffer << ", value=..."; | |||
| } | |||
| PrintTupleNodeUsedFlags(buffer, sequence_abs); | |||
| buffer << ">"; | |||
| } else if (type != nullptr) { | |||
| buffer << "<" << type; | |||
| if (tensor_value != nullptr && tensor_value != kAnyValue) { | |||
| buffer << ", value=..."; | |||
| } | |||
| PrintTupleNodeUsedFlags(buffer, sequence_abs); | |||
| buffer << ">"; | |||
| } else { | |||
| buffer << "<null>"; | |||
| @@ -244,7 +244,14 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap | |||
| inputs.emplace_back(call_node); | |||
| } | |||
| } | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| if (inputs.size() > 1) { | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| // Empty tuple. | |||
| auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList()); | |||
| auto empty_tuple = NewValueNode(empty_tuple_value); | |||
| return empty_tuple; | |||
| } | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, | |||
| @@ -452,86 +459,119 @@ bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bo | |||
| CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>()))); | |||
| } | |||
| namespace { | |||
| void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue, | |||
| const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) { | |||
| if (pos == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!"; | |||
| } | |||
| AnfNodePtr tuple_parameter = res->add_parameter(); | |||
| std::vector<AnfNodePtr> pos_elements; | |||
| PrimitivePtr pos_op = nullptr; | |||
| if (pos->isa<AbstractTuple>()) { | |||
| pos_elements.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| pos_op = prim::kPrimTupleGetItem; | |||
| } else { | |||
| pos_elements.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| pos_op = prim::kPrimListGetItem; | |||
| } | |||
| AnfNodePtr pos_value = nullptr; | |||
| AnfNodePtr pos_value_adjust = nullptr; | |||
| auto pos_parameter = res->add_parameter(); | |||
| if (pos->size() == 1) { | |||
| pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(0))}); | |||
| pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))}); | |||
| if (CheckTailGradFristSequence(sequeue, enable_tuple_grad)) { | |||
| res->set_output(res->NewCNode({NewValueNode(pos_op), tuple_parameter, pos_value_adjust})); | |||
| } else { | |||
| res->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < pos->size(); ++i) { | |||
| pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(i))}); | |||
| pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))}); | |||
| pos_elements.push_back(res->NewCNodeInOrder({NewValueNode(pos_op), tuple_parameter, pos_value_adjust})); | |||
| } | |||
| if (pos_elements.size() > 1) { | |||
| res->set_output(res->NewCNodeInOrder(pos_elements)); | |||
| } else if (pos->isa<AbstractTuple>()) { // Empty tuple. | |||
| auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList()); | |||
| auto empty_tuple = NewValueNode(empty_tuple_value); | |||
| res->set_output(empty_tuple); | |||
| } else { // Empty list. | |||
| auto empty_list_value = std::make_shared<ValueList>(ValuePtrList()); | |||
| auto empty_list = NewValueNode(empty_list_value); | |||
| res->set_output(empty_list); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr &sequeue, | |||
| const abstract::AbstractSequencePtr &pos) const { | |||
| MS_EXCEPTION_IF_NULL(sequeue); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->debug_info()->set_name("tail"); | |||
| AnfNodePtr ptrTup = ret->add_parameter(); | |||
| std::vector<AnfNodePtr> elems; | |||
| PrimitivePtr op = nullptr; | |||
| if (sequeue->isa<AbstractTuple>()) { | |||
| elems.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| op = prim::kPrimTupleGetItem; | |||
| } else { | |||
| elems.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| op = prim::kPrimListGetItem; | |||
| } | |||
| FuncGraphPtr res = std::make_shared<FuncGraph>(); | |||
| res->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| res->debug_info()->set_name("tail"); | |||
| if (tail_type_ == kGradFirst) { | |||
| AnfNodePtr tuple_parameter = res->add_parameter(); | |||
| PrimitivePtr getitem_op = nullptr; | |||
| if (sequeue->isa<AbstractTuple>()) { | |||
| getitem_op = prim::kPrimTupleGetItem; | |||
| } else { | |||
| getitem_op = prim::kPrimListGetItem; | |||
| } | |||
| if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) { | |||
| ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); | |||
| res->set_output(res->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))})); | |||
| } else { | |||
| ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); | |||
| res->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList()))); | |||
| } | |||
| return ret; | |||
| return res; | |||
| } | |||
| if (tail_type_ == kGradByPosition) { | |||
| if (pos == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!"; | |||
| } | |||
| std::vector<AnfNodePtr> pos_elems; | |||
| PrimitivePtr pos_op = nullptr; | |||
| if (pos->isa<AbstractTuple>()) { | |||
| pos_elems.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| pos_op = prim::kPrimTupleGetItem; | |||
| } else { | |||
| pos_elems.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| pos_op = prim::kPrimListGetItem; | |||
| } | |||
| AnfNodePtr pos_value = nullptr; | |||
| AnfNodePtr pos_value_adjust = nullptr; | |||
| auto ptrpos = ret->add_parameter(); | |||
| if (pos->size() == 1) { | |||
| pos_value = ret->NewCNode({NewValueNode(pos_op), ptrpos, NewValueNode(SizeToLong(0))}); | |||
| pos_value_adjust = ret->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))}); | |||
| if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) { | |||
| ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, pos_value_adjust})); | |||
| } else { | |||
| ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); | |||
| } | |||
| return ret; | |||
| } else { | |||
| for (size_t i = 0; i < pos->size(); ++i) { | |||
| pos_value = ret->NewCNode({NewValueNode(pos_op), ptrpos, NewValueNode(SizeToLong(i))}); | |||
| pos_value_adjust = ret->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))}); | |||
| pos_elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, pos_value_adjust})); | |||
| } | |||
| } | |||
| ret->set_output(ret->NewCNodeInOrder(pos_elems)); | |||
| return ret; | |||
| GenerateSequenceFuncGraphByPosition(res, sequeue, pos, enable_tuple_grad_); | |||
| return res; | |||
| } | |||
| AnfNodePtr tuple_parameter = res->add_parameter(); | |||
| std::vector<AnfNodePtr> elements; | |||
| PrimitivePtr op = nullptr; | |||
| if (sequeue->isa<AbstractTuple>()) { | |||
| elements.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| op = prim::kPrimTupleGetItem; | |||
| } else { | |||
| elements.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| op = prim::kPrimListGetItem; | |||
| } | |||
| for (size_t i = 1; i < sequeue->size(); ++i) { | |||
| if (tail_type_ == kGradAll) { | |||
| MS_EXCEPTION_IF_NULL((*sequeue)[i]); | |||
| if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || | |||
| (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr && | |||
| (*sequeue)[i]->BuildType()->isa<Number>())) { | |||
| elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | |||
| elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))})); | |||
| } | |||
| } else { | |||
| elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); | |||
| elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))})); | |||
| } | |||
| } | |||
| ret->set_output(ret->NewCNodeInOrder(elems)); | |||
| return ret; | |||
| if (elements.size() > 1) { | |||
| res->set_output(res->NewCNodeInOrder(elements)); | |||
| return res; | |||
| } else if (sequeue->isa<AbstractTuple>()) { // Empty tuple. | |||
| auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList()); | |||
| auto empty_tuple = NewValueNode(empty_tuple_value); | |||
| res->set_output(empty_tuple); | |||
| return res; | |||
| } else { // Empty list. | |||
| auto empty_list_value = std::make_shared<ValueList>(ValuePtrList()); | |||
| auto empty_list = NewValueNode(empty_list_value); | |||
| res->set_output(empty_list); | |||
| return res; | |||
| } | |||
| } | |||
| FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| @@ -25,6 +25,12 @@ class EliminateDeadNodePass { | |||
| EliminateDeadNodePass() = default; | |||
| ~EliminateDeadNodePass() = default; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (enable_eliminate_unused_element) { | |||
| return false; | |||
| } | |||
| static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; | |||
| MS_LOG(INFO) << "Closure enable:" << enable_closure; | |||
| if (!enable_closure) { | |||
| @@ -257,20 +257,20 @@ using CompileGraphs = compile::CompileGraphs; | |||
| using abstract::AnalysisResult; | |||
| using mindspore::abstract::AnalysisContextPtr; | |||
| abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, | |||
| abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const FuncGraphPtr &func_graph, | |||
| const abstract::AbstractBasePtrList &args_spec, bool clear) { | |||
| MS_LOG(DEBUG) << "AbstractAnalyze start"; | |||
| auto engine = res->engine(); | |||
| auto engine = resource->engine(); | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| if (clear) { | |||
| auto manager = res->manager(); | |||
| auto manager = resource->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| engine->Clear(); | |||
| for (auto &node : manager->all_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // Handle previous inferred value for CNode if is loaded from MindIR | |||
| if (res->is_load()) { | |||
| if (resource->is_load()) { | |||
| // If the primitive is not defined in front end,keep the inferred value loaded from MindIR. | |||
| auto primitive = GetCNodePrimitive(node); | |||
| if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) { | |||
| @@ -287,19 +287,19 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph | |||
| } | |||
| } | |||
| } | |||
| auto ret = engine->Run(func_graph, args_spec); | |||
| auto res = engine->Run(func_graph, args_spec); | |||
| MS_LOG(INFO) << "function call max depth: " << abstract::FunctionCallMaxDepth() | |||
| << ", simulate call max depth: " << abstract::StackFrameMaxDepth(); | |||
| MS_LOG(DEBUG) << "AbstractAnalyze end"; | |||
| return ret; | |||
| return res; | |||
| } | |||
| FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, | |||
| const abstract::AnalysisContextPtr &context) { | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| MS_LOG(DEBUG) << "ProgramSpecialize start"; | |||
| abstract::ProgramSpecializer spc(res->engine()); | |||
| FuncGraphPtr result = spc.Run(func_graph, context); | |||
| abstract::ProgramSpecializer specializer(res->engine()); | |||
| FuncGraphPtr result = specializer.Run(func_graph, context); | |||
| auto manager = res->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->KeepRoots({result}); | |||
| @@ -239,30 +239,30 @@ AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &c | |||
| return async_eval_result->GetResult(); | |||
| } | |||
| void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg, | |||
| void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr ¤t_abs, | |||
| AnalysisConfigAsyncResultCache *cache) { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| MS_EXCEPTION_IF_NULL(cache); | |||
| if (arg == nullptr) { | |||
| if (current_abs == nullptr) { | |||
| MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr"; | |||
| } | |||
| std::lock_guard<std::mutex> lock(lock_); | |||
| AsyncAbstractPtr async_eval_result = cache->get(conf); | |||
| if (async_eval_result == nullptr) { | |||
| async_eval_result = std::make_shared<AsyncAbstract>(); | |||
| async_eval_result->set_result(arg); | |||
| async_eval_result->set_result(current_abs); | |||
| cache->set(conf, async_eval_result); | |||
| } else { | |||
| auto ab1 = async_eval_result->TryGetResult(); | |||
| AbstractBasePtrList absList; | |||
| if (ab1 != nullptr) { | |||
| absList.push_back(arg); | |||
| absList.push_back(ab1); | |||
| auto previous_abs = async_eval_result->TryGetResult(); | |||
| AbstractBasePtrList abstract_list; | |||
| if (previous_abs != nullptr) { | |||
| abstract_list.push_back(previous_abs); | |||
| abstract_list.push_back(current_abs); | |||
| // Join two branches's result | |||
| auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node()); | |||
| auto joined_result = AnalysisEngine::ProcessEvalResults(abstract_list, conf->node()); | |||
| async_eval_result->set_result(joined_result->abstract()); | |||
| } else { | |||
| async_eval_result->set_result(arg); | |||
| async_eval_result->set_result(current_abs); | |||
| } | |||
| } | |||
| } | |||
| @@ -240,7 +240,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| const auto &node = parameters[i]; | |||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg); | |||
| engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr)); | |||
| MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString(); | |||
| MS_LOG(DEBUG) << GetInferThread() << "Set parameter[" << i << "] for " << fg->ToString() | |||
| << ", conf: " << conf->ToString() << ", arg: " << arg->ToString(); | |||
| } | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString() | |||
| << ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString() | |||
| @@ -416,8 +417,8 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt | |||
| EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| if (args_conf_list.empty()) { | |||
| MS_LOG(EXCEPTION) << "Size should be greater than 0"; | |||
| if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator") { | |||
| MS_LOG(EXCEPTION) << "Size should be greater than 0, during running " << identifier_; | |||
| } | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| @@ -517,7 +517,7 @@ TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_lis | |||
| } | |||
| return TypeJoin(args_type_list); | |||
| } | |||
| } // end anonymous namespace | |||
| } // namespace | |||
| py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) { | |||
| MS_EXCEPTION_IF_NULL(abs_base); | |||
| @@ -648,7 +648,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| CheckCustomPrimOutputInferResult(prim_py, res_spec); | |||
| return res_spec; | |||
| } | |||
| } // end anonymous namespace | |||
| } // namespace | |||
| EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base, | |||
| const AbstractBasePtrList &args) { | |||
| @@ -761,6 +761,14 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs | |||
| if (eval_result != nullptr) { | |||
| auto abs = eval_result->abstract()->Clone(); | |||
| auto attr = eval_result->attribute(); | |||
| // To check tuple/list operations with a white list of Python primitive. | |||
| if (prim_py_->name() == prim::kPrimStack->name()) { | |||
| // Set all used flags of tuple as true. | |||
| for (auto &arg : args) { | |||
| SetSequenceElementsUseFlags(arg, true); | |||
| } | |||
| } | |||
| return std::make_shared<EvalResult>(abs, attr); | |||
| } | |||
| @@ -774,6 +782,14 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs | |||
| MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | |||
| auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); | |||
| evaluator_cache_mgr_->SetValue(args, infer_result); | |||
| // To check tuple/list operations with a white list of Python primitive. | |||
| if (prim_py_->name() == prim::kPrimStack->name()) { | |||
| // Set all used flags of tuple as true. | |||
| for (auto &arg : args) { | |||
| SetSequenceElementsUseFlags(arg, true); | |||
| } | |||
| } | |||
| return infer_result; | |||
| } | |||
| @@ -1103,7 +1119,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt | |||
| return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); | |||
| } | |||
| } | |||
| } // end anonymous namespace | |||
| } // namespace | |||
| namespace { | |||
| class EmbedEvaluator : public SymbolicPrimEvaluator { | |||
| @@ -1452,6 +1468,54 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator { | |||
| } | |||
| }; | |||
| class MakeTupleEvaluator : public TransitionPrimEvaluator { | |||
| public: | |||
| MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {} | |||
| ~MakeTupleEvaluator() override = default; | |||
| MS_DECLARE_PARENT(MakeTupleEvaluator, TransitionPrimEvaluator); | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, | |||
| const AnfNodeConfigPtr &out_conf) override { | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(WARNING) << "For MakeTuple, the inputs should not be empty."; | |||
| } | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (enable_eliminate_unused_element) { | |||
| SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size())); | |||
| } | |||
| AnfNodeWeakPtrList sequence_nodes = | |||
| (enable_eliminate_unused_element ? AnfNodeWeakPtrList({AnfNodeWeakPtr(out_conf->node())}) : AnfNodeWeakPtrList()); | |||
| auto abs = std::make_shared<AbstractTuple>(args_spec_list, sequence_nodes); | |||
| auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>()); | |||
| evaluator_cache_mgr_->SetValue(args_spec_list, res); | |||
| return res; | |||
| } | |||
| }; | |||
| class MakeListEvaluator : public TransitionPrimEvaluator { | |||
| public: | |||
| MakeListEvaluator() : TransitionPrimEvaluator("MakeListEvaluator") {} | |||
| ~MakeListEvaluator() override = default; | |||
| MS_DECLARE_PARENT(MakeListEvaluator, TransitionPrimEvaluator); | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, | |||
| const AnfNodeConfigPtr &out_conf) override { | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(WARNING) << "For MakeList, the inputs should not be empty."; | |||
| } | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (enable_eliminate_unused_element) { | |||
| SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size())); | |||
| } | |||
| AnfNodeWeakPtrList sequence_nodes = | |||
| (enable_eliminate_unused_element ? AnfNodeWeakPtrList({AnfNodeWeakPtr(out_conf->node())}) : AnfNodeWeakPtrList()); | |||
| auto abs = std::make_shared<AbstractList>(args_spec_list, sequence_nodes); | |||
| auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>()); | |||
| evaluator_cache_mgr_->SetValue(args_spec_list, res); | |||
| return res; | |||
| } | |||
| }; | |||
| class PartialEvaluator : public Evaluator { | |||
| public: | |||
| PartialEvaluator() : Evaluator("PartialEvaluator") {} | |||
| @@ -1597,6 +1661,8 @@ void InitPrimEvaluatorConstructors() { | |||
| constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>(); | |||
| constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>(); | |||
| constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>(); | |||
| constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>(); | |||
| constructor[prim::kPrimMakeList] = std::make_shared<MakeListEvaluator>(); | |||
| } | |||
| } // namespace | |||
| @@ -66,6 +66,56 @@ class PythonPrimEvaluator final : public TrivialPrimEvaluator { | |||
| PrimitivePyPtr prim_py_; | |||
| }; | |||
| using ValuePtrList = std::vector<ValuePtr>; | |||
| using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); | |||
| class UniformPrimEvaluator final : public TrivialPrimEvaluator { | |||
| public: | |||
| UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type) | |||
| : TrivialPrimEvaluator("UniformPrimEvaluator"), | |||
| impl_(impl), | |||
| eval_value_(eval_value), | |||
| func_desc_(func_desc), | |||
| nargs_(func_desc_->args().size()), | |||
| return_value_type_(func_desc_->retval()), | |||
| specify_out_type_(specify_out_type) { | |||
| for (size_t i = 0; i < nargs_; ++i) { | |||
| TypePtr type = func_desc_->args()[i]; | |||
| if (type_map_[type]) { | |||
| type_map_[type]->push_back(i); | |||
| } else { | |||
| type_map_[type] = std::make_shared<std::vector<size_t>>(); | |||
| type_map_[type]->push_back(i); | |||
| } | |||
| } | |||
| } | |||
| ~UniformPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| ValuePtr RunImpl(const ValuePtrList &args) const; | |||
| // If eval_value_ is False, return broadened arguments. | |||
| AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { | |||
| if (!eval_value_) { | |||
| AbstractBasePtrList broadened_args_spec_list; | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list), | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); | |||
| return broadened_args_spec_list; | |||
| } | |||
| return args_spec_list; | |||
| } | |||
| private: | |||
| PrimitiveImpl impl_; | |||
| bool eval_value_; | |||
| const FunctionPtr func_desc_; | |||
| const std::size_t nargs_; | |||
| const TypePtr return_value_type_; | |||
| const TypePtr specify_out_type_; | |||
| mindspore::HashMap<TypePtr, std::shared_ptr<std::vector<size_t>>, TypeHasher, TypeEqual> type_map_; | |||
| }; | |||
| class DoSignatureEvaluator final : public Evaluator { | |||
| public: | |||
| explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} | |||
| @@ -117,56 +167,6 @@ class MixedPrecisionCastEvaluator final : public Evaluator { | |||
| bool IsInWhiteList(const PrimitivePtr &primitive); | |||
| using ValuePtrList = std::vector<ValuePtr>; | |||
| using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); | |||
| class UniformPrimEvaluator final : public TrivialPrimEvaluator { | |||
| public: | |||
| UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type) | |||
| : TrivialPrimEvaluator("UniformPrimEvaluator"), | |||
| impl_(impl), | |||
| eval_value_(eval_value), | |||
| func_desc_(func_desc), | |||
| nargs_(func_desc_->args().size()), | |||
| return_value_type_(func_desc_->retval()), | |||
| specify_out_type_(specify_out_type) { | |||
| for (size_t i = 0; i < nargs_; ++i) { | |||
| TypePtr type = func_desc_->args()[i]; | |||
| if (type_map_[type]) { | |||
| type_map_[type]->push_back(i); | |||
| } else { | |||
| type_map_[type] = std::make_shared<std::vector<size_t>>(); | |||
| type_map_[type]->push_back(i); | |||
| } | |||
| } | |||
| } | |||
| ~UniformPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; | |||
| ValuePtr RunImpl(const ValuePtrList &args) const; | |||
| // If eval_value_ is False, return broadened arguments. | |||
| AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { | |||
| if (!eval_value_) { | |||
| AbstractBasePtrList broadened_args_spec_list; | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list), | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); | |||
| return broadened_args_spec_list; | |||
| } | |||
| return args_spec_list; | |||
| } | |||
| private: | |||
| PrimitiveImpl impl_; | |||
| bool eval_value_; | |||
| const FunctionPtr func_desc_; | |||
| const std::size_t nargs_; | |||
| const TypePtr return_value_type_; | |||
| const TypePtr specify_out_type_; | |||
| mindspore::HashMap<TypePtr, std::shared_ptr<std::vector<size_t>>, TypeHasher, TypeEqual> type_map_; | |||
| }; | |||
| PrimEvaluatorMap &GetPrimEvaluatorConstructors(); | |||
| // Check whether type x is a subtype of model. | |||
| @@ -67,7 +67,12 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte | |||
| top_context_ = context; | |||
| MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString(); | |||
| } | |||
| return SpecializeFuncGraph(fg, context); | |||
| auto res = SpecializeFuncGraph(fg, context); | |||
| // Call PurifyElements() to purify tuple/list elements. | |||
| for (auto &sequence_abs : sequence_abstract_list_) { | |||
| sequence_abs->PurifyElements(); | |||
| } | |||
| return res; | |||
| } | |||
| FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { | |||
| @@ -80,10 +85,10 @@ FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, con | |||
| } | |||
| std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context); | |||
| FuncGraphPtr fg2 = fg_spec->specialized_func_graph(); | |||
| FuncGraphPtr specialized_func_graph = fg_spec->specialized_func_graph(); | |||
| specializations_[context->SpecializeKey()] = fg_spec; | |||
| fg_spec->Run(); | |||
| return fg2; | |||
| return specialized_func_graph; | |||
| } | |||
| std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) { | |||
| @@ -290,6 +295,133 @@ void FuncGraphSpecializer::SecondPass() { | |||
| } | |||
| } | |||
| namespace { | |||
| // Update elements use flags for MakeTuple/tuple node, | |||
| // and update the node's AbstractSequence 'sequence_nodes' info. | |||
| void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node, const AbstractBasePtr &old_abs) { | |||
| if (new_node == old_node) { | |||
| return; | |||
| } | |||
| AbstractSequencePtr old_sequence_abs = dyn_cast<AbstractSequence>(old_abs); | |||
| if (old_sequence_abs == nullptr || old_sequence_abs->sequence_nodes().empty()) { | |||
| MS_LOG(DEBUG) << "No sequence node in old abs, " << old_node->DebugString() << " --> " << new_node->DebugString(); | |||
| return; | |||
| } | |||
| for (auto &weak_node : old_sequence_abs->sequence_nodes()) { | |||
| auto sequence_node = weak_node.lock(); | |||
| if (sequence_node == nullptr) { | |||
| MS_LOG(DEBUG) << "The sequence_nodes is free. " << old_node->DebugString() << " --> " << new_node->DebugString(); | |||
| continue; | |||
| } | |||
| if (sequence_node != old_node) { | |||
| continue; | |||
| } | |||
| // Update new node's flags with old one, and update old sequence abstract's source node. | |||
| auto flags = GetSequenceNodeElementsUseFlags(old_node); | |||
| MS_LOG(DEBUG) << "Update sequence node, " << old_node->DebugString() << " --> " << new_node->DebugString() | |||
| << ", elements_use_flags: " << (*flags); | |||
| SetSequenceNodeElementsUseFlags(new_node, flags); | |||
| old_sequence_abs->update_sequence_node(sequence_node, new_node); | |||
| // Update new sequence abstract if it's not equal to old one. | |||
| const AbstractBasePtr &new_abs = new_node->abstract(); | |||
| if (old_abs == new_abs) { | |||
| continue; | |||
| } | |||
| AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_abs); | |||
| if (new_sequence_abs == nullptr) { | |||
| MS_LOG(EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString(); | |||
| } | |||
| if (new_sequence_abs->sequence_nodes().empty()) { | |||
| new_sequence_abs->set_sequence_nodes({AnfNodeWeakPtr(new_node)}); | |||
| } else { | |||
| new_sequence_abs->insert_sequence_node(new_node); | |||
| } | |||
| } | |||
| } | |||
| // Purify specific input of a CNode. | |||
| template <typename T> | |||
| void PurifySequenceValueNode(const CNodePtr &cnode, size_t index) { | |||
| const auto &old_input = cnode->input(index); | |||
| auto sequence_value = GetValueNode<std::shared_ptr<T>>(old_input); | |||
| if (sequence_value == nullptr) { | |||
| return; | |||
| } | |||
| auto flags = GetSequenceNodeElementsUseFlags(old_input); | |||
| if (flags == nullptr) { | |||
| return; | |||
| } | |||
| ValuePtrList elements; | |||
| for (size_t i = 0; i < (*flags).size(); ++i) { | |||
| if (!(*flags)[i]) { | |||
| auto zero = MakeValue(0); | |||
| elements.emplace_back(zero); | |||
| MS_LOG(INFO) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs[" | |||
| << index << "] of " << cnode->DebugString(); | |||
| } else { | |||
| elements.emplace_back(sequence_value->value()[i]); | |||
| } | |||
| } | |||
| auto new_sequence_value = std::make_shared<T>(elements); | |||
| auto new_input = NewValueNode(new_sequence_value); | |||
| auto new_input_abs = new_sequence_value->ToAbstract(); | |||
| AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_input_abs); | |||
| MS_EXCEPTION_IF_NULL(new_sequence_abs); | |||
| new_sequence_abs->set_sequence_nodes({AnfNodeWeakPtr(new_input)}); | |||
| new_input->set_abstract(new_sequence_abs); | |||
| // Always reset tuple value node's use flags as non-use. | |||
| SetSequenceNodeElementsUseFlags(new_input, std::make_shared<std::vector<bool>>(new_sequence_abs->elements().size())); | |||
| MS_LOG(DEBUG) << "Update ValueTuple/ValueList, " << old_input->DebugString() << " --> " << new_input->DebugString() | |||
| << ", which is inputs[" << index << "] of " << cnode->DebugString(); | |||
| cnode->set_input(index, new_input); | |||
| } | |||
| } // namespace | |||
| // Eliminate the unused items of Tuple/List. | |||
| void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) { | |||
| if (cnode == nullptr || cnode->abstract() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The parameter \'node\' and its abstract should not be null."; | |||
| } | |||
| const AbstractBasePtr abs = cnode->abstract(); | |||
| AbstractSequencePtr sequence_abs = dyn_cast<AbstractSequence>(abs); | |||
| if (sequence_abs == nullptr || sequence_abs->sequence_nodes().empty()) { | |||
| return; | |||
| } | |||
| // Not call PurifyElements() here, just add to list. | |||
| specializer_->sequence_abstract_list().emplace_back(sequence_abs); | |||
| // Purify MakeTuple/MakeList CNode. | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { | |||
| auto flags = GetSequenceNodeElementsUseFlags(cnode); | |||
| if (flags != nullptr) { | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.emplace_back(cnode->input(0)); | |||
| for (size_t i = 0; i < (*flags).size(); ++i) { | |||
| if (!(*flags)[i]) { | |||
| auto zero_value = NewValueNode(MakeValue(0)); | |||
| zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0))); | |||
| inputs.emplace_back(zero_value); | |||
| MS_LOG(INFO) << "Erase inputs[" << i << "] as zero for " << cnode->DebugString(); | |||
| } else { | |||
| inputs.emplace_back(cnode->input(i + 1)); | |||
| } | |||
| } | |||
| cnode->set_inputs(std::move(inputs)); | |||
| cnode->set_abstract(sequence_abs); | |||
| } | |||
| } | |||
| // Purify each Tuple/List ValueNode in CNode. | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| if (IsValueNode<ValueTuple>(cnode->input(i))) { | |||
| PurifySequenceValueNode<ValueTuple>(cnode, i); | |||
| } else if (IsValueNode<ValueList>(cnode->input(i))) { | |||
| PurifySequenceValueNode<ValueList>(cnode, i); | |||
| } | |||
| } | |||
| } | |||
| void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| ScopeGuard scope_guard(node->scope()); | |||
| @@ -304,7 +436,11 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||
| << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); | |||
| return; | |||
| } | |||
| new_node->set_abstract(GetEvaluatedValue(conf)); | |||
| try { | |||
| new_node->set_abstract(GetEvaluatedValue(conf)); | |||
| } catch (const std::exception &) { | |||
| MS_LOG(EXCEPTION) << "Fail to get abstract value with " << conf->ToString() << ", for " << new_node->DebugString(); | |||
| } | |||
| if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) { | |||
| auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract()); | |||
| if (partial_abstract->node() == node) { | |||
| @@ -315,35 +451,47 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||
| << ", func_graph_: " << func_graph_->ToString() | |||
| << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); | |||
| if (node->isa<CNode>()) { | |||
| auto attrs = conf->ObtainEvalResult()->attribute(); | |||
| auto c_old = node->cast<CNodePtr>(); | |||
| auto c_new = new_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(c_new); | |||
| auto new_inputs = c_new->inputs(); | |||
| auto old_inputs = c_old->inputs(); | |||
| for (size_t i = 0; i < old_inputs.size(); ++i) { | |||
| auto node_input = old_inputs[i]; | |||
| AnfNodeConfigPtr iconf = MakeConfig(node_input); | |||
| AbstractBasePtr ival = GetEvaluatedValue(iconf); | |||
| // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if | |||
| // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. | |||
| AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs, node); | |||
| if (replace_node == nullptr) { | |||
| replace_node = BuildReplacedNode(iconf); | |||
| replace_node->set_abstract(ival); | |||
| MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString(); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() | |||
| << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString(); | |||
| } | |||
| if (new_inputs[i] != replace_node) { | |||
| new_inputs[i] = replace_node; | |||
| MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); | |||
| } | |||
| if (!node->isa<CNode>()) { | |||
| return; | |||
| } | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| auto attrs = conf->ObtainEvalResult()->attribute(); | |||
| auto c_old = node->cast<CNodePtr>(); | |||
| auto c_new = new_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(c_new); | |||
| auto new_inputs = c_new->inputs(); | |||
| auto old_inputs = c_old->inputs(); | |||
| for (size_t i = 0; i < old_inputs.size(); ++i) { | |||
| auto node_input = old_inputs[i]; | |||
| AnfNodeConfigPtr input_conf = MakeConfig(node_input); | |||
| AbstractBasePtr abs; | |||
| try { | |||
| abs = GetEvaluatedValue(input_conf); | |||
| } catch (const std::exception &) { | |||
| MS_LOG(EXCEPTION) << "Fail to get input's abstract value, with input config: " << input_conf->ToString() | |||
| << ", in old node: " << c_old->DebugString(); | |||
| } | |||
| // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if | |||
| // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. | |||
| AnfNodePtr replace_node = BuildPossibleValueNode(node_input, abs, attrs, node); | |||
| if (replace_node == nullptr) { | |||
| replace_node = BuildReplacedNode(input_conf); | |||
| replace_node->set_abstract(abs); | |||
| MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << abs->ToString(); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() | |||
| << ", abs: " << abs->ToString() << ", replace_node: " << replace_node->ToString(); | |||
| } | |||
| if (enable_eliminate_unused_element) { | |||
| UpdateSequenceNode(replace_node, node_input, abs); | |||
| } | |||
| if (new_inputs[i] != replace_node) { | |||
| new_inputs[i] = replace_node; | |||
| MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); | |||
| } | |||
| c_new->set_inputs(new_inputs); | |||
| } | |||
| c_new->set_inputs(new_inputs); | |||
| } | |||
| AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { | |||
| @@ -506,10 +654,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &fun | |||
| << ", " << func->ToString(); | |||
| return func; | |||
| } | |||
| FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| v->set_flag(kFuncGraphFlagUndetermined, false); | |||
| return BuildValueNode(v, abs); | |||
| FuncGraphPtr func_graph = specializer_->SpecializeFuncGraph(context->func_graph(), context); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| func_graph->set_flag(kFuncGraphFlagUndetermined, false); | |||
| return BuildValueNode(func_graph, abs); | |||
| } | |||
| AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine, | |||
| @@ -643,20 +791,21 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB | |||
| return std::make_pair(AbstractBasePtrList(), nullptr); | |||
| } | |||
| void FuncGraphSpecializer::ProcessCNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (specializer_->seen().count(node) > 0) { | |||
| void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (specializer_->seen().count(cnode) > 0) { | |||
| return; | |||
| } | |||
| specializer_->AddSeen(node); | |||
| auto new_inputs = node->inputs(); | |||
| specializer_->AddSeen(cnode); | |||
| auto new_inputs = cnode->inputs(); | |||
| if (new_inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; | |||
| } | |||
| AnfNodePtr func = new_inputs[0]; | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| constexpr auto recursive_level = 2; | |||
| MS_LOG(DEBUG) << "Handle node: " << node->DebugString(recursive_level); | |||
| MS_LOG(DEBUG) << "Handle node: " << cnode->DebugString(recursive_level); | |||
| // First element is func so arg start from 1 | |||
| std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end()); | |||
| @@ -685,7 +834,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func->func_graph()); | |||
| if (status == kSpecializePoly || | |||
| (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) { | |||
| auto wrapped_node = BuildSpecializedParameterNode(node); | |||
| auto wrapped_node = BuildSpecializedParameterNode(cnode); | |||
| MS_LOG(DEBUG) << "Partial closure is handled, wrapped_node: " << wrapped_node->DebugString(recursive_level); | |||
| new_inputs[0] = wrapped_node; | |||
| } | |||
| @@ -723,7 +872,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &node) { | |||
| } | |||
| // Set the updated inputs. | |||
| node->set_inputs(new_inputs); | |||
| cnode->set_inputs(new_inputs); | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (enable_eliminate_unused_element) { | |||
| EliminateUnusedSequenceItem(cnode); | |||
| } | |||
| } | |||
| namespace { | |||
| @@ -756,7 +911,7 @@ bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argv | |||
| } | |||
| return false; | |||
| } | |||
| } // end anonymous namespace | |||
| } // namespace | |||
| SpecializeStatusCode FuncGraphSpecializer::AcquireUniqueEvalVal(const AbstractFunctionPtr &func, | |||
| const EvaluatorPtr &eval, | |||
| @@ -64,6 +64,8 @@ class ProgramSpecializer { | |||
| const AnalysisContextPtr &top_context() { return top_context_; } | |||
| std::vector<AbstractSequencePtr> &sequence_abstract_list() { return sequence_abstract_list_; } | |||
| private: | |||
| std::shared_ptr<AnalysisEngine> engine_; | |||
| mindspore::HashSet<AnfNodePtr> seen_; | |||
| @@ -71,6 +73,8 @@ class ProgramSpecializer { | |||
| std::unordered_map<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>, ContextHasher, ContextEqual> | |||
| specializations_; | |||
| AnalysisContextPtr top_context_; | |||
| // The list to purify tuple/list elements. | |||
| std::vector<AbstractSequencePtr> sequence_abstract_list_; | |||
| }; | |||
| class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> { | |||
| @@ -99,6 +103,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia | |||
| void ProcessNode(const AnfNodePtr &node); | |||
| void ProcessCNode(const CNodePtr &node); | |||
| void EliminateUnusedSequenceItem(const CNodePtr &cnode); | |||
| const NodeToNodeMap &cloned_nodes() const { return cloner_->cloned_nodes(); } | |||
| inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); | |||
| @@ -122,7 +122,10 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac | |||
| MS_LOG(INFO) << func_graph->ToString() << ": Run finished."; | |||
| MS_EXCEPTION_IF_NULL(output_conf); | |||
| result.inferred = output_conf->ObtainEvalResult(); | |||
| auto eval_result = output_conf->ObtainEvalResult(); | |||
| // Set the sequence nodes' elements use flags all true. | |||
| SetSequenceElementsUseFlagsRecursively(eval_result->abstract(), true); | |||
| result.eval_result = eval_result; | |||
| result.context = root_context; | |||
| } catch (const std::exception &ex) { | |||
| MS_LOG(INFO) << "Eval " << func_graph->ToString() << " threw exception."; | |||
| @@ -361,13 +364,14 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| return std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||
| } | |||
| // find prim infer function in the prim function map return a standard evaluator | |||
| // Find prim infer function in the prim function map return a standard evaluator | |||
| auto eval_impl = GetPrimitiveInferImpl(prim); | |||
| if (eval_impl.infer_shape_impl_ != nullptr) { | |||
| if (eval_impl.infer_shape_impl_ != nullptr && prim->name() != prim::kPrimMakeTuple->name() && | |||
| prim->name() != prim::kPrimMakeList->name()) { // Refactoring infer routine soon. | |||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | |||
| } | |||
| // use python infer function if the infer function not founded in the map return a python evaluator | |||
| // Use python infer function if the infer function not founded in the map return a python evaluator | |||
| EvaluatorPtr evaluator = nullptr; | |||
| if (prim->HasPyEvaluator()) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | |||
| @@ -388,7 +392,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| return nullptr; | |||
| } | |||
| // return a default evaluator | |||
| // Return a default evaluator | |||
| if (engine == nullptr) { | |||
| // If engine is nullptr, get constructor from default. | |||
| const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); | |||
| @@ -674,12 +678,13 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt | |||
| std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBasePtr &last_spec, | |||
| const AnfNodePtr &node, const std::string &error_info) { | |||
| constexpr int recursive_level = 2; | |||
| std::ostringstream buffer; | |||
| buffer << "The return values of different branches do not join. \n" | |||
| << error_info << "\nFor more details, please refer to the FAQ at https://www.mindspore.cn.\n" | |||
| << "The abstract type of the return value of the current branch is " << spec->ToString() | |||
| << ", and that of the previous branch is " << last_spec->ToString() << ".\n" | |||
| << "The node " << node->DebugString(); | |||
| << "The node is " << node->DebugString(recursive_level); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>()->input(0); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { | |||
| @@ -803,10 +808,9 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar | |||
| AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs, | |||
| const std::vector<AsyncAbstractPtr> &pending_async_abstract_list, | |||
| const std::vector<std::size_t> &index) { | |||
| if (orig_abs->isa<AbstractSequence>()) { | |||
| const auto &orig_abstract_seq = orig_abs->cast<AbstractSequencePtr>(); | |||
| MS_EXCEPTION_IF_NULL(orig_abstract_seq); | |||
| const auto &orig_elements = orig_abstract_seq->elements(); | |||
| const auto sequence_abs = dyn_cast<AbstractSequence>(orig_abs); | |||
| if (sequence_abs != nullptr) { | |||
| const auto &orig_elements = sequence_abs->elements(); | |||
| AbstractBasePtrList new_elements; | |||
| for (size_t i = 0; i < orig_elements.size(); ++i) { | |||
| if (orig_elements[i]->isa<AbstractFuncAtom>()) { | |||
| @@ -826,11 +830,15 @@ AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs, | |||
| new_elements.push_back(orig_elements[i]); | |||
| } | |||
| } | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| AbstractBasePtr new_abs; | |||
| if (orig_abs->isa<AbstractTuple>()) { | |||
| new_abs = std::make_shared<AbstractTuple>(new_elements); | |||
| new_abs = std::make_shared<AbstractTuple>( | |||
| new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : AnfNodeWeakPtrList())); | |||
| } else if (orig_abs->isa<AbstractList>()) { | |||
| new_abs = std::make_shared<AbstractList>(new_elements); | |||
| new_abs = std::make_shared<AbstractList>( | |||
| new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : AnfNodeWeakPtrList())); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString(); | |||
| } | |||
| @@ -864,7 +872,7 @@ void BuildPossibleSpecs(const AbstractBasePtr &first_result, | |||
| MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString() | |||
| << ", new: " << new_first_result->ToString(); | |||
| std::replace_if( | |||
| out_specs->begin(), out_specs->end(), [first_result](const auto &elem) { return elem == first_result; }, | |||
| out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; }, | |||
| new_first_result); | |||
| } else { | |||
| MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result"; | |||
| @@ -1059,6 +1067,17 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont | |||
| auto prim = value->cast<PrimitivePtr>(); | |||
| return MakeAbstractClosure(prim, anf_node); | |||
| } | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (enable_eliminate_unused_element && value->isa<ValueSequence>()) { | |||
| auto abs = value->ToAbstract(); | |||
| auto sequence_abs = dyn_cast<AbstractSequence>(abs); | |||
| MS_EXCEPTION_IF_NULL(sequence_abs); | |||
| if (anf_node != nullptr) { | |||
| SetSequenceNodeElementsUseFlags(anf_node, std::make_shared<std::vector<bool>>(sequence_abs->elements().size())); | |||
| sequence_abs->set_sequence_nodes({AnfNodeWeakPtr(anf_node)}); | |||
| } | |||
| } | |||
| return value->ToAbstract(); | |||
| } | |||
| @@ -213,7 +213,7 @@ using AnfNodeConfigMap = | |||
| std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; | |||
| struct AnalysisResult { | |||
| EvalResultPtr inferred; | |||
| EvalResultPtr eval_result; | |||
| AnalysisContextPtr context; | |||
| }; | |||
| @@ -241,13 +241,13 @@ const AbstractBasePtr AbstractSequence::operator[](const std::size_t &dim) const | |||
| return elements_[dim]; | |||
| } | |||
| std::string AbstractSequence::ToString() const { | |||
| std::string AbstractSequence::ToStringInternal() const { | |||
| std::ostringstream buffer; | |||
| size_t i = 0; | |||
| size_t size = elements_.size(); | |||
| for (const auto &ele : elements_) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| buffer << "element[" << i << "]: " << ele->ToString(); | |||
| for (const auto &element : elements_) { | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| buffer << "element[" << i << "]: " << element->ToString(); | |||
| if (i < size - 1) { | |||
| buffer << ", "; | |||
| } | |||
| @@ -256,11 +256,169 @@ std::string AbstractSequence::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| std::string AbstractSequence::ToString() const { | |||
| std::stringstream ss; | |||
| ss << type_name(); | |||
| ss << "{"; | |||
| ss << ToStringInternal(); | |||
| if (!sequence_nodes_.empty()) { | |||
| ss << ", sequence_nodes: {"; | |||
| for (size_t i = 0; i < sequence_nodes_.size(); ++i) { | |||
| auto sequence_node = sequence_nodes_[i].lock(); | |||
| if (sequence_node == nullptr) { | |||
| ss << "<freed node>"; | |||
| continue; | |||
| } else { | |||
| ss << sequence_node->DebugString(); | |||
| } | |||
| auto flags = GetSequenceNodeElementsUseFlags(sequence_node); | |||
| if (flags != nullptr) { | |||
| ss << ", elements_use_flags: " << (*flags); | |||
| } | |||
| if (i != sequence_nodes_.size() - 1) { | |||
| ss << ", "; | |||
| } | |||
| } | |||
| ss << "}"; | |||
| } | |||
| ss << "}"; | |||
| return ss.str(); | |||
| } | |||
| namespace { | |||
| void CollectSequenceNodes(const AnfNodeWeakPtrList &source_sequence_nodes, AnfNodeWeakPtrList *sequence_nodes_ptr) { | |||
| AnfNodeWeakPtrList &sequence_nodes = *sequence_nodes_ptr; | |||
| auto sequence_nodes_size = source_sequence_nodes.size(); | |||
| for (size_t i = 0; i < sequence_nodes_size; ++i) { | |||
| // Lock sequence nodes of this. | |||
| auto &source_weak_node = source_sequence_nodes[i]; | |||
| auto this_sequence_node = source_weak_node.lock(); | |||
| if (this_sequence_node == nullptr) { | |||
| continue; | |||
| } | |||
| // Check and emplace sequence node for this. | |||
| auto this_iter = std::find_if( | |||
| sequence_nodes.begin(), sequence_nodes.end(), | |||
| [&this_sequence_node](const AnfNodeWeakPtr &weak_node) { return this_sequence_node == weak_node.lock(); }); | |||
| if (this_iter == sequence_nodes.end()) { | |||
| sequence_nodes.emplace_back(AnfNodeWeakPtr(this_sequence_node)); | |||
| } | |||
| } | |||
| } | |||
| void SynchronizeSequenceNodesElementsUseFlags(const AnfNodeWeakPtrList &sequence_nodes) { | |||
| // Synchronize the elements use flags for all sequence nodes. | |||
| auto current_sequence_node = sequence_nodes[0].lock(); | |||
| MS_EXCEPTION_IF_NULL(current_sequence_node); | |||
| for (size_t i = 1; i < sequence_nodes.size(); ++i) { | |||
| // Synchronize the 'elements_use_flags' for all sequence node. | |||
| // We set the same 'elements_use_flags' for them after here. | |||
| auto latter_sequence_node = sequence_nodes[i].lock(); | |||
| MS_EXCEPTION_IF_NULL(latter_sequence_node); | |||
| // The 'current_sequence_node' is not equal to 'latter_sequence_node'. | |||
| auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node); | |||
| auto latter_flags = GetSequenceNodeElementsUseFlags(latter_sequence_node); | |||
| std::shared_ptr<std::vector<bool>> unique_flags = nullptr; // Choose the ptr (use_count > 1) as unique flags. | |||
| if (current_flags.use_count() == 1 && latter_flags.use_count() == 1) { | |||
| unique_flags = current_flags; | |||
| } else { | |||
| MS_EXCEPTION_IF_CHECK_FAIL(current_flags.use_count() > 1 && latter_flags.use_count() > 1, | |||
| "Allow only one side has more than one use count."); | |||
| if (current_flags.use_count() > 1) { | |||
| unique_flags = current_flags; | |||
| } else { // If latter_flags.use_count() > 1 | |||
| unique_flags = latter_flags; | |||
| } | |||
| } | |||
| for (size_t j = 0; j < current_flags->size(); ++j) { | |||
| MS_LOG(DEBUG) << "Check elements_use_flags[" << j << "], this_flag: " << (*current_flags)[j] | |||
| << ", other_flag: " << (*latter_flags)[j]; | |||
| if ((*current_flags)[j] != (*latter_flags)[j]) { | |||
| (*unique_flags)[j] = true; | |||
| } else { | |||
| (*unique_flags)[j] = (*current_flags)[j]; | |||
| } | |||
| } | |||
| if (unique_flags != current_flags) { | |||
| SetSequenceNodeElementsUseFlags(current_sequence_node, unique_flags); | |||
| } | |||
| if (unique_flags != latter_flags) { | |||
| SetSequenceNodeElementsUseFlags(latter_sequence_node, unique_flags); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &other) { | |||
| AnfNodeWeakPtrList sequence_nodes; | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (!enable_eliminate_unused_element) { | |||
| return sequence_nodes; | |||
| } | |||
| MS_LOG(DEBUG) << "this: " << ToString() << ", other: " << other->ToString(); | |||
| auto other_sequence = dyn_cast<AbstractSequence>(other); | |||
| auto this_sequence_nodes_size = sequence_nodes_.size(); | |||
| auto other_sequence_nodes_size = (other_sequence != nullptr ? other_sequence->sequence_nodes_.size() : 0); | |||
| if (this_sequence_nodes_size == 0 && other_sequence_nodes_size == 0) { | |||
| return sequence_nodes; | |||
| } | |||
| // Collect this and other sequence nodes. | |||
| CollectSequenceNodes(sequence_nodes_, &sequence_nodes); | |||
| CollectSequenceNodes(other_sequence->sequence_nodes_, &sequence_nodes); | |||
| if (sequence_nodes.empty()) { | |||
| MS_LOG(EXCEPTION) << "Sequence nodes size should not be empty."; | |||
| } | |||
| // Synchronize the elements use flags for all sequence nodes. | |||
| SynchronizeSequenceNodesElementsUseFlags(sequence_nodes); | |||
| return sequence_nodes; | |||
| } | |||
| void AbstractSequence::PurifyElements() { | |||
| if (sequence_nodes_.empty()) { | |||
| return; | |||
| } | |||
| // Just use any sequence node's elements_use_flags. | |||
| std::shared_ptr<std::vector<bool>> elements_use_flags_ptr = nullptr; | |||
| for (auto &weak_node : sequence_nodes_) { | |||
| auto sequence_node = weak_node.lock(); | |||
| if (sequence_node == nullptr) { | |||
| MS_LOG(DEBUG) << "The node in sequence_nodes is free."; | |||
| continue; | |||
| } | |||
| auto flags = GetSequenceNodeElementsUseFlags(sequence_node); | |||
| if (flags != nullptr) { | |||
| elements_use_flags_ptr = flags; | |||
| break; | |||
| } | |||
| } | |||
| // Purify the elements. | |||
| if (elements_use_flags_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Check if all sequence nodes are released, or none elements use flags in them. " << ToString(); | |||
| return; | |||
| } | |||
| auto &elements_use_flags = *elements_use_flags_ptr; | |||
| if (elements_use_flags.size() != elements_.size()) { | |||
| MS_LOG(EXCEPTION) << "Elements size should be equal to elements use flags size."; | |||
| } | |||
| for (size_t i = 0; i < elements_use_flags.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(elements_[i]); | |||
| if (!elements_use_flags[i]) { | |||
| const auto unuse_node_none = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(0)); | |||
| elements_[i] = unuse_node_none; | |||
| MS_LOG(INFO) << "Set element[" << i << "] to Zero."; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Keep element[" << i << "] as " << elements_[i]->ToString(); | |||
| } | |||
| } | |||
| } | |||
| TypePtrList AbstractSequence::ElementsType() const { | |||
| TypePtrList element_type_list; | |||
| for (const auto &ele : elements_) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| TypePtr element_type = ele->BuildType(); | |||
| for (const auto &element : elements_) { | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| TypePtr element_type = element->BuildType(); | |||
| element_type_list.push_back(element_type); | |||
| } | |||
| return element_type_list; | |||
| @@ -268,50 +426,50 @@ TypePtrList AbstractSequence::ElementsType() const { | |||
| BaseShapePtrList AbstractSequence::ElementsShape() const { | |||
| BaseShapePtrList element_shape_list; | |||
| for (const auto &ele : elements_) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| BaseShapePtr element_shape = ele->BuildShape(); | |||
| for (const auto &element : elements_) { | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| BaseShapePtr element_shape = element->BuildShape(); | |||
| element_shape_list.push_back(element_shape); | |||
| } | |||
| return element_shape_list; | |||
| } | |||
| AbstractBasePtrList AbstractSequence::ElementsClone() const { | |||
| AbstractBasePtrList ele_list; | |||
| for (const auto &ele : elements_) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| AbstractBasePtr clone = ele->Clone(); | |||
| ele_list.push_back(clone); | |||
| AbstractBasePtrList element_list; | |||
| for (const auto &element : elements_) { | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| AbstractBasePtr clone = element->Clone(); | |||
| element_list.push_back(clone); | |||
| } | |||
| return ele_list; | |||
| return element_list; | |||
| } | |||
| AbstractBasePtrList AbstractSequence::ElementsBroaden() const { | |||
| AbstractBasePtrList ele_list; | |||
| for (const auto &ele : elements_) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| AbstractBasePtr broadend = ele->Broaden(); | |||
| ele_list.push_back(broadend); | |||
| AbstractBasePtrList element_list; | |||
| for (const auto &element : elements_) { | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| AbstractBasePtr broadend = element->Broaden(); | |||
| element_list.push_back(broadend); | |||
| } | |||
| return ele_list; | |||
| return element_list; | |||
| } | |||
| AbstractBasePtrList AbstractSequence::ElementsPartialBroaden() const { | |||
| AbstractBasePtrList ele_list; | |||
| for (const auto &ele : elements_) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| AbstractBasePtr broadend = ele->PartialBroaden(); | |||
| ele_list.push_back(broadend); | |||
| AbstractBasePtrList element_list; | |||
| for (const auto &element : elements_) { | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| AbstractBasePtr broadend = element->PartialBroaden(); | |||
| element_list.push_back(broadend); | |||
| } | |||
| return ele_list; | |||
| return element_list; | |||
| } | |||
| template <typename T> | |||
| ValuePtr AbstractSequence::ElementsBuildValue() const { | |||
| std::vector<ValuePtr> element_value_list; | |||
| for (const auto &ele : elements_) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| ValuePtr element_value = ele->BuildValue(); | |||
| for (const auto &element : elements_) { | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| ValuePtr element_value = element->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(element_value); | |||
| if (element_value->isa<AnyValue>()) { | |||
| return kAnyValue; | |||
| @@ -692,7 +692,9 @@ class MS_CORE_API AbstractSequence : public AbstractBase { | |||
| /// \brief Constructor of AbstractSequence. | |||
| /// | |||
| /// \param[in] elements A list of abstracts. | |||
| explicit AbstractSequence(const AbstractBasePtrList &elements) : elements_(elements) {} | |||
| /// \param[in] sequence_nodes The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes. | |||
| explicit AbstractSequence(const AbstractBasePtrList &elements, const AnfNodeWeakPtrList &sequence_nodes) | |||
| : elements_(elements), sequence_nodes_(sequence_nodes) {} | |||
| /// \brief Destructor of AbstractSequence. | |||
| ~AbstractSequence() override = default; | |||
| @@ -738,6 +740,12 @@ class MS_CORE_API AbstractSequence : public AbstractBase { | |||
| template <typename T> | |||
| AbstractBasePtr ElementsJoin(const AbstractBasePtr &other); | |||
| /// \brief Combine other sequence nodes with this one. | |||
| /// | |||
| /// \param[in] other The other abstract to be joined. | |||
| /// \return A sequence nodes list combined. | |||
| AnfNodeWeakPtrList SequenceNodesJoin(const AbstractBasePtr &other); | |||
| /// \brief Get the size of the stored elements. | |||
| /// | |||
| /// \return A size_t. | |||
| @@ -748,8 +756,51 @@ class MS_CORE_API AbstractSequence : public AbstractBase { | |||
| /// \return A vector of elements. | |||
| const AbstractBasePtrList &elements() const { return elements_; } | |||
| /// \brief Purify the elements list, and clean unused elements. | |||
| void PurifyElements(); | |||
| /// \brief Get the sequence nodes where these 'AbstractSequence' evaluated from. | |||
| /// | |||
| /// \return The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes. | |||
| const AnfNodeWeakPtrList &sequence_nodes() const { return sequence_nodes_; } | |||
| /// \brief Set the sequence nodes where these 'AbstractSequence' evaluated from. | |||
| /// | |||
| /// \param[in] sequence_nodes The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes. | |||
| void set_sequence_nodes(const AnfNodeWeakPtrList &sequence_nodes) { sequence_nodes_ = sequence_nodes; } | |||
| /// \brief Insert a node into the sequence nodes. | |||
| /// | |||
| /// \param[in] sequence_node The node to intert into sequence nodes. | |||
| void insert_sequence_node(const AnfNodePtr &sequence_node) { | |||
| auto iter = | |||
| std::find_if(sequence_nodes_.begin(), sequence_nodes_.end(), | |||
| [&sequence_node](const AnfNodeWeakPtr &weak_node) { return sequence_node == weak_node.lock(); }); | |||
| if (iter == sequence_nodes_.end()) { | |||
| sequence_nodes_.emplace_back(sequence_node); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Fail to insert node \'" << sequence_node->DebugString() << "\' into sequence nodes."; | |||
| } | |||
| } | |||
| /// \brief Update the sequence nodes. | |||
| /// | |||
| /// \param[in] old_sequence_node The old node in sequence nodes. | |||
| /// \param[in] new_sequence_node The new node to replace old node in sequence nodes. | |||
| void update_sequence_node(const AnfNodePtr &old_sequence_node, const AnfNodePtr &new_sequence_node) { | |||
| auto iter = std::find_if( | |||
| sequence_nodes_.begin(), sequence_nodes_.end(), | |||
| [&old_sequence_node](const AnfNodeWeakPtr &weak_node) { return old_sequence_node == weak_node.lock(); }); | |||
| if (iter != sequence_nodes_.end()) { | |||
| *iter = new_sequence_node; | |||
| return; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Not found old node \'" << old_sequence_node->DebugString() << "\' in sequence nodes."; | |||
| } | |||
| std::size_t hash() const override; | |||
| std::string ToStringInternal() const; | |||
| std::string ToString() const override; | |||
| /// \brief Overwrite the operator '[]' to get an element. | |||
| @@ -767,6 +818,7 @@ class MS_CORE_API AbstractSequence : public AbstractBase { | |||
| protected: | |||
| AbstractBasePtrList elements_; | |||
| AnfNodeWeakPtrList sequence_nodes_; | |||
| }; | |||
| using AbstractSequencePtr = std::shared_ptr<AbstractSequence>; | |||
| @@ -776,7 +828,9 @@ class MS_CORE_API AbstractTuple final : public AbstractSequence { | |||
| /// \brief Constructor of AbstractTuple. | |||
| /// | |||
| /// \param[in] elements A list of abstracts. | |||
| explicit AbstractTuple(const AbstractBasePtrList &elements) : AbstractSequence(elements) {} | |||
| /// \param[in] tuple_node The nodes of tuple, usually are MakeTuple CNodes or tuple ValueNodes. | |||
| explicit AbstractTuple(const AbstractBasePtrList &elements, const AnfNodeWeakPtrList &tuple_nodes = {}) | |||
| : AbstractSequence(elements, tuple_nodes) {} | |||
| /// \brief Destructor of AbstractTuple. | |||
| ~AbstractTuple() override = default; | |||
| @@ -786,15 +840,22 @@ class MS_CORE_API AbstractTuple final : public AbstractSequence { | |||
| BaseShapePtr BuildShape() const override { return std::make_shared<TupleShape>(ElementsShape()); } | |||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); } | |||
| AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); } | |||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone(), sequence_nodes_); } | |||
| AbstractBasePtr PartialBroaden() const override { return std::make_shared<AbstractTuple>(ElementsPartialBroaden()); } | |||
| AbstractBasePtr Broaden() const override { | |||
| return std::make_shared<AbstractTuple>(ElementsBroaden(), sequence_nodes_); | |||
| } | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); } | |||
| AbstractBasePtr PartialBroaden() const override { | |||
| return std::make_shared<AbstractTuple>(ElementsPartialBroaden(), sequence_nodes_); | |||
| } | |||
| std::string ToString() const override { return type_name() + "(" + AbstractSequence::ToString() + ")"; } | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) override { | |||
| auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractTuple>(other)); | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| res->set_sequence_nodes(SequenceNodesJoin(other)); | |||
| return res; | |||
| } | |||
| /// \brief Check whether all elements of the tuple are tensors. | |||
| /// | |||
| @@ -821,7 +882,9 @@ class MS_CORE_API AbstractList final : public AbstractSequence { | |||
| /// \brief Constructor of AbstractList. | |||
| /// | |||
| /// \param[in] elements A list of abstracts. | |||
| explicit AbstractList(const AbstractBasePtrList &elements) : AbstractSequence(elements) {} | |||
| /// \param[in] list_node The nodes of list, usually are MakeList CNodes or list ValueNodes. | |||
| explicit AbstractList(const AbstractBasePtrList &elements, const AnfNodeWeakPtrList &list_nodes = {}) | |||
| : AbstractSequence(elements, list_nodes) {} | |||
| /// \brief Destructor of AbstractList. | |||
| ~AbstractList() override = default; | |||
| @@ -831,15 +894,22 @@ class MS_CORE_API AbstractList final : public AbstractSequence { | |||
| BaseShapePtr BuildShape() const override { return std::make_shared<ListShape>(ElementsShape()); } | |||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); } | |||
| AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); } | |||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone(), sequence_nodes_); } | |||
| AbstractBasePtr PartialBroaden() const override { return std::make_shared<AbstractList>(ElementsPartialBroaden()); } | |||
| AbstractBasePtr Broaden() const override { | |||
| return std::make_shared<AbstractList>(ElementsBroaden(), sequence_nodes_); | |||
| } | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); } | |||
| AbstractBasePtr PartialBroaden() const override { | |||
| return std::make_shared<AbstractList>(ElementsPartialBroaden(), sequence_nodes_); | |||
| } | |||
| std::string ToString() const override { return type_name() + "[" + AbstractSequence::ToString() + "]"; } | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) override { | |||
| auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractList>(other)); | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| res->set_sequence_nodes(SequenceNodesJoin(other)); | |||
| return res; | |||
| } | |||
| /// \brief Overwrite the operator '==' to compare other abstract list. | |||
| /// | |||
| @@ -136,6 +136,8 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| tuple_len = arg->elements().size(); | |||
| tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0); | |||
| // For Stack(tuple), set all used flags of tuple as true. | |||
| SetSequenceElementsUseFlags(args_spec_list[0], true); | |||
| } else if (args_spec_list[0]->isa<AbstractTensor>()) { | |||
| tuple_len = args_spec_list.size(); | |||
| tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| @@ -191,6 +191,9 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| return args_spec_list[0]; | |||
| } | |||
| // For F.depend(x, MakeTuple()) or F.depend(x, tuple), set all used flags of tuple as true. | |||
| SetSequenceElementsUseFlags(dependant_abstract, true); | |||
| auto depends = args_spec_list[0]->Broaden(); // Avoid eliminating the dependent node. | |||
| if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) { | |||
| // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value. | |||
| @@ -207,6 +210,12 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP | |||
| MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| // For UpdateState(x, MakeTuple()) or UpdateState(x, tuple), set all used flags of tuple as true. | |||
| for (size_t i = 1; i < args_spec_list.size(); i++) { | |||
| SetSequenceElementsUseFlags(args_spec_list[i], true); | |||
| } | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| @@ -36,7 +36,8 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tuples. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| constexpr int args_spec_size = 2; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| @@ -66,7 +67,8 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a string and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| constexpr int args_spec_size = 2; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| ValuePtr keyPtr = key->BuildValue(); | |||
| @@ -82,7 +84,8 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a string and a keyword. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| constexpr int args_spec_size = 2; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1); | |||
| @@ -103,7 +106,8 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list and a scalar whose value is an int32 number. | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| constexpr int args_spec_size = 2; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| auto queue = CheckArg<T>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| @@ -117,26 +121,41 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra | |||
| } | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index->ToString(); | |||
| } | |||
| auto idx_v = GetValue<int64_t>(index_value); | |||
| auto index_int64_value = GetValue<int64_t>(index_value); | |||
| std::size_t nelems = queue->elements().size(); | |||
| if (idx_v >= SizeToLong(nelems) || idx_v < -SizeToLong(nelems)) { | |||
| if (index_int64_value >= SizeToLong(nelems) || index_int64_value < -SizeToLong(nelems)) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToLong(nelems) << ", " | |||
| << SizeToLong(nelems) << "), but got " << idx_v << "."; | |||
| << SizeToLong(nelems) << "), but got " << index_int64_value << "."; | |||
| } | |||
| std::size_t uidx_v = 0; | |||
| if (idx_v >= 0) { | |||
| uidx_v = LongToSize(idx_v); | |||
| std::size_t index_unsigned_value = 0; | |||
| if (index_int64_value >= 0) { | |||
| index_unsigned_value = LongToSize(index_int64_value); | |||
| } else { | |||
| uidx_v = LongToSize(idx_v + SizeToLong(nelems)); | |||
| index_unsigned_value = LongToSize(index_int64_value + SizeToLong(nelems)); | |||
| } | |||
| return queue->elements()[uidx_v]; | |||
| if (!queue->sequence_nodes().empty()) { | |||
| for (auto &node : queue->sequence_nodes()) { | |||
| auto sequence_node = node.lock(); | |||
| if (sequence_node == nullptr) { | |||
| MS_LOG(DEBUG) << "The node in sequence_nodes is free."; | |||
| continue; | |||
| } | |||
| auto flags = GetSequenceNodeElementsUseFlags(sequence_node); | |||
| if (flags != nullptr) { | |||
| (*flags)[index_unsigned_value] = true; | |||
| MS_LOG(DEBUG) << "Set item[" << index_unsigned_value << "] as use flag for " << sequence_node->DebugString(); | |||
| } | |||
| } | |||
| } | |||
| return queue->elements()[index_unsigned_value]; | |||
| } | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase. | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| constexpr int args_spec_size = 3; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| auto queue = CheckArg<T>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| @@ -146,16 +165,17 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " | |||
| << index_value->ToString(); | |||
| } | |||
| auto idx_v = GetValue<int64_t>(index_value); | |||
| auto index_int64_value = GetValue<int64_t>(index_value); | |||
| AbstractBasePtrList elements = queue->elements(); | |||
| std::size_t nelems = elements.size(); | |||
| int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems); | |||
| if (idx_t < 0 || idx_t >= SizeToLong(nelems)) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << idx_v << " to set out of range: [-" << nelems | |||
| << "," << (nelems - 1) << "]."; | |||
| int64_t index_positive_value = index_int64_value >= 0 ? index_int64_value : index_int64_value + SizeToLong(nelems); | |||
| if (index_positive_value < 0 || index_positive_value >= SizeToLong(nelems)) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << index_int64_value << " to set out of range: [-" | |||
| << nelems << "," << (nelems - 1) << "]."; | |||
| } | |||
| size_t uidx_v = LongToSize(idx_t); | |||
| elements[uidx_v] = args_spec_list[2]; | |||
| size_t index_unsigned_value = LongToSize(index_positive_value); | |||
| constexpr int target_value_index = 2; | |||
| elements[index_unsigned_value] = args_spec_list[target_value_index]; | |||
| return std::make_shared<T>(elements); | |||
| } | |||
| @@ -183,7 +203,8 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict and a scalar whose value is a string. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| constexpr int args_spec_size = 2; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| @@ -206,7 +227,8 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| constexpr int args_spec_size = 3; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| @@ -235,7 +257,8 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| constexpr int args_spec_size = 1; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| AbstractBasePtrList keys; | |||
| @@ -248,7 +271,8 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| constexpr int args_spec_size = 1; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| AbstractBasePtrList values; | |||
| @@ -261,7 +285,8 @@ AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| constexpr int args_spec_size = 1; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| AbstractBasePtrList items; | |||
| @@ -276,7 +301,8 @@ AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePt | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a list and an object of a subclass of AbstractBase. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| constexpr int args_spec_size = 2; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0); | |||
| AbstractBasePtr item = dyn_cast<AbstractBase>(args_spec_list[1]); | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| @@ -298,7 +324,8 @@ AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| constexpr int args_spec_size = 1; | |||
| CheckArgsSize(op_name, args_spec_list, args_spec_size); | |||
| auto arg_abs = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto shape = arg_abs->BuildShape()->cast<ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| @@ -203,6 +203,7 @@ using AnfNodePtr = std::shared_ptr<AnfNode>; | |||
| using AnfNodePtrList = std::vector<AnfNodePtr>; | |||
| using AnfNodeSet = OrderedSet<AnfNodePtr>; | |||
| using AnfNodeWeakPtr = std::weak_ptr<AnfNode>; | |||
| using AnfNodeWeakPtrList = std::vector<AnfNodeWeakPtr>; | |||
| class FuncGraph; | |||
| using FuncGraphPtr = std::shared_ptr<FuncGraph>; | |||
| @@ -643,4 +643,54 @@ bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) | |||
| } | |||
| return IsOneOfPrimitive(cnode->input(0), prim_set); | |||
| } | |||
| // Set the sequence nodes' elements use flags all true. | |||
| void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag) { | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (!enable_eliminate_unused_element) { | |||
| return; | |||
| } | |||
| auto sequence_abs = dyn_cast<abstract::AbstractSequence>(abs); | |||
| if (sequence_abs == nullptr) { | |||
| return; | |||
| } | |||
| if (sequence_abs->sequence_nodes().empty()) { | |||
| return; | |||
| } | |||
| for (auto &weak_node : sequence_abs->sequence_nodes()) { | |||
| auto sequence_node = weak_node.lock(); | |||
| if (sequence_node == nullptr) { | |||
| MS_LOG(DEBUG) << "The node in sequence_nodes is free."; | |||
| continue; | |||
| } | |||
| auto flags = GetSequenceNodeElementsUseFlags(sequence_node); | |||
| if (flags != nullptr) { | |||
| auto &all_flags = (*flags); | |||
| std::transform(all_flags.begin(), all_flags.end(), all_flags.begin(), | |||
| [&new_flag](bool) -> bool { return new_flag; }); | |||
| } | |||
| } | |||
| } | |||
| // Set the sequence nodes' elements use flags all true recursively. | |||
| void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag) { | |||
| static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT"); | |||
| static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1"); | |||
| if (!enable_eliminate_unused_element) { | |||
| return; | |||
| } | |||
| SetSequenceElementsUseFlags(abs, new_flag); | |||
| // Check its elements if it's sequence node. | |||
| auto sequence_abs = dyn_cast<abstract::AbstractSequence>(abs); | |||
| if (sequence_abs == nullptr) { | |||
| return; | |||
| } | |||
| for (auto &element : sequence_abs->elements()) { | |||
| SetSequenceElementsUseFlagsRecursively(element, new_flag); | |||
| } | |||
| } | |||
| } // namespace mindspore | |||
| @@ -301,7 +301,7 @@ class MS_CORE_API AnfNode : public Base { | |||
| user_data_.set<T>(T::key, value); | |||
| } | |||
| /// \brief Set user data. | |||
| /// \brief Get user data. | |||
| /// | |||
| /// \param[in] key The key of user data. | |||
| /// \return Pointer to user data. | |||
| @@ -1200,6 +1200,18 @@ struct GraphSegment { | |||
| uint32_t graph_id_{0}; | |||
| }; | |||
| using GraphSegmentPtr = std::shared_ptr<GraphSegment>; | |||
| constexpr auto kElementsUseFlagsKey = "elements_use_flags"; | |||
| inline std::shared_ptr<std::vector<bool>> GetSequenceNodeElementsUseFlags(const AnfNodePtr &node) { | |||
| return node->template user_data<std::vector<bool>>(kElementsUseFlagsKey); | |||
| } | |||
| inline void SetSequenceNodeElementsUseFlags(const AnfNodePtr &node, const std::shared_ptr<std::vector<bool>> &flags) { | |||
| node->set_user_data(kElementsUseFlagsKey, flags); | |||
| } | |||
| void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag); | |||
| void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_ANF_H_ | |||
| @@ -140,7 +140,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) { | |||
| repl_node_[node] = std::move(new_const); | |||
| } | |||
| void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| void Cloner::CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(target); | |||
| auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_); | |||
| @@ -232,7 +232,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func | |||
| auto parent = cnode.first->first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(parent); | |||
| const auto &valuenode = parent->input(cnode.first->second); | |||
| CloneValueNode(valuenode, target_func_graph); | |||
| CloneFuncGraphValueNode(valuenode, target_func_graph); | |||
| } | |||
| } | |||
| @@ -88,7 +88,7 @@ class Cloner { | |||
| void SetDefaults(); | |||
| void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||
| void CloneValueNode(const AnfNodePtr &node); | |||
| void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||
| void CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||
| void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||
| void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); | |||
| void CloneValueNodes(const FuncGraphPtr &func_graph); | |||
| @@ -98,7 +98,10 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args)); | |||
| auto res = abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args)); | |||
| // For AddN(MakeTuple()) or AddN(tuple), set all used flags of tuple as true. | |||
| SetSequenceElementsUseFlags(input_args[0], true); | |||
| return res; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true); | |||
| } // namespace ops | |||
| @@ -160,7 +160,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr ret = | |||
| dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -186,7 +187,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr ret = | |||
| dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -212,7 +214,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr ret = | |||
| dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -238,7 +241,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) { | |||
| auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||
| AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr ret = | |||
| dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -265,7 +269,8 @@ TEST_F(TestComposite, test_UnpackCall_3args) { | |||
| abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | |||
| AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr ret = | |||
| dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -292,7 +297,8 @@ TEST_F(TestComposite, test_UnpackCall_5args) { | |||
| abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); | |||
| AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr ret = | |||
| dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -314,7 +320,7 @@ TEST_F(TestComposite, test_ZipOperation) { | |||
| auto tuple = std::make_shared<AbstractTuple>(eles); | |||
| AbstractBasePtrList args_spec_list = {tuple}; | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).eval_result->abstract()); | |||
| if (ret == nullptr) { | |||
| FAIL() << "Cast ret to abstract tuple failed."; | |||
| } | |||
| @@ -362,7 +368,7 @@ TEST_F(TestComposite, test_shard) { | |||
| auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4}); | |||
| AbstractBasePtrList args_spec_list = {tensor}; | |||
| auto ret = engine_->Run(shard_func_graph, args_spec_list).inferred->abstract(); | |||
| auto ret = engine_->Run(shard_func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_NE(ret, nullptr); | |||
| ASSERT_TRUE(ret->isa<abstract::AbstractTensor>()); | |||
| auto build_shape = ret->BuildShape(); | |||
| @@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) { | |||
| std::vector<int64_t> shape = {2, 2, 6, 6}; | |||
| expected->set_shape(std::make_shared<Shape>(shape)); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -143,7 +143,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) { | |||
| AbstractBasePtr abstract_x = FromValue(x, false); | |||
| args_spec_list.push_back(abstract_x); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | |||
| } | |||
| @@ -159,7 +159,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) { | |||
| AbstractBasePtr abstract_x = FromValue(x, false); | |||
| args_spec_list.push_back(abstract_x); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); | |||
| } | |||
| @@ -178,7 +178,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -197,7 +197,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -216,7 +216,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -235,7 +235,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) { | |||
| args_spec_list.push_back(abstract_x); | |||
| args_spec_list.push_back(abstract_y); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); | |||
| } | |||
| @@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) { | |||
| auto prim_typeof = std::make_shared<Primitive>("typeof"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| res->dump(); | |||
| TypePtr res_value = res->GetValueTrack()->cast<TypePtr>(); | |||
| res_value->dump(); | |||
| @@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) { | |||
| auto prim_list_map = std::make_shared<Primitive>("list_map"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto expected = std::make_shared<AbstractList>( | |||
| AbstractBasePtrList({FromValue(static_cast<int64_t>(3), false), FromValue(static_cast<int64_t>(3), false)})); | |||
| res->dump(); | |||
| @@ -189,7 +189,7 @@ TEST_F(TestPrim, test_list_reduce) { | |||
| auto prim_list_reduce = std::make_shared<Primitive>("list_reduce"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| res->dump(); | |||
| TypePtr res_type = res->GetTypeTrack(); | |||
| res_type->dump(); | |||
| @@ -206,7 +206,7 @@ TEST_F(TestPrim, test_scalar_to_array) { | |||
| auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| res->dump(); | |||
| TypePtr res_type = res->BuildType(); | |||
| res_type->dump(); | |||
| @@ -224,7 +224,7 @@ TEST_F(TestPrim, test_array_to_scalar) { | |||
| auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| res->dump(); | |||
| TypePtr res_type = res->BuildType(); | |||
| res_type->dump(); | |||
| @@ -240,7 +240,7 @@ TEST_F(TestPrim, test_J_1) { | |||
| auto prim_J = std::make_shared<Primitive>("J"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res); | |||
| ASSERT_TRUE(res_J != nullptr); | |||
| ASSERT_TRUE(*(res_J->element()) == *abstract_v1); | |||
| @@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) { | |||
| int64_t v1 = 1; | |||
| AbstractBasePtr abstract_v1 = FromValue(v1, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_v1}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| res->dump(); | |||
| AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res); | |||
| ASSERT_TRUE(res_J != nullptr); | |||
| @@ -301,7 +301,7 @@ TEST_F(TestPrim, test_switch1) { | |||
| AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false); | |||
| AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *arg1); | |||
| } | |||
| @@ -314,7 +314,7 @@ TEST_F(TestPrim, test_switch2) { | |||
| AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false); | |||
| AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "make result res: " << res->ToString(); | |||
| MS_LOG(INFO) << "make result arg2: " << arg2->ToString(); | |||
| ASSERT_TRUE(*res == *arg2); | |||
| @@ -327,7 +327,7 @@ TEST_F(TestPrim, test_identity) { | |||
| AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false); | |||
| AbstractBasePtrList args_spec_list = {abstract_v1}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *abstract_v1); | |||
| } | |||
| @@ -341,7 +341,7 @@ TEST_F(TestPrim, test_broadcast_shape) { | |||
| AbstractBasePtrList args_spec_list = {a, b}; | |||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract()); | |||
| AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).eval_result->abstract()); | |||
| auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); | |||
| std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)}; | |||
| @@ -361,7 +361,7 @@ TEST_F(TestPrim, test_partial) { | |||
| AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(1), false); | |||
| AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2}; | |||
| auto expected = std::make_shared<PartialAbstractClosure>( | |||
| std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list); | |||
| @@ -377,7 +377,7 @@ TEST_F(TestPrim, test_environ_set) { | |||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | |||
| AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false); | |||
| AbstractBasePtrList args_spec_list = {abstract_x}; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract(); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvironSet, 3); | |||
| @@ -385,7 +385,7 @@ TEST_F(TestPrim, test_environ_set) { | |||
| AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false); | |||
| args_spec_list = {abstract_environ, embed_x, abstract_y}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | |||
| ASSERT_TRUE(*res == *exp); | |||
| } | |||
| @@ -397,7 +397,7 @@ TEST_F(TestPrim, test_environ_get) { | |||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | |||
| AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false); | |||
| AbstractBasePtrList args_spec_list = {abstract_x}; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract(); | |||
| FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3); | |||
| @@ -405,7 +405,7 @@ TEST_F(TestPrim, test_environ_get) { | |||
| AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false); | |||
| args_spec_list = {abstract_environ, embed_x, abstract_y}; | |||
| AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | |||
| ASSERT_TRUE(*res == *exp); | |||
| @@ -414,7 +414,7 @@ TEST_F(TestPrim, test_environ_get) { | |||
| AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false); | |||
| args_spec_list = {res, embed_x, abstract_z}; | |||
| res = engine_->Run(graph_environ_get, args_spec_list).inferred->abstract(); | |||
| res = engine_->Run(graph_environ_get, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *abstract_x); | |||
| } | |||
| @@ -426,7 +426,7 @@ TEST_F(TestPrim, test_environ_add) { | |||
| FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); | |||
| AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false); | |||
| AbstractBasePtrList args_spec_list = {abstract_x}; | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract(); | |||
| FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3); | |||
| @@ -434,19 +434,19 @@ TEST_F(TestPrim, test_environ_add) { | |||
| AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false); | |||
| args_spec_list = {abstract_environ, embed_x, abstract_y}; | |||
| AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | |||
| ASSERT_TRUE(*abstract_e1 == *exp); | |||
| AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false); | |||
| args_spec_list = {abstract_environ, embed_x, abstract_z}; | |||
| AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*abstract_e2 == *exp); | |||
| FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvironAdd, 2); | |||
| args_spec_list = {abstract_e1, abstract_e2}; | |||
| AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *exp); | |||
| } | |||
| @@ -459,7 +459,7 @@ TEST_F(TestPrim, test_relu) { | |||
| AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW | |||
| AbstractBasePtrList args_spec_list = {expected}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -472,7 +472,7 @@ TEST_F(TestPrim, test_relu2) { | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5}); | |||
| AbstractBasePtrList args_spec_list = {arr}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack())); | |||
| } | |||
| @@ -505,7 +505,7 @@ TEST_F(TestPrim, test_conv2d1) { | |||
| std::vector<int64_t> shape = {2, 64, 14, 14}; | |||
| expected->set_shape(std::make_shared<Shape>(shape)); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -523,7 +523,7 @@ TEST_F(TestPrim, test_conv2d) { | |||
| auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3}); | |||
| AbstractBasePtrList args_spec_list = {input, weight}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16}); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| @@ -539,7 +539,7 @@ TEST_F(TestPrim, test_conv2d_native) { | |||
| auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3}); | |||
| AbstractBasePtrList args_spec_list = {input, weight}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16}); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| @@ -555,7 +555,7 @@ TEST_F(TestPrim, test_biasAdd) { | |||
| auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32}); | |||
| AbstractBasePtrList args_spec_list = {value, bias}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto res = dyn_cast<AbstractTensor>(ret); | |||
| auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32}); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| @@ -571,7 +571,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) { | |||
| auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10}); | |||
| AbstractBasePtrList args_spec_list = {logits, labels}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_NE(ret, nullptr); | |||
| auto res = dyn_cast<AbstractTuple>(ret); | |||
| auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64}); | |||
| @@ -600,7 +600,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) { | |||
| auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10}); | |||
| AbstractBasePtrList args_spec_list = {logits, labels}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto res = dyn_cast<AbstractScalar>(ret); | |||
| AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64); | |||
| expected->set_type(UTPrimUtils::kF64); | |||
| @@ -627,7 +627,7 @@ TEST_F(TestPrim, test_pooling) { | |||
| inputs->set_shape(inputs_dims); | |||
| AbstractBasePtr abstract_input = FromValue(inputs, false); | |||
| AbstractBasePtrList args_spec_list = {abstract_input}; | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expected = abstract_input->Clone()->Broaden(); | |||
| std::vector<int64_t> expected_dims = {8, 64, 2, 2}; | |||
| @@ -652,7 +652,7 @@ TEST_F(TestPrim, test_hastype) { | |||
| auto prim = std::make_shared<Primitive>("hastype"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -666,7 +666,7 @@ TEST_F(TestPrim, test_array_len) { | |||
| auto prim = std::make_shared<Primitive>("array_len"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -680,7 +680,7 @@ TEST_F(TestPrim, test_list_len) { | |||
| auto prim = std::make_shared<Primitive>("list_len"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -694,7 +694,7 @@ TEST_F(TestPrim, test_tuple_len) { | |||
| auto prim = std::make_shared<Primitive>("tuple_len"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -708,7 +708,7 @@ TEST_F(TestPrim, test_tuple_reversed) { | |||
| auto prim = std::make_shared<Primitive>("tuple_reversed"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "expect=" << expected->ToString(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -730,7 +730,7 @@ TEST_F(TestPrim, test_list_getitem) { | |||
| auto prim = std::make_shared<Primitive>("list_getitem"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *elem); | |||
| } | |||
| @@ -749,7 +749,7 @@ TEST_F(TestPrim, test_list_setitem) { | |||
| auto prim = std::make_shared<Primitive>("list_setitem"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| AbstractBasePtrList elems_exp = {elem1, elem2}; | |||
| auto expected = std::make_shared<AbstractList>(elems_exp); | |||
| @@ -771,7 +771,7 @@ TEST_F(TestPrim, test_list_append) { | |||
| auto prim = std::make_shared<Primitive>("list_append"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2})); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -795,7 +795,7 @@ TEST_F(TestPrim, test_tuple_setitem) { | |||
| auto prim = std::make_shared<Primitive>("tuple_setitem"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| AbstractBasePtrList elems_exp = {elem1, elem2}; | |||
| auto expected = std::make_shared<AbstractTuple>(elems_exp); | |||
| @@ -821,7 +821,7 @@ TEST_F(TestPrim, test_make_list) { | |||
| auto prim = std::make_shared<Primitive>("make_list"); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| @@ -844,7 +844,7 @@ TEST_F(TestPrim, test_make_range) { | |||
| AbstractBasePtrList elem_list({ele1, ele2, ele3}); | |||
| AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "res=" << res->ToString(); | |||
| MS_LOG(INFO) << "expected=" << expected->ToString(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| @@ -887,7 +887,7 @@ TEST_F(TestPrim, test_layernorm) { | |||
| AbstractBasePtr expected1 = abstract_mean_var->Clone(); | |||
| AbstractBasePtr expected2 = abstract_mean_var->Clone(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected0: " << expected0->ToString(); | |||
| MS_LOG(INFO) << "expected1: " << expected1->ToString(); | |||
| @@ -933,7 +933,7 @@ TEST_F(TestPrim, test_DropoutGenMask) { | |||
| AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), | |||
| std::make_shared<Shape>(std::vector<int64_t>{79})); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "res=" << res->ToString(); | |||
| MS_LOG(INFO) << "expected=" << expected->ToString(); | |||
| ASSERT_TRUE(*res == *expected); | |||
| @@ -963,7 +963,7 @@ TEST_F(TestPrim, test_dropout) { | |||
| std::vector<int64_t> shape = {2, 20, 32, 32}; | |||
| expected->set_shape(std::make_shared<Shape>(shape)); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected: " << expected->ToString(); | |||
| @@ -984,7 +984,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) { | |||
| auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | |||
| auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | |||
| AbstractBasePtrList args_spec_list = {x_input, y_input}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto res = dyn_cast<AbstractTuple>(ret); | |||
| AbstractBasePtrList x_idx_list; | |||
| auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | |||
| @@ -1008,7 +1008,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) { | |||
| auto x_input = std::make_shared<AbstractTuple>(x_arg_list); | |||
| auto y_input = std::make_shared<AbstractTuple>(y_arg_list); | |||
| AbstractBasePtrList args_spec_list = {x_input, y_input}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| auto res = dyn_cast<AbstractTuple>(ret); | |||
| AbstractBasePtrList x_idx_list({abstract::FromValue(1)}); | |||
| auto r_x = std::make_shared<AbstractTuple>(x_idx_list); | |||
| @@ -1033,7 +1033,7 @@ TEST_F(TestPrim, test_DictGetItem) { | |||
| AbstractBasePtr key = abstract::FromValue("x"); | |||
| AbstractBasePtrList args_spec_list = {array_dict, key}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | |||
| AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second)); | |||
| @@ -1052,7 +1052,7 @@ TEST_F(TestPrim, test_DictGetItem2) { | |||
| AbstractBasePtr key = abstract::FromValue("x"); | |||
| AbstractBasePtrList args_spec_list = {array_dict, key}; | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); | |||
| AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x); | |||
| @@ -164,7 +164,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) { | |||
| auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| } | |||
| @@ -262,7 +262,7 @@ TEST_F(TestInferGraph, test_inferred) { | |||
| MS_LOG(INFO) << "" << graph_f_->get_return()->ToString(); | |||
| AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false); | |||
| args_spec_list.push_back(abstract_v1); | |||
| AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| // now this test case failed randomly, have to debug. | |||
| @@ -273,7 +273,7 @@ TEST_F(TestInferGraph, test_inferred) { | |||
| args_spec_list.clear(); | |||
| args_spec_list.push_back(abstract_v1); | |||
| args_spec_list.push_back(abstract_v2); | |||
| abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract(); | |||
| abs_base_got = engine_->Run(graph_alpha_, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| } | |||
| @@ -359,7 +359,7 @@ TEST_F(TestInferMetaGraph, test_inferred) { | |||
| AbstractBasePtr abstract_v2 = FromValue(v1, false); | |||
| args_spec_list.push_back(abstract_v1); | |||
| args_spec_list.push_back(abstract_v2); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).eval_result->abstract(); | |||
| ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); | |||
| } | |||
| @@ -391,7 +391,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { | |||
| auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract(); | |||
| AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).eval_result->abstract(); | |||
| ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack())); | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt64); | |||
| } | |||
| @@ -446,7 +446,7 @@ void TestGraphEval::TearDown() { | |||
| TEST_F(TestGraphInfer, test_graph_infer_defaults) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(50), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(1), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(9), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(48), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_kwarg) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(7), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(46), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||
| @@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { | |||
| TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) { | |||
| FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults"); | |||
| AbstractBasePtrList args_spec_list = {}; | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); | |||
| AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract(); | |||
| AbstractBasePtr expect = FromValue(MakeValue(57), false); | |||
| ASSERT_EQ(*res, *expect); | |||
| } | |||