|
|
@@ -219,6 +219,7 @@ class _Context: |
|
|
self.set_param(ms_ctx_param.profiling_options, option) |
|
|
self.set_param(ms_ctx_param.profiling_options, option) |
|
|
|
|
|
|
|
|
def set_variable_memory_max_size(self, variable_memory_max_size): |
|
|
def set_variable_memory_max_size(self, variable_memory_max_size): |
|
|
|
|
|
"""set values of variable_memory_max_size and graph_memory_max_size""" |
|
|
if not _check_input_format(variable_memory_max_size): |
|
|
if not _check_input_format(variable_memory_max_size): |
|
|
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") |
|
|
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") |
|
|
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: |
|
|
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: |
|
|
@@ -227,7 +228,8 @@ class _Context: |
|
|
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) |
|
|
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) |
|
|
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" |
|
|
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" |
|
|
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) |
|
|
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) |
|
|
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) |
|
|
|
|
|
|
|
|
# pylint: disable=protected-access |
|
|
|
|
|
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_) |
|
|
|
|
|
|
|
|
def set_max_device_memory(self, max_device_memory): |
|
|
def set_max_device_memory(self, max_device_memory): |
|
|
if not _check_input_format(max_device_memory): |
|
|
if not _check_input_format(max_device_memory): |
|
|
@@ -425,6 +427,26 @@ def reset_auto_parallel_context(): |
|
|
_reset_auto_parallel_context() |
|
|
_reset_auto_parallel_context() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_target_specific_cfgs(device, arg_key): |
|
|
|
|
|
"""Checking whether a config is sutable for a specified device""" |
|
|
|
|
|
device_cfgs = { |
|
|
|
|
|
'enable_auto_mixed_precision': ['Ascend'], |
|
|
|
|
|
'enable_dump': ['Ascend'], |
|
|
|
|
|
'enable_profiling': ['Ascend'], |
|
|
|
|
|
'variable_memory_max_size': ['Ascend'], |
|
|
|
|
|
'max_device_memory': ['GPU'] |
|
|
|
|
|
} |
|
|
|
|
|
# configs not in map device_cfgs are supposed to be suitable for all devices |
|
|
|
|
|
if not arg_key in device_cfgs: |
|
|
|
|
|
return True |
|
|
|
|
|
supported_devices = device_cfgs[arg_key] |
|
|
|
|
|
if device in supported_devices: |
|
|
|
|
|
return True |
|
|
|
|
|
logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'" |
|
|
|
|
|
", ignore it.") |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, |
|
|
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, |
|
|
save_graphs_path=str, enable_dump=bool, |
|
|
save_graphs_path=str, enable_dump=bool, |
|
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, |
|
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, |
|
|
@@ -450,6 +472,26 @@ def set_context(**kwargs): |
|
|
The mode is not recommended to be changed after net was initilized because the implementations of some |
|
|
The mode is not recommended to be changed after net was initilized because the implementations of some |
|
|
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE. |
|
|
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE. |
|
|
|
|
|
|
|
|
|
|
|
Some configurations are device specific, see the bellow table for details: |
|
|
|
|
|
|
|
|
|
|
|
=========================== =========================== ================= |
|
|
|
|
|
Common(CPU/GPU/Asecend) Ascend GPU |
|
|
|
|
|
=========================== =========================== ================= |
|
|
|
|
|
check_bprop enable_auto_mixed_precision max_device_memory |
|
|
|
|
|
device_id enable_dump |
|
|
|
|
|
device_target enable_profiling |
|
|
|
|
|
enable_graph_kernel variable_memory_max_size |
|
|
|
|
|
enable_reduce_precision |
|
|
|
|
|
enable_sparse |
|
|
|
|
|
mode |
|
|
|
|
|
print_file_path |
|
|
|
|
|
profiling_options |
|
|
|
|
|
reserve_class_name_in_scope |
|
|
|
|
|
save_dump_path |
|
|
|
|
|
save_graphs |
|
|
|
|
|
save_graphs_path |
|
|
|
|
|
=========================== =========================== ================= |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). |
|
|
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). |
|
|
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend". |
|
|
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend". |
|
|
@@ -513,14 +555,21 @@ def set_context(**kwargs): |
|
|
>>> context.set_context(max_call_depth=80) |
|
|
>>> context.set_context(max_call_depth=80) |
|
|
""" |
|
|
""" |
|
|
ctx = _context() |
|
|
ctx = _context() |
|
|
|
|
|
# set device target first |
|
|
|
|
|
if 'device_target' in kwargs: |
|
|
|
|
|
ctx.set_device_target(kwargs['device_target']) |
|
|
|
|
|
device = ctx.get_param(ms_ctx_param.device_target) |
|
|
for key, value in kwargs.items(): |
|
|
for key, value in kwargs.items(): |
|
|
|
|
|
if not _check_target_specific_cfgs(device, key): |
|
|
|
|
|
continue |
|
|
if hasattr(ctx, key): |
|
|
if hasattr(ctx, key): |
|
|
setattr(ctx, key, value) |
|
|
setattr(ctx, key, value) |
|
|
continue |
|
|
continue |
|
|
if key in ctx.setters: |
|
|
if key in ctx.setters: |
|
|
ctx.setters[key](ctx, value) |
|
|
ctx.setters[key](ctx, value) |
|
|
continue |
|
|
continue |
|
|
if key in ms_ctx_param.__members__: |
|
|
|
|
|
|
|
|
# enum variables begining with '_' are for internal use |
|
|
|
|
|
if key in ms_ctx_param.__members__ and key[0] != '_': |
|
|
ctx.set_param(ms_ctx_param.__members__[key], value) |
|
|
ctx.set_param(ms_ctx_param.__members__[key], value) |
|
|
continue |
|
|
continue |
|
|
raise ValueError("Set context keyword %s is not recognized!" % key) |
|
|
raise ValueError("Set context keyword %s is not recognized!" % key) |
|
|
@@ -540,9 +589,12 @@ def get_context(attr_key): |
|
|
ValueError: If input key is not an attribute in context. |
|
|
ValueError: If input key is not an attribute in context. |
|
|
""" |
|
|
""" |
|
|
ctx = _context() |
|
|
ctx = _context() |
|
|
|
|
|
device = ctx.get_param(ms_ctx_param.device_target) |
|
|
|
|
|
_ = _check_target_specific_cfgs(device, attr_key) |
|
|
if hasattr(ctx, attr_key): |
|
|
if hasattr(ctx, attr_key): |
|
|
return getattr(ctx, attr_key) |
|
|
return getattr(ctx, attr_key) |
|
|
if attr_key in ms_ctx_param.__members__: |
|
|
|
|
|
|
|
|
# enum variables begining with '_' are for internal use |
|
|
|
|
|
if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_': |
|
|
return ctx.get_param(ms_ctx_param.__members__[attr_key]) |
|
|
return ctx.get_param(ms_ctx_param.__members__[attr_key]) |
|
|
raise ValueError("Get context keyword %s is not recognized!" % attr_key) |
|
|
raise ValueError("Get context keyword %s is not recognized!" % attr_key) |
|
|
|
|
|
|
|
|
|