Browse Source

!28313 Fix forward op release

Merge pull request !28313 from zjun/fix_forward_release
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
04f83c9782
5 changed files with 29 additions and 25 deletions
  1. +22
    -18
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +2
    -2
      mindspore/ccsrc/backend/session/session_basic.h
  3. +2
    -2
      mindspore/ccsrc/runtime/framework/graph_compiler.cc
  4. +2
    -2
      mindspore/ccsrc/runtime/framework/graph_compiler.h
  5. +1
    -1
      mindspore/ccsrc/vm/backend.h

+ 22
- 18
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1364,37 +1364,41 @@ void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithInde
}

void SessionBasic::GetForwardOpOutputRefCount(const KernelGraph *graph,
std::multiset<std::string> *forward_op_output_tensor_id) {
std::map<std::string, size_t> *forward_op_output_tensor_id) {
if (!pynative::PynativeExecutor::GetInstance()->grad_executor()->grad_is_running()) {
return;
}
const auto &graph_value_nodes = graph->graph_value_nodes();
std::vector<tensor::TensorPtr> tensor_value_list;
for (const auto &v : graph_value_nodes) {
const auto &value = GetValueNode(v);
MS_EXCEPTION_IF_NULL(value);
TensorValueToTensor(value, &tensor_value_list);
}
const auto &forward_op_output_id = pynative::PynativeExecutor::GetInstance()->grad_executor()->forward_op_output_id();
MS_LOG(DEBUG) << "Total forward op out put size " << forward_op_output_id.size();
for (const auto &t : tensor_value_list) {
if (forward_op_output_id.find(t->id()) != forward_op_output_id.end()) {
(*forward_op_output_tensor_id).emplace(t->id());
for (const auto &kernel : graph->execution_order()) {
const auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t i = 1; i <= input_tensor_num; ++i) {
const auto &input = kernel->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
if (real_input->isa<ValueNode>()) {
const auto &tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
MS_EXCEPTION_IF_NULL(tensor);
if (forward_op_output_id.find(tensor->id()) != forward_op_output_id.end()) {
(*forward_op_output_tensor_id)[tensor->id()] += 1;
}
}
}
}
MS_LOG(DEBUG) << "Total value nodes in graph size " << graph_value_nodes.size() << ", total tensor value node size "
<< tensor_value_list.size() << ", forward op output tensor size "
<< forward_op_output_tensor_id->size();
MS_LOG(DEBUG) << "Forward op output tensor in bprop graph size " << forward_op_output_tensor_id->size();
}

void SessionBasic::ReleaseForwardOpOutput(const std::vector<tensor::TensorPtr> &input_tensors,
std::multiset<std::string> *forward_op_output_tensor_id) {
std::map<std::string, size_t> *forward_op_output_tensor_id) {
MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
for (const auto &tensor : input_tensors) {
auto it = forward_op_output_tensor_id->find(tensor->id());
if (it != forward_op_output_tensor_id->end()) {
tensor->set_device_address(nullptr);
forward_op_output_tensor_id->erase(it);
if (--(it->second) == 0) {
tensor->set_device_address(nullptr);
forward_op_output_tensor_id->erase(it);
}
}
}
}
@@ -2425,7 +2429,7 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
graph_output_info.graph_outputs = outputs;
CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
std::map<KernelWithIndex, size_t> cnode_refcount;
std::multiset<std::string> forward_op_output_tensor_id;
std::map<std::string, size_t> forward_op_output_tensor_id;
GetRefCount(kernel_graph.get(), &cnode_refcount);
GetForwardOpOutputRefCount(kernel_graph.get(), &forward_op_output_tensor_id);
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);


+ 2
- 2
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -196,9 +196,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
VectorRef *const outputs,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count);
void GetForwardOpOutputRefCount(const KernelGraph *graph, std::multiset<std::string> *forward_op_output_tensor_id);
void GetForwardOpOutputRefCount(const KernelGraph *graph, std::map<std::string, size_t> *forward_op_output_tensor_id);
void ReleaseForwardOpOutput(const std::vector<tensor::TensorPtr> &input_tensors,
std::multiset<std::string> *forward_op_output_tensor_id);
std::map<std::string, size_t> *forward_op_output_tensor_id);
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map);



+ 2
- 2
mindspore/ccsrc/runtime/framework/graph_compiler.cc View File

@@ -638,7 +638,7 @@ void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<Kern
}

void GraphCompiler::CalculateForwardOpOutputCount(const KernelGraphPtr &graph,
std::multiset<std::string> *forward_op_output_tensor_id) const {
std::map<std::string, size_t> *forward_op_output_tensor_id) const {
MS_EXCEPTION_IF_NULL(session_);
forward_op_output_tensor_id->clear();
session_->GetForwardOpOutputRefCount(graph.get(), forward_op_output_tensor_id);
@@ -652,7 +652,7 @@ void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernel
}

void GraphCompiler::UpdateForwardOpOutputRefCount(const std::vector<tensor::TensorPtr> &input_tensor,
std::multiset<std::string> *forward_op_output_tensor_id) const {
std::map<std::string, size_t> *forward_op_output_tensor_id) const {
MS_EXCEPTION_IF_NULL(session_);
MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
session_->ReleaseForwardOpOutput(input_tensor, forward_op_output_tensor_id);


+ 2
- 2
mindspore/ccsrc/runtime/framework/graph_compiler.h View File

@@ -146,7 +146,7 @@ class GraphCompiler {

// Calculate forward op output ref count of PyNative back graph.
void CalculateForwardOpOutputCount(const KernelGraphPtr &graph,
std::multiset<std::string> *forward_op_output_tensor_id) const;
std::map<std::string, size_t> *forward_op_output_tensor_id) const;

// Update ref count of PyNative back propagation operators.
void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
@@ -155,7 +155,7 @@ class GraphCompiler {

// Update forward op output ref count of PyNative back graph.
void UpdateForwardOpOutputRefCount(const std::vector<tensor::TensorPtr> &input_tensor,
std::multiset<std::string> *forward_op_output_tensor_id) const;
std::map<std::string, size_t> *forward_op_output_tensor_id) const;

// Handle single op output tensor and recover output of original complete kernel graph.
void RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,


+ 1
- 1
mindspore/ccsrc/vm/backend.h View File

@@ -195,7 +195,7 @@ class MindRTBackend : public Backend {
std::map<GraphId, std::map<KernelWithIndex, size_t>> cnode_ref_counts_;

// Cache forward op output value node tensor ref count of kernels for back propagation graph in PyNative mode.
std::multiset<std::string> forward_op_output_tensor_id_;
std::map<std::string, size_t> forward_op_output_tensor_id_;

FuncGraph *root_graph_;
GraphPartitionPtr graph_partition_;


Loading…
Cancel
Save