| @@ -16,18 +16,22 @@ | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.communication.management import GlobalComm, get_group_size | from mindspore.communication.management import GlobalComm, get_group_size | ||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp | |||||
| from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| reduce_opt = C.MultitypeFuncGraph("reduce_opt") | reduce_opt = C.MultitypeFuncGraph("reduce_opt") | ||||
| _all_reduce = AllReduce() | _all_reduce = AllReduce() | ||||
| _all_gather = None | |||||
| def _init_optimizer_allreduce(): | |||||
| def _init_optimizer_communication(): | |||||
| global _all_reduce | global _all_reduce | ||||
| global _all_gather | |||||
| _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | _all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | ||||
| _all_reduce.add_prim_attr('fusion', 1) | _all_reduce.add_prim_attr('fusion', 1) | ||||
| _all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP) | |||||
| @reduce_opt.register("Function", "Number", "Bool", "Tensor") | @reduce_opt.register("Function", "Number", "Bool", "Tensor") | ||||
| @@ -72,8 +76,8 @@ def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): | |||||
| degree = F.scalar_cast(degree, F.dtype(grad[1])) | degree = F.scalar_cast(degree, F.dtype(grad[1])) | ||||
| dout = _all_gather(grad[1]) | dout = _all_gather(grad[1]) | ||||
| cast_op = P.Cast() | cast_op = P.Cast() | ||||
| dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout))) | |||||
| grad = (indices, dout, dout[2]) | |||||
| dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | |||||
| grad = (indices, dout, grad[2]) | |||||
| return grad | return grad | ||||
| @@ -110,7 +114,7 @@ def _tensors_allreduce_with_sparse(allreduce_filter, grad): | |||||
| if allreduce_filter: | if allreduce_filter: | ||||
| indices = _all_gather(grad[0]) | indices = _all_gather(grad[0]) | ||||
| dout = _all_gather(grad[1]) | dout = _all_gather(grad[1]) | ||||
| grad = (indices, dout, dout[2]) | |||||
| grad = (indices, dout, grad[2]) | |||||
| return grad | return grad | ||||
| @@ -131,6 +135,20 @@ def _tensors_get_datatype(grad): | |||||
| return F.dtype(grad) | return F.dtype(grad) | ||||
| @_get_datatype.register("Tuple") | |||||
| def _tensors_get_datatype_with_sparse(grad): | |||||
| """ | |||||
| Acquire gradient datatype. | |||||
| Args: | |||||
| grad (Tuple): The gradient tensor before operation. | |||||
| Returns: | |||||
| mstype, the datatype of gradient. | |||||
| """ | |||||
| return F.dtype(grad[1]) | |||||
| _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | ||||
| @@ -149,6 +167,22 @@ def _tensors_cast_datatype(datatype, grad): | |||||
| return F.cast(grad, datatype) | return F.cast(grad, datatype) | ||||
| @_cast_datatype.register("TypeType", "Tuple") | |||||
| def _tensors_cast_datatype_with_sparse(datatype, grad): | |||||
| """ | |||||
| Cast gradient to datatype. | |||||
| Args: | |||||
| datatype (mstype): the destination datatype of gradient. | |||||
| grad (Tuple): The gradient tensor before operation. | |||||
| Returns: | |||||
| Tuple, the gradient tuple after operation. | |||||
| """ | |||||
| dout = F.cast(grad[1], datatype) | |||||
| return (grad[0], dout, grad[2]) | |||||
| class DistributedGradReducer(Cell): | class DistributedGradReducer(Cell): | ||||
| """ | """ | ||||
| A distributed optimizer. | A distributed optimizer. | ||||
| @@ -224,7 +258,7 @@ class DistributedGradReducer(Cell): | |||||
| def __init__(self, parameters, mean=True, degree=None): | def __init__(self, parameters, mean=True, degree=None): | ||||
| super(DistributedGradReducer, self).__init__(auto_prefix=False) | super(DistributedGradReducer, self).__init__(auto_prefix=False) | ||||
| self.hyper_map = C.HyperMap() | |||||
| self.map_ = C.Map() | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| if degree is None: | if degree is None: | ||||
| self.degree = get_group_size() | self.degree = get_group_size() | ||||
| @@ -234,19 +268,27 @@ class DistributedGradReducer(Cell): | |||||
| self.degree = degree | self.degree = degree | ||||
| self.mean = mean | self.mean = mean | ||||
| self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) | self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) | ||||
| _init_optimizer_allreduce() | |||||
| _init_optimizer_communication() | |||||
| def construct(self, grads): | def construct(self, grads): | ||||
| # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | |||||
| # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | |||||
| # and cast back after the operation. | |||||
| datatypes = self.hyper_map(F.partial(_get_datatype), grads) | |||||
| grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) | |||||
| """ | |||||
| In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | |||||
| result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | |||||
| and cast back after the operation. | |||||
| Args: | |||||
| grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation. | |||||
| Returns: | |||||
| new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation. | |||||
| """ | |||||
| datatypes = self.map_(F.partial(_get_datatype), grads) | |||||
| grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) | |||||
| if self.mean: | if self.mean: | ||||
| new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) | |||||
| new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads) | |||||
| else: | else: | ||||
| new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) | |||||
| new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads) | |||||
| new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) | |||||
| new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) | |||||
| return new_grad | return new_grad | ||||