You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_runtime_manager.cc 5.7 kB

4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/kernel_runtime_manager.h"
  17. #include "utils/log_adapter.h"
  18. #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
  19. #include "ps/ps_cache/ps_cache_manager.h"
  20. #endif
  21. #include "backend/session/pynative_task_manager.h"
  22. namespace mindspore {
  23. namespace device {
  24. void KernelRuntimeManager::ClearRuntimeResource() {
  25. // Just remove PyNative tasks before runtime resource release.
  26. session::PynativeTaskManager::GetInstance().Reset();
  27. #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
  28. if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
  29. ps::ps_cache_instance.SyncEmbeddingTable();
  30. }
  31. #endif
  32. std::lock_guard<std::mutex> guard(lock_);
  33. for (auto &iter : runtime_map_) {
  34. MS_LOG(INFO) << "Release device " << iter.first;
  35. MS_EXCEPTION_IF_NULL(iter.second);
  36. iter.second->ReleaseDeviceRes();
  37. }
  38. runtime_map_.clear();
  39. }
  40. void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) {
  41. std::lock_guard<std::mutex> guard(lock_);
  42. for (auto &iter : runtime_map_) {
  43. MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource";
  44. if (!iter.second) {
  45. MS_LOG(ERROR) << "Kernel runtime is nullptr";
  46. continue;
  47. }
  48. iter.second->ClearGraphRuntimeResource(graph_id);
  49. }
  50. }
  51. KernelRuntimeManager &KernelRuntimeManager::Instance() {
  52. static KernelRuntimeManager instance{};
  53. return instance;
  54. }
  55. void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) {
  56. if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
  57. (void)runtime_creators_.emplace(device_name, runtime_creator);
  58. }
  59. }
  60. std::string KernelRuntimeManager::GetDeviceKey(const std::string &device_name, uint32_t device_id) {
  61. std::string device_key = device_name + "_" + std::to_string(device_id);
  62. return device_key;
  63. }
  64. KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) {
  65. auto runtime_key = GetDeviceKey(device_name, device_id);
  66. auto runtime_iter = runtime_map_.find(runtime_key);
  67. if (runtime_iter != runtime_map_.end()) {
  68. return runtime_iter->second.get();
  69. } else if (!runtime_map_.empty()) {
  70. auto cur_runtime_key = runtime_map_.begin()->first;
  71. auto find_pos = cur_runtime_key.rfind('_');
  72. if (find_pos != std::string::npos) {
  73. if (cur_runtime_key.size() > find_pos + 1) {
  74. auto cur_device_id = cur_runtime_key.substr(find_pos + 1);
  75. MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
  76. << ", set device id: " << device_id << " failed";
  77. } else {
  78. MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: "
  79. << device_id << " failed";
  80. }
  81. }
  82. }
  83. return GetKernelRuntime(device_name, device_id);
  84. }
  85. KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) {
  86. std::string runtime_key = GetDeviceKey(device_name, device_id);
  87. std::lock_guard<std::mutex> guard(lock_);
  88. auto runtime_iter = runtime_map_.find(runtime_key);
  89. if (runtime_iter != runtime_map_.end()) {
  90. return runtime_iter->second.get();
  91. }
  92. std::shared_ptr<KernelRuntime> kernel_runtime;
  93. auto creator_iter = runtime_creators_.find(device_name);
  94. if (creator_iter != runtime_creators_.end()) {
  95. MS_EXCEPTION_IF_NULL(creator_iter->second);
  96. kernel_runtime = (creator_iter->second)();
  97. MS_EXCEPTION_IF_NULL(kernel_runtime);
  98. kernel_runtime->set_device_id(device_id);
  99. runtime_map_[runtime_key] = kernel_runtime;
  100. } else {
  101. MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id;
  102. }
  103. return kernel_runtime.get();
  104. }
  105. KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() {
  106. auto ms_context = MsContext::GetInstance();
  107. MS_EXCEPTION_IF_NULL(ms_context);
  108. uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  109. std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  110. return GetKernelRuntime(device_name, device_id);
  111. }
  112. void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
  113. session::PynativeTaskManager::GetInstance().Reset();
  114. std::string runtime_key = GetDeviceKey(device_name, device_id);
  115. std::lock_guard<std::mutex> guard(lock_);
  116. auto runtime_iter = runtime_map_.find(runtime_key);
  117. if (runtime_iter == runtime_map_.end()) {
  118. return;
  119. }
  120. auto runtime = runtime_iter->second.get();
  121. if (runtime == nullptr) {
  122. return;
  123. }
  124. #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
  125. if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
  126. ps::ps_cache_instance.SyncEmbeddingTable();
  127. }
  128. #endif
  129. runtime->ReleaseDeviceRes();
  130. runtime_map_.erase(runtime_iter);
  131. }
  132. } // namespace device
  133. } // namespace mindspore