|
|
|
@@ -27,32 +27,26 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt") |
|
|
|
|
|
|
|
def _init_allreduce_operators(length): |
|
|
|
""" initialize allreduce communication operators""" |
|
|
|
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") |
|
|
|
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() |
|
|
|
if is_parallel_optimizer and split_indices: |
|
|
|
group = 1 |
|
|
|
fusion = () |
|
|
|
for i in range(length): |
|
|
|
fusion = fusion + (group,) |
|
|
|
if split_indices[group - 1] <= i + 1: |
|
|
|
if group >= len(split_indices): |
|
|
|
continue |
|
|
|
group = group + 1 |
|
|
|
index = tuple(range(1, length + 1)) |
|
|
|
else: |
|
|
|
fusion = (1,) * length |
|
|
|
index = (0,) * length |
|
|
|
opt_list = () |
|
|
|
group = 1 |
|
|
|
fusion = () |
|
|
|
for i in range(length): |
|
|
|
opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) |
|
|
|
opt.add_prim_attr('fusion', fusion[i]) |
|
|
|
opt.add_prim_attr('index', index[i]) |
|
|
|
opt_list = opt_list + (opt,) |
|
|
|
return opt_list |
|
|
|
fusion = fusion + (group,) |
|
|
|
if split_indices[group - 1] <= i + 1: |
|
|
|
if group >= len(split_indices): |
|
|
|
continue |
|
|
|
group = group + 1 |
|
|
|
index = tuple(range(1, length + 1)) |
|
|
|
op_list = () |
|
|
|
for i in range(length): |
|
|
|
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) |
|
|
|
op.add_prim_attr('fusion', fusion[i]) |
|
|
|
op.add_prim_attr('index', index[i]) |
|
|
|
op_list = op_list + (op,) |
|
|
|
return op_list |
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function", "Bool") |
|
|
|
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce, ps_parameter): |
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") |
|
|
|
def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): |
|
|
|
""" |
|
|
|
Apply allreduce on gradient. |
|
|
|
|
|
|
|
@@ -60,9 +54,10 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc |
|
|
|
degree (int): The mean coefficient. |
|
|
|
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. |
|
|
|
allgather (Primitive): The communication operator for sparse gradients. |
|
|
|
allreduce (Primitive): The communication operator for gradients. |
|
|
|
allreduce_filter (bool): When it is true, allreduce would apply. |
|
|
|
grad (Tensor): The gradient tensor before operation. |
|
|
|
allreduce (Primitive): The communication operator for gradients. |
|
|
|
ps_parameter(Bool): Use parameter server or not. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Tensor, the gradient tensor after operation. |
|
|
|
@@ -78,8 +73,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc |
|
|
|
return grad |
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "IndexedSlices", "Function") |
|
|
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce): |
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") |
|
|
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): |
|
|
|
""" |
|
|
|
Apply allgather on gradient instead of allreduce for sparse feature. |
|
|
|
Allgather is a communication operation used for distributed deep learning. |
|
|
|
@@ -88,9 +83,9 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr |
|
|
|
degree (int): The mean coefficient. |
|
|
|
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. |
|
|
|
allgather (Primitive): The communication operator for sparse gradients. |
|
|
|
allreduce_filter (bool): When it is true, allgather would apply. |
|
|
|
grad (IndexedSlices): The gradient before operation. |
|
|
|
allreduce (Primitive): The communication operator for gradients. |
|
|
|
allreduce_filter (bool): When it is true, allgather would apply. |
|
|
|
grad (tuple): The indices, gradient tensor and tensor_shape before operation. |
|
|
|
|
|
|
|
Returns: |
|
|
|
IndexedSlices, the gradient after operation. |
|
|
|
@@ -256,7 +251,14 @@ class DistributedGradReducer(Cell): |
|
|
|
self.degree = degree |
|
|
|
self.mean = mean |
|
|
|
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) |
|
|
|
self.opt_list = _init_allreduce_operators(len(parameters)) |
|
|
|
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") |
|
|
|
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() |
|
|
|
if is_parallel_optimizer and split_indices: |
|
|
|
self.split_fusion = True |
|
|
|
self.op_list = _init_allreduce_operators(len(parameters)) |
|
|
|
else: |
|
|
|
self.split_fusion = False |
|
|
|
self.allreduce = AllReduce().add_prim_attr('fusion', 1) |
|
|
|
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) |
|
|
|
ps_filter = lambda x: x.is_param_ps |
|
|
|
self.ps_parameters = tuple(ps_filter(x) for x in parameters) |
|
|
|
@@ -275,8 +277,11 @@ class DistributedGradReducer(Cell): |
|
|
|
""" |
|
|
|
datatypes = self.map_(F.partial(_get_datatype), grads) |
|
|
|
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) |
|
|
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), |
|
|
|
self.allreduce_filter, grads, self.opt_list, self.ps_parameters) |
|
|
|
|
|
|
|
if self.split_fusion: |
|
|
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), |
|
|
|
self.opt_list, self.allreduce_filter, grads, self.ps_parameters) |
|
|
|
else: |
|
|
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, |
|
|
|
self.allreduce), self.allreduce_filter, grads, self.ps_parameters) |
|
|
|
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) |
|
|
|
return new_grad |