Merge pull request !1737 from lirongzhen1/sparsetags/v0.5.0-beta
| @@ -52,6 +52,31 @@ def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad): | |||||
| return grad | return grad | ||||
| @reduce_opt.register("Function", "Number", "Bool", "Tuple") | |||||
| def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): | |||||
| """ | |||||
| Apply mean and allgather on gradient instead of allreduce for sparse feature. | |||||
| Allgather is a communication operation used for distributed deep learning. | |||||
| Args: | |||||
| mul (Primitive): Div operation. | |||||
| degree (int): The mean coefficient. | |||||
| allreduce_filter (bool): When it is true, allgather would apply. | |||||
| grad (Tuple): The indices, gradient tensor and tensor_shape before operation. | |||||
| Returns: | |||||
| Tuple, include indices, the gradient tensor and tensor_shape after operation. | |||||
| """ | |||||
| if allreduce_filter: | |||||
| indices = _all_gather(grad[0]) | |||||
| degree = F.scalar_cast(degree, F.dtype(grad[1])) | |||||
| dout = _all_gather(grad[1]) | |||||
| cast_op = P.Cast() | |||||
| dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout))) | |||||
| grad = (indices, dout, dout[2]) | |||||
| return grad | |||||
| @reduce_opt.register("Bool", "Tensor") | @reduce_opt.register("Bool", "Tensor") | ||||
| def _tensors_allreduce(allreduce_filter, grad): | def _tensors_allreduce(allreduce_filter, grad): | ||||
| """ | """ | ||||
| @@ -69,6 +94,26 @@ def _tensors_allreduce(allreduce_filter, grad): | |||||
| return grad | return grad | ||||
| @reduce_opt.register("Bool", "Tuple") | |||||
| def _tensors_allreduce_with_sparse(allreduce_filter, grad): | |||||
| """ | |||||
| Apply mean and allgather on gradient instead of allreduce for sparse feature. | |||||
| Allgather is a communication operation used for distributed deep learning. | |||||
| Args: | |||||
| allreduce_filter (bool): When it is true, allgather would apply. | |||||
| grad (Tuple): The indices, gradient tensor and tensor_shape before operation. | |||||
| Returns: | |||||
| Tuple, include indices, the gradient tensor and tensor_shape after operation. | |||||
| """ | |||||
| if allreduce_filter: | |||||
| indices = _all_gather(grad[0]) | |||||
| dout = _all_gather(grad[1]) | |||||
| grad = (indices, dout, dout[2]) | |||||
| return grad | |||||
| _get_datatype = C.MultitypeFuncGraph("_get_datatype") | _get_datatype = C.MultitypeFuncGraph("_get_datatype") | ||||
| @@ -26,9 +26,10 @@ from .grad_base import bprop_getters | |||||
| @bprop_getters.register(AllReduce) | @bprop_getters.register(AllReduce) | ||||
| def get_bprop_all_reduce(self): | def get_bprop_all_reduce(self): | ||||
| """Generate bprop for AllReduce.""" | |||||
| """Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature.""" | |||||
| all_reduce_grad = AllReduce(ReduceOp.SUM, self.group) | all_reduce_grad = AllReduce(ReduceOp.SUM, self.group) | ||||
| all_gather = AllGather(group=self.group) | |||||
| if self.instance_name: | if self.instance_name: | ||||
| instance_name = "grad" + self.instance_name | instance_name = "grad" + self.instance_name | ||||
| all_reduce_grad.set_prim_instance_name(instance_name) | all_reduce_grad.set_prim_instance_name(instance_name) | ||||
| @@ -42,15 +43,28 @@ def get_bprop_all_reduce(self): | |||||
| if self.op == ReduceOp.SUM: | if self.op == ReduceOp.SUM: | ||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = all_reduce_grad(dout) | |||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||||
| dx = all_reduce_grad(dout) | |||||
| else: | |||||
| indices = all_gather(dout[0]) | |||||
| grad = all_gather(dout[1]) | |||||
| dx = (indices, grad, dout[2]) | |||||
| return (dx,) | return (dx,) | ||||
| else: | else: | ||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = all_reduce_grad(dout) | |||||
| z = equal(x, out) | |||||
| z = cast(z, dtype(dx)) | |||||
| dx = mul(dx, z) | |||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||||
| dx = all_reduce_grad(dout) | |||||
| z = equal(x, out) | |||||
| z = cast(z, dtype(dx)) | |||||
| dx = mul(dx, z) | |||||
| else: | |||||
| indices = all_gather(dout[0]) | |||||
| grad = all_gather(dout[1]) | |||||
| z = equal(x, out) | |||||
| z = cast(z, dtype(grad)) | |||||
| grad = mul(grad, z) | |||||
| dx = (indices, grad, dout[2]) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -147,12 +161,16 @@ def get_bprop_all_to_all(self): | |||||
| @bprop_getters.register(_MirrorOperator) | @bprop_getters.register(_MirrorOperator) | ||||
| def get_bprop_mirror_operator(self): | def get_bprop_mirror_operator(self): | ||||
| """Backpropagator for _MirrorOperator, do allreduce for the devices in group(only for one group).""" | |||||
| """ | |||||
| Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group), | |||||
| allgather for sparse feature. | |||||
| """ | |||||
| group = self.group | group = self.group | ||||
| dev_num = self.dev_num | dev_num = self.dev_num | ||||
| mean_flag = self.mean_flag | mean_flag = self.mean_flag | ||||
| all_reduce = AllReduce(group=group) | all_reduce = AllReduce(group=group) | ||||
| all_gather = AllGather(group=group) | |||||
| mul = P.Mul() | mul = P.Mul() | ||||
| cast = P.Cast() | cast = P.Cast() | ||||
| @@ -170,12 +188,25 @@ def get_bprop_mirror_operator(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| if mean_flag: | if mean_flag: | ||||
| dx = all_reduce(dout) | |||||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | |||||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | |||||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | |||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||||
| dx = all_reduce(dout) | |||||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | |||||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | |||||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | |||||
| else: | |||||
| indices = all_gather(dout[0]) | |||||
| grad = all_gather(dout[1]) | |||||
| float_one = F.scalar_cast(1.0, F.dtype(grad)) | |||||
| num = F.scalar_cast(dev_num, F.dtype(grad)) | |||||
| grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) | |||||
| dx = (indices, grad, dout[2]) | |||||
| else: | else: | ||||
| dx = all_reduce(dout) | |||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||||
| dx = all_reduce(dout) | |||||
| else: | |||||
| indices = all_gather(dout[0]) | |||||
| grad = all_gather(dout[1]) | |||||
| dx = (indices, grad, dout[2]) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -0,0 +1,118 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test sparse feature bprop """ | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator | |||||
| from mindspore.ops._grad.grad_base import bprop_getters | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.communication.management import HCCL_WORLD_COMM_GROUP | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x): | |||||
| return C.grad_all(self.network)(x) | |||||
| class VirtualGatherV2(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init index_select""" | |||||
| super(VirtualGatherV2, self).__init__('VirtualGatherV2') | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||||
| def __infer__(self, params, indices, axis): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||||
| axis_v = axis['value'] | |||||
| params_shp = params['shape'] | |||||
| rank = len(params_shp) | |||||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||||
| if axis_v < 0: | |||||
| axis_v += rank | |||||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': params['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| @bprop_getters.register(VirtualGatherV2) | |||||
| def get_bprop_gather_v2(self): | |||||
| """Generate bprop for GatherV2""" | |||||
| def bprop(x, indices, axis, out, dout): | |||||
| return (indices, dout, x), axis, out | |||||
| return bprop | |||||
| def test_bprop_with_sparse_feature_allreduce(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, axis=0, shape=None): | |||||
| super(Net, self).__init__() | |||||
| if shape is None: | |||||
| shape = [8, 8] | |||||
| self.all_reduce = AllReduce() | |||||
| self.gatherv2 = VirtualGatherV2() | |||||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||||
| self.axis = axis | |||||
| def construct(self, x): | |||||
| out = self.all_reduce(x) | |||||
| out = self.gatherv2(out, self.index, self.axis) | |||||
| return out | |||||
| net = GradWrap(Net()) | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x) | |||||
| def test_bprop_with_sparse_feature_mirror(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, axis=0, shape=None): | |||||
| super(Net, self).__init__() | |||||
| if shape is None: | |||||
| shape = [8, 8] | |||||
| self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP) | |||||
| self.gatherv2 = VirtualGatherV2() | |||||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||||
| self.axis = axis | |||||
| def construct(self, x): | |||||
| out = self.mirror(x) | |||||
| out = self.gatherv2(out, self.index, self.axis) | |||||
| return out | |||||
| net = GradWrap(Net()) | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x) | |||||