Merge pull request !1737 from lirongzhen1/sparsetags/v0.5.0-beta
| @@ -52,6 +52,31 @@ def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad): | |||
| return grad | |||
| @reduce_opt.register("Function", "Number", "Bool", "Tuple") | |||
| def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad): | |||
| """ | |||
| Apply mean and allgather on gradient instead of allreduce for sparse feature. | |||
| Allgather is a communication operation used for distributed deep learning. | |||
| Args: | |||
| mul (Primitive): Div operation. | |||
| degree (int): The mean coefficient. | |||
| allreduce_filter (bool): When it is true, allgather would apply. | |||
| grad (Tuple): The indices, gradient tensor and tensor_shape before operation. | |||
| Returns: | |||
| Tuple, include indices, the gradient tensor and tensor_shape after operation. | |||
| """ | |||
| if allreduce_filter: | |||
| indices = _all_gather(grad[0]) | |||
| degree = F.scalar_cast(degree, F.dtype(grad[1])) | |||
| dout = _all_gather(grad[1]) | |||
| cast_op = P.Cast() | |||
| dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout))) | |||
| grad = (indices, dout, dout[2]) | |||
| return grad | |||
| @reduce_opt.register("Bool", "Tensor") | |||
| def _tensors_allreduce(allreduce_filter, grad): | |||
| """ | |||
| @@ -69,6 +94,26 @@ def _tensors_allreduce(allreduce_filter, grad): | |||
| return grad | |||
| @reduce_opt.register("Bool", "Tuple") | |||
| def _tensors_allreduce_with_sparse(allreduce_filter, grad): | |||
| """ | |||
| Apply mean and allgather on gradient instead of allreduce for sparse feature. | |||
| Allgather is a communication operation used for distributed deep learning. | |||
| Args: | |||
| allreduce_filter (bool): When it is true, allgather would apply. | |||
| grad (Tuple): The indices, gradient tensor and tensor_shape before operation. | |||
| Returns: | |||
| Tuple, include indices, the gradient tensor and tensor_shape after operation. | |||
| """ | |||
| if allreduce_filter: | |||
| indices = _all_gather(grad[0]) | |||
| dout = _all_gather(grad[1]) | |||
| grad = (indices, dout, dout[2]) | |||
| return grad | |||
| _get_datatype = C.MultitypeFuncGraph("_get_datatype") | |||
| @@ -26,9 +26,10 @@ from .grad_base import bprop_getters | |||
| @bprop_getters.register(AllReduce) | |||
| def get_bprop_all_reduce(self): | |||
| """Generate bprop for AllReduce.""" | |||
| """Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature.""" | |||
| all_reduce_grad = AllReduce(ReduceOp.SUM, self.group) | |||
| all_gather = AllGather(group=self.group) | |||
| if self.instance_name: | |||
| instance_name = "grad" + self.instance_name | |||
| all_reduce_grad.set_prim_instance_name(instance_name) | |||
| @@ -42,15 +43,28 @@ def get_bprop_all_reduce(self): | |||
| if self.op == ReduceOp.SUM: | |||
| def bprop(x, out, dout): | |||
| dx = all_reduce_grad(dout) | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| dx = all_reduce_grad(dout) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| dx = (indices, grad, dout[2]) | |||
| return (dx,) | |||
| else: | |||
| def bprop(x, out, dout): | |||
| dx = all_reduce_grad(dout) | |||
| z = equal(x, out) | |||
| z = cast(z, dtype(dx)) | |||
| dx = mul(dx, z) | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| dx = all_reduce_grad(dout) | |||
| z = equal(x, out) | |||
| z = cast(z, dtype(dx)) | |||
| dx = mul(dx, z) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| z = equal(x, out) | |||
| z = cast(z, dtype(grad)) | |||
| grad = mul(grad, z) | |||
| dx = (indices, grad, dout[2]) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -147,12 +161,16 @@ def get_bprop_all_to_all(self): | |||
| @bprop_getters.register(_MirrorOperator) | |||
| def get_bprop_mirror_operator(self): | |||
| """Backpropagator for _MirrorOperator, do allreduce for the devices in group(only for one group).""" | |||
| """ | |||
| Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group), | |||
| allgather for sparse feature. | |||
| """ | |||
| group = self.group | |||
| dev_num = self.dev_num | |||
| mean_flag = self.mean_flag | |||
| all_reduce = AllReduce(group=group) | |||
| all_gather = AllGather(group=group) | |||
| mul = P.Mul() | |||
| cast = P.Cast() | |||
| @@ -170,12 +188,25 @@ def get_bprop_mirror_operator(self): | |||
| def bprop(x, out, dout): | |||
| if mean_flag: | |||
| dx = all_reduce(dout) | |||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | |||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | |||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| dx = all_reduce(dout) | |||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | |||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | |||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| float_one = F.scalar_cast(1.0, F.dtype(grad)) | |||
| num = F.scalar_cast(dev_num, F.dtype(grad)) | |||
| grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) | |||
| dx = (indices, grad, dout[2]) | |||
| else: | |||
| dx = all_reduce(dout) | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| dx = all_reduce(dout) | |||
| else: | |||
| indices = all_gather(dout[0]) | |||
| grad = all_gather(dout[1]) | |||
| dx = (indices, grad, dout[2]) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -0,0 +1,118 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test sparse feature bprop """ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||
| from mindspore.common.api import _executor | |||
| from mindspore.communication.management import HCCL_WORLD_COMM_GROUP | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x): | |||
| return C.grad_all(self.network)(x) | |||
| class VirtualGatherV2(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init index_select""" | |||
| super(VirtualGatherV2, self).__init__('VirtualGatherV2') | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||
| def __infer__(self, params, indices, axis): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||
| axis_v = axis['value'] | |||
| params_shp = params['shape'] | |||
| rank = len(params_shp) | |||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||
| if axis_v < 0: | |||
| axis_v += rank | |||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||
| out = {'shape': out_shape, | |||
| 'dtype': params['dtype'], | |||
| 'value': None} | |||
| return out | |||
| @bprop_getters.register(VirtualGatherV2) | |||
| def get_bprop_gather_v2(self): | |||
| """Generate bprop for GatherV2""" | |||
| def bprop(x, indices, axis, out, dout): | |||
| return (indices, dout, x), axis, out | |||
| return bprop | |||
| def test_bprop_with_sparse_feature_allreduce(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=0, shape=None): | |||
| super(Net, self).__init__() | |||
| if shape is None: | |||
| shape = [8, 8] | |||
| self.all_reduce = AllReduce() | |||
| self.gatherv2 = VirtualGatherV2() | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.axis = axis | |||
| def construct(self, x): | |||
| out = self.all_reduce(x) | |||
| out = self.gatherv2(out, self.index, self.axis) | |||
| return out | |||
| net = GradWrap(Net()) | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x) | |||
| def test_bprop_with_sparse_feature_mirror(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=0, shape=None): | |||
| super(Net, self).__init__() | |||
| if shape is None: | |||
| shape = [8, 8] | |||
| self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP) | |||
| self.gatherv2 = VirtualGatherV2() | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.axis = axis | |||
| def construct(self, x): | |||
| out = self.mirror(x) | |||
| out = self.gatherv2(out, self.index, self.axis) | |||
| return out | |||
| net = GradWrap(Net()) | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x) | |||