Browse Source

!5913 add count of graphs using the parameter

Merge pull request !5913 from limingqi107/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f480e48271
3 changed files with 17 additions and 2 deletions
  1. +2
    -0
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +8
    -1
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  3. +7
    -1
      mindspore/core/ir/anf.h

+ 2
- 0
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -469,6 +469,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(true);
return new_parameter;
@@ -812,6 +813,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();

return new_parameter;
}


+ 8
- 1
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -803,11 +803,18 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
if (!input_node->isa<Parameter>()) {
continue;
}
auto parameter = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
parameter->DecreaseUsedGraphCount();
// Only the parameter has no graph used, then clear the output address.
if (parameter->used_graph_count() != 0) {
continue;
}
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) {
if (!AnfAlgo::OutputAddrExist(input_node, index)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
}
}
// clear input value node output address.


+ 7
- 1
mindspore/core/ir/anf.h View File

@@ -282,7 +282,7 @@ class ANode : public AnfNode {
class Parameter : public ANode {
public:
explicit Parameter(const FuncGraphPtr &func_graph)
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {}
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {}
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);

@@ -300,6 +300,10 @@ class Parameter : public ANode {
ValuePtr default_param() const { return default_param_; }
ParamInfoPtr param_info() const;

void IncreaseUsedGraphCount() { used_graph_count_++; }
void DecreaseUsedGraphCount() { used_graph_count_--; }
int used_graph_count() const { return used_graph_count_; }

bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) {
return false;
@@ -315,6 +319,8 @@ class Parameter : public ANode {
std::string name_;
bool has_default_;
ValuePtr default_param_;
// The count of graphs using the parameter.
int used_graph_count_;
};
using ParameterPtr = std::shared_ptr<Parameter>;



Loading…
Cancel
Save