From: @zhangzhaoju Reviewed-by: @zh_qh,@ginfung,@zh_qh Signed-off-by: @zh_qh,@zh_qhpull/14940/MERGE
| @@ -415,9 +415,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||
| auto out_node = c_input->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(out_node); | |||
| out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward)); | |||
| // clear resource | |||
| fg->ClearAllManagerInfo(); | |||
| func_graph->ClearAllManagerInfo(); | |||
| } | |||
| bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { | |||
| @@ -468,10 +468,10 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python | |||
| FuncGraphPtr func_graph = nullptr; | |||
| ValuePtr value = nullptr; | |||
| bool is_cache = data_converter::GetObjectValue(obj_id, &value); | |||
| if (is_cache) { | |||
| if (value && value->isa<FuncGraph>()) { | |||
| MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; | |||
| func_graph = value->cast<FuncGraphPtr>(); | |||
| if (is_cache && value != nullptr && value->isa<FuncGraph>()) { | |||
| MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; | |||
| func_graph = value->cast<FuncGraphPtr>(); | |||
| if (!func_graph->dropped()) { | |||
| return func_graph; | |||
| } | |||
| } | |||
| @@ -197,10 +197,6 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &as | |||
| python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]); | |||
| MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << "."; | |||
| } | |||
| // Clear manager info after checking missing return | |||
| for (const auto &fg : mng->func_graphs()) { | |||
| fg->ClearAllManagerInfo(); | |||
| } | |||
| } | |||
| FuncGraphPtr Parser::ParseFuncGraph() { | |||
| @@ -1711,9 +1707,6 @@ void Parser::RemoveUnnecessaryPhis() { | |||
| new_parameters.resize(std::distance(new_parameters.begin(), it)); | |||
| func_graph->set_parameters(new_parameters); | |||
| } | |||
| for (const auto &fg : mng->func_graphs()) { | |||
| fg->ClearAllManagerInfo(); | |||
| } | |||
| } | |||
| // ParseAst class code | |||
| @@ -256,8 +256,6 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| } | |||
| this->debug_info_ = info; | |||
| } | |||
| // Clear all info from manager. | |||
| void ClearAllManagerInfo(); | |||
| // Get all nodes belonging to this func graph. | |||
| const AnfNodeSet &nodes(); | |||
| void CopyNodes(const FuncGraphPtr &source); | |||
| @@ -389,7 +387,17 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| int64_t stage() { return stage_; } | |||
| void set_stage(int64_t stage) { stage_ = stage; } | |||
| bool dropped() { return dropped_; } | |||
| void set_dropped(bool dropped) { dropped_ = dropped; } | |||
| private: | |||
| // Only used for func_graph manager to control resource free. | |||
| int attached_mng_cnt() { return attached_mng_cnt_; } | |||
| void IncAttachedMngCnt() { attached_mng_cnt_++; } | |||
| void DecAttachedMngCnt() { attached_mng_cnt_--; } | |||
| // Clear all info from manager. | |||
| void ClearAllManagerInfo(); | |||
| // Graph is manipulated by manager and others. | |||
| friend FuncGraphManager; | |||
| @@ -442,6 +450,7 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| // FuncGraphManager. In that special case, Manage() should be called to make the func graph | |||
| // managed. | |||
| std::weak_ptr<FuncGraphManager> manager_; | |||
| int attached_mng_cnt_ = 0; | |||
| GraphDebugInfoPtr debug_info_; | |||
| void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, | |||
| @@ -460,6 +469,10 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher, | |||
| abstract::AbstractBasePtrListEqual> | |||
| func_graph_cache_; | |||
| // If the graph was changed, it should be dropped in cache data_converter::object_map_ | |||
| // which used by ConvertToFuncGraph. | |||
| bool dropped_ = false; | |||
| }; | |||
| inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) { | |||
| @@ -211,6 +211,15 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { | |||
| // Clear the all information in manager | |||
| void FuncGraphManager::Clear() { | |||
| for (auto graph : func_graphs_) { | |||
| graph->DecAttachedMngCnt(); | |||
| if (graph->attached_mng_cnt() == 0) { | |||
| graph->ClearAllManagerInfo(); | |||
| } else if (graph->attached_mng_cnt() < 0) { | |||
| MS_LOG(EXCEPTION) << "graph:" << graph->ToString() << " attached cnt not right:" << graph->attached_mng_cnt(); | |||
| } | |||
| } | |||
| func_graphs_.clear(); | |||
| all_nodes_.clear(); | |||
| node_users_.clear(); | |||
| @@ -280,6 +289,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { | |||
| fg->set_manager(this_manager); | |||
| } | |||
| func_graphs_.add(fg); | |||
| fg->IncAttachedMngCnt(); | |||
| } | |||
| void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { | |||
| @@ -310,7 +320,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool | |||
| for (auto &fg : dropped) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| all_nodes_.difference_update(fg->parameters()); | |||
| (void)func_graphs_.erase(fg); | |||
| EraseOneGraph(fg.get()); | |||
| if (fg->manager().get() == this) { | |||
| fg->set_manager(nullptr); | |||
| } | |||
| @@ -485,7 +495,8 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||
| MoveAllNodes(source, target); | |||
| all_nodes_.difference_update(source->parameters()); | |||
| (void)func_graphs_.erase(source); | |||
| EraseOneGraph(source.get()); | |||
| source->set_dropped(true); | |||
| if (source->manager().get() == this) { | |||
| source->set_manager(nullptr); | |||
| } | |||
| @@ -514,7 +525,7 @@ void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||
| void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||
| auto fg = node->func_graph(); | |||
| if (input->isa<ValueNode>()) { | |||
| if (fg != nullptr && input->isa<ValueNode>()) { | |||
| fg->DropValueNode(input); | |||
| if (IsValueNode<FuncGraph>(input)) { | |||
| auto used = GetValueNode<FuncGraphPtr>(input); | |||
| @@ -540,8 +551,8 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||
| target->CopyFreeVariables(source); | |||
| target->CopyFuncGraphsUsed(source); | |||
| target->CopyJValueNodes(source); | |||
| signals_->InvalidateComputer(); | |||
| source->ClearAllManagerInfo(); | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| FuncGraphTransaction FuncGraphManager::Transact() { | |||
| @@ -636,6 +647,18 @@ void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) { | |||
| MaybeDropFuncGraphs(*drop_func_graphs); | |||
| } | |||
| void FuncGraphManager::EraseOneGraph(FuncGraph *fg) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| size_t erase_cnt = func_graphs_.erase(fg->shared_from_base<FuncGraph>()); | |||
| if (!erase_cnt) { | |||
| return; | |||
| } | |||
| fg->DecAttachedMngCnt(); | |||
| if (fg->attached_mng_cnt() == 0) { | |||
| fg->ClearAllManagerInfo(); | |||
| } | |||
| } | |||
| void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms) { | |||
| changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); | |||
| } | |||
| @@ -301,6 +301,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| if (is_manage_) { | |||
| RemoveRoots(); | |||
| } | |||
| Clear(); | |||
| } | |||
| void Reset(); | |||
| @@ -359,6 +360,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| std::shared_ptr<ParentComputer> func_graph_parent_; | |||
| private: | |||
| // Erase OneGraph From Manager | |||
| void EraseOneGraph(FuncGraph *fg); | |||
| void AddIntoManaged(const FuncGraphPtr &fg); | |||
| void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); | |||
| void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction); | |||