You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memory_scheduler.cc 13 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/memory_scheduler.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #ifdef _MSC_VER
  20. #include <time.h>
  21. #else
  22. #include <sys/time.h>
  23. #endif
  24. #include "utils/log_adapter.h"
  25. #include "utils/convert_utils_base.h"
  26. namespace mindspore {
  27. namespace device {
  28. namespace {
  29. constexpr float kMaxMemReuseFactor = 1.0;
  30. constexpr float kMinMemReuseFactor = 0.5;
  31. constexpr float kRetryFactor = 0.1;
  32. constexpr size_t kMockTimes = 5;
  33. double GetCurrentTime() {
  34. #ifdef _MSC_VER
  35. return time(NULL) * 1.0e6;
  36. #else
  37. struct timeval tv;
  38. (void)gettimeofday(&tv, nullptr);
  39. return tv.tv_sec * 1.0e6 + tv.tv_usec;
  40. #endif
  41. }
  42. } // namespace
  43. void MemScheduler::Clear() {
  44. if (mem_handler_ == nullptr) {
  45. return;
  46. }
  47. for (auto &item : mem_result_) {
  48. mem_handler_->FreeDevice(item.second);
  49. }
  50. mem_result_.clear();
  51. }
  52. void MemScheduler::ClearAllocatedMem() {
  53. if (mem_handler_ == nullptr) {
  54. return;
  55. }
  56. for (auto &item : mem_result_) {
  57. const auto device_ptr = item.second;
  58. if (device_ptr != nullptr) {
  59. mem_handler_->FreeDevice(device_ptr);
  60. }
  61. }
  62. mem_result_.clear();
  63. for (const auto &item : swap_host_ptr_) {
  64. const auto host_ptr = item.second;
  65. if (host_ptr != nullptr) {
  66. mem_handler_->FreeHost(host_ptr);
  67. }
  68. }
  69. swap_host_ptr_.clear();
  70. }
  71. void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
  72. if (key == nullptr) {
  73. return;
  74. }
  75. auto event = std::make_shared<MemEvent>(event_type, current_step_);
  76. event->mem_size = mem_size;
  77. event->key = key;
  78. (void)mem_events_[key].emplace_back(event);
  79. if (step_keys_.size() < current_step_ + 1) {
  80. step_keys_.resize(current_step_ + 1);
  81. }
  82. if (event->type == kGet) {
  83. (void)step_keys_[current_step_].insert(event->key);
  84. }
  85. }
  86. void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
  87. if (need_record_event_) {
  88. mem_priority_[key] = priority;
  89. Record(key, kInit, mem_size);
  90. }
  91. init_host_ptr_[key] = host_ptr;
  92. }
  93. void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
  94. if (need_record_event_) {
  95. if (mem_priority_.find(key) == mem_priority_.end()) {
  96. mem_priority_[key] = priority;
  97. Record(key, kMalloc, mem_size);
  98. }
  99. Record(key, kGet, mem_size);
  100. return nullptr;
  101. }
  102. if (strategy_ == nullptr) {
  103. return nullptr;
  104. }
  105. auto iter = mem_result_.find(key);
  106. if (iter != mem_result_.end()) {
  107. auto ptr = iter->second;
  108. MS_EXCEPTION_IF_NULL(ptr);
  109. return ptr;
  110. }
  111. return nullptr;
  112. }
  113. bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *stream) {
  114. const auto &iter = mem_result_.find(event->key);
  115. const bool new_malloc = iter == mem_result_.end();
  116. void *device_ptr = nullptr;
  117. if (new_malloc) {
  118. device_ptr = MallocDevice(event->mem_size, stream);
  119. if (device_ptr == nullptr) {
  120. return false;
  121. }
  122. } else {
  123. device_ptr = iter->second;
  124. }
  125. if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) {
  126. MS_LOG(DEBUG) << "Init input data from host, key: " << event->key;
  127. auto host_ptr = init_host_ptr_[event->key];
  128. MS_EXCEPTION_IF_NULL(host_ptr);
  129. mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
  130. }
  131. mem_result_[event->key] = device_ptr;
  132. return true;
  133. }
  134. bool MemScheduler::PreComputeMalloc(const std::shared_ptr<MemEvent> &event, void *stream) {
  135. const auto &iter = mem_result_.find(event->key);
  136. const bool new_malloc = iter == mem_result_.end();
  137. void *device_ptr = nullptr;
  138. if (new_malloc) {
  139. device_ptr = MallocDevice(event->mem_size, stream);
  140. if (device_ptr == nullptr) {
  141. return false;
  142. }
  143. } else {
  144. device_ptr = iter->second;
  145. }
  146. mem_result_[event->key] = device_ptr;
  147. return true;
  148. }
  149. bool MemScheduler::PreComputeSwapIn(const std::shared_ptr<MemEvent> &event, void *stream) {
  150. bool from_init = true;
  151. void *host_ptr = nullptr;
  152. GetHostPtr(event->key, &host_ptr, &from_init);
  153. auto device_ptr = MallocDevice(event->mem_size, stream);
  154. if (device_ptr == nullptr) {
  155. return false;
  156. }
  157. MS_EXCEPTION_IF_NULL(host_ptr);
  158. mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
  159. mem_result_[event->key] = device_ptr;
  160. if (!from_init) {
  161. mem_handler_->FreeHost(host_ptr);
  162. (void)swap_host_ptr_.erase(event->key);
  163. }
  164. return true;
  165. }
  166. bool MemScheduler::PreComputeGet(const std::shared_ptr<MemEvent> &event, void *stream) {
  167. const auto key = event->key;
  168. const auto mem_size = event->mem_size;
  169. auto iter = mem_result_.find(key);
  170. if (iter != mem_result_.end()) {
  171. auto ptr = iter->second;
  172. MS_EXCEPTION_IF_NULL(ptr);
  173. return true;
  174. }
  175. if (!optimized_ || stream == nullptr) {
  176. return false;
  177. }
  178. void *host_ptr = nullptr;
  179. bool from_init = false;
  180. GetHostPtr(key, &host_ptr, &from_init);
  181. if (host_ptr == nullptr) {
  182. return false;
  183. }
  184. auto device_ptr = MallocDevice(mem_size, stream);
  185. mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream);
  186. if (!from_init) {
  187. (void)swap_host_ptr_.erase(host_ptr);
  188. mem_handler_->FreeHost(host_ptr);
  189. }
  190. mem_result_[key] = device_ptr;
  191. return true;
  192. }
  193. bool MemScheduler::PreCompute(void *stream) {
  194. if (strategy_ == nullptr) {
  195. return true;
  196. }
  197. MS_EXCEPTION_IF_NULL(mem_handler_);
  198. auto &events = strategy_->GetPreComputeEvents(current_step_);
  199. for (auto &event : events) {
  200. MS_EXCEPTION_IF_NULL(event);
  201. MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
  202. bool ret = true;
  203. if (event->type == kInit) {
  204. ret = PreComputeInit(event, stream);
  205. } else if (event->type == kMalloc) {
  206. ret = PreComputeMalloc(event, stream);
  207. } else if (event->type == kSwapIn) {
  208. ret = PreComputeSwapIn(event, stream);
  209. } else if (event->type == kGet) {
  210. ret = PreComputeGet(event, stream);
  211. }
  212. if (!ret) {
  213. return false;
  214. }
  215. }
  216. if (record_compute_time_ && !updated_) {
  217. compute_start_time_ = GetCurrentTime();
  218. }
  219. return true;
  220. }
  221. bool MemScheduler::PostCompute(void *stream) {
  222. if (strategy_ == nullptr) {
  223. ++current_step_;
  224. return true;
  225. }
  226. if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
  227. compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
  228. }
  229. auto &events = strategy_->GetPostComputeEvents(current_step_);
  230. for (auto &event : events) {
  231. MS_EXCEPTION_IF_NULL(event);
  232. MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
  233. if (event->type == kFree) {
  234. auto ptr = mem_result_[event->key];
  235. if (ptr == nullptr) {
  236. return false;
  237. }
  238. mem_handler_->FreeDevice(ptr);
  239. (void)mem_result_.erase(event->key);
  240. } else if (event->type == kSwapOut) {
  241. auto device_ptr = mem_result_[event->key];
  242. if (device_ptr == nullptr) {
  243. return false;
  244. }
  245. SwapOutAndFreeDevice(event->key, device_ptr, event->mem_size, stream);
  246. }
  247. }
  248. ++current_step_;
  249. return true;
  250. }
  251. void MemScheduler::OptMemUsage(float mem_used_factor) {
  252. MS_EXCEPTION_IF_NULL(mem_handler_);
  253. if (strategy_ == nullptr) {
  254. strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_,
  255. high_priority_updated_step_, total_step_);
  256. if (manual_offload_keys_.empty()) {
  257. compute_time_.resize(total_step_);
  258. } else {
  259. updated_ = true;
  260. }
  261. }
  262. auto available_mem_size = mem_handler_->GetAvailableMemSize();
  263. available_mem_size = FloatToSize(available_mem_size * mem_used_factor);
  264. strategy_->set_mem_size(available_mem_size);
  265. strategy_->Execute();
  266. }
  267. bool MemScheduler::Optimize() {
  268. AdjustFirstEventIndex();
  269. float mem_used_factor = kMaxMemReuseFactor;
  270. while (mem_used_factor >= kMinMemReuseFactor) {
  271. bool ret = true;
  272. OptMemUsage(mem_used_factor);
  273. for (size_t mock_time = 0; mock_time < kMockTimes; ++mock_time) {
  274. ret = Mock();
  275. if (!ret) {
  276. break;
  277. }
  278. }
  279. if (ret) {
  280. optimized_ = true;
  281. return true;
  282. }
  283. ClearAllocatedMem();
  284. mem_used_factor -= kRetryFactor;
  285. }
  286. return false;
  287. }
  288. bool MemScheduler::Mock() {
  289. current_step_ = 0;
  290. for (size_t step = 0; step < total_step_; ++step) {
  291. bool ret = PreCompute(nullptr);
  292. if (!ret) {
  293. return false;
  294. }
  295. auto &step_keys = step_keys_[step];
  296. for (auto &key : step_keys) {
  297. auto ptr = GetOrMalloc(key, 0);
  298. if (ptr == nullptr) {
  299. return false;
  300. }
  301. }
  302. ret = PostCompute(nullptr);
  303. if (!ret) {
  304. return false;
  305. }
  306. }
  307. return true;
  308. }
  309. void MemScheduler::AdjustFirstEventIndex() {
  310. for (const auto &item : mem_events_) {
  311. const auto &mem_events = item.second;
  312. if (mem_events.empty()) {
  313. continue;
  314. }
  315. auto &first_event = mem_events[0];
  316. MS_EXCEPTION_IF_NULL(first_event);
  317. const auto &priority_iter = mem_priority_.find(item.first);
  318. const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
  319. if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
  320. const auto &second_event = mem_events[1];
  321. MS_EXCEPTION_IF_NULL(second_event);
  322. first_event->index = second_event->index;
  323. }
  324. }
  325. }
  326. void *MemScheduler::MallocDevice(size_t mem_size, void *stream) {
  327. const auto &no_reuse_key = step_keys_[current_step_];
  328. auto device_ptr = mem_handler_->MallocDevice(mem_size);
  329. if (device_ptr != nullptr || !optimized_) {
  330. return device_ptr;
  331. }
  332. auto iter = mem_result_.begin();
  333. using KeySizePair = std::pair<const void *, size_t>;
  334. auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
  335. std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less);
  336. while (iter != mem_result_.end()) {
  337. const auto key = iter->first;
  338. if (no_reuse_key.count(key) != 0) {
  339. ++iter;
  340. continue;
  341. }
  342. const auto device_mem_size = GetMemSize(key);
  343. mem_can_swap.push({key, device_mem_size});
  344. if (device_mem_size >= mem_size) {
  345. SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream);
  346. device_ptr = mem_handler_->MallocDevice(mem_size);
  347. return device_ptr;
  348. }
  349. ++iter;
  350. }
  351. while (!mem_can_swap.empty()) {
  352. const auto &max_mem_in_device = mem_can_swap.top();
  353. mem_can_swap.pop();
  354. const auto key = max_mem_in_device.first;
  355. const auto swap_mem_size = max_mem_in_device.second;
  356. auto swap_device_ptr = mem_result_[key];
  357. MS_EXCEPTION_IF_NULL(swap_device_ptr);
  358. SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream);
  359. device_ptr = mem_handler_->MallocDevice(mem_size);
  360. if (device_ptr != nullptr) {
  361. return device_ptr;
  362. }
  363. }
  364. return nullptr;
  365. }
  366. void MemScheduler::SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream) {
  367. auto host_ptr = GetOrMallocHostPtr(key, mem_size);
  368. MS_EXCEPTION_IF_NULL(host_ptr);
  369. mem_handler_->SwapOut(device_ptr, host_ptr, mem_size, stream);
  370. mem_handler_->FreeDevice(device_ptr);
  371. (void)mem_result_.erase(key);
  372. }
  373. size_t MemScheduler::GetMemSize(const void *key) {
  374. const auto &iter = mem_events_.find(key);
  375. if (iter == mem_events_.end() || iter->second.empty()) {
  376. MS_LOG(EXCEPTION) << "Get mem size for device address key[" << key << "] failed.";
  377. }
  378. return iter->second[0]->mem_size;
  379. }
  380. void *MemScheduler::GetOrMallocHostPtr(const void *key, size_t mem_size) {
  381. void *host_ptr = nullptr;
  382. bool from_init = false;
  383. GetHostPtr(key, &host_ptr, &from_init);
  384. if (host_ptr != nullptr) {
  385. return host_ptr;
  386. }
  387. host_ptr = mem_handler_->MallocHost(mem_size);
  388. swap_host_ptr_[key] = host_ptr;
  389. return host_ptr;
  390. }
  391. void MemScheduler::GetHostPtr(const void *key, void **host_ptr, bool *from_init) {
  392. auto iter = init_host_ptr_.find(key);
  393. if (iter != init_host_ptr_.end()) {
  394. *host_ptr = iter->second;
  395. *from_init = true;
  396. return;
  397. }
  398. iter = swap_host_ptr_.find(key);
  399. if (iter != swap_host_ptr_.end()) {
  400. *host_ptr = iter->second;
  401. *from_init = false;
  402. }
  403. }
  404. void MemScheduler::Update() {
  405. if (!optimized_) {
  406. return;
  407. }
  408. if (strategy_ == nullptr || !strategy_->need_swap()) {
  409. return;
  410. }
  411. if (updated_) {
  412. return;
  413. }
  414. if (!record_compute_time_) {
  415. record_compute_time_ = true;
  416. return;
  417. }
  418. strategy_->SetComputeTime(compute_time_);
  419. strategy_->Execute();
  420. updated_ = true;
  421. }
  422. } // namespace device
  423. } // namespace mindspore