From: @liangzhibo Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/12113/MERGE
| @@ -15,12 +15,13 @@ | |||
| """Providing decorators.""" | |||
| def deprecated(version, substitute): | |||
| def deprecated(version, substitute, use_substitute_name=False): | |||
| """deprecated warning | |||
| Args: | |||
| version (str): version that the operator or function is deprecated. | |||
| substitute (str): the substitute name for deprecated operator or function. | |||
| use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function | |||
| """ | |||
| def decorate(func): | |||
| @@ -29,6 +30,8 @@ def deprecated(version, substitute): | |||
| name = cls.__name__ if cls else func.__name__ | |||
| print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " | |||
| f"use '{substitute}' instead.") | |||
| if cls and use_substitute_name: | |||
| cls.substitute_name = substitute | |||
| ret = func(*args, **kwargs) | |||
| return ret | |||
| @@ -393,6 +393,39 @@ def _regenerate_output_shape(x_shp, ind_shp, axis): | |||
| @bprop_getters.register(P.Gather) | |||
| def get_bprop_gather(self): | |||
| """Generate bprop for Gather""" | |||
| def bprop(x, indices, axis, out, dout): | |||
| orig_indices = indices | |||
| if F.rank(dout) == 0: | |||
| dout = P.ExpandDims()(dout, -1) | |||
| if F.rank(indices) == 0: | |||
| indices = P.ExpandDims()(indices, -1) | |||
| x_shp = shape_op(x) | |||
| ind_shp = shape_op(indices) | |||
| out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) | |||
| dout = reshape(dout, out_shp) | |||
| x_shp = shape_op(x) | |||
| out_shp = shape_op(dout) | |||
| ind_shp = shape_op(indices) | |||
| # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) | |||
| perm_1 = _generate_shape_index(out_shp, ind_shp, axis) | |||
| values_transpose = transpose(dout, perm_1) | |||
| if -1 in shape_op(x): | |||
| params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) | |||
| else: | |||
| params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) | |||
| # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | |||
| perm_2 = _generate_inverse_index(x_shp, axis) | |||
| params_grad = transpose(params_grad, perm_2) | |||
| return params_grad, zeros_like(orig_indices), zeros_like(axis) | |||
| return bprop | |||
| @bprop_getters.register(P.GatherV2) | |||
| def get_bprop_gather_v2(self): | |||
| """Generate bprop for GatherV2""" | |||
| @@ -593,6 +626,23 @@ def get_bprop_stack(self): | |||
| return bprop | |||
| @bprop_getters.register(P.Pack) | |||
| def get_bprop_pack(self): | |||
| """Generate bprop for pack""" | |||
| axis = self.axis | |||
| def bprop(x, out, dout): | |||
| stack_grad = P.Unstack(axis) | |||
| out = stack_grad(dout) | |||
| if is_sub_class(F.typeof(x), ms.list_): | |||
| ret = [] | |||
| for item in out: | |||
| ret.append(item) | |||
| return (ret,) | |||
| return (out,) | |||
| return bprop | |||
| @bprop_getters.register(P.ReverseV2) | |||
| def get_bprop_reverse_v2(self): | |||
| """Generate bprop for ReverseV2""" | |||
| @@ -619,6 +669,19 @@ def get_bprop_unstack(self): | |||
| return bprop | |||
| @bprop_getters.register(P.Unpack) | |||
| def get_bprop_unpack(self): | |||
| """Generate bprop for Unpack""" | |||
| axis = self.axis | |||
| def bprop(x, out, dout): | |||
| unstack_grad = P.Stack(axis) | |||
| out = unstack_grad(dout) | |||
| return (out,) | |||
| return bprop | |||
| @bprop_getters.register(P.StridedSlice) | |||
| def get_bprop_strided_slice(self): | |||
| """Generate bprop for StridedSlice""" | |||
| @@ -156,7 +156,7 @@ def bprop_batchmatmul(self): | |||
| @bprop_getters.register(P.Add) | |||
| def get_bprop_tensor_add(self): | |||
| def get_bprop_add(self): | |||
| """Grad definition for `Add` operation.""" | |||
| def bprop(x, y, out, dout): | |||
| @@ -165,6 +165,16 @@ def get_bprop_tensor_add(self): | |||
| return bprop | |||
| @bprop_getters.register(P.TensorAdd) | |||
| def get_bprop_tensor_add(self): | |||
| """Grad definition for `TensorAdd` operation.""" | |||
| def bprop(x, y, out, dout): | |||
| return binop_grad_common(x, y, dout, dout) | |||
| return bprop | |||
| @bprop_getters.register(P.Neg) | |||
| def get_bprop_neg(self): | |||
| """Grad definition for `Neg` operation.""" | |||
| @@ -616,6 +616,18 @@ def get_bprop_gelu(self): | |||
| return bprop | |||
| @bprop_getters.register(P.Gelu) | |||
| def get_bprop_gelu_2(self): | |||
| """Grad definition for `Gelu` operation.""" | |||
| input_grad = G.GeLUGrad() | |||
| def bprop(x, out, dout): | |||
| dx = input_grad(dout, x, out) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.FastGeLU) | |||
| def get_bprop_fast_gelu(self): | |||
| """Grad definition for `FastGeLU` operation.""" | |||
| @@ -628,6 +640,18 @@ def get_bprop_fast_gelu(self): | |||
| return bprop | |||
| @bprop_getters.register(P.FastGelu) | |||
| def get_bprop_fast_gelu_2(self): | |||
| """Grad definition for `FastGelu` operation.""" | |||
| input_grad = G.FastGeLUGrad() | |||
| def bprop(x, out, dout): | |||
| dx = input_grad(dout, x) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.FusedBatchNorm) | |||
| def get_bprop_fused_batch_norm(self): | |||
| """Grad definition for `FusedBatchNorm` operation.""" | |||
| @@ -33,6 +33,7 @@ from .. import signature as sig | |||
| from ..._checkparam import Rel | |||
| from ..._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ...common._decorator import deprecated | |||
| from ...common.parameter import Parameter | |||
| from ...common.tensor import Tensor | |||
| @@ -815,14 +816,27 @@ class Gather(PrimitiveWithCheck): | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| def GatherV2(): | |||
| class GatherV2(PrimitiveWithCheck): | |||
| """ | |||
| Returns a slice of the input tensor based on the specified indices and axis. | |||
| The usage of GatherV2 is deprecated. Please use Gather. | |||
| Same as operator Gather. GatherV2 will be deprecated in the future. | |||
| Please use Gather instead. | |||
| """ | |||
| logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.") | |||
| return Gather() | |||
| @deprecated("1.1", "Gather", True) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize index_select""" | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||
| def __check__(self, params, indices, axis): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||
| validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) | |||
| axis_v = axis['value'] | |||
| validator.check_value_type('axis', axis_v, [int], self.name) | |||
| rank = len(params['shape']) | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| class SparseGatherV2(Gather): | |||
| @@ -2291,26 +2305,29 @@ class Stack(PrimitiveWithInfer): | |||
| 'value': None} | |||
| return out | |||
| def Pack(axis=0): | |||
| """ | |||
| Packs a list of tensors in specified axis. | |||
| The usage of Pack is deprecated. Please use Stack. | |||
| class Pack(PrimitiveWithInfer): | |||
| """ | |||
| logger.warning("WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.") | |||
| return Stack(axis) | |||
| def Unpack(axis=0): | |||
| Same as operator Stack. Pack will be deprecated in the future. | |||
| Please use Stack instead. | |||
| """ | |||
| Unpacks tensor in specified axis. | |||
| The usage of Unpack is deprecated. Please use Unstack. | |||
| @deprecated("1.1", "Stack", True) | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| """Initialize Stack""" | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| self.axis = axis | |||
| """ | |||
| logger.warning("WARN_DEPRECATED: The usage of Unpack is deprecated. Please use Unstack.") | |||
| return Unstack(axis) | |||
| def __infer__(self, value): | |||
| x_shape = value['shape'] | |||
| x_type = value['dtype'] | |||
| self.add_prim_attr('num', len(x_shape)) | |||
| all_shape = _get_stack_shape(x_shape, x_type, self.axis, self.name) | |||
| out = {'shape': all_shape, | |||
| 'dtype': x_type[0], | |||
| 'value': None} | |||
| return out | |||
| class Unstack(PrimitiveWithInfer): | |||
| @@ -2384,6 +2401,47 @@ class Unstack(PrimitiveWithInfer): | |||
| return out | |||
| class Unpack(PrimitiveWithInfer): | |||
| """ | |||
| Same as operator Unstack. Unpack will be deprecated in the future. | |||
| Please use Unstack instead. | |||
| """ | |||
| @deprecated("1.1", "Unstack", True) | |||
| @prim_attr_register | |||
| def __init__(self, axis=0): | |||
| """Initialize Unstack""" | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| self.axis = axis | |||
| def __infer__(self, x): | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| x_shape = list(x['shape']) | |||
| dim = len(x_shape) | |||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||
| if self.axis < 0: | |||
| self.axis = self.axis + dim | |||
| output_num = x_shape[self.axis] | |||
| validator.check_value_type("num", output_num, [int], self.name) | |||
| validator.check_positive_int(output_num, "output_num", self.name) | |||
| self.add_prim_attr('num', output_num) | |||
| output_valid_check = x_shape[self.axis] - output_num | |||
| validator.check_int(output_valid_check, 0, Rel.EQ, | |||
| "The dimension which to unstack divides output_num", self.name) | |||
| out_shapes = [] | |||
| out_dtypes = [] | |||
| out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] | |||
| for _ in range(output_num): | |||
| out_shapes.append(tuple(out_shape)) | |||
| out_dtypes.append(x['dtype']) | |||
| out_shapes = tuple(out_shapes) | |||
| out_dtypes = tuple(out_dtypes) | |||
| out = {'shape': out_shapes, | |||
| 'dtype': out_dtypes, | |||
| 'value': None} | |||
| return out | |||
| class Slice(PrimitiveWithInfer): | |||
| """ | |||
| Slices a tensor in the specified shape. | |||
| @@ -18,13 +18,13 @@ | |||
| import copy | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| from ... import context | |||
| from .. import signature as sig | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor, MetaTensor | |||
| from ...common._decorator import deprecated | |||
| from .._utils import get_broadcast_shape | |||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||
| @@ -162,14 +162,25 @@ class Add(_MathBinaryOp): | |||
| return None | |||
| def TensorAdd(): | |||
| class TensorAdd(_MathBinaryOp): | |||
| """ | |||
| Adds two input tensors element-wise. | |||
| The usage of TensorAdd is deprecated. Please use Add. | |||
| Same as operator Add. TensorAdd will be deprecated in the future. | |||
| Please use Add instead. | |||
| """ | |||
| logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.") | |||
| return Add() | |||
| @deprecated("1.1", "Add", True) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| _MathBinaryOp.__init__(self) | |||
| def infer_value(self, x, y): | |||
| if x is not None and y is not None: | |||
| x = x.asnumpy() | |||
| y = y.asnumpy() | |||
| out = x + y | |||
| out = np.array(out, x.dtype) | |||
| return Tensor(out) | |||
| return None | |||
| class AssignAdd(PrimitiveWithInfer): | |||
| @@ -26,6 +26,7 @@ from .. import signature as sig | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common._decorator import deprecated | |||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register | |||
| @@ -2970,14 +2971,24 @@ class GeLU(PrimitiveWithInfer): | |||
| return input_x | |||
| def Gelu(): | |||
| class Gelu(PrimitiveWithInfer): | |||
| """ | |||
| Gaussian Error Linear Units activation function. | |||
| The usage of Gelu is deprecated. Please use GeLU. | |||
| Same as operator GeLU. Gelu will be deprecated in the future. | |||
| Please use GeLU instead. | |||
| """ | |||
| logger.warning("WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.") | |||
| return GeLU() | |||
| @deprecated("1.1", "GeLU", True) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize GeLU""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, input_x): | |||
| return input_x | |||
| def infer_dtype(self, input_x): | |||
| validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) | |||
| return input_x | |||
| class FastGeLU(PrimitiveWithInfer): | |||
| @@ -3022,14 +3033,25 @@ class FastGeLU(PrimitiveWithInfer): | |||
| return input_x | |||
| def FastGelu(): | |||
| class FastGelu(PrimitiveWithInfer): | |||
| """ | |||
| Fast Gaussian Error Linear Units activation function. | |||
| The usage of FastGelu is deprecated. Please use FastGeLU. | |||
| Same as operator FastGeLU. FastGelu will be deprecated in the future. | |||
| Please use FastGeLU instead. | |||
| """ | |||
| logger.warning("WARN_DEPRECATED: The usage of FastGelu is deprecated. Please use FastGeLU.") | |||
| return FastGeLU() | |||
| @deprecated("1.1", "FastGeLU", True) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init FastGeLU""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, input_x): | |||
| return input_x | |||
| def infer_dtype(self, input_x): | |||
| validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) | |||
| return input_x | |||
| class GetNext(PrimitiveWithInfer): | |||
| @@ -454,10 +454,13 @@ def prim_attr_register(fn): | |||
| """ | |||
| def deco(self, *args, **kwargs): | |||
| class_name = self.__class__.__name__ | |||
| if hasattr(self.__class__, "substitute_name"): | |||
| class_name = self.__class__.substitute_name | |||
| if isinstance(self, PrimitiveWithInfer): | |||
| PrimitiveWithInfer.__init__(self, self.__class__.__name__) | |||
| PrimitiveWithInfer.__init__(self, class_name) | |||
| elif isinstance(self, PrimitiveWithCheck): | |||
| PrimitiveWithCheck.__init__(self, self.__class__.__name__) | |||
| PrimitiveWithCheck.__init__(self, class_name) | |||
| else: | |||
| Primitive.__init__(self, self.__class__.__name__) | |||
| bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | |||