| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Define the grad rules of neural network related operations.""" | """Define the grad rules of neural network related operations.""" | ||||
| import math | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import _selected_grad_ops as SG | from mindspore.ops import _selected_grad_ops as SG | ||||
| from mindspore.ops.primitive import constexpr | from mindspore.ops.primitive import constexpr | ||||
| @@ -628,19 +629,62 @@ def get_bprop_onehot(self): | |||||
| return bprop | return bprop | ||||
| @constexpr | |||||
| def _range_op(start, limit, delta, dtype): | |||||
| """helper function for Grad TopK""" | |||||
| range_op = inner.Range(float(start), float(limit), float(delta)) | |||||
| length_input = math.ceil((limit - start) / delta) | |||||
| input_tensor = Tensor(list(range(length_input)), dtype) | |||||
| range_out = range_op(input_tensor) | |||||
| return range_out | |||||
| @constexpr | |||||
| def _get_1d_shape(in_shape): | |||||
| """helper function for Grad TopK""" | |||||
| out_shape = 1 | |||||
| for i in in_shape: | |||||
| out_shape *= i | |||||
| return (out_shape,) | |||||
| @bprop_getters.register(P.TopK) | @bprop_getters.register(P.TopK) | ||||
| def get_bprop_top_kv2(self): | def get_bprop_top_kv2(self): | ||||
| """Grad definition for `TopK` operation.""" | """Grad definition for `TopK` operation.""" | ||||
| scatter = P.ScatterNd() | scatter = P.ScatterNd() | ||||
| expand_dims = P.ExpandDims() | expand_dims = P.ExpandDims() | ||||
| shape_op = P.Shape() | shape_op = P.Shape() | ||||
| reshape_op = P.Reshape() | |||||
| dtype = P.DType() | |||||
| def bprop(input_x, k, out, dout): | def bprop(input_x, k, out, dout): | ||||
| # (n1, n2, ...., n_p), in_lastdim = n_p | |||||
| in_shape = shape_op(input_x) | |||||
| in_lastdim = in_shape[-1] | |||||
| # (n_1, ... n_(p-1), k), ind_lastdim = k | |||||
| indices = out[1] | indices = out[1] | ||||
| indices = expand_dims(indices, -1) | |||||
| updates = dout[0] | |||||
| shapes = shape_op(input_x) | |||||
| return scatter(indices, updates, shapes), zeros_like(k) | |||||
| ind_shape = shape_op(indices) | |||||
| ind_lastdim = ind_shape[-1] | |||||
| # (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1) | |||||
| ind_2d = reshape_op(indices, (-1, ind_lastdim)) | |||||
| outerdim = shape_op(ind_2d)[0] | |||||
| # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] | |||||
| indices_dtype = dtype(indices) | |||||
| range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype) | |||||
| # expand_dims to (k, 1), then broadcast | |||||
| ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,)) | |||||
| in_shape_1d = _get_1d_shape(in_shape) | |||||
| out_grad = reshape_op( | |||||
| scatter( | |||||
| expand_dims(ind, -1), | |||||
| reshape_op(dout[0], (-1,)), | |||||
| in_shape_1d), | |||||
| in_shape) | |||||
| return out_grad, zeros_like(k) | |||||
| return bprop | return bprop | ||||