Browse Source

Add deprecated function

tags/v1.2.0-rc1
l00591931 4 years ago
parent
commit
edbe3bfd3b
4 changed files with 56 additions and 12 deletions
  1. +4
    -1
      mindspore/common/_decorator.py
  2. +24
    -4
      mindspore/ops/operations/array_ops.py
  3. +23
    -5
      mindspore/ops/operations/math_ops.py
  4. +5
    -2
      mindspore/ops/primitive.py

+ 4
- 1
mindspore/common/_decorator.py View File

@@ -15,12 +15,13 @@
"""Providing decorators."""


def deprecated(version, substitute):
def deprecated(version, substitute, use_substitute_name=False):
"""deprecated warning

Args:
version (str): version that the operator or function is deprecated.
substitute (str): the substitute name for deprecated operator or function.
use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function
"""

def decorate(func):
@@ -29,6 +30,8 @@ def deprecated(version, substitute):
name = cls.__name__ if cls else func.__name__
print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, "
f"use '{substitute}' instead.")
if cls and use_substitute_name:
cls.substitute_name = substitute
ret = func(*args, **kwargs)
return ret



+ 24
- 4
mindspore/ops/operations/array_ops.py View File

@@ -33,6 +33,7 @@ from .. import signature as sig
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ...common._decorator import deprecated
from ...common.parameter import Parameter
from ...common.tensor import Tensor

@@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck):
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)


def GatherV2():
"""Warning: This will be changed later"""
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.")
return Gather()
class GatherV2(PrimitiveWithCheck):
"""
Same as operator Gather. GatherV2 will be deprecated in the future.
Please use Gather instead.
"""
#deprecate_new_name = "Gather"

@deprecated("1.1", "Gather", True)
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)


class SparseGatherV2(Gather):
"""


+ 23
- 5
mindspore/ops/operations/math_ops.py View File

@@ -18,13 +18,13 @@
import copy

import numpy as np
from mindspore import log as logger
from ... import context
from .. import signature as sig
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ...common._decorator import deprecated
from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op

@@ -161,10 +161,28 @@ class Add(_MathBinaryOp):
return Tensor(out)
return None

def TensorAdd():
"""Warning: This will be changed later"""
logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.")
return Add()

class TensorAdd(_MathBinaryOp):
"""
Same as operator Add. TensorAdd will be deprecated in the future.
Please use Add instead.
"""
#deprecate_new_name = "Add"

@deprecated("1.1", "Add", True)
@prim_attr_register
def __init__(self):
_MathBinaryOp.__init__(self)

def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
out = x + y
out = np.array(out, x.dtype)
return Tensor(out)
return None


class AssignAdd(PrimitiveWithInfer):
"""


+ 5
- 2
mindspore/ops/primitive.py View File

@@ -466,10 +466,13 @@ def prim_attr_register(fn):
"""

def deco(self, *args, **kwargs):
class_name = self.__class__.__name__
if hasattr(self.__class__, "substitute_name"):
class_name = self.__class__.substitute_name
if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
PrimitiveWithInfer.__init__(self, class_name)
elif isinstance(self, PrimitiveWithCheck):
PrimitiveWithCheck.__init__(self, self.__class__.__name__)
PrimitiveWithCheck.__init__(self, class_name)
else:
Primitive.__init__(self, self.__class__.__name__)
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)


Loading…
Cancel
Save