Browse Source

!5617 clear graph output address in graph destructor

Merge pull request !5617 from limingqi107/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
24f00cc6dc
4 changed files with 42 additions and 3 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/optimizer/gpu/remove_format_transform_pair.cc
  2. +4
    -2
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
  3. +34
    -0
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  4. +3
    -0
      mindspore/ccsrc/runtime/device/kernel_runtime.h

+ 1
- 1
mindspore/ccsrc/backend/optimizer/gpu/remove_format_transform_pair.cc View File

@@ -46,7 +46,7 @@ const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, c
MS_LOG(EXCEPTION) << "The pattern is not transpose pair, "
<< "node:" << AnfAlgo::GetCNodeName(node) << " node input:" << AnfAlgo::GetCNodeName(input_node);
}
// If transpose operator used by more than one other operators, it cant not be deleted directly.
// If transpose operator used by more than one other operators, it cant not be deleted directly.
if (IsUsedByOthers(graph, input_node)) {
MS_LOG(DEBUG) << "The transpose node [" << input_node->fullname_with_scope()
<< "] is used by more than one other operators.";


+ 4
- 2
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc View File

@@ -397,8 +397,8 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
bin_map->RemoveKernelCache();
}
void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &,
const std::unordered_set<ValueNodePtr> &,
void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) {
MS_LOG(INFO) << "Clear graph:" << graph_id << " GPU runtime resource";
// Release the kernel resource.
@@ -409,6 +409,8 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
}
kernel_mod->ReleaseResource();
}
// Clear the output address of graph.
ClearOutputAddress(inputs, value_nodes, execution_order);
}
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {


+ 34
- 0
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -854,6 +854,40 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vect
MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
}

void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) {
// clear input parameter output address.
for (const auto &input_node : inputs) {
MS_EXCEPTION_IF_NULL(input_node);
if (!input_node->isa<Parameter>()) {
continue;
}
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) {
if (!AnfAlgo::OutputAddrExist(input_node, index)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
}
}
// clear input value node output address.
for (const auto &value_node : value_nodes) {
if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
}
// clear cnode output address.
for (const auto &cnode : execution_order) {
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
if (!AnfAlgo::OutputAddrExist(cnode, index)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
}
}
}

bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr,
const AddressPtrList &kernel_inputs,
const AddressPtrList &kernel_outputs,


+ 3
- 0
mindspore/ccsrc/runtime/device/kernel_runtime.h View File

@@ -72,6 +72,9 @@ class KernelRuntime {
virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order);
virtual void ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order);
virtual bool SyncStream() = 0;

#ifdef ENABLE_DUMP_E2E


Loading…
Cancel
Save