GitOrigin-RevId: 3b2f829cc5
tags/v1.3.0
| @@ -12,6 +12,7 @@ from .elemwise import * | |||
| from .math import * | |||
| from .nn import * | |||
| from .tensor import * | |||
| from .utils import * | |||
| from . import distributed # isort:skip | |||
| @@ -19,6 +19,7 @@ from ..core.tensor.utils import astype | |||
| from ..device import get_default_device | |||
| from ..jit.tracing import is_tracing | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_func | |||
| __all__ = [ | |||
| "abs", | |||
| @@ -567,3 +568,10 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor: | |||
| return maximum(x, lower) | |||
| else: | |||
| return minimum(x, upper) | |||
| sigmoid = deprecated_func("1.3", "megengine.functional.nn", "sigmoid", True) | |||
| hsigmoid = deprecated_func("1.3", "megengine.functional.nn", "hsigmoid", True) | |||
| relu = deprecated_func("1.3", "megengine.functional.nn", "relu", True) | |||
| relu6 = deprecated_func("1.3", "megengine.functional.nn", "relu6", True) | |||
| hswish = deprecated_func("1.3", "megengine.functional.nn", "hswish", True) | |||
| @@ -22,10 +22,11 @@ from ..device import get_default_device | |||
| from ..distributed import WORLD, is_distributed | |||
| from ..random import uniform | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_func | |||
| from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||
| from .debug_param import get_execution_strategy | |||
| from .distributed import all_reduce_sum | |||
| from .elemwise import exp, floor, log, log1p, maximum, minimum | |||
| from .elemwise import _elwise, exp, floor, log, log1p, maximum, minimum | |||
| from .math import argsort, matmul, max, prod, sum | |||
| from .tensor import ( | |||
| broadcast_to, | |||
| @@ -70,6 +71,10 @@ __all__ = [ | |||
| "relu", | |||
| "relu6", | |||
| "hswish", | |||
| "resize", | |||
| "remap", | |||
| "warp_affine", | |||
| "warp_perspective", | |||
| ] | |||
| @@ -1434,43 +1439,6 @@ def nvof(src: Tensor, precision: int = 1) -> Tensor: | |||
| return apply(op, src)[0] | |||
| def _elwise(*args, mode): | |||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||
| if len(tensor_args) == 0: | |||
| dtype = utils.dtype_promotion(args) | |||
| first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
| args = utils.convert_inputs(first_arg, *args[1:]) | |||
| else: | |||
| args = utils.convert_inputs(*args) | |||
| if mode in ( | |||
| Elemwise.Mode.TRUE_DIV, | |||
| Elemwise.Mode.EXP, | |||
| Elemwise.Mode.POW, | |||
| Elemwise.Mode.LOG, | |||
| Elemwise.Mode.EXPM1, | |||
| Elemwise.Mode.LOG1P, | |||
| Elemwise.Mode.TANH, | |||
| Elemwise.Mode.ACOS, | |||
| Elemwise.Mode.ASIN, | |||
| Elemwise.Mode.ATAN2, | |||
| Elemwise.Mode.CEIL, | |||
| Elemwise.Mode.COS, | |||
| Elemwise.Mode.FLOOR, | |||
| Elemwise.Mode.H_SWISH, | |||
| Elemwise.Mode.ROUND, | |||
| Elemwise.Mode.SIGMOID, | |||
| Elemwise.Mode.SIN, | |||
| ): | |||
| if mode in ( | |||
| Elemwise.Mode.CEIL, | |||
| Elemwise.Mode.FLOOR, | |||
| Elemwise.Mode.ROUND, | |||
| ) and np.issubdtype(args[0].dtype, np.integer): | |||
| return args[0] | |||
| args = tuple(map(lambda x: astype(x, "float32"), args)) | |||
| return _elwise_apply(args, mode) | |||
| def hswish(x): | |||
| """ | |||
| Element-wise `x * relu6(x + 3) / 6`. | |||
| @@ -1518,5 +1486,16 @@ def relu6(x): | |||
| return minimum(maximum(x, 0), 6) | |||
| interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True) | |||
| roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True) | |||
| roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True) | |||
| nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True) | |||
| resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True) | |||
| remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True) | |||
| warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True) | |||
| warp_perspective = deprecated_func( | |||
| "1.3", "megengine.functional.vision", "warp_perspective", True | |||
| ) | |||
| from .loss import * # isort:skip | |||
| from .quantized import conv_bias_activation # isort:skip | |||
| @@ -10,8 +10,11 @@ from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.core2 import sync as _sync | |||
| from ..core.ops.builtin import AssertEqual | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_func | |||
| from .elemwise import abs, maximum, minimum | |||
| __all__ = ["topk_accuracy"] | |||
| def _assert_equal( | |||
| expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False | |||
| @@ -55,3 +58,9 @@ def _assert_equal( | |||
| result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] | |||
| _sync() # sync interpreter to get exception | |||
| return result | |||
| topk_accuracy = deprecated_func( | |||
| "1.3", "megengine.functional.metric", "topk_accuracy", True | |||
| ) | |||
| copy = deprecated_func("1.3", "megengine.functional.tensor", "copy", True) | |||
| @@ -5,4 +5,36 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import importlib | |||
| import warnings | |||
| from deprecated.sphinx import deprecated | |||
| def deprecated_func(version, origin, name, tbd): | |||
| """ | |||
| :param version: version to deprecate this function | |||
| :param origin: origin module path | |||
| :param name: function name | |||
| :param tbd: to be discussed, if true, ignore warnings | |||
| """ | |||
| should_warning = not tbd | |||
| def wrapper(*args, **kwargs): | |||
| nonlocal should_warning | |||
| module = importlib.import_module(origin) | |||
| func = module.__getattribute__(name) | |||
| if should_warning: | |||
| with warnings.catch_warnings(): | |||
| warnings.simplefilter(action="always") | |||
| warnings.warn( | |||
| "Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format( | |||
| name, origin, name, version | |||
| ), | |||
| category=DeprecationWarning, | |||
| stacklevel=2, | |||
| ) | |||
| should_warning = False | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||