You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memory_scheduler.h 5.3 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
  18. #include <vector>
  19. #include <map>
  20. #include <set>
  21. #include <memory>
  22. #include <utility>
  23. #include "runtime/device/memory_offload_strategy.h"
  24. namespace mindspore {
  25. namespace device {
  26. class MemHandler {
  27. public:
  28. MemHandler() = default;
  29. virtual ~MemHandler() = default;
  30. virtual size_t GetAvailableMemSize() = 0;
  31. virtual void *MallocDevice(size_t mem_size) = 0;
  32. virtual void FreeDevice(void *ptr) = 0;
  33. virtual void *MallocHost(size_t mem_size) = 0;
  34. virtual void FreeHost(void *ptr) = 0;
  35. virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
  36. virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
  37. };
  38. class MemScheduler {
  39. public:
  40. MemScheduler() = default;
  41. ~MemScheduler() = default;
  42. bool need_record_event() const { return need_record_event_; }
  43. void set_need_record_event(bool flag) { need_record_event_ = flag; }
  44. bool optimized() const { return optimized_; }
  45. void Update();
  46. void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
  47. void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow);
  48. void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
  49. bool HasDeviceMem(const void *key) const { return mem_result_.find(key) != mem_result_.end(); }
  50. void UpdateHighPriorityMem(const void *key) {
  51. if (need_record_event_) {
  52. (void)high_priority_updated_step_[key].emplace_back(current_step_);
  53. }
  54. }
  55. void SetTotalStep(size_t step) {
  56. total_step_ = step;
  57. step_keys_.resize(total_step_);
  58. }
  59. void Reset() { current_step_ = 0; }
  60. bool PreCompute(void *stream);
  61. bool PostCompute(void *stream);
  62. bool Optimize();
  63. void Clear();
  64. void ClearAllocatedMem();
  65. void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
  66. void AddMemNeedInit(const void *key) { (void)high_priority_mem_need_init_.insert(key); }
  67. void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); }
  68. private:
  69. void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
  70. void OptMemUsage(float mem_used_factor = 1.0f);
  71. bool Mock();
  72. void AdjustFirstEventIndex();
  73. void *MallocDevice(size_t mem_size, void *stream);
  74. void SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream);
  75. size_t GetMemSize(const void *key);
  76. void *GetOrMallocHostPtr(const void *key, size_t mem_size);
  77. void GetHostPtr(const void *key, void **host_ptr, bool *from_init);
  78. bool PreComputeInit(const std::shared_ptr<MemEvent> &event, void *stream);
  79. bool PreComputeMalloc(const std::shared_ptr<MemEvent> &event, void *stream);
  80. bool PreComputeSwapIn(const std::shared_ptr<MemEvent> &event, void *stream);
  81. bool PreComputeGet(const std::shared_ptr<MemEvent> &event, void *stream);
  82. std::map<const void *, MemPriority> mem_priority_;
  83. std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
  84. std::set<const void *> manual_offload_keys_;
  85. std::vector<std::set<const void *>> step_keys_;
  86. std::map<const void *, void *> mem_result_;
  87. std::map<const void *, void *> init_host_ptr_;
  88. std::map<const void *, void *> swap_host_ptr_;
  89. std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
  90. std::set<const void *> high_priority_mem_need_init_;
  91. size_t total_step_{0};
  92. size_t current_step_{0};
  93. bool need_record_event_{true};
  94. bool optimized_{false};
  95. double compute_start_time_{0};
  96. std::vector<double> compute_time_;
  97. bool record_compute_time_{false};
  98. bool updated_{false};
  99. std::shared_ptr<MemHandler> mem_handler_{nullptr};
  100. std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
  101. };
  102. class MemSchedulerManager {
  103. public:
  104. MemSchedulerManager() = default;
  105. ~MemSchedulerManager() = default;
  106. std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) {
  107. auto scheduler = GetMemScheduler(uid);
  108. if (scheduler == nullptr) {
  109. scheduler = std::make_shared<MemScheduler>();
  110. graph_mem_scheduler_map_[uid] = scheduler;
  111. }
  112. return scheduler;
  113. }
  114. std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) {
  115. auto iter = graph_mem_scheduler_map_.find(uid);
  116. if (iter != graph_mem_scheduler_map_.end()) {
  117. return iter->second;
  118. }
  119. return nullptr;
  120. }
  121. private:
  122. std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_;
  123. };
  124. } // namespace device
  125. } // namespace mindspore
  126. #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_