diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc index 67a487ce28..c8de6b349e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc @@ -31,6 +31,8 @@ DatasetIteratorKernel::DatasetIteratorKernel() : handle_(HandleMgr::INVALID_HAND DatasetIteratorKernel::~DatasetIteratorKernel() { GpuBufferMgr::GetInstance().Close(handle_); } +void DatasetIteratorKernel::ReleaseResource() { GpuBufferMgr::GetInstance().Close(handle_); } + const std::vector &DatasetIteratorKernel::GetInputSizeList() const { return input_size_list_; } const std::vector &DatasetIteratorKernel::GetOutputSizeList() const { return output_size_list_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h index 746aed3294..b20df721a6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h @@ -35,6 +35,7 @@ class DatasetIteratorKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override; bool Init(const CNodePtr &kernel_node) override; + void ReleaseResource() override; protected: void InitSizeLists() override; diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel.h b/mindspore/ccsrc/backend/kernel_compiler/kernel.h index 01f8e75f49..c41223220c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel.h @@ -119,6 +119,7 @@ class KernelMod { virtual bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) = 0; virtual std::vector GenParameters() { return {}; } + virtual void ReleaseResource() {} virtual ~KernelMod() = default; void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 93a39469d5..9cba4ba6de 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1125,6 +1125,9 @@ void KernelGraph::UpdateChildGraphOrder() { std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } -KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } +KernelGraph::~KernelGraph() { + device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_, *inputs_, graph_value_nodes_, + execution_order_); +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 7640d276fa..4734f482bc 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -115,7 +115,9 @@ void AscendKernelRuntime::ClearGraphModelMap() { } } -void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { +void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &, + const std::unordered_set &, + const std::vector &) { MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper"; if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) { MS_LOG(DEBUG) << "Unload dump info " << graph_id; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 995ee96c75..42b9dcda99 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "runtime/device/kernel_runtime.h" #include "runtime/context.h" #include "framework/ge_runtime/davinci_model.h" @@ -43,7 +44,9 @@ class AscendKernelRuntime : public KernelRuntime { bool GenTask(const session::KernelGraph *graph) override; bool RunTask(const session::KernelGraph *graph) override; bool LoadTask(const session::KernelGraph *graph) override; - void ClearGraphRuntimeResource(uint32_t graph_id) override; + void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &inputs, + const std::unordered_set &value_nodes, + const std::vector &execution_order) override; bool SyncStream() override; protected: diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 4406c3cb68..5602b86d1f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -397,6 +397,18 @@ void GPUKernelRuntime::ReleaseDeviceRes() { bin_map->RemoveKernelCache(); } +void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &, + const std::unordered_set &, + const std::vector &execution_order) { + MS_LOG(INFO) << "Clear graph:" << graph_id << " GPU runtime resource"; + // Release the kernel resource. + for (const auto &kernel : execution_order) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + kernel_mod->ReleaseResource(); + } +} + void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index 8f3cb9cb25..8ff8b773fb 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "runtime/device/kernel_runtime.h" #include "runtime/device/kernel_runtime_manager.h" #include "backend/optimizer/mem_reuse/mem_swap_manager.h" @@ -37,6 +38,9 @@ class GPUKernelRuntime : public KernelRuntime { ~GPUKernelRuntime() override = default; bool Init() override; void ReleaseDeviceRes() override; + void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &inputs, + const std::unordered_set &value_nodes, + const std::vector &execution_order) override; void AssignMemory(session::KernelGraph *graph) override; bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; #ifdef ENABLE_DUMP_E2E diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 9d57e64caa..4e5eaaae3b 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "utils/ms_utils.h" #include "common/trans.h" #include "utils/utils.h" @@ -841,7 +842,8 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) { return true; } -void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { +void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &, + const std::unordered_set &, const std::vector &) { MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index cc33c5646e..2b262e13a7 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -20,7 +20,7 @@ #include #include #include - +#include #include "runtime/device/device_address.h" #include "ir/tensor.h" #include "utils/convert_utils.h" @@ -69,7 +69,9 @@ class KernelRuntime { const AddressPtrList &kernel_workspaces) const; virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); - virtual void ClearGraphRuntimeResource(uint32_t graph_id); + virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &inputs, + const std::unordered_set &value_nodes, + const std::vector &execution_order); virtual bool SyncStream() = 0; #ifdef ENABLE_DUMP_E2E diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc index 626259f9ce..0c7c66e3c8 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -29,7 +29,9 @@ void KernelRuntimeManager::ClearRuntimeResource() { runtime_map_.clear(); } -void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) { +void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id, const std::vector &inputs, + const std::unordered_set &value_nodes, + const std::vector &execution_order) { std::lock_guard guard(lock_); for (auto &iter : runtime_map_) { MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource"; @@ -37,7 +39,7 @@ void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) { MS_LOG(ERROR) << "Kernel runtime is nullptr"; continue; } - iter.second->ClearGraphRuntimeResource(graph_id); + iter.second->ClearGraphRuntimeResource(graph_id, inputs, value_nodes, execution_order); } } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h index bf88f53087..26e0ff8804 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include "utils/ms_utils.h" #include "runtime/device/kernel_runtime.h" namespace mindspore { @@ -38,7 +40,9 @@ class KernelRuntimeManager { KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); void ClearRuntimeResource(); - void ClearGraphResource(uint32_t graph_id); + void ClearGraphResource(uint32_t graph_id, const std::vector &inputs, + const std::unordered_set &value_nodes, + const std::vector &execution_order); private: KernelRuntimeManager() = default;