| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Define the grad rules of neural network related operations.""" | |||
| import math | |||
| import numpy as np | |||
| from mindspore.ops import _selected_grad_ops as SG | |||
| from mindspore.ops.primitive import constexpr | |||
| @@ -628,19 +629,62 @@ def get_bprop_onehot(self): | |||
| return bprop | |||
| @constexpr | |||
| def _range_op(start, limit, delta, dtype): | |||
| """helper function for Grad TopK""" | |||
| range_op = inner.Range(float(start), float(limit), float(delta)) | |||
| length_input = math.ceil((limit - start) / delta) | |||
| input_tensor = Tensor(list(range(length_input)), dtype) | |||
| range_out = range_op(input_tensor) | |||
| return range_out | |||
| @constexpr | |||
| def _get_1d_shape(in_shape): | |||
| """helper function for Grad TopK""" | |||
| out_shape = 1 | |||
| for i in in_shape: | |||
| out_shape *= i | |||
| return (out_shape,) | |||
| @bprop_getters.register(P.TopK) | |||
| def get_bprop_top_kv2(self): | |||
| """Grad definition for `TopK` operation.""" | |||
| scatter = P.ScatterNd() | |||
| expand_dims = P.ExpandDims() | |||
| shape_op = P.Shape() | |||
| reshape_op = P.Reshape() | |||
| dtype = P.DType() | |||
| def bprop(input_x, k, out, dout): | |||
| # (n1, n2, ...., n_p), in_lastdim = n_p | |||
| in_shape = shape_op(input_x) | |||
| in_lastdim = in_shape[-1] | |||
| # (n_1, ... n_(p-1), k), ind_lastdim = k | |||
| indices = out[1] | |||
| indices = expand_dims(indices, -1) | |||
| updates = dout[0] | |||
| shapes = shape_op(input_x) | |||
| return scatter(indices, updates, shapes), zeros_like(k) | |||
| ind_shape = shape_op(indices) | |||
| ind_lastdim = ind_shape[-1] | |||
| # (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1) | |||
| ind_2d = reshape_op(indices, (-1, ind_lastdim)) | |||
| outerdim = shape_op(ind_2d)[0] | |||
| # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] | |||
| indices_dtype = dtype(indices) | |||
| range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype) | |||
| # expand_dims to (k, 1), then broadcast | |||
| ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,)) | |||
| in_shape_1d = _get_1d_shape(in_shape) | |||
| out_grad = reshape_op( | |||
| scatter( | |||
| expand_dims(ind, -1), | |||
| reshape_op(dout[0], (-1,)), | |||
| in_shape_1d), | |||
| in_shape) | |||
| return out_grad, zeros_like(k) | |||
| return bprop | |||