From d2d4fdeb049062ab2cce2c5a13f74eff9e491804 Mon Sep 17 00:00:00 2001 From: caifubi Date: Fri, 12 Jun 2020 17:01:19 +0800 Subject: [PATCH] use context memeory variable in ascend memroy allocate --- .../device/ascend/ascend_memory_manager.cc | 51 ++++++++++++++----- .../device/ascend/ascend_memory_manager.h | 2 + mindspore/ccsrc/utils/context/ms_context.h | 4 ++ 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc index 4c7b897cac..42c611c3af 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#include #include "device/ascend/ascend_memory_manager.h" #include "device/ascend/ascend_memory_pool.h" #include "utils/context/ms_context.h" @@ -21,25 +21,52 @@ namespace mindspore { namespace device { namespace ascend { -const uint64_t kAscendDeviceMemGB = 26; -const uint64_t kAscendMemPoolGB = 4; -const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30); -const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30); +constexpr uint64_t kAscendDeviceMemGB = 26; +constexpr uint64_t kAscendMemPoolGB = 4; +constexpr uint64_t kMemSizeGB = 30; +constexpr uint64_t kMaxMemSizeGB = 30; +constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); +constexpr uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << kMemSizeGB); void AscendMemoryManager::MallocDeviceMemory() { - device_mem_size_ = kAscendDeviceMemSize; + auto context_mem = GetDeviceMemSizeFromContext(); + device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; static_mem_offset_ = device_mem_size_; auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM); if (ret != RT_ERROR_NONE) { MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]"; } - device_mem_pool_size_ = kAscendMemPoolSize; - ret = rtMalloc(reinterpret_cast(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; + + if (context_mem == 0) { + device_mem_pool_size_ = kAscendMemPoolSize; + ret = rtMalloc(reinterpret_cast(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; + } + AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_); + AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_); + } +} + +uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + auto variable_memory_max_size = context->variable_memory_max_size(); + if (variable_memory_max_size == "0") { + return 0; + } + MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size; + auto pos = variable_memory_max_size.find('*'); + if (pos == std::string::npos) { + MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size"; + } + auto gb_str = variable_memory_max_size.substr(0, pos); + auto gb_var = std::stoull(gb_str); + MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; + if (gb_var > kMaxMemSizeGB || gb_var == 0) { + MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; } - AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_); - AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_); + return gb_var << kMemSizeGB; } void AscendMemoryManager::FreeDeviceMemory() { diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/device/ascend/ascend_memory_manager.h index 90c8b2dfca..7fdd8f553e 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h +++ b/mindspore/ccsrc/device/ascend/ascend_memory_manager.h @@ -32,6 +32,8 @@ class AscendMemoryManager : public MemoryManager { private: uint8_t *device_mem_pool_base_{nullptr}; uint64_t device_mem_pool_size_{0}; + + uint64_t GetDeviceMemSizeFromContext(); }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index cfedefe3d5..0d5406fc79 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -140,6 +140,10 @@ class MsContext { variable_memory_max_size_ = variable_memory_max_size; } + const std::string &variable_memory_max_size() const { return variable_memory_max_size_; } + + const std::string &graph_memory_max_size() const { return graph_memory_max_size_; } + void set_enable_profiling(bool flag) { profiling_mode_ = flag; } bool enable_profiling() const { return profiling_mode_; }