| @@ -16,18 +16,22 @@ | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.communication.management import GlobalComm, get_group_size | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp | |||
| from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather | |||
| import mindspore.common.dtype as mstype | |||
| reduce_opt = C.MultitypeFuncGraph("reduce_opt") | |||
| _all_reduce = AllReduce() | |||
| _all_gather = None | |||
| def _init_optimizer_allreduce(): | |||
| def _init_optimizer_communication(): | |||
| global _all_reduce | |||
| global _all_gather | |||
| _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | |||
| _all_reduce.add_prim_attr('fusion', 1) | |||
| _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP) | |||
| @reduce_opt.register("Function", "Number", "Bool", "Tensor") | |||
| @@ -72,8 +76,8 @@ def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): | |||
| degree = F.scalar_cast(degree, F.dtype(grad[1])) | |||
| dout = _all_gather(grad[1]) | |||
| cast_op = P.Cast() | |||
| dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout))) | |||
| grad = (indices, dout, dout[2]) | |||
| dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | |||
| grad = (indices, dout, grad[2]) | |||
| return grad | |||
| @@ -110,7 +114,7 @@ def _tensors_allreduce_with_sparse(allreduce_filter, grad): | |||
| if allreduce_filter: | |||
| indices = _all_gather(grad[0]) | |||
| dout = _all_gather(grad[1]) | |||
| grad = (indices, dout, dout[2]) | |||
| grad = (indices, dout, grad[2]) | |||
| return grad | |||
| @@ -131,6 +135,20 @@ def _tensors_get_datatype(grad): | |||
| return F.dtype(grad) | |||
| @_get_datatype.register("Tuple") | |||
| def _tensors_get_datatype_with_sparse(grad): | |||
| """ | |||
| Acquire gradient datatype. | |||
| Args: | |||
| grad (Tuple): The gradient tensor before operation. | |||
| Returns: | |||
| mstype, the datatype of gradient. | |||
| """ | |||
| return F.dtype(grad[1]) | |||
| _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | |||
| @@ -149,6 +167,22 @@ def _tensors_cast_datatype(datatype, grad): | |||
| return F.cast(grad, datatype) | |||
| @_cast_datatype.register("TypeType", "Tuple") | |||
| def _tensors_cast_datatype_with_sparse(datatype, grad): | |||
| """ | |||
| Cast gradient to datatype. | |||
| Args: | |||
| datatype (mstype): the destination datatype of gradient. | |||
| grad (Tuple): The gradient tensor before operation. | |||
| Returns: | |||
| Tuple, the gradient tuple after operation. | |||
| """ | |||
| dout = F.cast(grad[1], datatype) | |||
| return (grad[0], dout, grad[2]) | |||
| class DistributedGradReducer(Cell): | |||
| """ | |||
| A distributed optimizer. | |||
| @@ -224,7 +258,7 @@ class DistributedGradReducer(Cell): | |||
| def __init__(self, parameters, mean=True, degree=None): | |||
| super(DistributedGradReducer, self).__init__(auto_prefix=False) | |||
| self.hyper_map = C.HyperMap() | |||
| self.map_ = C.Map() | |||
| self.mul = P.Mul() | |||
| if degree is None: | |||
| self.degree = get_group_size() | |||
| @@ -234,19 +268,27 @@ class DistributedGradReducer(Cell): | |||
| self.degree = degree | |||
| self.mean = mean | |||
| self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) | |||
| _init_optimizer_allreduce() | |||
| _init_optimizer_communication() | |||
| def construct(self, grads): | |||
| # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | |||
| # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | |||
| # and cast back after the operation. | |||
| datatypes = self.hyper_map(F.partial(_get_datatype), grads) | |||
| grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) | |||
| """ | |||
| In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | |||
| result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | |||
| and cast back after the operation. | |||
| Args: | |||
| grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation. | |||
| Returns: | |||
| new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation. | |||
| """ | |||
| datatypes = self.map_(F.partial(_get_datatype), grads) | |||
| grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) | |||
| if self.mean: | |||
| new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) | |||
| new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) | |||
| else: | |||
| new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) | |||
| new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads) | |||
| new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) | |||
| new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) | |||
| return new_grad | |||