Merge pull request !5913 from limingqi107/mastertags/v1.0.0
| @@ -469,6 +469,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf | |||||
| } | } | ||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| new_parameter->IncreaseUsedGraphCount(); | |||||
| graph_inputs->push_back(new_parameter); | graph_inputs->push_back(new_parameter); | ||||
| valid_inputs->push_back(true); | valid_inputs->push_back(true); | ||||
| return new_parameter; | return new_parameter; | ||||
| @@ -812,6 +813,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph | |||||
| } | } | ||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| new_parameter->IncreaseUsedGraphCount(); | |||||
| return new_parameter; | return new_parameter; | ||||
| } | } | ||||
| @@ -803,11 +803,18 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs, | |||||
| if (!input_node->isa<Parameter>()) { | if (!input_node->isa<Parameter>()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto parameter = input_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| parameter->DecreaseUsedGraphCount(); | |||||
| // Only the parameter has no graph used, then clear the output address. | |||||
| if (parameter->used_graph_count() != 0) { | |||||
| continue; | |||||
| } | |||||
| for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) { | for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) { | ||||
| if (!AnfAlgo::OutputAddrExist(input_node, index)) { | if (!AnfAlgo::OutputAddrExist(input_node, index)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); | |||||
| AnfAlgo::SetOutputAddr(nullptr, index, input_node.get()); | |||||
| } | } | ||||
| } | } | ||||
| // clear input value node output address. | // clear input value node output address. | ||||
| @@ -282,7 +282,7 @@ class ANode : public AnfNode { | |||||
| class Parameter : public ANode { | class Parameter : public ANode { | ||||
| public: | public: | ||||
| explicit Parameter(const FuncGraphPtr &func_graph) | explicit Parameter(const FuncGraphPtr &func_graph) | ||||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {} | |||||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {} | |||||
| ~Parameter() override = default; | ~Parameter() override = default; | ||||
| MS_DECLARE_PARENT(Parameter, ANode); | MS_DECLARE_PARENT(Parameter, ANode); | ||||
| @@ -300,6 +300,10 @@ class Parameter : public ANode { | |||||
| ValuePtr default_param() const { return default_param_; } | ValuePtr default_param() const { return default_param_; } | ||||
| ParamInfoPtr param_info() const; | ParamInfoPtr param_info() const; | ||||
| void IncreaseUsedGraphCount() { used_graph_count_++; } | |||||
| void DecreaseUsedGraphCount() { used_graph_count_--; } | |||||
| int used_graph_count() const { return used_graph_count_; } | |||||
| bool operator==(const AnfNode &other) const override { | bool operator==(const AnfNode &other) const override { | ||||
| if (!other.isa<Parameter>()) { | if (!other.isa<Parameter>()) { | ||||
| return false; | return false; | ||||
| @@ -315,6 +319,8 @@ class Parameter : public ANode { | |||||
| std::string name_; | std::string name_; | ||||
| bool has_default_; | bool has_default_; | ||||
| ValuePtr default_param_; | ValuePtr default_param_; | ||||
| // The count of graphs using the parameter. | |||||
| int used_graph_count_; | |||||
| }; | }; | ||||
| using ParameterPtr = std::shared_ptr<Parameter>; | using ParameterPtr = std::shared_ptr<Parameter>; | ||||