diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index c66bfbe646..9354b42e55 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -16,18 +16,22 @@ from mindspore.nn.cell import Cell from mindspore.communication.management import GlobalComm, get_group_size from mindspore.ops import functional as F, composite as C, operations as P -from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp +from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather import mindspore.common.dtype as mstype reduce_opt = C.MultitypeFuncGraph("reduce_opt") _all_reduce = AllReduce() +_all_gather = None -def _init_optimizer_allreduce(): +def _init_optimizer_communication(): global _all_reduce + global _all_gather + _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) _all_reduce.add_prim_attr('fusion', 1) + _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP) @reduce_opt.register("Function", "Number", "Bool", "Tensor") @@ -72,8 +76,8 @@ def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): degree = F.scalar_cast(degree, F.dtype(grad[1])) dout = _all_gather(grad[1]) cast_op = P.Cast() - dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout))) - grad = (indices, dout, dout[2]) + dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) + grad = (indices, dout, grad[2]) return grad @@ -110,7 +114,7 @@ def _tensors_allreduce_with_sparse(allreduce_filter, grad): if allreduce_filter: indices = _all_gather(grad[0]) dout = _all_gather(grad[1]) - grad = (indices, dout, dout[2]) + grad = (indices, dout, grad[2]) return grad @@ -131,6 +135,20 @@ def _tensors_get_datatype(grad): return F.dtype(grad) +@_get_datatype.register("Tuple") +def _tensors_get_datatype_with_sparse(grad): + """ + Acquire gradient datatype. + + Args: + grad (Tuple): The gradient tensor before operation. + + Returns: + mstype, the datatype of gradient. + """ + return F.dtype(grad[1]) + + _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") @@ -149,6 +167,22 @@ def _tensors_cast_datatype(datatype, grad): return F.cast(grad, datatype) +@_cast_datatype.register("TypeType", "Tuple") +def _tensors_cast_datatype_with_sparse(datatype, grad): + """ + Cast gradient to datatype. + + Args: + datatype (mstype): the destination datatype of gradient. + grad (Tuple): The gradient tensor before operation. + + Returns: + Tuple, the gradient tuple after operation. + """ + dout = F.cast(grad[1], datatype) + return (grad[0], dout, grad[2]) + + class DistributedGradReducer(Cell): """ A distributed optimizer. @@ -224,7 +258,7 @@ class DistributedGradReducer(Cell): def __init__(self, parameters, mean=True, degree=None): super(DistributedGradReducer, self).__init__(auto_prefix=False) - self.hyper_map = C.HyperMap() + self.map_ = C.Map() self.mul = P.Mul() if degree is None: self.degree = get_group_size() @@ -234,19 +268,27 @@ class DistributedGradReducer(Cell): self.degree = degree self.mean = mean self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) - _init_optimizer_allreduce() + _init_optimizer_communication() def construct(self, grads): - # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the - # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, - # and cast back after the operation. - datatypes = self.hyper_map(F.partial(_get_datatype), grads) - grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) + """ + In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the + result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, + and cast back after the operation. + + Args: + grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation. + + Returns: + new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation. + """ + datatypes = self.map_(F.partial(_get_datatype), grads) + grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) if self.mean: - new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) + new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) else: - new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) + new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads) - new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) + new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) return new_grad