|
|
|
@@ -219,7 +219,7 @@ class _Context: |
|
|
|
self.set_param(ms_ctx_param.profiling_options, option) |
|
|
|
|
|
|
|
def set_variable_memory_max_size(self, variable_memory_max_size): |
|
|
|
if not check_input_format(variable_memory_max_size): |
|
|
|
if not _check_input_format(variable_memory_max_size): |
|
|
|
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") |
|
|
|
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: |
|
|
|
raise ValueError("Context param variable_memory_max_size should be less than 31GB.") |
|
|
|
@@ -230,7 +230,7 @@ class _Context: |
|
|
|
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) |
|
|
|
|
|
|
|
def set_max_device_memory(self, max_device_memory): |
|
|
|
if not check_input_format(max_device_memory): |
|
|
|
if not _check_input_format(max_device_memory): |
|
|
|
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") |
|
|
|
max_device_memory_value = float(max_device_memory[:-2]) |
|
|
|
if max_device_memory_value == 0: |
|
|
|
@@ -289,7 +289,7 @@ class _Context: |
|
|
|
thread_info.debug_runtime = enable |
|
|
|
|
|
|
|
|
|
|
|
def check_input_format(x): |
|
|
|
def _check_input_format(x): |
|
|
|
import re |
|
|
|
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' |
|
|
|
result = re.match(pattern, x) |
|
|
|
|