| @@ -383,9 +383,9 @@ def set_auto_parallel_context(**kwargs): | |||||
| full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter | full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter | ||||
| should be set with True. Default: False. | should be set with True. Default: False. | ||||
| enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for | enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for | ||||
| data parallel training in the benefit of time and memory saving. For now, auto parallel mode | |||||
| supports all optimizers. Data parallel mode only supports `Lamb` and `AdamWeightDecay`. | |||||
| Default: False. | |||||
| data parallel training in the benefit of time and memory saving. Currently, auto and semi auto | |||||
| parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports | |||||
| `Lamb` and `AdamWeightDecay` in Ascend . Default: False. | |||||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | ||||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. | and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. | ||||
| pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how | pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how | ||||
| @@ -164,8 +164,10 @@ class Optimizer(Cell): | |||||
| self.param_length = len(self.parameters) | self.param_length = len(self.parameters) | ||||
| self.map_ = C.Map() | self.map_ = C.Map() | ||||
| if context.get_auto_parallel_context("enable_parallel_optimizer"): | if context.get_auto_parallel_context("enable_parallel_optimizer"): | ||||
| if _get_parallel_mode() == ParallelMode.DATA_PARALLEL: | |||||
| if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend": | |||||
| self.use_parallel = True | self.use_parallel = True | ||||
| elif context.get_context("device_target") != "Ascend": | |||||
| raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.") | |||||
| elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL): | elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL): | ||||
| raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode())) | raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode())) | ||||
| else: | else: | ||||
| @@ -174,10 +176,10 @@ class Optimizer(Cell): | |||||
| self.use_parallel = False | self.use_parallel = False | ||||
| if self.use_parallel: | if self.use_parallel: | ||||
| if self.cls_name not in ["Lamb", "AdamWeightDecay"]: | if self.cls_name not in ["Lamb", "AdamWeightDecay"]: | ||||
| raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) | |||||
| raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name)) | |||||
| self.dev_num = _get_device_num() | self.dev_num = _get_device_num() | ||||
| if self.dev_num > self.param_length: | if self.dev_num > self.param_length: | ||||
| raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" | |||||
| raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is" | |||||
| " less than the number of devices {}".format(self.param_length, self.dev_num)) | " less than the number of devices {}".format(self.param_length, self.dev_num)) | ||||
| self.param_rank = self._get_parameter_group_id() | self.param_rank = self._get_parameter_group_id() | ||||
| self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) | self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) | ||||
| @@ -164,6 +164,12 @@ def test_edge_case(): | |||||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | context.set_auto_parallel_context(parallel_mode="stand_alone") | ||||
| Lamb(net.trainable_params(), learning_rate=0.1) | Lamb(net.trainable_params(), learning_rate=0.1) | ||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| context.set_context(device_target="GPU") | |||||
| context.set_auto_parallel_context(parallel_mode="data_parallel") | |||||
| Lamb(net.trainable_params(), learning_rate=0.1) | |||||
| with pytest.raises(RuntimeError): | |||||
| context.set_context(device_target="Ascend") | |||||
| context.set_auto_parallel_context(parallel_mode="data_parallel") | |||||
| Adam(net.trainable_params(), learning_rate=0.1) | Adam(net.trainable_params(), learning_rate=0.1) | ||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| context.set_auto_parallel_context(device_num=16) | context.set_auto_parallel_context(device_num=16) | ||||