|
|
|
@@ -508,16 +508,40 @@ def get_bprop_sigmoid_grad(self): |
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _get_transpose_axis(x_shp, axis): |
|
|
|
rank = len(x_shp) |
|
|
|
if axis < 0: |
|
|
|
axis += rank |
|
|
|
reverse_axis = [i for i in range(rank)] |
|
|
|
reverse_axis[axis] = rank - 1 |
|
|
|
reverse_axis[rank - 1] = axis |
|
|
|
return tuple(reverse_axis) |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Softmax) |
|
|
|
def get_bprop_softmax(self): |
|
|
|
"""Grad definition for `Softmax` operation.""" |
|
|
|
sum_func = P.ReduceSum(keep_dims=True) |
|
|
|
sub = P.Sub() |
|
|
|
mul = P.Mul() |
|
|
|
get_shape = P.Shape() |
|
|
|
transpose = P.Transpose() |
|
|
|
axis = self.axis |
|
|
|
if not isinstance(axis, int): |
|
|
|
axis = axis[0] |
|
|
|
|
|
|
|
def bprop(x, out, dout): |
|
|
|
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) |
|
|
|
# dx = (dout - sum(dout * out)) * out |
|
|
|
# This formula is correct only when the `axis` is the last dimension. |
|
|
|
# In order to support the scenario where the `axis` is other values, |
|
|
|
# we transpose the data of the `axis` dimension to the last dimension for calculation, |
|
|
|
# and then transpose it back after the calculation. |
|
|
|
reverse_axis = _get_transpose_axis(get_shape(x), axis) |
|
|
|
out = transpose(out, reverse_axis) |
|
|
|
dout = transpose(dout, reverse_axis) |
|
|
|
dx = mul(out, sub(dout, sum_func(mul(out, dout), -1))) |
|
|
|
dx = transpose(dx, reverse_axis) |
|
|
|
return (dx,) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|