You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memory_scheduler.h 4.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
  18. #include <vector>
  19. #include <map>
  20. #include <set>
  21. #include <memory>
  22. #include <utility>
  23. #include "runtime/device/memory_offload_strategy.h"
  24. namespace mindspore {
  25. namespace device {
  26. class MemHandler {
  27. public:
  28. virtual size_t GetAvailableMemSize() = 0;
  29. virtual void *MallocDevice(size_t mem_size) = 0;
  30. virtual void FreeDevice(void *ptr) = 0;
  31. virtual void *MallocHost(size_t mem_size) = 0;
  32. virtual void FreeHost(void *ptr) = 0;
  33. virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
  34. virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
  35. };
  36. class MemScheduler {
  37. public:
  38. MemScheduler() = default;
  39. ~MemScheduler() = default;
  40. bool need_record_event() const { return need_record_event_; }
  41. void set_need_record_event(bool flag) { need_record_event_ = flag; }
  42. bool optimized() const { return optimized_; }
  43. void Update();
  44. void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
  45. void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow);
  46. void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
  47. void SetTotalStep(size_t step) {
  48. total_step_ = step;
  49. step_events_.resize(total_step_);
  50. }
  51. void ResetCurrentStep() { current_step_ = 0; }
  52. bool PreCompute(void *stream);
  53. bool PostCompute(void *stream);
  54. void Optimize();
  55. void Clear();
  56. void ClearTempMem();
  57. void SetMemPriority(const void *key, MemPriority priority);
  58. private:
  59. void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
  60. void OptMemUsage(float mem_used_factor = 1.0f);
  61. void AdjustFirstEventIndex();
  62. std::map<const void *, MemPriority> mem_priority_;
  63. std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
  64. std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
  65. std::map<const void *, void *> mem_result_;
  66. std::map<const void *, void *> init_host_ptr_;
  67. std::map<const void *, void *> swap_host_ptr_;
  68. std::map<const void *, void *> high_priority_device_ptr_;
  69. size_t total_step_{0};
  70. size_t current_step_{0};
  71. bool need_record_event_{true};
  72. bool optimized_{false};
  73. float mem_used_factor_{0.9};
  74. double compute_start_time_{0};
  75. std::vector<double> compute_time_;
  76. bool record_compute_time_{false};
  77. bool updated_{false};
  78. std::shared_ptr<MemHandler> mem_handler_{nullptr};
  79. std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
  80. };
  81. class MemSchedulerManager {
  82. public:
  83. MemSchedulerManager() = default;
  84. ~MemSchedulerManager() = default;
  85. std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) {
  86. auto scheduler = GetMemScheduler(uid);
  87. if (scheduler == nullptr) {
  88. scheduler = std::make_shared<MemScheduler>();
  89. graph_mem_scheduler_map_[uid] = scheduler;
  90. }
  91. return scheduler;
  92. }
  93. std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) {
  94. auto iter = graph_mem_scheduler_map_.find(uid);
  95. if (iter != graph_mem_scheduler_map_.end()) {
  96. return iter->second;
  97. }
  98. return nullptr;
  99. }
  100. private:
  101. std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_;
  102. };
  103. } // namespace device
  104. } // namespace mindspore
  105. #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_