/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ #include #include #include #include #include #include "runtime/device/memory_offload_strategy.h" namespace mindspore { namespace device { class MemHandler { public: virtual size_t GetAvailableMemSize() = 0; virtual void *MallocDevice(size_t mem_size) = 0; virtual void FreeDevice(void *ptr) = 0; virtual void *MallocHost(size_t mem_size) = 0; virtual void FreeHost(void *ptr) = 0; virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0; virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0; }; class MemScheduler { public: MemScheduler() = default; ~MemScheduler() = default; bool need_record_event() const { return need_record_event_; } void set_need_record_event(bool flag) { need_record_event_ = flag; } bool optimized() const { return optimized_; } void Update(); void SetMemHandler(const std::shared_ptr &handler) { mem_handler_ = handler; } void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow); void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow); void SetTotalStep(size_t step) { total_step_ = step; step_events_.resize(total_step_); } void ResetCurrentStep() { current_step_ = 0; } bool PreCompute(void *stream); bool PostCompute(void *stream); void Optimize(); void Clear(); void ClearTempMem(); void SetMemPriority(const void *key, MemPriority priority); private: void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0); void OptMemUsage(float mem_used_factor = 1.0f); void AdjustFirstEventIndex(); std::map mem_priority_; std::map>> mem_events_; std::vector>> step_events_; std::map mem_result_; std::map init_host_ptr_; std::map swap_host_ptr_; std::map high_priority_device_ptr_; size_t total_step_{0}; size_t current_step_{0}; bool need_record_event_{true}; bool optimized_{false}; float mem_used_factor_{0.9}; double compute_start_time_{0}; std::vector compute_time_; bool record_compute_time_{false}; bool updated_{false}; std::shared_ptr mem_handler_{nullptr}; std::shared_ptr strategy_{nullptr}; }; class MemSchedulerManager { public: MemSchedulerManager() = default; ~MemSchedulerManager() = default; std::shared_ptr GetOrCreateMemScheduler(uint64_t uid) { auto scheduler = GetMemScheduler(uid); if (scheduler == nullptr) { scheduler = std::make_shared(); graph_mem_scheduler_map_[uid] = scheduler; } return scheduler; } std::shared_ptr GetMemScheduler(uint64_t uid) { auto iter = graph_mem_scheduler_map_.find(uid); if (iter != graph_mem_scheduler_map_.end()) { return iter->second; } return nullptr; } private: std::map> graph_mem_scheduler_map_; }; } // namespace device } // namespace mindspore #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_