|
|
|
@@ -344,7 +344,7 @@ def _context(): |
|
|
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, |
|
|
|
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, |
|
|
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, |
|
|
|
all_reduce_fusion_config=list, pipeline_stages=int) |
|
|
|
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int) |
|
|
|
def set_auto_parallel_context(**kwargs): |
|
|
|
r""" |
|
|
|
Set auto parallel context, which is valid only for Ascend and GPU target. |
|
|
|
@@ -371,6 +371,7 @@ def set_auto_parallel_context(**kwargs): |
|
|
|
all_reduce_fusion_config strategy_ckpt_save_file |
|
|
|
enable_parallel_optimizer full_batch |
|
|
|
\ pipeline_stages |
|
|
|
\ grad_accumulation_step |
|
|
|
=========================== =========================== |
|
|
|
|
|
|
|
Args: |
|
|
|
@@ -420,6 +421,8 @@ def set_auto_parallel_context(**kwargs): |
|
|
|
the devices are distributed alone the pipeline. The total devices will be divided into |
|
|
|
'pipeline_stags' stages. This currently could only be used when |
|
|
|
parallel mode semi_auto_parallel is enabled. Default: 1. |
|
|
|
grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode. |
|
|
|
This should be a positive int. Default: 1. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: If input key is not attribute in auto parallel context. |
|
|
|
|