|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455 |
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "runtime/device/memory_scheduler.h"
- #include <algorithm>
- #include <queue>
- #ifdef _MSC_VER
- #include <time.h>
- #else
- #include <sys/time.h>
- #endif
- #include "utils/log_adapter.h"
- #include "utils/convert_utils_base.h"
-
- namespace mindspore {
- namespace device {
- namespace {
- constexpr float kMaxMemReuseFactor = 1.0;
- constexpr float kMinMemReuseFactor = 0.5;
- constexpr float kRetryFactor = 0.1;
- constexpr size_t kMockTimes = 5;
-
- double GetCurrentTime() {
- #ifdef _MSC_VER
- return time(NULL) * 1.0e6;
- #else
- struct timeval tv;
- (void)gettimeofday(&tv, nullptr);
- return tv.tv_sec * 1.0e6 + tv.tv_usec;
- #endif
- }
- } // namespace
-
- void MemScheduler::Clear() {
- if (mem_handler_ == nullptr) {
- return;
- }
- for (auto &item : mem_result_) {
- mem_handler_->FreeDevice(item.second);
- }
- mem_result_.clear();
- }
-
- void MemScheduler::ClearAllocatedMem() {
- if (mem_handler_ == nullptr) {
- return;
- }
- for (auto &item : mem_result_) {
- const auto device_ptr = item.second;
- if (device_ptr != nullptr) {
- mem_handler_->FreeDevice(device_ptr);
- }
- }
- mem_result_.clear();
- for (const auto &item : swap_host_ptr_) {
- const auto host_ptr = item.second;
- if (host_ptr != nullptr) {
- mem_handler_->FreeHost(host_ptr);
- }
- }
- swap_host_ptr_.clear();
- }
-
- void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
- if (key == nullptr) {
- return;
- }
- auto event = std::make_shared<MemEvent>(event_type, current_step_);
- event->mem_size = mem_size;
- event->key = key;
- (void)mem_events_[key].emplace_back(event);
- if (step_keys_.size() < current_step_ + 1) {
- step_keys_.resize(current_step_ + 1);
- }
- if (event->type == kGet) {
- (void)step_keys_[current_step_].insert(event->key);
- }
- }
-
- void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
- if (need_record_event_) {
- mem_priority_[key] = priority;
- Record(key, kInit, mem_size);
- }
- init_host_ptr_[key] = host_ptr;
- }
-
- void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
- if (need_record_event_) {
- if (mem_priority_.find(key) == mem_priority_.end()) {
- mem_priority_[key] = priority;
- Record(key, kMalloc, mem_size);
- }
- Record(key, kGet, mem_size);
- return nullptr;
- }
- if (strategy_ == nullptr) {
- return nullptr;
- }
- auto iter = mem_result_.find(key);
- if (iter != mem_result_.end()) {
- auto ptr = iter->second;
- MS_EXCEPTION_IF_NULL(ptr);
- return ptr;
- }
- return nullptr;
- }
-
- bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *stream) {
- const auto &iter = mem_result_.find(event->key);
- const bool new_malloc = iter == mem_result_.end();
- void *device_ptr = nullptr;
- if (new_malloc) {
- device_ptr = MallocDevice(event->mem_size, stream);
- if (device_ptr == nullptr) {
- return false;
- }
- } else {
- device_ptr = iter->second;
- }
- if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) {
- MS_LOG(DEBUG) << "Init input data from host, key: " << event->key;
- auto host_ptr = init_host_ptr_[event->key];
- MS_EXCEPTION_IF_NULL(host_ptr);
- mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
- }
- mem_result_[event->key] = device_ptr;
- return true;
- }
-
- bool MemScheduler::PreComputeMalloc(const std::shared_ptr<MemEvent> &event, void *stream) {
- const auto &iter = mem_result_.find(event->key);
- const bool new_malloc = iter == mem_result_.end();
- void *device_ptr = nullptr;
- if (new_malloc) {
- device_ptr = MallocDevice(event->mem_size, stream);
- if (device_ptr == nullptr) {
- return false;
- }
- } else {
- device_ptr = iter->second;
- }
- mem_result_[event->key] = device_ptr;
- return true;
- }
-
- bool MemScheduler::PreComputeSwapIn(const std::shared_ptr<MemEvent> &event, void *stream) {
- bool from_init = true;
- void *host_ptr = nullptr;
- GetHostPtr(event->key, &host_ptr, &from_init);
- auto device_ptr = MallocDevice(event->mem_size, stream);
- if (device_ptr == nullptr) {
- return false;
- }
- MS_EXCEPTION_IF_NULL(host_ptr);
- mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
- mem_result_[event->key] = device_ptr;
- if (!from_init) {
- mem_handler_->FreeHost(host_ptr);
- (void)swap_host_ptr_.erase(event->key);
- }
- return true;
- }
-
- bool MemScheduler::PreComputeGet(const std::shared_ptr<MemEvent> &event, void *stream) {
- const auto key = event->key;
- const auto mem_size = event->mem_size;
- auto iter = mem_result_.find(key);
- if (iter != mem_result_.end()) {
- auto ptr = iter->second;
- MS_EXCEPTION_IF_NULL(ptr);
- return true;
- }
- if (!optimized_ || stream == nullptr) {
- return false;
- }
- void *host_ptr = nullptr;
- bool from_init = false;
- GetHostPtr(key, &host_ptr, &from_init);
- if (host_ptr == nullptr) {
- return false;
- }
- auto device_ptr = MallocDevice(mem_size, stream);
- mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream);
- if (!from_init) {
- (void)swap_host_ptr_.erase(host_ptr);
- mem_handler_->FreeHost(host_ptr);
- }
- mem_result_[key] = device_ptr;
- return true;
- }
-
- bool MemScheduler::PreCompute(void *stream) {
- if (strategy_ == nullptr) {
- return true;
- }
- MS_EXCEPTION_IF_NULL(mem_handler_);
- auto &events = strategy_->GetPreComputeEvents(current_step_);
- for (auto &event : events) {
- MS_EXCEPTION_IF_NULL(event);
- MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
- bool ret = true;
- if (event->type == kInit) {
- ret = PreComputeInit(event, stream);
- } else if (event->type == kMalloc) {
- ret = PreComputeMalloc(event, stream);
- } else if (event->type == kSwapIn) {
- ret = PreComputeSwapIn(event, stream);
- } else if (event->type == kGet) {
- ret = PreComputeGet(event, stream);
- }
- if (!ret) {
- return false;
- }
- }
- if (record_compute_time_ && !updated_) {
- compute_start_time_ = GetCurrentTime();
- }
- return true;
- }
-
- bool MemScheduler::PostCompute(void *stream) {
- if (strategy_ == nullptr) {
- ++current_step_;
- return true;
- }
-
- if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
- compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
- }
-
- auto &events = strategy_->GetPostComputeEvents(current_step_);
- for (auto &event : events) {
- MS_EXCEPTION_IF_NULL(event);
- MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
- if (event->type == kFree) {
- auto ptr = mem_result_[event->key];
- if (ptr == nullptr) {
- return false;
- }
- mem_handler_->FreeDevice(ptr);
- (void)mem_result_.erase(event->key);
- } else if (event->type == kSwapOut) {
- auto device_ptr = mem_result_[event->key];
- if (device_ptr == nullptr) {
- return false;
- }
- SwapOutAndFreeDevice(event->key, device_ptr, event->mem_size, stream);
- }
- }
- ++current_step_;
- return true;
- }
-
- void MemScheduler::OptMemUsage(float mem_used_factor) {
- MS_EXCEPTION_IF_NULL(mem_handler_);
-
- if (strategy_ == nullptr) {
- strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_,
- high_priority_updated_step_, total_step_);
- if (manual_offload_keys_.empty()) {
- compute_time_.resize(total_step_);
- } else {
- updated_ = true;
- }
- }
-
- auto available_mem_size = mem_handler_->GetAvailableMemSize();
- available_mem_size = FloatToSize(available_mem_size * mem_used_factor);
- strategy_->set_mem_size(available_mem_size);
- strategy_->Execute();
- }
-
- bool MemScheduler::Optimize() {
- AdjustFirstEventIndex();
- float mem_used_factor = kMaxMemReuseFactor;
- while (mem_used_factor >= kMinMemReuseFactor) {
- bool ret = true;
- OptMemUsage(mem_used_factor);
- for (size_t mock_time = 0; mock_time < kMockTimes; ++mock_time) {
- ret = Mock();
- if (!ret) {
- break;
- }
- }
- if (ret) {
- optimized_ = true;
- return true;
- }
- ClearAllocatedMem();
- mem_used_factor -= kRetryFactor;
- }
- return false;
- }
-
- bool MemScheduler::Mock() {
- current_step_ = 0;
- for (size_t step = 0; step < total_step_; ++step) {
- bool ret = PreCompute(nullptr);
- if (!ret) {
- return false;
- }
- auto &step_keys = step_keys_[step];
- for (auto &key : step_keys) {
- auto ptr = GetOrMalloc(key, 0);
- if (ptr == nullptr) {
- return false;
- }
- }
- ret = PostCompute(nullptr);
- if (!ret) {
- return false;
- }
- }
- return true;
- }
-
- void MemScheduler::AdjustFirstEventIndex() {
- for (const auto &item : mem_events_) {
- const auto &mem_events = item.second;
- if (mem_events.empty()) {
- continue;
- }
- auto &first_event = mem_events[0];
- MS_EXCEPTION_IF_NULL(first_event);
- const auto &priority_iter = mem_priority_.find(item.first);
- const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
- if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
- const auto &second_event = mem_events[1];
- MS_EXCEPTION_IF_NULL(second_event);
- first_event->index = second_event->index;
- }
- }
- }
-
- void *MemScheduler::MallocDevice(size_t mem_size, void *stream) {
- const auto &no_reuse_key = step_keys_[current_step_];
- auto device_ptr = mem_handler_->MallocDevice(mem_size);
- if (device_ptr != nullptr || !optimized_) {
- return device_ptr;
- }
- auto iter = mem_result_.begin();
- using KeySizePair = std::pair<const void *, size_t>;
- auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
- std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less);
- while (iter != mem_result_.end()) {
- const auto key = iter->first;
- if (no_reuse_key.count(key) != 0) {
- ++iter;
- continue;
- }
- const auto device_mem_size = GetMemSize(key);
- mem_can_swap.push({key, device_mem_size});
- if (device_mem_size >= mem_size) {
- SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream);
- device_ptr = mem_handler_->MallocDevice(mem_size);
- return device_ptr;
- }
- ++iter;
- }
- while (!mem_can_swap.empty()) {
- const auto &max_mem_in_device = mem_can_swap.top();
- mem_can_swap.pop();
- const auto key = max_mem_in_device.first;
- const auto swap_mem_size = max_mem_in_device.second;
- auto swap_device_ptr = mem_result_[key];
- MS_EXCEPTION_IF_NULL(swap_device_ptr);
- SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream);
- device_ptr = mem_handler_->MallocDevice(mem_size);
- if (device_ptr != nullptr) {
- return device_ptr;
- }
- }
- return nullptr;
- }
-
- void MemScheduler::SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream) {
- auto host_ptr = GetOrMallocHostPtr(key, mem_size);
- MS_EXCEPTION_IF_NULL(host_ptr);
- mem_handler_->SwapOut(device_ptr, host_ptr, mem_size, stream);
- mem_handler_->FreeDevice(device_ptr);
- (void)mem_result_.erase(key);
- }
-
- size_t MemScheduler::GetMemSize(const void *key) {
- const auto &iter = mem_events_.find(key);
- if (iter == mem_events_.end() || iter->second.empty()) {
- MS_LOG(EXCEPTION) << "Get mem size for device address key[" << key << "] failed.";
- }
- return iter->second[0]->mem_size;
- }
-
- void *MemScheduler::GetOrMallocHostPtr(const void *key, size_t mem_size) {
- void *host_ptr = nullptr;
- bool from_init = false;
- GetHostPtr(key, &host_ptr, &from_init);
- if (host_ptr != nullptr) {
- return host_ptr;
- }
- host_ptr = mem_handler_->MallocHost(mem_size);
- swap_host_ptr_[key] = host_ptr;
- return host_ptr;
- }
-
- void MemScheduler::GetHostPtr(const void *key, void **host_ptr, bool *from_init) {
- auto iter = init_host_ptr_.find(key);
- if (iter != init_host_ptr_.end()) {
- *host_ptr = iter->second;
- *from_init = true;
- return;
- }
- iter = swap_host_ptr_.find(key);
- if (iter != swap_host_ptr_.end()) {
- *host_ptr = iter->second;
- *from_init = false;
- }
- }
-
- void MemScheduler::Update() {
- if (!optimized_) {
- return;
- }
-
- if (strategy_ == nullptr || !strategy_->need_swap()) {
- return;
- }
-
- if (updated_) {
- return;
- }
-
- if (!record_compute_time_) {
- record_compute_time_ = true;
- return;
- }
-
- strategy_->SetComputeTime(compute_time_);
- strategy_->Execute();
- updated_ = true;
- }
- } // namespace device
- } // namespace mindspore
|