Browse Source

!10219 Fix softmax prop when axis is not the last dimension.

From: @liangchenghui
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @wuxuejian
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2599aefad0
1 changed files with 25 additions and 1 deletions
  1. +25
    -1
      mindspore/ops/_grad/grad_nn_ops.py

+ 25
- 1
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -508,16 +508,40 @@ def get_bprop_sigmoid_grad(self):
return bprop


@constexpr
def _get_transpose_axis(x_shp, axis):
rank = len(x_shp)
if axis < 0:
axis += rank
reverse_axis = [i for i in range(rank)]
reverse_axis[axis] = rank - 1
reverse_axis[rank - 1] = axis
return tuple(reverse_axis)


@bprop_getters.register(P.Softmax)
def get_bprop_softmax(self):
"""Grad definition for `Softmax` operation."""
sum_func = P.ReduceSum(keep_dims=True)
sub = P.Sub()
mul = P.Mul()
get_shape = P.Shape()
transpose = P.Transpose()
axis = self.axis
if not isinstance(axis, int):
axis = axis[0]

def bprop(x, out, dout):
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
# dx = (dout - sum(dout * out)) * out
# This formula is correct only when the `axis` is the last dimension.
# In order to support the scenario where the `axis` is other values,
# we transpose the data of the `axis` dimension to the last dimension for calculation,
# and then transpose it back after the calculation.
reverse_axis = _get_transpose_axis(get_shape(x), axis)
out = transpose(out, reverse_axis)
dout = transpose(dout, reverse_axis)
dx = mul(out, sub(dout, sum_func(mul(out, dout), -1)))
dx = transpose(dx, reverse_axis)
return (dx,)

return bprop


Loading…
Cancel
Save