| @@ -288,20 +288,20 @@ class Parameter(Tensor_): | |||||
| @property | @property | ||||
| def comm_fusion(self): | def comm_fusion(self): | ||||
| """Get the fusion type for communication operators corresponding to this parameter.""" | |||||
| return self.param_info.comm_fusion | |||||
| @comm_fusion.setter | |||||
| def comm_fusion(self, comm_fusion_): | |||||
| """ | """ | ||||
| Get and Set the fusion type (int) for communication operators corresponding to this parameter. | |||||
| In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or | In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or | ||||
| gradients aggregation are inserted automatically.Set the fusion type for communication operators generated | |||||
| for this parameter. Only `Ascend` and `Graph` mode is supported. | |||||
| gradients aggregation are inserted automatically. Set the fusion type for communication operators generated | |||||
| for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0, | |||||
| operators will not be fused together. | |||||
| Args: | |||||
| comm_fusion_ (int): The value of fusion must be greater than or equal to 0. | |||||
| When the value of fusion is 0, operators will not be fused together. | |||||
| Only `Ascend` and `Graph` mode is supported. | |||||
| """ | """ | ||||
| return self.param_info.comm_fusion | |||||
| @comm_fusion.setter | |||||
| def comm_fusion(self, comm_fusion_): | |||||
| if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode(): | if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode(): | ||||
| raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE") | raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE") | ||||
| Validator.check_non_negative_int(comm_fusion_) | Validator.check_non_negative_int(comm_fusion_) | ||||
| @@ -344,7 +344,7 @@ def _context(): | |||||
| @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, | @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, | ||||
| auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, | auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, | ||||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | ||||
| all_reduce_fusion_config=list, pipeline_stages=int) | |||||
| all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int) | |||||
| def set_auto_parallel_context(**kwargs): | def set_auto_parallel_context(**kwargs): | ||||
| r""" | r""" | ||||
| Set auto parallel context, which is valid only for Ascend and GPU target. | Set auto parallel context, which is valid only for Ascend and GPU target. | ||||
| @@ -371,6 +371,7 @@ def set_auto_parallel_context(**kwargs): | |||||
| all_reduce_fusion_config strategy_ckpt_save_file | all_reduce_fusion_config strategy_ckpt_save_file | ||||
| enable_parallel_optimizer full_batch | enable_parallel_optimizer full_batch | ||||
| \ pipeline_stages | \ pipeline_stages | ||||
| \ grad_accumulation_step | |||||
| =========================== =========================== | =========================== =========================== | ||||
| Args: | Args: | ||||
| @@ -420,6 +421,8 @@ def set_auto_parallel_context(**kwargs): | |||||
| the devices are distributed alone the pipeline. The total devices will be divided into | the devices are distributed alone the pipeline. The total devices will be divided into | ||||
| 'pipeline_stags' stages. This currently could only be used when | 'pipeline_stags' stages. This currently could only be used when | ||||
| parallel mode semi_auto_parallel is enabled. Default: 1. | parallel mode semi_auto_parallel is enabled. Default: 1. | ||||
| grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode. | |||||
| This should be a positive int. Default: 1. | |||||
| Raises: | Raises: | ||||
| ValueError: If input key is not attribute in auto parallel context. | ValueError: If input key is not attribute in auto parallel context. | ||||
| @@ -18,7 +18,7 @@ import mindspore.context as context | |||||
| from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size | from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size | ||||
| from mindspore.parallel._ps_context import _is_role_pserver | from mindspore.parallel._ps_context import _is_role_pserver | ||||
| from mindspore._c_expression import AutoParallelContext | from mindspore._c_expression import AutoParallelContext | ||||
| from mindspore._checkparam import args_type_check | |||||
| from mindspore._checkparam import args_type_check, Validator | |||||
| _MAX_GROUP_NAME_LEN = 127 | _MAX_GROUP_NAME_LEN = 127 | ||||
| _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" | _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" | ||||
| @@ -257,6 +257,7 @@ class _AutoParallelContext: | |||||
| grad_accumulation_step (int): The grad accumulation step. | grad_accumulation_step (int): The grad accumulation step. | ||||
| """ | """ | ||||
| self.check_context_handle() | self.check_context_handle() | ||||
| Validator.check_positive_int(grad_accumulation_step) | |||||
| self._context_handle.set_grad_accumulation_step(grad_accumulation_step) | self._context_handle.set_grad_accumulation_step(grad_accumulation_step) | ||||
| def get_grad_accumulation_step(self): | def get_grad_accumulation_step(self): | ||||