diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 07b8c60252..f259361e05 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -508,16 +508,40 @@ def get_bprop_sigmoid_grad(self): return bprop +@constexpr +def _get_transpose_axis(x_shp, axis): + rank = len(x_shp) + if axis < 0: + axis += rank + reverse_axis = [i for i in range(rank)] + reverse_axis[axis] = rank - 1 + reverse_axis[rank - 1] = axis + return tuple(reverse_axis) + + @bprop_getters.register(P.Softmax) def get_bprop_softmax(self): """Grad definition for `Softmax` operation.""" sum_func = P.ReduceSum(keep_dims=True) sub = P.Sub() mul = P.Mul() + get_shape = P.Shape() + transpose = P.Transpose() axis = self.axis + if not isinstance(axis, int): + axis = axis[0] def bprop(x, out, dout): - dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) + # dx = (dout - sum(dout * out)) * out + # This formula is correct only when the `axis` is the last dimension. + # In order to support the scenario where the `axis` is other values, + # we transpose the data of the `axis` dimension to the last dimension for calculation, + # and then transpose it back after the calculation. + reverse_axis = _get_transpose_axis(get_shape(x), axis) + out = transpose(out, reverse_axis) + dout = transpose(dout, reverse_axis) + dx = mul(out, sub(dout, sum_func(mul(out, dout), -1))) + dx = transpose(dx, reverse_axis) return (dx,) return bprop