| @@ -45,8 +45,35 @@ def _init_allreduce_operators(length): | |||||
| return op_list | return op_list | ||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor") | |||||
| def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad): | |||||
| """ | |||||
| Apply allreduce on gradient. | |||||
| Args: | |||||
| degree (int): The mean coefficient. | |||||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. | |||||
| allgather (Primitive): The communication operator for sparse gradients. | |||||
| allreduce (Primitive): The communication operator for gradients. | |||||
| allreduce_filter (bool): When it is true, allreduce would apply. | |||||
| grad (Tensor): The gradient tensor before operation. | |||||
| Returns: | |||||
| Tensor, the gradient tensor after operation. | |||||
| """ | |||||
| if allreduce_filter: | |||||
| grad = allreduce(grad) | |||||
| if mean: | |||||
| degree = F.scalar_cast(degree, F.dtype(grad)) | |||||
| cast_op = P.Cast() | |||||
| mul_op = P.Mul() | |||||
| grad = mul_op(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) | |||||
| return grad | |||||
| return grad | |||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") | @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") | ||||
| def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||||
| def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||||
| """ | """ | ||||
| Apply allreduce on gradient. | Apply allreduce on gradient. | ||||
| @@ -76,8 +103,37 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra | |||||
| return grad | return grad | ||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") | |||||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): | |||||
| """ | |||||
| Apply allgather on gradient instead of allreduce for sparse feature. | |||||
| Allgather is a communication operation used for distributed deep learning. | |||||
| Args: | |||||
| degree (int): The mean coefficient. | |||||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. | |||||
| allgather (Primitive): The communication operator for sparse gradients. | |||||
| allreduce (Primitive): The communication operator for gradients. | |||||
| allreduce_filter (bool): When it is true, allgather would apply. | |||||
| grad (tuple): The indices, gradient tensor and tensor_shape before operation. | |||||
| Returns: | |||||
| IndexedSlices, the gradient after operation. | |||||
| """ | |||||
| if allreduce_filter: | |||||
| indices = allgather(grad.indices()) | |||||
| dout = allgather(grad.values()) | |||||
| if mean: | |||||
| degree = F.scalar_cast(degree, F.dtype(grad.values())) | |||||
| cast_op = P.Cast() | |||||
| mul_op = P.Mul() | |||||
| dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | |||||
| grad = IndexedSlices(indices, dout, grad.dense_shape()) | |||||
| return grad | |||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") | @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") | ||||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||||
| def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||||
| """ | """ | ||||
| Apply allgather on gradient instead of allreduce for sparse feature. | Apply allgather on gradient instead of allreduce for sparse feature. | ||||
| Allgather is a communication operation used for distributed deep learning. | Allgather is a communication operation used for distributed deep learning. | ||||
| @@ -269,6 +325,7 @@ class DistributedGradReducer(Cell): | |||||
| self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) | self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) | ||||
| ps_filter = lambda x: x.is_param_ps | ps_filter = lambda x: x.is_param_ps | ||||
| self.ps_parameters = tuple(ps_filter(x) for x in parameters) | self.ps_parameters = tuple(ps_filter(x) for x in parameters) | ||||
| self.enable_parameter_server = any(self.ps_parameters) | |||||
| def construct(self, grads): | def construct(self, grads): | ||||
| """ | """ | ||||
| @@ -285,10 +342,18 @@ class DistributedGradReducer(Cell): | |||||
| datatypes = self.map_(F.partial(_get_datatype), grads) | datatypes = self.map_(F.partial(_get_datatype), grads) | ||||
| grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) | grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) | ||||
| if self.split_fusion: | if self.split_fusion: | ||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||||
| self.opt_list, self.allreduce_filter, grads, self.ps_parameters) | |||||
| if self.enable_parameter_server: | |||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||||
| self.opt_list, self.allreduce_filter, grads, self.ps_parameters) | |||||
| else: | |||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||||
| self.opt_list, self.allreduce_filter, grads) | |||||
| else: | else: | ||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, | |||||
| self.allreduce), self.allreduce_filter, grads, self.ps_parameters) | |||||
| if self.enable_parameter_server: | |||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, | |||||
| self.allreduce), self.allreduce_filter, grads, self.ps_parameters) | |||||
| else: | |||||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, | |||||
| self.allreduce), self.allreduce_filter, grads) | |||||
| new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) | new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) | ||||
| return new_grad | return new_grad | ||||