|
|
|
@@ -164,8 +164,10 @@ class Optimizer(Cell): |
|
|
|
self.param_length = len(self.parameters) |
|
|
|
self.map_ = C.Map() |
|
|
|
if context.get_auto_parallel_context("enable_parallel_optimizer"): |
|
|
|
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL: |
|
|
|
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend": |
|
|
|
self.use_parallel = True |
|
|
|
elif context.get_context("device_target") != "Ascend": |
|
|
|
raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.") |
|
|
|
elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL): |
|
|
|
raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode())) |
|
|
|
else: |
|
|
|
@@ -174,10 +176,10 @@ class Optimizer(Cell): |
|
|
|
self.use_parallel = False |
|
|
|
if self.use_parallel: |
|
|
|
if self.cls_name not in ["Lamb", "AdamWeightDecay"]: |
|
|
|
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) |
|
|
|
raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name)) |
|
|
|
self.dev_num = _get_device_num() |
|
|
|
if self.dev_num > self.param_length: |
|
|
|
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" |
|
|
|
raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is" |
|
|
|
" less than the number of devices {}".format(self.param_length, self.dev_num)) |
|
|
|
self.param_rank = self._get_parameter_group_id() |
|
|
|
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) |
|
|
|
|