|
|
|
@@ -95,23 +95,23 @@ class _AutoParallelContext: |
|
|
|
self.check_context_handle() |
|
|
|
return self._context_handle.get_global_rank() |
|
|
|
|
|
|
|
def set_mirror_mean(self, mirror_mean): |
|
|
|
def set_gradients_mean(self, gradients_mean): |
|
|
|
""" |
|
|
|
Set mirror_mean flag. |
|
|
|
Set gradients_mean flag. |
|
|
|
|
|
|
|
Note: |
|
|
|
If mirror_mean is true, it will insert a div operator after parameter gradients allreduce. |
|
|
|
If gradients_mean is true, it will insert a div operator after parameter gradients allreduce. |
|
|
|
|
|
|
|
Args: |
|
|
|
mirror_mean (bool): The mirror_mean flag. |
|
|
|
gradients_mean (bool): The gradients_mean flag. |
|
|
|
""" |
|
|
|
self.check_context_handle() |
|
|
|
self._context_handle.set_mirror_mean(mirror_mean) |
|
|
|
self._context_handle.set_gradients_mean(gradients_mean) |
|
|
|
|
|
|
|
def get_mirror_mean(self): |
|
|
|
"""Get mirror_mean flag.""" |
|
|
|
def get_gradients_mean(self): |
|
|
|
"""Get gradients_mean flag.""" |
|
|
|
self.check_context_handle() |
|
|
|
return self._context_handle.get_mirror_mean() |
|
|
|
return self._context_handle.get_gradients_mean() |
|
|
|
|
|
|
|
def set_gradient_fp32_sync(self, gradient_fp32_sync): |
|
|
|
""" |
|
|
|
@@ -453,7 +453,7 @@ def auto_parallel_context(): |
|
|
|
_set_auto_parallel_context_func_map = { |
|
|
|
"device_num": auto_parallel_context().set_device_num, |
|
|
|
"global_rank": auto_parallel_context().set_global_rank, |
|
|
|
"mirror_mean": auto_parallel_context().set_mirror_mean, |
|
|
|
"gradients_mean": auto_parallel_context().set_gradients_mean, |
|
|
|
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, |
|
|
|
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, |
|
|
|
"parallel_mode": auto_parallel_context().set_parallel_mode, |
|
|
|
@@ -468,7 +468,7 @@ _set_auto_parallel_context_func_map = { |
|
|
|
_get_auto_parallel_context_func_map = { |
|
|
|
"device_num": auto_parallel_context().get_device_num, |
|
|
|
"global_rank": auto_parallel_context().get_global_rank, |
|
|
|
"mirror_mean": auto_parallel_context().get_mirror_mean, |
|
|
|
"gradients_mean": auto_parallel_context().get_gradients_mean, |
|
|
|
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, |
|
|
|
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, |
|
|
|
"parallel_mode": auto_parallel_context().get_parallel_mode, |
|
|
|
@@ -480,7 +480,7 @@ _get_auto_parallel_context_func_map = { |
|
|
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} |
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, |
|
|
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, |
|
|
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, |
|
|
|
parameter_broadcast=bool, strategy_ckpt_load_file=str, |
|
|
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) |
|
|
|
@@ -495,7 +495,7 @@ def _set_auto_parallel_context(**kwargs): |
|
|
|
Args: |
|
|
|
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. |
|
|
|
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. |
|
|
|
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. |
|
|
|
gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. |
|
|
|
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated |
|
|
|
calculations. Default: True. |
|
|
|
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True. |
|
|
|
@@ -562,7 +562,7 @@ def _reset_auto_parallel_context(): |
|
|
|
|
|
|
|
- device_num: 1. |
|
|
|
- global_rank: 0. |
|
|
|
- mirror_mean: False. |
|
|
|
- gradients_mean: False. |
|
|
|
- gradient_fp32_sync: True. |
|
|
|
- parallel_mode: "stand_alone". |
|
|
|
- parameter_broadcast: False. |
|
|
|
|