|
|
|
@@ -15,8 +15,10 @@ |
|
|
|
|
|
|
|
"""array_ops""" |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import mindspore as ms |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from .. import operations as P |
|
|
|
from ..operations import _grad_ops as G |
|
|
|
from ..operations import _inner_ops as inner |
|
|
|
@@ -459,6 +461,87 @@ def get_bprop_sparse_gather_v2(self): |
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _range_op(start, limit, delta, dtype): |
|
|
|
"""helper function for grad of Sort""" |
|
|
|
output_tensor = Tensor(list(range(start, limit, delta)), dtype) |
|
|
|
return output_tensor |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _get_1d_shape(in_shape): |
|
|
|
"""helper function for grad of Sort""" |
|
|
|
out_shape = 1 |
|
|
|
for i in in_shape: |
|
|
|
out_shape *= i |
|
|
|
return (out_shape,) |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _get_transposition(axis, rank): |
|
|
|
"""helper function for grad of Sort""" |
|
|
|
if axis < 0: |
|
|
|
axis += rank |
|
|
|
transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]] |
|
|
|
trans = tuple(transposition.tolist()) |
|
|
|
return trans |
|
|
|
|
|
|
|
@bprop_getters.register(P.Sort) |
|
|
|
def get_bprop_sort(self): |
|
|
|
"""Grad definition for `Sort` operation.""" |
|
|
|
axis = self.axis |
|
|
|
descending = self.descending |
|
|
|
scatter = P.ScatterNd() |
|
|
|
expand_dims = P.ExpandDims() |
|
|
|
reshape_op = P.Reshape() |
|
|
|
dtype = P.DType() |
|
|
|
topk = P.TopK() |
|
|
|
neg = P.Neg() |
|
|
|
tranpose = P.Transpose() |
|
|
|
|
|
|
|
def bprop(input_x, out, dout): |
|
|
|
x_shape = input_x.shape |
|
|
|
k = x_shape[axis] |
|
|
|
rank = F.rank(input_x) |
|
|
|
dvalue = dout[0] |
|
|
|
if not descending: |
|
|
|
input_x = neg(input_x) |
|
|
|
dvalue = neg(dvalue) |
|
|
|
if axis == -1 or (axis + 1) == rank: |
|
|
|
transposition = None |
|
|
|
top_k_input = input_x |
|
|
|
else: |
|
|
|
transposition = _get_transposition(axis, rank) |
|
|
|
top_k_input = tranpose(input_x, transposition) |
|
|
|
|
|
|
|
_, indices = topk(top_k_input, k) |
|
|
|
ind_shape = indices.shape |
|
|
|
top_k_input_shape = top_k_input.shape |
|
|
|
in_lastdim = top_k_input_shape[-1] |
|
|
|
ind_lastdim = ind_shape[-1] |
|
|
|
ind_2d = reshape_op(indices, (-1, ind_lastdim)) |
|
|
|
outer_dim = ind_2d.shape[0] |
|
|
|
|
|
|
|
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] |
|
|
|
indices_dtype = dtype(indices) |
|
|
|
range_flatten_index = _range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype) |
|
|
|
|
|
|
|
# expand_dims to (k, 1), then broadcast |
|
|
|
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,)) |
|
|
|
x_shape_1d = _get_1d_shape(top_k_input_shape) |
|
|
|
|
|
|
|
if transposition is not None: |
|
|
|
dvalue = tranpose(dvalue, invert_permutation(transposition)) |
|
|
|
out_grad = reshape_op( |
|
|
|
scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape) |
|
|
|
dx = tranpose(out_grad, invert_permutation(transposition)) |
|
|
|
else: |
|
|
|
dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape) |
|
|
|
if not descending: |
|
|
|
dx = neg(dx) |
|
|
|
return (dx,) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Identity) |
|
|
|
def get_bprop_identity(self): |
|
|
|
"""Generate bprop for Identity""" |
|
|
|
@@ -475,6 +558,7 @@ def get_bprop_range(self): |
|
|
|
|
|
|
|
def bprop(x, out, dout): |
|
|
|
return (zeros_like(x),) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@@ -506,7 +590,7 @@ def get_bprop_reverse_v2(self): |
|
|
|
dx = reverse_grad(dout) |
|
|
|
return (dx,) |
|
|
|
|
|
|
|
return bprop |
|
|
|
return bprop |
|
|
|
|
|
|
|
@bprop_getters.register(P.Unpack) |
|
|
|
def get_bprop_unpack(self): |
|
|
|
|