From 50f313b71fe1c504276f004a339c9a45ae647f67 Mon Sep 17 00:00:00 2001 From: mindspore-ci-bot <314202276@qq.com> Date: Wed, 3 Feb 2021 20:31:45 +0800 Subject: [PATCH] Add deprecator operators --- mindspore/common/_decorator.py | 5 +- mindspore/ops/_grad/grad_array_ops.py | 63 ++++++++++++++++ mindspore/ops/_grad/grad_math_ops.py | 12 +++- mindspore/ops/_grad/grad_nn_ops.py | 24 +++++++ mindspore/ops/operations/array_ops.py | 100 ++++++++++++++++++++------ mindspore/ops/operations/math_ops.py | 25 +++++-- mindspore/ops/operations/nn_ops.py | 46 ++++++++---- mindspore/ops/primitive.py | 7 +- 8 files changed, 238 insertions(+), 44 deletions(-) diff --git a/mindspore/common/_decorator.py b/mindspore/common/_decorator.py index 892d76548f..7f75fd17ae 100644 --- a/mindspore/common/_decorator.py +++ b/mindspore/common/_decorator.py @@ -15,12 +15,13 @@ """Providing decorators.""" -def deprecated(version, substitute): +def deprecated(version, substitute, use_substitute_name=False): """deprecated warning Args: version (str): version that the operator or function is deprecated. substitute (str): the substitute name for deprecated operator or function. + use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function """ def decorate(func): @@ -29,6 +30,8 @@ def deprecated(version, substitute): name = cls.__name__ if cls else func.__name__ print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " f"use '{substitute}' instead.") + if cls and use_substitute_name: + cls.substitute_name = substitute ret = func(*args, **kwargs) return ret diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 745c215f6d..fa5a1c2279 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -393,6 +393,39 @@ def _regenerate_output_shape(x_shp, ind_shp, axis): @bprop_getters.register(P.Gather) +def get_bprop_gather(self): + """Generate bprop for Gather""" + + def bprop(x, indices, axis, out, dout): + orig_indices = indices + if F.rank(dout) == 0: + dout = P.ExpandDims()(dout, -1) + if F.rank(indices) == 0: + indices = P.ExpandDims()(indices, -1) + x_shp = shape_op(x) + ind_shp = shape_op(indices) + out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) + dout = reshape(dout, out_shp) + + x_shp = shape_op(x) + out_shp = shape_op(dout) + ind_shp = shape_op(indices) + # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) + perm_1 = _generate_shape_index(out_shp, ind_shp, axis) + values_transpose = transpose(dout, perm_1) + if -1 in shape_op(x): + params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) + else: + params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) + # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) + perm_2 = _generate_inverse_index(x_shp, axis) + params_grad = transpose(params_grad, perm_2) + return params_grad, zeros_like(orig_indices), zeros_like(axis) + + return bprop + + +@bprop_getters.register(P.GatherV2) def get_bprop_gather_v2(self): """Generate bprop for GatherV2""" @@ -593,6 +626,23 @@ def get_bprop_stack(self): return bprop +@bprop_getters.register(P.Pack) +def get_bprop_pack(self): + """Generate bprop for pack""" + axis = self.axis + + def bprop(x, out, dout): + stack_grad = P.Unstack(axis) + out = stack_grad(dout) + if is_sub_class(F.typeof(x), ms.list_): + ret = [] + for item in out: + ret.append(item) + return (ret,) + return (out,) + + return bprop + @bprop_getters.register(P.ReverseV2) def get_bprop_reverse_v2(self): """Generate bprop for ReverseV2""" @@ -619,6 +669,19 @@ def get_bprop_unstack(self): return bprop +@bprop_getters.register(P.Unpack) +def get_bprop_unpack(self): + """Generate bprop for Unpack""" + axis = self.axis + + def bprop(x, out, dout): + unstack_grad = P.Stack(axis) + out = unstack_grad(dout) + return (out,) + + return bprop + + @bprop_getters.register(P.StridedSlice) def get_bprop_strided_slice(self): """Generate bprop for StridedSlice""" diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index cc1a0be978..e585e22ecd 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -156,7 +156,7 @@ def bprop_batchmatmul(self): @bprop_getters.register(P.Add) -def get_bprop_tensor_add(self): +def get_bprop_add(self): """Grad definition for `Add` operation.""" def bprop(x, y, out, dout): @@ -165,6 +165,16 @@ def get_bprop_tensor_add(self): return bprop +@bprop_getters.register(P.TensorAdd) +def get_bprop_tensor_add(self): + """Grad definition for `TensorAdd` operation.""" + + def bprop(x, y, out, dout): + return binop_grad_common(x, y, dout, dout) + + return bprop + + @bprop_getters.register(P.Neg) def get_bprop_neg(self): """Grad definition for `Neg` operation.""" diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 129fd0fd46..89630435c0 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -616,6 +616,18 @@ def get_bprop_gelu(self): return bprop +@bprop_getters.register(P.Gelu) +def get_bprop_gelu_2(self): + """Grad definition for `Gelu` operation.""" + input_grad = G.GeLUGrad() + + def bprop(x, out, dout): + dx = input_grad(dout, x, out) + return (dx,) + + return bprop + + @bprop_getters.register(P.FastGeLU) def get_bprop_fast_gelu(self): """Grad definition for `FastGeLU` operation.""" @@ -628,6 +640,18 @@ def get_bprop_fast_gelu(self): return bprop +@bprop_getters.register(P.FastGelu) +def get_bprop_fast_gelu_2(self): + """Grad definition for `FastGelu` operation.""" + input_grad = G.FastGeLUGrad() + + def bprop(x, out, dout): + dx = input_grad(dout, x) + return (dx,) + + return bprop + + @bprop_getters.register(P.FusedBatchNorm) def get_bprop_fused_batch_norm(self): """Grad definition for `FusedBatchNorm` operation.""" diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index eacff3534f..d3f94ab57a 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -33,6 +33,7 @@ from .. import signature as sig from ..._checkparam import Rel from ..._checkparam import Validator as validator from ...common import dtype as mstype +from ...common._decorator import deprecated from ...common.parameter import Parameter from ...common.tensor import Tensor @@ -815,14 +816,27 @@ class Gather(PrimitiveWithCheck): validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) -def GatherV2(): +class GatherV2(PrimitiveWithCheck): """ - Returns a slice of the input tensor based on the specified indices and axis. - - The usage of GatherV2 is deprecated. Please use Gather. + Same as operator Gather. GatherV2 will be deprecated in the future. + Please use Gather instead. """ - logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.") - return Gather() + + @deprecated("1.1", "Gather", True) + @prim_attr_register + def __init__(self): + """Initialize index_select""" + self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) + self.add_prim_attr("dynamic_shape_depends", [2]) + + def __check__(self, params, indices, axis): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) + validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) + axis_v = axis['value'] + validator.check_value_type('axis', axis_v, [int], self.name) + rank = len(params['shape']) + validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) class SparseGatherV2(Gather): @@ -2292,26 +2306,29 @@ class Stack(PrimitiveWithInfer): 'value': None} return out -def Pack(axis=0): - """ - Packs a list of tensors in specified axis. - - The usage of Pack is deprecated. Please use Stack. +class Pack(PrimitiveWithInfer): """ - logger.warning("WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.") - return Stack(axis) - - -def Unpack(axis=0): + Same as operator Stack. Pack will be deprecated in the future. + Please use Stack instead. """ - Unpacks tensor in specified axis. - The usage of Unpack is deprecated. Please use Unstack. + @deprecated("1.1", "Stack", True) + @prim_attr_register + def __init__(self, axis=0): + """Initialize Stack""" + validator.check_value_type("axis", axis, [int], self.name) + self.axis = axis - """ - logger.warning("WARN_DEPRECATED: The usage of Unpack is deprecated. Please use Unstack.") - return Unstack(axis) + def __infer__(self, value): + x_shape = value['shape'] + x_type = value['dtype'] + self.add_prim_attr('num', len(x_shape)) + all_shape = _get_stack_shape(x_shape, x_type, self.axis, self.name) + out = {'shape': all_shape, + 'dtype': x_type[0], + 'value': None} + return out class Unstack(PrimitiveWithInfer): @@ -2385,6 +2402,47 @@ class Unstack(PrimitiveWithInfer): return out +class Unpack(PrimitiveWithInfer): + """ + Same as operator Unstack. Unpack will be deprecated in the future. + Please use Unstack instead. + """ + + @deprecated("1.1", "Unstack", True) + @prim_attr_register + def __init__(self, axis=0): + """Initialize Unstack""" + validator.check_value_type("axis", axis, [int], self.name) + self.axis = axis + + def __infer__(self, x): + validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) + x_shape = list(x['shape']) + dim = len(x_shape) + validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) + if self.axis < 0: + self.axis = self.axis + dim + output_num = x_shape[self.axis] + validator.check_value_type("num", output_num, [int], self.name) + validator.check_positive_int(output_num, "output_num", self.name) + self.add_prim_attr('num', output_num) + output_valid_check = x_shape[self.axis] - output_num + validator.check_int(output_valid_check, 0, Rel.EQ, + "The dimension which to unstack divides output_num", self.name) + out_shapes = [] + out_dtypes = [] + out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] + for _ in range(output_num): + out_shapes.append(tuple(out_shape)) + out_dtypes.append(x['dtype']) + out_shapes = tuple(out_shapes) + out_dtypes = tuple(out_dtypes) + out = {'shape': out_shapes, + 'dtype': out_dtypes, + 'value': None} + return out + + class Slice(PrimitiveWithInfer): """ Slices a tensor in the specified shape. diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 5ae582e83f..542cfaf7eb 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -18,13 +18,13 @@ import copy import numpy as np -from mindspore import log as logger from ... import context from .. import signature as sig from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor, MetaTensor +from ...common._decorator import deprecated from .._utils import get_broadcast_shape from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op @@ -162,14 +162,25 @@ class Add(_MathBinaryOp): return None -def TensorAdd(): +class TensorAdd(_MathBinaryOp): """ - Adds two input tensors element-wise. - - The usage of TensorAdd is deprecated. Please use Add. + Same as operator Add. TensorAdd will be deprecated in the future. + Please use Add instead. """ - logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.") - return Add() + + @deprecated("1.1", "Add", True) + @prim_attr_register + def __init__(self): + _MathBinaryOp.__init__(self) + + def infer_value(self, x, y): + if x is not None and y is not None: + x = x.asnumpy() + y = y.asnumpy() + out = x + y + out = np.array(out, x.dtype) + return Tensor(out) + return None class AssignAdd(PrimitiveWithInfer): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2fb458faa0..3171039aea 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -26,6 +26,7 @@ from .. import signature as sig from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype +from ...common._decorator import deprecated from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register @@ -2970,14 +2971,24 @@ class GeLU(PrimitiveWithInfer): return input_x -def Gelu(): +class Gelu(PrimitiveWithInfer): """ - Gaussian Error Linear Units activation function. - - The usage of Gelu is deprecated. Please use GeLU. + Same as operator GeLU. Gelu will be deprecated in the future. + Please use GeLU instead. """ - logger.warning("WARN_DEPRECATED: The usage of Gelu is deprecated. Please use GeLU.") - return GeLU() + + @deprecated("1.1", "GeLU", True) + @prim_attr_register + def __init__(self): + """Initialize GeLU""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, input_x): + return input_x + + def infer_dtype(self, input_x): + validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) + return input_x class FastGeLU(PrimitiveWithInfer): @@ -3022,14 +3033,25 @@ class FastGeLU(PrimitiveWithInfer): return input_x -def FastGelu(): +class FastGelu(PrimitiveWithInfer): """ - Fast Gaussian Error Linear Units activation function. - - The usage of FastGelu is deprecated. Please use FastGeLU. + Same as operator FastGeLU. FastGelu will be deprecated in the future. + Please use FastGeLU instead. """ - logger.warning("WARN_DEPRECATED: The usage of FastGelu is deprecated. Please use FastGeLU.") - return FastGeLU() + + @deprecated("1.1", "FastGeLU", True) + @prim_attr_register + def __init__(self): + """init FastGeLU""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, input_x): + return input_x + + def infer_dtype(self, input_x): + validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) + return input_x + class GetNext(PrimitiveWithInfer): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 68683ffcd9..c3f2b107f2 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -454,10 +454,13 @@ def prim_attr_register(fn): """ def deco(self, *args, **kwargs): + class_name = self.__class__.__name__ + if hasattr(self.__class__, "substitute_name"): + class_name = self.__class__.substitute_name if isinstance(self, PrimitiveWithInfer): - PrimitiveWithInfer.__init__(self, self.__class__.__name__) + PrimitiveWithInfer.__init__(self, class_name) elif isinstance(self, PrimitiveWithCheck): - PrimitiveWithCheck.__init__(self, self.__class__.__name__) + PrimitiveWithCheck.__init__(self, class_name) else: Primitive.__init__(self, self.__class__.__name__) bound_args = inspect.signature(fn).bind(self, *args, **kwargs)