|
|
|
@@ -185,13 +185,20 @@ class _AutoParallelContext: |
|
|
|
self.check_context_handle() |
|
|
|
return self._context_handle.get_parallel_mode() |
|
|
|
|
|
|
|
def set_strategy_search_mode(self, strategy_search_mode): |
|
|
|
def set_strategy_search_mode(self, auto_parallel_search_mode): |
|
|
|
""" |
|
|
|
Set search mode of strategy. |
|
|
|
|
|
|
|
Args: |
|
|
|
auto_parallel_search_mode (str): The search mode of strategy. |
|
|
|
""" |
|
|
|
self.check_context_handle() |
|
|
|
ret = self._context_handle.set_strategy_search_mode(strategy_search_mode) |
|
|
|
ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode) |
|
|
|
if ret is False: |
|
|
|
raise ValueError("Strategy search mode does not support {}".format(strategy_search_mode)) |
|
|
|
raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode)) |
|
|
|
|
|
|
|
def get_strategy_search_mode(self): |
|
|
|
"""Get search mode of strategy.""" |
|
|
|
self.check_context_handle() |
|
|
|
return self._context_handle.get_strategy_search_mode() |
|
|
|
|
|
|
|
@@ -422,6 +429,7 @@ _set_auto_parallel_context_func_map = { |
|
|
|
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror, |
|
|
|
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, |
|
|
|
"parallel_mode": auto_parallel_context().set_parallel_mode, |
|
|
|
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, |
|
|
|
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast, |
|
|
|
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, |
|
|
|
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, |
|
|
|
@@ -435,6 +443,7 @@ _get_auto_parallel_context_func_map = { |
|
|
|
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror, |
|
|
|
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, |
|
|
|
"parallel_mode": auto_parallel_context().get_parallel_mode, |
|
|
|
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, |
|
|
|
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast, |
|
|
|
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, |
|
|
|
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, |
|
|
|
@@ -442,8 +451,9 @@ _get_auto_parallel_context_func_map = { |
|
|
|
|
|
|
|
|
|
|
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, |
|
|
|
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool, |
|
|
|
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool) |
|
|
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, |
|
|
|
parameter_broadcast=bool, strategy_ckpt_load_file=str, |
|
|
|
strategy_ckpt_save_file=str, full_batch=bool) |
|
|
|
def _set_auto_parallel_context(**kwargs): |
|
|
|
""" |
|
|
|
Set auto parallel context. |
|
|
|
@@ -471,6 +481,12 @@ def _set_auto_parallel_context(**kwargs): |
|
|
|
setting parallel strategies. |
|
|
|
|
|
|
|
- auto_parallel: Achieving parallelism automatically. |
|
|
|
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming" |
|
|
|
and "dynamic_programming". |
|
|
|
|
|
|
|
- recursive_programming: Recursive programming search mode. |
|
|
|
|
|
|
|
- dynamic_programming: Dynamic programming search mode. |
|
|
|
parameter_broadcast (bool): Indicating whether to broadcast parameters before training. |
|
|
|
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter |
|
|
|
broadcast. Default: False. |
|
|
|
|