|
|
|
@@ -36,10 +36,11 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut |
|
|
|
|
|
|
|
GRAPH_MODE = 0 |
|
|
|
PYNATIVE_MODE = 1 |
|
|
|
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. |
|
|
|
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. |
|
|
|
_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' |
|
|
|
_k_context = None |
|
|
|
|
|
|
|
|
|
|
|
def _make_directory(path): |
|
|
|
"""Make directory.""" |
|
|
|
real_path = None |
|
|
|
@@ -432,6 +433,7 @@ def set_auto_parallel_context(**kwargs): |
|
|
|
""" |
|
|
|
_set_auto_parallel_context(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def get_auto_parallel_context(attr_key): |
|
|
|
""" |
|
|
|
Gets auto parallel context attribute value according to the key. |
|
|
|
@@ -542,7 +544,13 @@ def set_context(**kwargs): |
|
|
|
device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1], |
|
|
|
while device_num_per_host should be no more than 4096. Default: 0. |
|
|
|
save_graphs (bool): Whether to save graphs. Default: False. |
|
|
|
save_graphs_path (str): Path to save graphs. Default: "." |
|
|
|
save_graphs_path (str): Path to save graphs. Default: ".". |
|
|
|
|
|
|
|
If the program is executed in the parallel mode, `save_graphs_path` should consist of the path and the |
|
|
|
current device id, to ensure that writing file conflicts won't happen when the different processes try to |
|
|
|
create the files in the same directory. For example, the `device_id` can be generated by |
|
|
|
`device_id = os.getenv("DEVICE_ID")` and the `save_graphs_path` can be set by |
|
|
|
`context.set_context(save_graphs_path="path/to/ir/files"+device_id)`. |
|
|
|
enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be |
|
|
|
compiled into a fused kernel automatically. Default: False. |
|
|
|
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. |
|
|
|
@@ -700,6 +708,7 @@ class ParallelMode: |
|
|
|
AUTO_PARALLEL = "auto_parallel" |
|
|
|
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] |
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(enable_ps=bool) |
|
|
|
def set_ps_context(**kwargs): |
|
|
|
""" |
|
|
|
@@ -750,6 +759,7 @@ def get_ps_context(attr_key): |
|
|
|
""" |
|
|
|
return _get_ps_context(attr_key) |
|
|
|
|
|
|
|
|
|
|
|
def reset_ps_context(): |
|
|
|
""" |
|
|
|
Reset parameter server training mode context attributes to the default values: |
|
|
|
|