| @@ -21,9 +21,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| constexpr uint64_t kAscendDeviceMemGB = 30; | |||||
| constexpr uint64_t kAscendInitDeviceMemGB = 30; | |||||
| constexpr uint64_t kAscendMaxDeviceMemGB = 31; | |||||
| constexpr uint64_t kMemSizeGB = 30; | constexpr uint64_t kMemSizeGB = 30; | ||||
| constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); | |||||
| constexpr uint64_t kAscendDeviceMemSize = (kAscendInitDeviceMemGB << kMemSizeGB); | |||||
| void AscendMemoryManager::MallocDeviceMemory() { | void AscendMemoryManager::MallocDeviceMemory() { | ||||
| auto context_mem = GetDeviceMemSizeFromContext(); | auto context_mem = GetDeviceMemSizeFromContext(); | ||||
| @@ -58,8 +59,8 @@ uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { | |||||
| auto gb_str = variable_memory_max_size.substr(0, pos); | auto gb_str = variable_memory_max_size.substr(0, pos); | ||||
| auto gb_var = std::stoull(gb_str); | auto gb_var = std::stoull(gb_str); | ||||
| MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; | MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; | ||||
| if (gb_var > kAscendDeviceMemGB || gb_var == 0) { | |||||
| MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; | |||||
| if (gb_var > kAscendMaxDeviceMemGB || gb_var == 0) { | |||||
| MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-31]GB"; | |||||
| } | } | ||||
| return gb_var << kMemSizeGB; | return gb_var << kMemSizeGB; | ||||
| } | } | ||||
| @@ -225,8 +225,8 @@ class _Context: | |||||
| """set values of variable_memory_max_size and graph_memory_max_size""" | """set values of variable_memory_max_size and graph_memory_max_size""" | ||||
| if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern): | if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern): | ||||
| raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") | raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") | ||||
| if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: | |||||
| raise ValueError("Context param variable_memory_max_size should be less than 31GB.") | |||||
| if int(variable_memory_max_size[:-2]) > _DEVICE_APP_MEMORY_SIZE: | |||||
| raise ValueError("Context param variable_memory_max_size should be not greater than 31GB.") | |||||
| variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" | variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" | ||||
| graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) | graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) | ||||
| graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" | graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" | ||||
| @@ -115,7 +115,7 @@ def test_variable_memory_max_size(): | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| context.set_context(variable_memory_max_size="1G") | context.set_context(variable_memory_max_size="1G") | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| context.set_context(variable_memory_max_size="31GB") | |||||
| context.set_context(variable_memory_max_size="32GB") | |||||
| context.set_context(variable_memory_max_size="3GB") | context.set_context(variable_memory_max_size="3GB") | ||||