| @@ -15,12 +15,13 @@ | |||||
| """Providing decorators.""" | """Providing decorators.""" | ||||
| def deprecated(version, substitute): | |||||
| def deprecated(version, substitute, use_substitute_name=False): | |||||
| """deprecated warning | """deprecated warning | ||||
| Args: | Args: | ||||
| version (str): version that the operator or function is deprecated. | version (str): version that the operator or function is deprecated. | ||||
| substitute (str): the substitute name for deprecated operator or function. | substitute (str): the substitute name for deprecated operator or function. | ||||
| use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function | |||||
| """ | """ | ||||
| def decorate(func): | def decorate(func): | ||||
| @@ -29,6 +30,8 @@ def deprecated(version, substitute): | |||||
| name = cls.__name__ if cls else func.__name__ | name = cls.__name__ if cls else func.__name__ | ||||
| print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " | print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " | ||||
| f"use '{substitute}' instead.") | f"use '{substitute}' instead.") | ||||
| if cls and use_substitute_name: | |||||
| cls.substitute_name = substitute | |||||
| ret = func(*args, **kwargs) | ret = func(*args, **kwargs) | ||||
| return ret | return ret | ||||
| @@ -33,6 +33,7 @@ from .. import signature as sig | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common._decorator import deprecated | |||||
| from ...common.parameter import Parameter | from ...common.parameter import Parameter | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| @@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck): | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | ||||
| def GatherV2(): | |||||
| """Warning: This will be changed later""" | |||||
| logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.") | |||||
| return Gather() | |||||
| class GatherV2(PrimitiveWithCheck): | |||||
| """ | |||||
| Same as operator Gather. GatherV2 will be deprecated in the future. | |||||
| Please use Gather instead. | |||||
| """ | |||||
| #deprecate_new_name = "Gather" | |||||
| @deprecated("1.1", "Gather", True) | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Initialize index_select""" | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||||
| def __check__(self, params, indices, axis): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) | |||||
| axis_v = axis['value'] | |||||
| validator.check_value_type('axis', axis_v, [int], self.name) | |||||
| rank = len(params['shape']) | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||||
| class SparseGatherV2(Gather): | class SparseGatherV2(Gather): | ||||
| """ | """ | ||||
| @@ -18,13 +18,13 @@ | |||||
| import copy | import copy | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | |||||
| from ... import context | from ... import context | ||||
| from .. import signature as sig | from .. import signature as sig | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from ...common._decorator import deprecated | |||||
| from .._utils import get_broadcast_shape | from .._utils import get_broadcast_shape | ||||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | ||||
| @@ -161,10 +161,28 @@ class Add(_MathBinaryOp): | |||||
| return Tensor(out) | return Tensor(out) | ||||
| return None | return None | ||||
| def TensorAdd(): | |||||
| """Warning: This will be changed later""" | |||||
| logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.") | |||||
| return Add() | |||||
| class TensorAdd(_MathBinaryOp): | |||||
| """ | |||||
| Same as operator Add. TensorAdd will be deprecated in the future. | |||||
| Please use Add instead. | |||||
| """ | |||||
| #deprecate_new_name = "Add" | |||||
| @deprecated("1.1", "Add", True) | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| _MathBinaryOp.__init__(self) | |||||
| def infer_value(self, x, y): | |||||
| if x is not None and y is not None: | |||||
| x = x.asnumpy() | |||||
| y = y.asnumpy() | |||||
| out = x + y | |||||
| out = np.array(out, x.dtype) | |||||
| return Tensor(out) | |||||
| return None | |||||
| class AssignAdd(PrimitiveWithInfer): | class AssignAdd(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -466,10 +466,13 @@ def prim_attr_register(fn): | |||||
| """ | """ | ||||
| def deco(self, *args, **kwargs): | def deco(self, *args, **kwargs): | ||||
| class_name = self.__class__.__name__ | |||||
| if hasattr(self.__class__, "substitute_name"): | |||||
| class_name = self.__class__.substitute_name | |||||
| if isinstance(self, PrimitiveWithInfer): | if isinstance(self, PrimitiveWithInfer): | ||||
| PrimitiveWithInfer.__init__(self, self.__class__.__name__) | |||||
| PrimitiveWithInfer.__init__(self, class_name) | |||||
| elif isinstance(self, PrimitiveWithCheck): | elif isinstance(self, PrimitiveWithCheck): | ||||
| PrimitiveWithCheck.__init__(self, self.__class__.__name__) | |||||
| PrimitiveWithCheck.__init__(self, class_name) | |||||
| else: | else: | ||||
| Primitive.__init__(self, self.__class__.__name__) | Primitive.__init__(self, self.__class__.__name__) | ||||
| bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | ||||