# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Define the grad rules of math related operations.""" from functools import reduce import numpy as np from .. import functional as F from .. import operations as P from ..operations import _grad_ops as G from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..functional import broadcast_gradient_args, reduced_shape, tuple_div from .grad_base import bprop_getters from ..primitive import constexpr shape_op = P.Shape() reduce_sum = P.ReduceSum() reshape = P.Reshape() tile = P.Tile() def binop_grad_common(x, y, dx, dy): """ Common grad definition for binary operations. The function is usually used in backprop op to reduce additional dimensions created by broadcasting. """ shape_of_x = shape_op(x) shape_of_y = shape_op(y) rx = broadcast_gradient_args(shape_of_x, shape_of_y) # if input shape is the same as dout shape, do not need to reduce reduce_dx = dx reduce_dy = dy if rx[0]: # if dx is scalar whose shape is (), do not need reduce if shape_op(dx): dx = reduce_sum(dx, rx[0]) reduce_dx = reshape(dx, shape_of_x) if rx[1]: # if dy is scalar whose shape is (), do not need reduce if shape_op(dy): dy = reduce_sum(dy, rx[1]) reduce_dy = reshape(dy, shape_of_y) return reduce_dx, reduce_dy def _sum_grad(x, axis, dout): """Grad definition for `Sum` operation.""" # input_shape = [2, 3] axis = [1] input_shape = shape_op(x) # output_shape_kept_dims = [2, 1] output_shape_kept_dims = reduced_shape(input_shape, axis) # tile_scaling = [1, 3] tile_scaling = tuple_div(input_shape, output_shape_kept_dims) grad = reshape(dout, output_shape_kept_dims) return tile(grad, tile_scaling) def _min_or_max_grad(x, axis, out, dout): """Grad definition for `Min` and `Max` operations.""" # input_shape = [2, 3] axis = [1] input_shape = shape_op(x) # output_shape_kept_dims = [2, 1] output_shape_kept_dims = reduced_shape(input_shape, axis) y = reshape(out, output_shape_kept_dims) grad = reshape(dout, output_shape_kept_dims) indicators = F.cast(F.equal(y, x), F.dtype(grad)) min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad)) num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num return indicators / num_selected * grad def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout): """ArgMinWiwhValue and ArgMaxWithValue grad.""" expand = P.ExpandDims() x_shape = F.shape(x) x_dim = len(x_shape) x_axis = axis if x_axis < 0: x_axis = axis + x_dim onehot_axis = x_axis depth = x_shape[x_axis] if keep_dims: dout_expand = dout[1] out = op(x) else: dout_expand = expand(dout[1], onehot_axis) if onehot_axis >= len(shape_op(out[0])): onehot_axis = -1 onehot = P.OneHot(onehot_axis) type_x = F.dtype(x) on_value = F.cast(F.scalar_to_array(1.0), type_x) off_value = F.cast(F.scalar_to_array(0.0), type_x) dx = dout_expand * onehot(out[0], depth, on_value, off_value) return dx @bprop_getters.register(P.MatMul) def bprop_matmul(self): """Grad definition for `MatMul` operation.""" ta = self.transpose_a tb = self.transpose_b mul1 = P.MatMul(transpose_a=(ta and tb), transpose_b=(ta or (not tb))) mul2 = P.MatMul(transpose_a=((not ta) or tb), transpose_b=(ta and tb)) def bprop(x, w, out, dout): if ta: dx = mul1(w, dout) else: dx = mul1(dout, w) if tb: dw = mul2(dout, x) else: dw = mul2(x, dout) return dx, dw return bprop @bprop_getters.register(P.BatchMatMul) def bprop_batchmatmul(self): """Grad definition for `BatchMatMul` operation.""" ta = self.transpose_a tb = self.transpose_b mul1 = P.BatchMatMul(transpose_a=(ta and tb), transpose_b=(ta or (not tb))) mul2 = P.BatchMatMul(transpose_a=((not ta) or tb), transpose_b=(ta and tb)) def bprop(x, w, out, dout): if ta: dx = mul1(w, dout) else: dx = mul1(dout, w) if tb: dw = mul2(dout, x) else: dw = mul2(x, dout) return dx, dw return bprop @bprop_getters.register(P.TensorAdd) def get_bprop_tensor_add(self): """Grad definition for `TensorAdd` operation.""" def bprop(x, y, out, dout): return binop_grad_common(x, y, dout, dout) return bprop @bprop_getters.register(P.Neg) def get_bprop_neg(self): """Grad definition for `Neg` operation.""" neg_grad = P.Neg() def bprop(x, out, dout): dx = neg_grad(dout) return (dx,) return bprop @bprop_getters.register(P.Sub) def get_bprop_sub(self): """Grad definition for `Sub` operation.""" neg_func = P.Neg() def bprop(x, y, out, dout): return binop_grad_common(x, y, dout, neg_func(dout)) return bprop @bprop_getters.register(P.Mul) def get_bprop_mul(self): """Grad definition for `Mul` operation.""" mul_func = P.Mul() def bprop(x, y, out, dout): bc_dx = mul_func(dout, y) bc_dy = mul_func(dout, x) return binop_grad_common(x, y, bc_dx, bc_dy) return bprop @bprop_getters.register(P.RealDiv) def get_bprop_real_div(self): """Grad definition for `RealDiv` operation.""" div_op = P.RealDiv() neg = P.Neg() mul_op = P.Mul() def bprop(x, y, out, dout): bc_x = div_op(dout, y) bc_y = neg(mul_op(bc_x, out)) return binop_grad_common(x, y, bc_x, bc_y) return bprop @bprop_getters.register(P.Div) def get_bprop_div(self): """Grad definition for `Div` operation.""" div_op = P.Div() neg = P.Neg() mul_op = P.Mul() def bprop(x, y, out, dout): bc_x = div_op(dout, y) bc_y = neg(mul_op(bc_x, out)) return binop_grad_common(x, y, bc_x, bc_y) return bprop @bprop_getters.register(P.Floor) def get_bprop_floor(self): """Grad definition for `floor` operation.""" fill_ = P.Fill() shape_ = P.Shape() dtype_ = P.DType() def bprop(x, out, dout): bc_x = fill_(dtype_(x), shape_(x), 0.) return (bc_x,) return bprop @bprop_getters.register(P.FloorDiv) def get_bprop_floordiv(self): """Grad definition for `FloorDiv` operation.""" div_op = P.FloorDiv() neg = P.Neg() mul_op = P.Mul() def bprop(x, y, out, dout): bc_x = div_op(dout, y) bc_y = neg(mul_op(bc_x, out)) return binop_grad_common(x, y, bc_x, bc_y) return bprop @bprop_getters.register(P.FloorMod) def get_bprop_floormod(self): """Grad definition for `FloorMod` operation.""" div_op = P.FloorMod() neg = P.Neg() mul_op = P.Mul() def bprop(x, y, out, dout): bc_x = div_op(dout, y) bc_y = neg(mul_op(bc_x, out)) return binop_grad_common(x, y, bc_x, bc_y) return bprop @bprop_getters.register(P.Square) def get_bprop_square(self): """Grad definition for `Square` operation.""" mul_func = P.Mul() fill_func = P.Fill() dtype = P.DType() def bprop(x, out, dout): temp = mul_func(dout, x) dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp) return (dx,) return bprop @bprop_getters.register(P.Sqrt) def get_bprop_sqrt(self): """Grad definition for `Sqrt` operation.""" mul_func = P.Mul() fill_func = P.Fill() div_op = P.RealDiv() sqrt = P.Sqrt() dtype = P.DType() def bprop(x, out, dout): temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x)) dx = mul_func(dout, temp) return (dx,) return bprop @bprop_getters.register(P.Rsqrt) def get_bprop_rsqrt(self): """Grad definition for `Rsqrt` operation.""" def bprop(x, out, dout): grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x)*x) dx = dout * grad return (dx,) return bprop @bprop_getters.register(P.Reciprocal) def get_bprop_reciprocal(self): """Grad definition for `Reciprocal` operation.""" neg = P.Neg() mul = P.Mul() square = P.Square() reciprocal = P.Reciprocal() def bprop(x, out, dout): g = neg(reciprocal(square(x))) dx = mul(dout, g) return (dx,) return bprop @bprop_getters.register(P.Log) def get_bprop_log(self): """Grad definition for `Log` operation.""" reciprocal = P.Reciprocal() def bprop(x, out, dout): g = reciprocal(x) dx = g * dout return dx, 0 return bprop @bprop_getters.register(P.Erf) def get_bprop_erf(self): """Grad definition for `Erf` operation.""" exp = P.Exp() square = P.Square() sqrt = P.Sqrt() cast = P.Cast() dtype = P.DType() def bprop(x, out, dout): half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x)) x_square = square(x) dx = dout * half_root_pi * exp(-x_square) return (dx,) return bprop @bprop_getters.register(P.Pow) def get_bprop_pow(self): """Grad definition for `Pow` operation.""" pow_op = P.Pow() ln = P.Log() def bprop(x, power, out, dout): bc_dx = power * pow_op(x, power - 1.0) * dout bc_dpower = out * ln(x) * dout return binop_grad_common(x, power, bc_dx, bc_dpower) return bprop @bprop_getters.register(P.Exp) def get_bprop_exp(self): """Grad definition for `Exp` operation.""" exp_ = P.Exp() def bprop(x, out, dout): g = exp_(x) dx = g * dout return (dx,) return bprop @bprop_getters.register(P.Minimum) def get_bprop_minimum(self): """Grad definition for `Minimum` operation.""" input_grad = G.MinimumGrad() def bprop(x, y, out, dout): dx, dy = input_grad(x, y, dout) return dx, dy return bprop @bprop_getters.register(P.Maximum) def get_bprop_maximum(self): """Grad definition for `Maximum` operation.""" input_grad = G.MaximumGrad() def bprop(x, y, out, dout): dx, dy = input_grad(x, y, dout) return dx, dy return bprop @bprop_getters.register(P.ReduceSum) def get_bprop_reducesum(self): """Grad definition for `ReduceSum` operation.""" def bprop(x, axis, out, dout): dx = _sum_grad(x, axis, dout) return dx, zeros_like(axis) return bprop @bprop_getters.register(P.CumSum) def get_bprop_cumsum(self): """Grad definition for `CumSum` operation.""" cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse) def bprop(x, axis, out, dout): return cumsum(dout, axis), zeros_like(axis) return bprop @constexpr def _split_shape_index(input_shape, axis): """Calculate reduce_prod grad transpose indices and perm shape.""" rank = len(input_shape) if isinstance(axis, int): axis = tuple([axis]) reduction_indices = tuple([(i + rank) % rank for i in axis]) other_indices = tuple(set(range(rank)) - set(reduction_indices)) reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices]) other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices]) perm = reduction_indices + other_indices return tuple([reduced_num, other_num]), perm @constexpr def _invert_permutation(perm): """Calculate invert permutation.""" out = [0] * len(perm) for i, value in enumerate(perm): out[value] = i return tuple(out) @bprop_getters.register(P.ReduceProd) def get_bprop_reduceprod(self): """Grad definition for `ReduceProd` operation.""" transpose = P.Transpose() left_cumprod = P.CumProd(exclusive=True) right_cumprod = P.CumProd(exclusive=True, reverse=True) def bprop(x, axis, out, dout): """Grad definition for `Product` operation.""" # Expand dout to full input shape input_shape = shape_op(x) output_shape_kept_dims = reduced_shape(input_shape, axis) dout = reshape(dout, output_shape_kept_dims) tile_scaling = tuple_div(input_shape, output_shape_kept_dims) grad = tile(dout, tile_scaling) # Pack all reduced dimensions into a single one, so we can perform the cumprod ops. pack_shape, perm = _split_shape_index(input_shape, axis) permuted = transpose(x, perm) permuted_shape = shape_op(permuted) reshaped = reshape(permuted, pack_shape) # Calculate product, leaving out the current entry left = left_cumprod(reshaped, 0) right = right_cumprod(reshaped, 0) y = reshape(left * right, permuted_shape) # Invert the transpose and reshape operations. # Make sure to set the statically known shape information through a reshape. out = transpose(y, _invert_permutation(perm)) * grad dx = reshape(out, input_shape) return dx, zeros_like(axis) return bprop @bprop_getters.register(P.CumProd) def get_bprop_cumprod(self): """Grad definition for `CumProd` operation.""" cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse) cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse) def bprop(x, axis, out, dout): """Grad definition for `Product` operation.""" # This will fails when x contains 0 prod = cumprod(x, axis) out = cumsum(prod * dout, axis) return out / x, zeros_like(axis) return bprop @bprop_getters.register(P.ReduceAll) def get_bprop_reduceall(self): """Grad definition for `ReduceAll` operation.""" def bprop(x, axis, out, dout): return zeros_like(x), zeros_like(axis) return bprop @bprop_getters.register(P.ReduceMax) def get_bprop_reducemax(self): """Grad definition for `Max` operation.""" def bprop(x, axis, out, dout): dx = _min_or_max_grad(x, axis, out, dout) return (dx, zeros_like(axis)) return bprop @bprop_getters.register(P.ArgMaxWithValue) def get_bprop_argmaxwithvalue(self): """Grad definition for `ArgMaxWithValue` operation.""" axis = self.axis keep_dims = self.keep_dims op = P.ArgMaxWithValue(axis) def bprop(x, out, dout): dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout) return (dx,) return bprop @bprop_getters.register(P.ReduceMin) def get_bprop_reducemin(self): """Grad definition for `ReduceMin` operation.""" def bprop(x, axis, out, dout): dx = _min_or_max_grad(x, axis, out, dout) return (dx, zeros_like(axis)) return bprop @bprop_getters.register(P.ArgMinWithValue) def get_bprop_argminwithvalue(self): """Generate bprop for ArgMinWithValue""" axis = self.axis keep_dims = self.keep_dims op = P.ArgMinWithValue(axis) def bprop(x, out, dout): dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout) return (dx,) return bprop @bprop_getters.register(P.ReduceMean) def get_bprop_reduce_mean(self): """Grad definition for `ReduceMean` operation.""" div_op = P.RealDiv() cast = P.Cast() dtype = P.DType() def bprop(x, axis, out, dout): grad = _sum_grad(x, axis, dout) div_shape = F.shape_mul(shape_op(x)) / F.shape_mul(shape_op(out)) dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad))) return dx, zeros_like(axis) return bprop @bprop_getters.register(P.Equal) def get_bprop_equal(self): """Grad definition for `Equal` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.NotEqual) def get_bprop_not_equal(self): """Grad definition for `NotEqual` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.Greater) def get_bprop_greater(self): """Grad definition for `Greater` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.GreaterEqual) def get_bprop_greater_equal(self): """Grad definition for `GreaterEqual` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.Less) def get_bprop_less(self): """Grad definition for `Less` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.LessEqual) def get_bprop_less_equal(self): """Grad definition for `LessEqual` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.LogicalNot) def get_bprop_logical_not(self): """Grad definition for `LogicalNot` operation.""" def bprop(x, out, dout): return (zeros_like(x),) return bprop @bprop_getters.register(P.LogicalAnd) def get_bprop_logical_and(self): """Grad definition for `LogicalAnd` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.LogicalOr) def get_bprop_logical_or(self): """Grad definition for `LogicalOr` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.NPUAllocFloatStatus) def get_bprop_npu_alloc_float_status(self): """Grad definition for `NPUAllocFloatStatus` operation.""" def bprop(out, dout): return () return bprop @bprop_getters.register(P.NPUGetFloatStatus) def get_bprop_npu_get_float_status(self): """Grad definition for `NPUGetFloatStatus` operation.""" def bprop(x, out, dout): return (zeros_like(x),) return bprop @bprop_getters.register(P.NPUClearFloatStatus) def get_bprop_npu_clear_float_status(self): """Grad definition for `NPUClearFloatStatus` operation.""" def bprop(x, out, dout): return (zeros_like(x),) return bprop @bprop_getters.register(P.AssignAdd) def get_bprop_assign_add(self): """Grad definition for `AssignAdd` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.AssignSub) def get_bprop_assign_sub(self): """Grad definition for `AssignSub` operation.""" def bprop(x, y, out, dout): return zeros_like(x), zeros_like(y) return bprop @bprop_getters.register(P.Sin) def get_bprop_sin(self): """Grad definition for `Sin` operation.""" cos = P.Cos() def bprop(x, out, dout): dx = dout*cos(x) return (dx,) return bprop @bprop_getters.register(P.Cos) def get_bprop_cos(self): """Grad definition for `Cos` operation.""" sin = P.Sin() neg = P.Neg() def bprop(x, out, dout): dx = dout*neg(sin(x)) return (dx,) return bprop @bprop_getters.register(P.ACos) def get_bprop_acos(self): """Grad definition for `ACos` operation.""" input_grad = G.ACosGrad() def bprop(x, out, dout): dx = input_grad(x, dout) return (dx,) return bprop @bprop_getters.register(P.Acosh) def get_bprop_acosh(self): """Grad definition for `Acosh` operation.""" input_grad = G.AcoshGrad() def bprop(x, out, dout): dx = input_grad(x, dout) return (dx,) return bprop @bprop_getters.register(P.Abs) def get_bprop_abs(self): """Grad definition for `Abs` operation.""" abs_grad = G.AbsGrad() def bprop(x, out, dout): dx = abs_grad(x, dout) return (dx,) return bprop @bprop_getters.register(P.ScalarCast) def get_bprop_scalar_cast(self): """Generate bprop for ScalarCast""" def bprop(x, t, out, dout): return F.scalar_cast(dout, F.typeof(x)), zeros_like(t) return bprop @bprop_getters.register(P.AddN) def get_bprop_scalar_addn(self): """Generate bprop for AddN""" def bprop(x, out, dout): dx = () for _ in range(len(x)): dx = dx + (dout,) return dx return bprop @bprop_getters.register(P.Sign) def get_bprop_sign(self): """Generate bprop for Sign""" def bprop(x, out, dout): return (zeros_like(x),) return bprop @bprop_getters.register(P.Round) def get_bprop_round(self): """Generate bprop for Round""" def bprop(x, out, dout): return (zeros_like(x),) return bprop @bprop_getters.register(P.Atan2) def get_bprop_atan2(self): """Generate bprop for Atan2""" square = P.Square() def bprop(x, y, out, dout): tmp = dout / (square(x) + square(y)) dx = tmp * y dy = tmp * (-x) return (dx, dy) return bprop