Browse Source

Add deprecated function

tags/v1.2.0-rc1
l00591931 4 years ago
parent
commit
edbe3bfd3b
4 changed files with 56 additions and 12 deletions
  1. +4
    -1
      mindspore/common/_decorator.py
  2. +24
    -4
      mindspore/ops/operations/array_ops.py
  3. +23
    -5
      mindspore/ops/operations/math_ops.py
  4. +5
    -2
      mindspore/ops/primitive.py

+ 4
- 1
mindspore/common/_decorator.py View File

@@ -15,12 +15,13 @@
"""Providing decorators.""" """Providing decorators."""




def deprecated(version, substitute):
def deprecated(version, substitute, use_substitute_name=False):
"""deprecated warning """deprecated warning


Args: Args:
version (str): version that the operator or function is deprecated. version (str): version that the operator or function is deprecated.
substitute (str): the substitute name for deprecated operator or function. substitute (str): the substitute name for deprecated operator or function.
use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function
""" """


def decorate(func): def decorate(func):
@@ -29,6 +30,8 @@ def deprecated(version, substitute):
name = cls.__name__ if cls else func.__name__ name = cls.__name__ if cls else func.__name__
print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, "
f"use '{substitute}' instead.") f"use '{substitute}' instead.")
if cls and use_substitute_name:
cls.substitute_name = substitute
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
return ret return ret




+ 24
- 4
mindspore/ops/operations/array_ops.py View File

@@ -33,6 +33,7 @@ from .. import signature as sig
from ..._checkparam import Rel from ..._checkparam import Rel
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common._decorator import deprecated
from ...common.parameter import Parameter from ...common.parameter import Parameter
from ...common.tensor import Tensor from ...common.tensor import Tensor


@@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck):
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)




def GatherV2():
"""Warning: This will be changed later"""
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.")
return Gather()
class GatherV2(PrimitiveWithCheck):
"""
Same as operator Gather. GatherV2 will be deprecated in the future.
Please use Gather instead.
"""
#deprecate_new_name = "Gather"

@deprecated("1.1", "Gather", True)
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)



class SparseGatherV2(Gather): class SparseGatherV2(Gather):
""" """


+ 23
- 5
mindspore/ops/operations/math_ops.py View File

@@ -18,13 +18,13 @@
import copy import copy


import numpy as np import numpy as np
from mindspore import log as logger
from ... import context from ... import context
from .. import signature as sig from .. import signature as sig
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from ...common._decorator import deprecated
from .._utils import get_broadcast_shape from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op


@@ -161,10 +161,28 @@ class Add(_MathBinaryOp):
return Tensor(out) return Tensor(out)
return None return None


def TensorAdd():
"""Warning: This will be changed later"""
logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.")
return Add()

class TensorAdd(_MathBinaryOp):
"""
Same as operator Add. TensorAdd will be deprecated in the future.
Please use Add instead.
"""
#deprecate_new_name = "Add"

@deprecated("1.1", "Add", True)
@prim_attr_register
def __init__(self):
_MathBinaryOp.__init__(self)

def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
out = x + y
out = np.array(out, x.dtype)
return Tensor(out)
return None



class AssignAdd(PrimitiveWithInfer): class AssignAdd(PrimitiveWithInfer):
""" """


+ 5
- 2
mindspore/ops/primitive.py View File

@@ -466,10 +466,13 @@ def prim_attr_register(fn):
""" """


def deco(self, *args, **kwargs): def deco(self, *args, **kwargs):
class_name = self.__class__.__name__
if hasattr(self.__class__, "substitute_name"):
class_name = self.__class__.substitute_name
if isinstance(self, PrimitiveWithInfer): if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
PrimitiveWithInfer.__init__(self, class_name)
elif isinstance(self, PrimitiveWithCheck): elif isinstance(self, PrimitiveWithCheck):
PrimitiveWithCheck.__init__(self, self.__class__.__name__)
PrimitiveWithCheck.__init__(self, class_name)
else: else:
Primitive.__init__(self, self.__class__.__name__) Primitive.__init__(self, self.__class__.__name__)
bound_args = inspect.signature(fn).bind(self, *args, **kwargs) bound_args = inspect.signature(fn).bind(self, *args, **kwargs)


Loading…
Cancel
Save