| @@ -36,10 +36,11 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut | |||||
| GRAPH_MODE = 0 | GRAPH_MODE = 0 | ||||
| PYNATIVE_MODE = 1 | PYNATIVE_MODE = 1 | ||||
| _DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. | |||||
| _DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. | |||||
| _re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | _re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | ||||
| _k_context = None | _k_context = None | ||||
| def _make_directory(path): | def _make_directory(path): | ||||
| """Make directory.""" | """Make directory.""" | ||||
| real_path = None | real_path = None | ||||
| @@ -432,6 +433,7 @@ def set_auto_parallel_context(**kwargs): | |||||
| """ | """ | ||||
| _set_auto_parallel_context(**kwargs) | _set_auto_parallel_context(**kwargs) | ||||
| def get_auto_parallel_context(attr_key): | def get_auto_parallel_context(attr_key): | ||||
| """ | """ | ||||
| Gets auto parallel context attribute value according to the key. | Gets auto parallel context attribute value according to the key. | ||||
| @@ -542,7 +544,13 @@ def set_context(**kwargs): | |||||
| device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1], | device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1], | ||||
| while device_num_per_host should be no more than 4096. Default: 0. | while device_num_per_host should be no more than 4096. Default: 0. | ||||
| save_graphs (bool): Whether to save graphs. Default: False. | save_graphs (bool): Whether to save graphs. Default: False. | ||||
| save_graphs_path (str): Path to save graphs. Default: "." | |||||
| save_graphs_path (str): Path to save graphs. Default: ".". | |||||
| If the program is executed in the parallel mode, `save_graphs_path` should consist of the path and the | |||||
| current device id, to ensure that writing file conflicts won't happen when the different processes try to | |||||
| create the files in the same directory. For example, the `device_id` can be generated by | |||||
| `device_id = os.getenv("DEVICE_ID")` and the `save_graphs_path` can be set by | |||||
| `context.set_context(save_graphs_path="path/to/ir/files"+device_id)`. | |||||
| enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be | enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be | ||||
| compiled into a fused kernel automatically. Default: False. | compiled into a fused kernel automatically. Default: False. | ||||
| reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. | reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. | ||||
| @@ -700,6 +708,7 @@ class ParallelMode: | |||||
| AUTO_PARALLEL = "auto_parallel" | AUTO_PARALLEL = "auto_parallel" | ||||
| MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] | MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] | ||||
| @args_type_check(enable_ps=bool) | @args_type_check(enable_ps=bool) | ||||
| def set_ps_context(**kwargs): | def set_ps_context(**kwargs): | ||||
| """ | """ | ||||
| @@ -750,6 +759,7 @@ def get_ps_context(attr_key): | |||||
| """ | """ | ||||
| return _get_ps_context(attr_key) | return _get_ps_context(attr_key) | ||||
| def reset_ps_context(): | def reset_ps_context(): | ||||
| """ | """ | ||||
| Reset parameter server training mode context attributes to the default values: | Reset parameter server training mode context attributes to the default values: | ||||