| @@ -27,32 +27,26 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt") | |||
| def _init_allreduce_operators(length): | |||
| """ initialize allreduce communication operators""" | |||
| is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") | |||
| split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() | |||
| if is_parallel_optimizer and split_indices: | |||
| group = 1 | |||
| fusion = () | |||
| for i in range(length): | |||
| fusion = fusion + (group,) | |||
| if split_indices[group - 1] <= i + 1: | |||
| if group >= len(split_indices): | |||
| continue | |||
| group = group + 1 | |||
| index = tuple(range(1, length + 1)) | |||
| else: | |||
| fusion = (1,) * length | |||
| index = (0,) * length | |||
| opt_list = () | |||
| group = 1 | |||
| fusion = () | |||
| for i in range(length): | |||
| opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) | |||
| opt.add_prim_attr('fusion', fusion[i]) | |||
| opt.add_prim_attr('index', index[i]) | |||
| opt_list = opt_list + (opt,) | |||
| return opt_list | |||
| fusion = fusion + (group,) | |||
| if split_indices[group - 1] <= i + 1: | |||
| if group >= len(split_indices): | |||
| continue | |||
| group = group + 1 | |||
| index = tuple(range(1, length + 1)) | |||
| op_list = () | |||
| for i in range(length): | |||
| op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) | |||
| op.add_prim_attr('fusion', fusion[i]) | |||
| op.add_prim_attr('index', index[i]) | |||
| op_list = op_list + (op,) | |||
| return op_list | |||
| @reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function", "Bool") | |||
| def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce, ps_parameter): | |||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") | |||
| def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||
| """ | |||
| Apply allreduce on gradient. | |||
| @@ -60,9 +54,10 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc | |||
| degree (int): The mean coefficient. | |||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. | |||
| allgather (Primitive): The communication operator for sparse gradients. | |||
| allreduce (Primitive): The communication operator for gradients. | |||
| allreduce_filter (bool): When it is true, allreduce would apply. | |||
| grad (Tensor): The gradient tensor before operation. | |||
| allreduce (Primitive): The communication operator for gradients. | |||
| ps_parameter(Bool): Use parameter server or not. | |||
| Returns: | |||
| Tensor, the gradient tensor after operation. | |||
| @@ -78,8 +73,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc | |||
| return grad | |||
| @reduce_opt.register("Number", "Bool", "Function", "Bool", "IndexedSlices", "Function") | |||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce): | |||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") | |||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): | |||
| """ | |||
| Apply allgather on gradient instead of allreduce for sparse feature. | |||
| Allgather is a communication operation used for distributed deep learning. | |||
| @@ -88,9 +83,9 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr | |||
| degree (int): The mean coefficient. | |||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. | |||
| allgather (Primitive): The communication operator for sparse gradients. | |||
| allreduce_filter (bool): When it is true, allgather would apply. | |||
| grad (IndexedSlices): The gradient before operation. | |||
| allreduce (Primitive): The communication operator for gradients. | |||
| allreduce_filter (bool): When it is true, allgather would apply. | |||
| grad (tuple): The indices, gradient tensor and tensor_shape before operation. | |||
| Returns: | |||
| IndexedSlices, the gradient after operation. | |||
| @@ -256,7 +251,14 @@ class DistributedGradReducer(Cell): | |||
| self.degree = degree | |||
| self.mean = mean | |||
| self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) | |||
| self.opt_list = _init_allreduce_operators(len(parameters)) | |||
| is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") | |||
| split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() | |||
| if is_parallel_optimizer and split_indices: | |||
| self.split_fusion = True | |||
| self.op_list = _init_allreduce_operators(len(parameters)) | |||
| else: | |||
| self.split_fusion = False | |||
| self.allreduce = AllReduce().add_prim_attr('fusion', 1) | |||
| self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) | |||
| ps_filter = lambda x: x.is_param_ps | |||
| self.ps_parameters = tuple(ps_filter(x) for x in parameters) | |||
| @@ -275,8 +277,11 @@ class DistributedGradReducer(Cell): | |||
| """ | |||
| datatypes = self.map_(F.partial(_get_datatype), grads) | |||
| grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||
| self.allreduce_filter, grads, self.opt_list, self.ps_parameters) | |||
| if self.split_fusion: | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||
| self.opt_list, self.allreduce_filter, grads, self.ps_parameters) | |||
| else: | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, | |||
| self.allreduce), self.allreduce_filter, grads, self.ps_parameters) | |||
| new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) | |||
| return new_grad | |||