Merge pull request !1580 from caifubi/fix-multi-graph-device-resource-bugtags/v0.5.0-beta
| @@ -64,6 +64,21 @@ void AscendKernelRuntime::ClearGraphModelMap() { | |||||
| } | } | ||||
| } | } | ||||
| void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { | |||||
| MS_LOG(INFO) << "clear graph:" << graph_id << " runtime resource"; | |||||
| auto iter = graph_model_map_.find(graph_id); | |||||
| if (iter == graph_model_map_.end()) { | |||||
| MS_LOG(WARNING) << "GraphId:" << graph_id << " not found"; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Ge UnloadModel " << iter->first; | |||||
| auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter->first); | |||||
| if (!ret) { | |||||
| MS_LOG(ERROR) << "UnloadModel failed"; | |||||
| } | |||||
| graph_model_map_.erase(iter); | |||||
| } | |||||
| bool AscendKernelRuntime::NeedDestroyHccl() { | bool AscendKernelRuntime::NeedDestroyHccl() { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -40,6 +40,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||||
| bool GenTask(const session::KernelGraph *graph) override; | bool GenTask(const session::KernelGraph *graph) override; | ||||
| bool RunTask(const session::KernelGraph *graph) override; | bool RunTask(const session::KernelGraph *graph) override; | ||||
| bool LoadTask(const session::KernelGraph *graph) override; | bool LoadTask(const session::KernelGraph *graph) override; | ||||
| void ClearGraphRuntimeResource(uint32_t graph_id) override; | |||||
| protected: | protected: | ||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | ||||
| @@ -680,6 +680,10 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { | |||||
| MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; | |||||
| } | |||||
| #ifdef ENABLE_DUMP_E2E | #ifdef ENABLE_DUMP_E2E | ||||
| bool KernelRuntime::SetDumpConf() { | bool KernelRuntime::SetDumpConf() { | ||||
| dump_conf_ptr_ = std::make_shared<Dump>(); | dump_conf_ptr_ = std::make_shared<Dump>(); | ||||
| @@ -54,6 +54,7 @@ class KernelRuntime { | |||||
| bool LaunchKernel(const session::KernelGraph *graph); | bool LaunchKernel(const session::KernelGraph *graph); | ||||
| virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); | virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); | ||||
| virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); | virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); | ||||
| virtual void ClearGraphRuntimeResource(uint32_t graph_id); | |||||
| #ifdef ENABLE_DUMP_E2E | #ifdef ENABLE_DUMP_E2E | ||||
| DumpConfPtr GetDumpConf(); | DumpConfPtr GetDumpConf(); | ||||
| @@ -29,6 +29,18 @@ void KernelRuntimeManager::ClearRuntimeResource() { | |||||
| runtime_map_.clear(); | runtime_map_.clear(); | ||||
| } | } | ||||
| void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) { | |||||
| std::lock_guard<std::mutex> guard(lock_); | |||||
| for (auto &iter : runtime_map_) { | |||||
| MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource"; | |||||
| if (!iter.second) { | |||||
| MS_LOG(ERROR) << "Kernel runtime is nullptr"; | |||||
| continue; | |||||
| } | |||||
| iter.second->ClearGraphRuntimeResource(graph_id); | |||||
| } | |||||
| } | |||||
| void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { | void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { | ||||
| if (runtime_creators_.find(device_name) == runtime_creators_.end()) { | if (runtime_creators_.find(device_name) == runtime_creators_.end()) { | ||||
| (void)runtime_creators_.emplace(device_name, runtime_creator); | (void)runtime_creators_.emplace(device_name, runtime_creator); | ||||
| @@ -38,6 +38,7 @@ class KernelRuntimeManager { | |||||
| KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); | KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); | ||||
| KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); | KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); | ||||
| void ClearRuntimeResource(); | void ClearRuntimeResource(); | ||||
| void ClearGraphResource(uint32_t graph_id); | |||||
| private: | private: | ||||
| KernelRuntimeManager() = default; | KernelRuntimeManager() = default; | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "device/kernel_info.h" | #include "device/kernel_info.h" | ||||
| #include "kernel/kernel_build_info.h" | #include "kernel/kernel_build_info.h" | ||||
| #include "device/kernel_runtime_manager.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -717,5 +718,7 @@ void KernelGraph::UpdateCallRealInput() { | |||||
| } | } | ||||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | ||||
| KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,7 +42,7 @@ class KernelGraph : public FuncGraph { | |||||
| executable_ = true; | executable_ = true; | ||||
| stream_distinction_label_ = kInvalidDistincLabel; | stream_distinction_label_ = kInvalidDistincLabel; | ||||
| } | } | ||||
| ~KernelGraph() override = default; | |||||
| ~KernelGraph() override; | |||||
| MS_DECLARE_PARENT(KernelGraph, FuncGraph); | MS_DECLARE_PARENT(KernelGraph, FuncGraph); | ||||