You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memory_scheduler.cc 9.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/memory_scheduler.h"
  17. #include <algorithm>
  18. #include "utils/log_adapter.h"
  19. #ifdef _MSC_VER
  20. #include <time.h>
  21. #else
  22. #include <sys/time.h>
  23. #endif
  24. namespace mindspore {
  25. namespace device {
  26. namespace {
  27. constexpr float kMaxMemReuseFactor = 1.0;
  28. constexpr float kMinMemReuseFactor = 0.5;
  29. constexpr float kRetryFactor = 0.1;
  30. double GetCurrentTime() {
  31. #ifdef _MSC_VER
  32. return time(NULL) * 1.0e6;
  33. #else
  34. struct timeval tv;
  35. (void)gettimeofday(&tv, nullptr);
  36. return tv.tv_sec * 1.0e6 + tv.tv_usec;
  37. #endif
  38. }
  39. } // namespace
  40. void MemScheduler::Clear() {
  41. if (mem_handler_ == nullptr) {
  42. return;
  43. }
  44. for (auto &item : high_priority_device_ptr_) {
  45. mem_handler_->FreeDevice(item.second);
  46. }
  47. high_priority_device_ptr_.clear();
  48. }
  49. void MemScheduler::ClearTempMem() {
  50. if (mem_handler_ == nullptr) {
  51. return;
  52. }
  53. for (auto &item : mem_result_) {
  54. const auto device_ptr = item.second;
  55. if (device_ptr == nullptr) {
  56. mem_handler_->FreeDevice(device_ptr);
  57. }
  58. }
  59. mem_result_.clear();
  60. high_priority_device_ptr_.clear();
  61. for (const auto &item : swap_host_ptr_) {
  62. const auto host_ptr = item.second;
  63. if (host_ptr != nullptr) {
  64. mem_handler_->FreeHost(host_ptr);
  65. }
  66. }
  67. swap_host_ptr_.clear();
  68. }
  69. void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
  70. void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
  71. if (key == nullptr) {
  72. return;
  73. }
  74. auto event = std::make_shared<MemEvent>(event_type, current_step_);
  75. event->mem_size = mem_size;
  76. event->key = key;
  77. mem_events_[key].emplace_back(event);
  78. if (step_events_.size() < current_step_ + 1) {
  79. step_events_.resize(current_step_ + 1);
  80. }
  81. step_events_[current_step_].emplace_back(event);
  82. }
  83. void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
  84. if (need_record_event_) {
  85. mem_priority_[key] = priority;
  86. Record(key, kInit, mem_size);
  87. }
  88. init_host_ptr_[key] = host_ptr;
  89. }
  90. void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
  91. if (need_record_event_) {
  92. if (mem_priority_.find(key) == mem_priority_.end()) {
  93. mem_priority_[key] = priority;
  94. Record(key, kMalloc, mem_size);
  95. }
  96. Record(key, kGet, mem_size);
  97. return nullptr;
  98. }
  99. if (strategy_ == nullptr) {
  100. return nullptr;
  101. }
  102. auto iter = mem_result_.find(key);
  103. if (iter != mem_result_.end()) {
  104. auto ptr = iter->second;
  105. MS_EXCEPTION_IF_NULL(ptr);
  106. return ptr;
  107. }
  108. return nullptr;
  109. }
  110. bool MemScheduler::PreCompute(void *stream) {
  111. if (strategy_ == nullptr) {
  112. return true;
  113. }
  114. MS_EXCEPTION_IF_NULL(mem_handler_);
  115. auto &events = strategy_->GetPreComputeEvents(current_step_);
  116. for (auto &event : events) {
  117. MS_EXCEPTION_IF_NULL(event);
  118. MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
  119. if (event->type == kInit || event->type == kMalloc) {
  120. auto priority = mem_priority_[event->key];
  121. auto iter = high_priority_device_ptr_.find(event->key);
  122. if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
  123. MS_EXCEPTION_IF_NULL(iter->second);
  124. mem_result_[event->key] = iter->second;
  125. continue;
  126. }
  127. auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
  128. if (device_ptr == nullptr) {
  129. return false;
  130. }
  131. if (priority != kMemPriorityLow) {
  132. high_priority_device_ptr_[event->key] = device_ptr;
  133. }
  134. if (event->type == kInit) {
  135. auto host_ptr = init_host_ptr_[event->key];
  136. MS_EXCEPTION_IF_NULL(host_ptr);
  137. mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
  138. }
  139. mem_result_[event->key] = device_ptr;
  140. } else if (event->type == kSwapIn) {
  141. bool from_init = true;
  142. auto host_ptr = init_host_ptr_[event->key];
  143. if (host_ptr == nullptr) {
  144. host_ptr = swap_host_ptr_[event->key];
  145. from_init = false;
  146. }
  147. auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
  148. if (device_ptr == nullptr) {
  149. return false;
  150. }
  151. MS_EXCEPTION_IF_NULL(host_ptr);
  152. mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
  153. mem_result_[event->key] = device_ptr;
  154. if (mem_priority_[event->key] == kMemPriorityHigh) {
  155. high_priority_device_ptr_[event->key] = device_ptr;
  156. }
  157. if (!from_init) {
  158. mem_handler_->FreeHost(host_ptr);
  159. (void)swap_host_ptr_.erase(event->key);
  160. }
  161. }
  162. }
  163. if (record_compute_time_ && !updated_) {
  164. compute_start_time_ = GetCurrentTime();
  165. }
  166. return true;
  167. }
  168. bool MemScheduler::PostCompute(void *stream) {
  169. if (strategy_ == nullptr) {
  170. ++current_step_;
  171. return true;
  172. }
  173. if (record_compute_time_ && !updated_) {
  174. compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
  175. }
  176. auto &events = strategy_->GetPostComputeEvents(current_step_);
  177. for (auto &event : events) {
  178. MS_EXCEPTION_IF_NULL(event);
  179. MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
  180. if (event->type == kFree) {
  181. auto ptr = mem_result_[event->key];
  182. if (ptr == nullptr) {
  183. return false;
  184. }
  185. mem_handler_->FreeDevice(ptr);
  186. (void)mem_result_.erase(event->key);
  187. } else if (event->type == kSwapOut) {
  188. auto device_ptr = mem_result_[event->key];
  189. if (device_ptr == nullptr) {
  190. return false;
  191. }
  192. auto host_ptr = init_host_ptr_[event->key];
  193. if (host_ptr == nullptr) {
  194. host_ptr = mem_handler_->MallocHost(event->mem_size);
  195. swap_host_ptr_[event->key] = host_ptr;
  196. }
  197. MS_EXCEPTION_IF_NULL(host_ptr);
  198. mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
  199. mem_handler_->FreeDevice(device_ptr);
  200. (void)mem_result_.erase(event->key);
  201. if (mem_priority_[event->key] == kMemPriorityHigh) {
  202. high_priority_device_ptr_.erase(event->key);
  203. }
  204. }
  205. }
  206. ++current_step_;
  207. return true;
  208. }
  209. void MemScheduler::OptMemUsage(float mem_used_factor) {
  210. mem_used_factor_ = mem_used_factor;
  211. MS_EXCEPTION_IF_NULL(mem_handler_);
  212. if (strategy_ == nullptr) {
  213. strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_);
  214. compute_time_.resize(total_step_);
  215. }
  216. auto available_mem_size = mem_handler_->GetAvailableMemSize();
  217. available_mem_size = available_mem_size * mem_used_factor_;
  218. strategy_->set_mem_size(available_mem_size);
  219. strategy_->Execute();
  220. }
  221. void MemScheduler::Optimize() {
  222. AdjustFirstEventIndex();
  223. float mem_used_factor = kMaxMemReuseFactor;
  224. while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
  225. OptMemUsage(mem_used_factor);
  226. current_step_ = 0;
  227. bool ret = true;
  228. for (size_t step = 0; step < total_step_; ++step) {
  229. ret = PreCompute(nullptr);
  230. auto &step_events = step_events_[step];
  231. for (auto &event : step_events) {
  232. if (event->type != kGet) {
  233. continue;
  234. }
  235. auto ptr = GetOrMalloc(event->key, event->mem_size);
  236. if (ptr == nullptr) {
  237. ret = false;
  238. break;
  239. }
  240. }
  241. if (!ret) {
  242. break;
  243. }
  244. PostCompute(nullptr);
  245. }
  246. if (ret) {
  247. optimized_ = true;
  248. } else {
  249. ClearTempMem();
  250. mem_used_factor -= kRetryFactor;
  251. }
  252. }
  253. }
  254. void MemScheduler::AdjustFirstEventIndex() {
  255. for (const auto &item : mem_events_) {
  256. const auto &mem_events = item.second;
  257. if (mem_events.empty()) {
  258. continue;
  259. }
  260. auto &first_event = mem_events[0];
  261. MS_EXCEPTION_IF_NULL(first_event);
  262. const auto &priority_iter = mem_priority_.find(item.first);
  263. const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
  264. if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
  265. const auto &second_event = mem_events[1];
  266. MS_EXCEPTION_IF_NULL(second_event);
  267. first_event->index = second_event->index;
  268. }
  269. }
  270. }
  271. void MemScheduler::Update() {
  272. if (!optimized_) {
  273. return;
  274. }
  275. if (strategy_ == nullptr || !strategy_->need_swap()) {
  276. return;
  277. }
  278. if (updated_) {
  279. return;
  280. }
  281. if (!record_compute_time_) {
  282. record_compute_time_ = true;
  283. return;
  284. }
  285. strategy_->SetComputeTime(compute_time_);
  286. strategy_->Execute();
  287. updated_ = true;
  288. }
  289. } // namespace device
  290. } // namespace mindspore