| @@ -15,12 +15,13 @@ | |||
| """Providing decorators.""" | |||
| def deprecated(version, substitute): | |||
| def deprecated(version, substitute, use_substitute_name=False): | |||
| """deprecated warning | |||
| Args: | |||
| version (str): version that the operator or function is deprecated. | |||
| substitute (str): the substitute name for deprecated operator or function. | |||
| use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function | |||
| """ | |||
| def decorate(func): | |||
| @@ -29,6 +30,8 @@ def deprecated(version, substitute): | |||
| name = cls.__name__ if cls else func.__name__ | |||
| print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " | |||
| f"use '{substitute}' instead.") | |||
| if cls and use_substitute_name: | |||
| cls.substitute_name = substitute | |||
| ret = func(*args, **kwargs) | |||
| return ret | |||
| @@ -33,6 +33,7 @@ from .. import signature as sig | |||
| from ..._checkparam import Rel | |||
| from ..._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ...common._decorator import deprecated | |||
| from ...common.parameter import Parameter | |||
| from ...common.tensor import Tensor | |||
| @@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck): | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| def GatherV2(): | |||
| """Warning: This will be changed later""" | |||
| logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.") | |||
| return Gather() | |||
| class GatherV2(PrimitiveWithCheck): | |||
| """ | |||
| Same as operator Gather. GatherV2 will be deprecated in the future. | |||
| Please use Gather instead. | |||
| """ | |||
| #deprecate_new_name = "Gather" | |||
| @deprecated("1.1", "Gather", True) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize index_select""" | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||
| def __check__(self, params, indices, axis): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||
| validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) | |||
| axis_v = axis['value'] | |||
| validator.check_value_type('axis', axis_v, [int], self.name) | |||
| rank = len(params['shape']) | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| class SparseGatherV2(Gather): | |||
| """ | |||
| @@ -18,13 +18,13 @@ | |||
| import copy | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| from ... import context | |||
| from .. import signature as sig | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor | |||
| from ...common._decorator import deprecated | |||
| from .._utils import get_broadcast_shape | |||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||
| @@ -161,10 +161,28 @@ class Add(_MathBinaryOp): | |||
| return Tensor(out) | |||
| return None | |||
| def TensorAdd(): | |||
| """Warning: This will be changed later""" | |||
| logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.") | |||
| return Add() | |||
| class TensorAdd(_MathBinaryOp): | |||
| """ | |||
| Same as operator Add. TensorAdd will be deprecated in the future. | |||
| Please use Add instead. | |||
| """ | |||
| #deprecate_new_name = "Add" | |||
| @deprecated("1.1", "Add", True) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| _MathBinaryOp.__init__(self) | |||
| def infer_value(self, x, y): | |||
| if x is not None and y is not None: | |||
| x = x.asnumpy() | |||
| y = y.asnumpy() | |||
| out = x + y | |||
| out = np.array(out, x.dtype) | |||
| return Tensor(out) | |||
| return None | |||
| class AssignAdd(PrimitiveWithInfer): | |||
| """ | |||
| @@ -466,10 +466,13 @@ def prim_attr_register(fn): | |||
| """ | |||
| def deco(self, *args, **kwargs): | |||
| class_name = self.__class__.__name__ | |||
| if hasattr(self.__class__, "substitute_name"): | |||
| class_name = self.__class__.substitute_name | |||
| if isinstance(self, PrimitiveWithInfer): | |||
| PrimitiveWithInfer.__init__(self, self.__class__.__name__) | |||
| PrimitiveWithInfer.__init__(self, class_name) | |||
| elif isinstance(self, PrimitiveWithCheck): | |||
| PrimitiveWithCheck.__init__(self, self.__class__.__name__) | |||
| PrimitiveWithCheck.__init__(self, class_name) | |||
| else: | |||
| Primitive.__init__(self, self.__class__.__name__) | |||
| bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | |||