diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc index 154e46abe2..6025af40bc 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc @@ -343,6 +343,12 @@ bool MemSwapManager::RetreatSwapInfo() { if (!trigger_swap_) { trigger_swap_ = true; } + if (retreat_count_ > kRetreatCountMax) { + MS_LOG(ERROR) << "RetreatSwapInfo exceed upper bound of count"; + return false; + } + retreat_count_++; + if (swap_info_already_set_) { ResetSwapInfo(); RetreatSwapThreshold(); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h index b50000779e..e274746fd8 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h @@ -36,7 +36,8 @@ class MemSwapManager { tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1), - distance_decay_step_(1) { + distance_decay_step_(1), + retreat_count_(0) { mem_copy_manager_ = mem_copy_manager; } @@ -156,6 +157,7 @@ class MemSwapManager { size_t tensor_size_num_; size_t distance_threshold_; size_t distance_decay_step_; + size_t retreat_count_; MemCopyManagerPtr mem_copy_manager_{nullptr}; const mindspore::session::KernelGraph *kernel_graph_{nullptr}; @@ -165,6 +167,8 @@ class MemSwapManager { static constexpr size_t kDistanceInitFactor = 3; static constexpr size_t kDistanceLowerBound = 3; + // The upper bound of count for searching memory swap scheme recurrently. + static constexpr size_t kRetreatCountMax = 50; }; using MemSwapManagerPtr = std::shared_ptr; } // namespace memswap