# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Generate bprop for comm ops""" from mindspore import Tensor import mindspore.common.dtype as mstype from mindspore.ops import functional as F from mindspore.communication import get_rank, get_group_size from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2, Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap, _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather) from .grad_base import bprop_getters from ..operations._inner_ops import Send, Receive from ..operations import _grad_ops as G @bprop_getters.register(AllReduce) def get_bprop_all_reduce(self): """Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature.""" all_reduce_grad = AllReduce(ReduceOp.SUM, self.group) all_gather = AllGather(group=self.group) if self.instance_name: instance_name = "grad" + self.instance_name all_reduce_grad.set_prim_instance_name(instance_name) equal = P.Equal() cast = P.Cast() mul = P.Mul() div = P.RealDiv() dtype = P.DType() if self.op == ReduceOp.PROD: def bprop(x, out, dout): dy1 = mul(dout, out) dy2 = all_reduce_grad(dy1) dx = div(dy2, x) return (dx,) elif self.op == ReduceOp.SUM: def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce_grad(dout) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) dx = RowTensor(indices, grad, dout.dense_shape) return (dx,) else: def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce_grad(dout) z = equal(x, out) z = cast(z, dtype(dx)) dx = mul(dx, z) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) z = equal(x, out) z = cast(z, dtype(grad)) grad = mul(grad, z) dx = RowTensor(indices, grad, dout.dense_shape) return (dx,) return bprop @bprop_getters.register(Send) def get_bprop_send(self): """Generate bprop for Send.""" shape = self.get_attr_dict()["shape"] dtype = self.get_attr_dict()["dtype"] send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group_back) virtual_input = Tensor(0.0, dtype) def bprop(x, out, dout): dx = send_grad(virtual_input) return (dx,) return bprop @bprop_getters.register(Receive) def get_bprop_receive(self): """Generate bprop for Receive.""" receive_grad = Send(self.tag, self.rank, self.group_back) depend = P.Depend() cast = P.Cast() out_tensor = Tensor(0.0, mstype.float16) is_opt_shard = _get_enable_parallel_optimizer() def bprop(x, out, dout): send_out = receive_grad(dout) if is_opt_shard: dx = depend(F.zeros_like(x), send_out) else: dx = depend(cast(out_tensor, F.dtype(x)), send_out) return (dx,) return bprop @bprop_getters.register(_VirtualAdd) def get_bprop_virtual_add(self): """Generate bprop for _VirtualAdd""" def bprop(x, grad_accu, out, dout): return (dout + grad_accu, zeros_like(grad_accu)) return bprop @bprop_getters.register(_VirtualAssignAdd) def get_bprop_virtual_assign_add(self): """Generate bprop for VirtualAssignAdd.""" assign_add = P.AssignAdd() cast = P.Cast() dtype = P.DType() out_tensor = Tensor(0.0, mstype.float16) reduce_scatter = None group = self.get_attr_dict().get("group", None) fusion = self.get_attr_dict().get("fusion", 0) if group: reduce_scatter = ReduceScatter(ReduceOp.SUM, group).add_prim_attr("fusion", fusion) if self.instance_name: instance_name = "_grad_accumulation_shard_grad" + self.instance_name reduce_scatter.set_prim_instance_name(instance_name) # For pipeline training, as the fused communication will be visited later # this may make memory increase, so we need to add a tag to let the # fused communication not be effective. reduce_scatter.add_prim_attr("not_delay_fusion", True) def bprop(x, y, out, dout): if reduce_scatter: dout = reduce_scatter(dout) temp = assign_add(y, dout) return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp) return bprop @bprop_getters.register(_VirtualAccuGrad) def get_bprop_virtual_accu_grad(self): """Generate bprop for VirtualAccuGrad.""" cast = P.Cast() dtype = P.DType() out_tensor = Tensor(0.0, mstype.float16) def bprop(x, y, out, dout): return (F.depend(y, dout), cast(out_tensor, dtype(y))) return bprop @bprop_getters.register(_MirrorMicroStepOperator) def get_bprop_mirror_micro_step_operator(self): """ Backpropagator for _MirrorMicroStepOperator, do allreduce or allgather for the devices in the group, allgather for sparse feature. """ group = self.group dev_num = self.dev_num mean_flag = self.mean_flag scale = 1 / dev_num all_reduce = AllReduce(group=group) fusion = self.get_attr_dict()["fusion"] all_reduce.add_prim_attr("fusion", fusion) if hasattr(self, 'parameter'): parameter = self.parameter all_reduce.add_prim_attr("parameter", parameter) if self.instance_name: instance_name = "grad_mirror" + self.instance_name all_reduce.set_prim_instance_name(instance_name) cast = P.Cast() dtype = P.DType() assign = P.Assign() if "parameter_micro" in self.get_attr_dict(): assign.add_prim_attr("parameter_micro", 0) out_tensor = Tensor(1.0, mstype.float16) opt_shard = _get_enable_parallel_optimizer() def bprop(x, z, out, dout): real_grad = z assign_out = dout if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): z = F.depend(z, dout) real_grad = all_reduce(z) real_grad = F.tensor_mul(real_grad, scale) assign_out = assign(z, real_grad) else: if F.issubclass_(F.typeof(dout), mstype.tensor): z = F.depend(z, dout) real_grad = all_reduce(z) assign_out = assign(z, real_grad) if opt_shard: return (real_grad, cast(out_tensor, dtype(z))) return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out) return bprop @bprop_getters.register(Broadcast) def get_bprop_broad_cast(self): """Generate bprop for Broadcast.""" def bprop(x, out, dout): return (dout,) return bprop @bprop_getters.register(AllGather) def get_bprop_all_gather(self): """Generate bprop for AllGather""" fusion = self.get_attr_dict()["fusion"] reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) if self.instance_name: instance_name = "grad_" + self.instance_name reduce_scatter.set_prim_instance_name(instance_name) mean_flag = self.get_attr_dict()["mean_flag"] scale = 1 / self.rank_size def bprop(x, out, dout): dx = reduce_scatter(dout) if mean_flag: dx = F.tensor_mul(dx, scale) return (dx,) return bprop @bprop_getters.register(_MiniStepAllGather) def get_bprop_mini_step_all_gather(self): """Generate bprop for _MiniStepAllGather""" fusion = self.get_attr_dict()["fusion"] mean_flag = self.get_attr_dict()["mean_flag"] do_mirror = self.get_attr_dict()["do_mirror"] add_accu = self.get_attr_dict().get("add_accu", False) gradient_shard = _get_grad_accumulation_shard() scale = 1 / self.rank_size all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) assign_add = P.AssignAdd() if self.instance_name: instance_name = "grad_" + self.instance_name all_reduce.set_prim_instance_name(instance_name) rank = get_rank(self.group) dev_num = get_group_size(self.group) split = P.Split(output_num=dev_num) def bprop(x, z, out, dout): if do_mirror: if not gradient_shard: z = F.depend(z, F.assign_add(z, dout)) grad = all_reduce(z) dx = split(grad)[rank] if mean_flag: dx = F.tensor_mul(dx, scale) else: dout = F.depend(dout, z) grad = all_reduce(dout) dx = split(grad)[rank] if mean_flag: dx = F.tensor_mul(dx, scale) if add_accu: z = assign_add(z, dx) dx = F.depend(dx, z) else: dx = dout return (dx, zeros_like(z)) return bprop @bprop_getters.register(_MicroStepAllGather) def get_bprop_micro_step_all_gather(self): """Generate bprop for _MicroStepAllGather""" fusion = self.get_attr_dict()["fusion"] mean_flag = self.get_attr_dict()["mean_flag"] do_mirror = self.get_attr_dict()["do_mirror"] scale = 1 / self.rank_size all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) rank = get_rank(self.group) dev_num = get_group_size(self.group) split = P.Split(output_num=dev_num) if self.instance_name: instance_name = "grad_" + self.instance_name all_reduce.set_prim_instance_name(instance_name) cast = P.Cast() dtype = P.DType() out_tensor = Tensor(1.0, mstype.float16) # z: accu_grad def bprop(x, z, out, dout): z = F.depend(z, dout) if not do_mirror: return (z, cast(out_tensor, dtype(z))) real_grad = all_reduce(z) real_grad = split(real_grad)[rank] if mean_flag: real_grad = F.tensor_mul(real_grad, scale) return (real_grad, cast(out_tensor, dtype(z))) return bprop @bprop_getters.register(_HostAllGather) def get_bprop_host_all_gather(self): """Generate bprop for _HostAllGather""" host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group) if self.instance_name: instance_name = "grad" + self.instance_name host_all_gather_grad.set_prim_instance_name(instance_name) def bprop(x, out, dout): dx = host_all_gather_grad(dout) return (dx,) return bprop @bprop_getters.register(ReduceScatter) def get_bprop_reduce_scatter(self): """Generate bprop for ReduceScatter""" reduce_scatter_grad = AllGather(self.group) if self.instance_name: instance_name = "grad" + self.instance_name reduce_scatter_grad.set_prim_instance_name(instance_name) if self.op != ReduceOp.SUM: raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.") def bprop(x, out, dout): dx = reduce_scatter_grad(dout) return (dx,) return bprop @bprop_getters.register(AllSwap) def get_bprop_allswap(self): """Generate bprop for AllSwap.""" all_swap_grad = AllSwap(self.group) if self.instance_name: instance_name = "grad" + self.instance_name all_swap_grad.set_prim_instance_name(instance_name) def bprop(x, send_size, recv_size, out, dout): dx = all_swap_grad(dout, recv_size, send_size) return (dx, zeros_like(send_size), zeros_like(recv_size)) return bprop @bprop_getters.register(_HostReduceScatter) def get_bprop_host_reduce_scatter(self): """Generate bprop for _HostReduceScatter""" host_reduce_scatter_grad = _HostAllGather(self.group) if self.instance_name: instance_name = "grad" + self.instance_name host_reduce_scatter_grad.set_prim_instance_name(instance_name) if self.op != ReduceOp.SUM: raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.") def bprop(x, out, dout): dx = host_reduce_scatter_grad(dout) return (dx,) return bprop @bprop_getters.register(NeighborExchange) def get_bprop_neighborexchange(self): """Generate bprop for NeighborExchange.""" group = self.group send_rank_ids = self.recv_rank_ids recv_rank_ids = self.send_rank_ids recv_shapes = self.send_shapes send_shapes = self.recv_shapes recv_type = self.recv_type neighborexchange_grad = NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, group) def bprop(x, out, dout): return (neighborexchange_grad(dout),) return bprop @bprop_getters.register(AlltoAll) def get_bprop_all_to_all(self): """Generate bprop for AlltoAll.""" all_to_all_grad = AlltoAll(self.split_count, self.concat_dim, self.split_dim, self.group) if self.instance_name: instance_name = "grad" + self.instance_name all_to_all_grad.set_prim_instance_name(instance_name) def bprop(x, out, dout): dx = all_to_all_grad(dout) return (dx,) return bprop @bprop_getters.register(NeighborExchangeV2) def get_bprop_neighborexchangev2(self): """Generate bprop for NeighborExchangeV2.""" group = self.group send_rank_ids = self.recv_rank_ids recv_rank_ids = self.send_rank_ids send_lens = self.recv_lens recv_lens = self.send_lens data_format = self.data_format neighborexchangev2_grad = G.NeighborExchangeV2Grad(send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group) def bprop(x, out, dout): return (neighborexchangev2_grad(dout),) return bprop @bprop_getters.register(_MirrorOperator) def get_bprop_mirror_operator(self): """ Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group), allgather for sparse feature. """ group = self.group dev_num = self.dev_num mean_flag = self.mean_flag all_reduce = AllReduce(group=group) all_gather = AllGather(group=group) mul = P.Mul() cast = P.Cast() fusion = self.get_attr_dict()["fusion"] all_reduce.add_prim_attr("fusion", fusion) if hasattr(self, 'parameter'): parameter = self.parameter all_reduce.add_prim_attr("parameter", parameter) if self.instance_name: instance_name = "grad_mirror" + self.instance_name all_reduce.set_prim_instance_name(instance_name) def bprop(x, out, dout): if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce(dout) float_one = F.scalar_cast(1.0, F.dtype(dx)) num = F.scalar_cast(dev_num, F.dtype(dx)) dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) float_one = F.scalar_cast(1.0, F.dtype(grad)) num = F.scalar_cast(dev_num, F.dtype(grad)) grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) dx = RowTensor(indices, grad, dout.dense_shape) else: if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce(dout) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) dx = RowTensor(indices, grad, dout.dense_shape) return (dx,) return bprop @bprop_getters.register(_MirrorMiniStepOperator) def get_bprop_mirror_mini_step_operator(self): """ Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group, allgather for sparse feature. """ group = self.group dev_num = self.dev_num mean_flag = self.mean_flag all_reduce = AllReduce(group=group) mul = P.Mul() cast = P.Cast() fusion = self.get_attr_dict()["fusion"] all_reduce.add_prim_attr("fusion", fusion) if hasattr(self, 'parameter'): parameter = self.parameter all_reduce.add_prim_attr("parameter", parameter) if self.instance_name: instance_name = "grad_mirror" + self.instance_name all_reduce.set_prim_instance_name(instance_name) do_mirror = self.get_attr_dict()["do_mirror"] def bprop(x, z, out, dout): if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: z = F.depend(z, F.assign_add(z, dout)) real_grad = all_reduce(z) dx = real_grad else: dx = dout float_one = F.scalar_cast(1.0, F.dtype(dx)) num = F.scalar_cast(dev_num, F.dtype(dx)) dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) else: dx = zeros_like(x) # The grad accumulation do not support row tensor now else: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: z = F.depend(z, F.assign_add(z, dout)) real_grad = all_reduce(z) dx = real_grad else: dx = dout else: dx = zeros_like(x) # The grad accumulation do not support row tensor now return (dx, zeros_like(z)) return bprop @bprop_getters.register(_VirtualDiv) def get_bprop_virtual_div_operator(self): """Backpropagator for _VirtualDiv, do Div for the divisor.""" divisor = self.divisor op = P.RealDiv() cast = P.Cast() dtype = P.DType() def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \ or F.issubclass_(F.dtype(dout), mstype.int16): return (dout,) dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout))) return (dx,) if F.issubclass_(F.typeof(dout), mstype.tuple_): dx = () input_nums = F.tuple_len(dout) for i in range(input_nums): ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i]))) dx = dx + (ele_grad,) return (dx,) dx = [] input_nums = F.list_len(dout) for i in range(input_nums): ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i]))) dx.append(ele_grad) return (dx,) return bprop @bprop_getters.register(_GetTensorSlice) def get_bprop_get_tensor_slice_operator(self): """Backpropagator for _GetTensorSlice""" def bprop(x, dev_mat, tensor_map, out, dout): return (zeros_like(x),) return bprop