GitOrigin-RevId: 6fbffc4845
tags/v1.5.0
| @@ -117,6 +117,7 @@ def _atexit(handler): | |||||
| # subpackages | # subpackages | ||||
| import megengine.amp | |||||
| import megengine.autodiff | import megengine.autodiff | ||||
| import megengine.data | import megengine.data | ||||
| import megengine.distributed | import megengine.distributed | ||||
| @@ -0,0 +1,14 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import mprop | |||||
| from ..core.tensor.amp import * | |||||
| from .autocast import autocast | |||||
| mprop.init() | |||||
| @@ -0,0 +1,79 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| from ..core.tensor import amp | |||||
| class autocast: | |||||
| r""" | |||||
| A class to control autocast mode for amp as a context manager or a decorator. | |||||
| :param enabled: Whether autocast mode is enabled. | |||||
| :low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the | |||||
| target dtype in tensor casting for better speed and memory. Default: float16. | |||||
| :high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the | |||||
| target dtype in tensor casting for better precision. Default: float32. | |||||
| Examples: | |||||
| ..code-block:: | |||||
| # used as decorator | |||||
| @autocast() | |||||
| def train_step(image, label): | |||||
| with gm: | |||||
| logits = model(image) | |||||
| loss = F.nn.cross_entropy(logits, label) | |||||
| gm.backward(loss) | |||||
| opt.step().clear_grad() | |||||
| return loss | |||||
| # used as context manager | |||||
| def train_step(image, label): | |||||
| with autocast(): | |||||
| with gm: | |||||
| logits = model(image) | |||||
| loss = F.nn.cross_entropy(logits, label) | |||||
| gm.backward(loss) | |||||
| opt.step().clear_grad() | |||||
| return loss | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| enabled: bool = True, | |||||
| low_prec_dtype: str = "float16", | |||||
| high_prec_dtype: str = "float32", | |||||
| ): | |||||
| self.enabled = enabled | |||||
| self.high_prec_dtype = high_prec_dtype | |||||
| self.low_prec_dtype = low_prec_dtype | |||||
| self._origin_enabled = None | |||||
| self._origin_high = None | |||||
| self._origin_low = None | |||||
| def __enter__(self): | |||||
| self._origin_enabled, amp._enabled = amp._enabled, self.enabled | |||||
| self._origin_high = amp._high_prec_dtype | |||||
| amp._high_prec_dtype = self.high_prec_dtype | |||||
| self._origin_low = amp._low_prec_dtype | |||||
| amp._low_prec_dtype = self.low_prec_dtype | |||||
| def __exit__(self, *args): | |||||
| amp._enabled = self._origin_enabled | |||||
| amp._high_prec_dtype = self._origin_high | |||||
| amp._low_prec_dtype = self._origin_low | |||||
| def __call__(self, func): | |||||
| @functools.wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| with self: | |||||
| return func(*args, **kwargs) | |||||
| return wrapper | |||||
| @@ -49,7 +49,7 @@ class Device: | |||||
| return self._cn == rhs._cn | return self._cn == rhs._cn | ||||
| def device(obj): | |||||
| def as_device(obj): | |||||
| if isinstance(obj, Device): | if isinstance(obj, Device): | ||||
| return obj | return obj | ||||
| return Device(obj) | return Device(obj) | ||||
| @@ -0,0 +1,78 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| _enabled = False | |||||
| _high_prec_dtype = "float32" | |||||
| _low_prec_dtype = "float16" | |||||
| @property | |||||
| def enabled(mod): | |||||
| r""" | |||||
| Get or set amp autocast mode enabled or not. | |||||
| Examples: | |||||
| ..code-block:: | |||||
| import megengine as mge | |||||
| mge.amp.enabled = True | |||||
| """ | |||||
| return _enabled | |||||
| @enabled.setter | |||||
| def enabled(mod, enabled: bool): | |||||
| global _enabled | |||||
| _enabled = enabled | |||||
| @property | |||||
| def high_prec_dtype(mod): | |||||
| r""" | |||||
| Get or set amp autocast mode's higher precision dtype. It will change the | |||||
| target dtype in tensor casting for better precision. Default: float32. | |||||
| Examples: | |||||
| ..code-block:: | |||||
| import megengine as mge | |||||
| mge.amp.high_prec_dtype = "float32" | |||||
| """ | |||||
| return _high_prec_dtype | |||||
| @high_prec_dtype.setter | |||||
| def high_prec_dtype(mod, dtype: str): | |||||
| global _high_prec_dtype | |||||
| _high_prec_dtype = dtype | |||||
| @property | |||||
| def low_prec_dtype(mod): | |||||
| r""" | |||||
| Get or set amp autocast mode's lower precision dtype. It will change the | |||||
| target dtype in tensor casting for better speed and memory. Default: float16. | |||||
| Examples: | |||||
| ..code-block:: | |||||
| import megengine as mge | |||||
| mge.amp.low_prec_dtype = "float16" | |||||
| """ | |||||
| return _low_prec_dtype | |||||
| @low_prec_dtype.setter | |||||
| def low_prec_dtype(mod, dtype: str): | |||||
| global _low_prec_dtype | |||||
| _low_prec_dtype = dtype | |||||
| @@ -15,15 +15,20 @@ import numpy as np | |||||
| from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | from .._imperative_rt.core2 import SymbolVar, Tensor, apply | ||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.builtin import Elemwise, GetVarShape | |||||
| from . import utils | |||||
| from .indexing import getitem as _getitem | |||||
| from .indexing import setitem as _setitem | |||||
| from .utils import isscalar | |||||
| from .utils import make_shape_tuple as _make_shape_tuple | |||||
| from .utils import setscalar | |||||
| _ElwMod = Elemwise.Mode | |||||
| from . import amp | |||||
| from .indexing import getitem, setitem | |||||
| from .utils import ( | |||||
| _normalize_axis, | |||||
| astensor1d, | |||||
| astype, | |||||
| cast_tensors, | |||||
| convert_inputs, | |||||
| isscalar, | |||||
| make_shape_tuple, | |||||
| setscalar, | |||||
| ) | |||||
| _ElwMod = builtin.Elemwise.Mode | |||||
| def _elwise_apply(args, mode): | def _elwise_apply(args, mode): | ||||
| @@ -40,47 +45,59 @@ def _elwise_apply(args, mode): | |||||
| def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
| args = convert_inputs(*args) | |||||
| if mode in ( | if mode in ( | ||||
| _ElwMod.TRUE_DIV, | _ElwMod.TRUE_DIV, | ||||
| _ElwMod.EXP, | |||||
| _ElwMod.POW, | _ElwMod.POW, | ||||
| _ElwMod.CEIL, | |||||
| _ElwMod.FLOOR, | |||||
| _ElwMod.ROUND, | |||||
| _ElwMod.LOG, | |||||
| _ElwMod.EXPM1, | |||||
| _ElwMod.LOG1P, | |||||
| _ElwMod.TANH, | |||||
| _ElwMod.ACOS, | |||||
| _ElwMod.ASIN, | |||||
| _ElwMod.ATAN2, | |||||
| _ElwMod.COS, | |||||
| _ElwMod.H_SWISH, | |||||
| _ElwMod.SIGMOID, | |||||
| _ElwMod.SIN, | |||||
| ) and ( | |||||
| amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||||
| ): | ): | ||||
| if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( | |||||
| args[0].dtype, np.integer | |||||
| ): | |||||
| return args[0] | |||||
| args = tuple( | |||||
| map( | |||||
| lambda x: x.astype("float32") | |||||
| if hasattr(x, "dtype") and x.dtype != np.float32 | |||||
| else x, | |||||
| args, | |||||
| ) | |||||
| ) | |||||
| args = utils.convert_inputs(*args) | |||||
| # autocast to FP32 to maintain precision | |||||
| # or to avoid op's not supporting all int args | |||||
| args = cast_tensors(*args, promote=True) | |||||
| if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( | |||||
| args[0].dtype, np.integer | |||||
| ): | |||||
| return args[0] | |||||
| return _elwise_apply(args, mode) | return _elwise_apply(args, mode) | ||||
| def _matmul(inp1, inp2): | def _matmul(inp1, inp2): | ||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||||
| else: | |||||
| compute_mode = "default" | |||||
| inp1, inp2 = convert_inputs(inp1, inp2) | |||||
| op = builtin.MatrixMul( | op = builtin.MatrixMul( | ||||
| transposeA=False, transposeB=False, compute_mode="default", format="default" | |||||
| transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||||
| ) | ) | ||||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
| (result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
| return result | return result | ||||
| def _transpose(data, axes): | def _transpose(data, axes): | ||||
| op = builtin.Dimshuffle(axes) | op = builtin.Dimshuffle(axes) | ||||
| (data,) = utils.convert_inputs(data) | |||||
| (data,) = convert_inputs(data) | |||||
| (result,) = apply(op, data) | (result,) = apply(op, data) | ||||
| return result | return result | ||||
| def _broadcast(inp, shape): | def _broadcast(inp, shape): | ||||
| shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||||
| shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(builtin.Broadcast(), inp, shape) | (result,) = apply(builtin.Broadcast(), inp, shape) | ||||
| return result | return result | ||||
| @@ -88,7 +105,7 @@ def _broadcast(inp, shape): | |||||
| def _reshape(x, shape): | def _reshape(x, shape): | ||||
| unspec_axis = None | unspec_axis = None | ||||
| try: | try: | ||||
| shape_tuple = _make_shape_tuple(shape) | |||||
| shape_tuple = make_shape_tuple(shape) | |||||
| except ValueError: | except ValueError: | ||||
| pass | pass | ||||
| else: | else: | ||||
| @@ -102,7 +119,7 @@ def _reshape(x, shape): | |||||
| "multiple -1 in shape: {} & {}".format(unspec_axis, i) | "multiple -1 in shape: {} & {}".format(unspec_axis, i) | ||||
| ) | ) | ||||
| unspec_axis = i | unspec_axis = i | ||||
| shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||||
| shape = astensor1d(shape, x, dtype="int32", device=x.device) | |||||
| if unspec_axis is None: | if unspec_axis is None: | ||||
| op = builtin.Reshape() | op = builtin.Reshape() | ||||
| else: | else: | ||||
| @@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
| return list(map(int, axis)) | return list(map(int, axis)) | ||||
| axis = get_axes() | axis = get_axes() | ||||
| axis = utils._normalize_axis(inp.ndim, axis) | |||||
| axis = _normalize_axis(inp.ndim, axis) | |||||
| axis = [a - i for i, a in enumerate(axis)] | axis = [a - i for i, a in enumerate(axis)] | ||||
| op = builtin.RemoveAxis(axis=axis) | op = builtin.RemoveAxis(axis=axis) | ||||
| @@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
| def _reduce(mode): | def _reduce(mode): | ||||
| def f(self, axis=None, keepdims: bool = False): | def f(self, axis=None, keepdims: bool = False): | ||||
| data = self | data = self | ||||
| (data,) = utils.convert_inputs(data) | |||||
| (data,) = convert_inputs(data) | |||||
| if mode == "mean": | if mode == "mean": | ||||
| data = data.astype("float32") | data = data.astype("float32") | ||||
| elif self.dtype == np.bool_: | elif self.dtype == np.bool_: | ||||
| @@ -196,7 +213,7 @@ def _reduce(mode): | |||||
| op = builtin.Reduce(mode=mode, axis=0) | op = builtin.Reduce(mode=mode, axis=0) | ||||
| (result,) = apply(op, data) | (result,) = apply(op, data) | ||||
| elif isinstance(axis, collections.abc.Iterable): | elif isinstance(axis, collections.abc.Iterable): | ||||
| axis = utils._normalize_axis(self.ndim, axis, reverse=True) | |||||
| axis = _normalize_axis(self.ndim, axis, reverse=True) | |||||
| for ai in axis: | for ai in axis: | ||||
| op = builtin.Reduce(mode=mode, axis=ai) | op = builtin.Reduce(mode=mode, axis=ai) | ||||
| (data,) = apply(op, data) | (data,) = apply(op, data) | ||||
| @@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC): | |||||
| yield self[i] | yield self[i] | ||||
| def __getitem__(self, index): | def __getitem__(self, index): | ||||
| return _getitem(self, index) | |||||
| return getitem(self, index) | |||||
| def __setitem__(self, index, value): | def __setitem__(self, index, value): | ||||
| if index is not Ellipsis: | if index is not Ellipsis: | ||||
| value = _setitem(self, index, value) | |||||
| value = setitem(self, index, value) | |||||
| self._reset(value) | self._reset(value) | ||||
| __contains__ = _todo | __contains__ = _todo | ||||
| @@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| Returns a :class:`Tensor` with the same data and number of elements | Returns a :class:`Tensor` with the same data and number of elements | ||||
| with the specified :attr:`~.Tensor.dtype`. | with the specified :attr:`~.Tensor.dtype`. | ||||
| """ | """ | ||||
| return utils.astype(self, dtype) | |||||
| return astype(self, dtype) | |||||
| def reshape(self, *args): | def reshape(self, *args): | ||||
| r""" | r""" | ||||
| @@ -18,7 +18,7 @@ import numpy as np | |||||
| from .. import _imperative_rt | from .. import _imperative_rt | ||||
| from .._imperative_rt import GraphOptimizeOptions | from .._imperative_rt import GraphOptimizeOptions | ||||
| from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | ||||
| from .._wrap import device as as_device | |||||
| from .._wrap import as_device | |||||
| from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
| from .core import TensorBase | from .core import TensorBase | ||||
| @@ -13,9 +13,10 @@ import numpy as np | |||||
| from .._imperative_rt import make_const | from .._imperative_rt import make_const | ||||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | ||||
| from .._wrap import device as as_device | |||||
| from .._wrap import as_device | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from .amp import _high_prec_dtype, _low_prec_dtype | |||||
| from .dtype import is_dtype_equal, is_quantize | from .dtype import is_dtype_equal, is_quantize | ||||
| _enable_convert_inputs = True | _enable_convert_inputs = True | ||||
| @@ -98,6 +99,14 @@ def convert_inputs(*args, device=None): | |||||
| return tuple(map(convert, args)) | return tuple(map(convert, args)) | ||||
| def cast_tensors(*args, promote=False): | |||||
| if promote: | |||||
| dtype = _high_prec_dtype | |||||
| else: | |||||
| dtype = _low_prec_dtype | |||||
| return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | |||||
| def result_type(*args): | def result_type(*args): | ||||
| dtypes = [] | dtypes = [] | ||||
| for i in args: | for i in args: | ||||
| @@ -12,10 +12,8 @@ import numpy as np | |||||
| from ..core._imperative_rt.core2 import SymbolVar, apply | from ..core._imperative_rt.core2 import SymbolVar, apply | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
| from ..core.tensor import utils | |||||
| from ..core.tensor.array_method import _elwise_apply | |||||
| from ..core.tensor.utils import astype | |||||
| from ..device import get_default_device | |||||
| from ..core.tensor.array_method import _elwise | |||||
| from ..core.tensor.utils import astype, convert_inputs | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.deprecation import deprecated_func | from ..utils.deprecation import deprecated_func | ||||
| @@ -69,46 +67,9 @@ __all__ = [ | |||||
| ] | ] | ||||
| def _elwise(*args, mode): | |||||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||||
| if len(tensor_args) == 0: | |||||
| dtype = utils.dtype_promotion(args) | |||||
| first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||||
| args = utils.convert_inputs(first_arg, *args[1:]) | |||||
| else: | |||||
| args = utils.convert_inputs(*args) | |||||
| if mode in ( | |||||
| Elemwise.Mode.TRUE_DIV, | |||||
| Elemwise.Mode.EXP, | |||||
| Elemwise.Mode.POW, | |||||
| Elemwise.Mode.LOG, | |||||
| Elemwise.Mode.EXPM1, | |||||
| Elemwise.Mode.LOG1P, | |||||
| Elemwise.Mode.TANH, | |||||
| Elemwise.Mode.ACOS, | |||||
| Elemwise.Mode.ASIN, | |||||
| Elemwise.Mode.ATAN2, | |||||
| Elemwise.Mode.CEIL, | |||||
| Elemwise.Mode.COS, | |||||
| Elemwise.Mode.FLOOR, | |||||
| Elemwise.Mode.H_SWISH, | |||||
| Elemwise.Mode.ROUND, | |||||
| Elemwise.Mode.SIGMOID, | |||||
| Elemwise.Mode.SIN, | |||||
| ): | |||||
| if mode in ( | |||||
| Elemwise.Mode.CEIL, | |||||
| Elemwise.Mode.FLOOR, | |||||
| Elemwise.Mode.ROUND, | |||||
| ) and np.issubdtype(args[0].dtype, np.integer): | |||||
| return args[0] | |||||
| args = tuple(map(lambda x: astype(x, "float32"), args)) | |||||
| return _elwise_apply(args, mode) | |||||
| def _elemwise_multi_type(*args, mode, **kwargs): | def _elemwise_multi_type(*args, mode, **kwargs): | ||||
| op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | ||||
| args = utils.convert_inputs(*args) | |||||
| args = convert_inputs(*args) | |||||
| (result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
| return result | return result | ||||
| @@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply | |||||
| from ..core._trace_option import use_symbolic_shape | from ..core._trace_option import use_symbolic_shape | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import utils | |||||
| from ..core.tensor import amp | |||||
| from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .debug_param import get_execution_strategy | from .debug_param import get_execution_strategy | ||||
| from .elemwise import clip, exp, log, log1p | from .elemwise import clip, exp, log, log1p | ||||
| @@ -471,7 +472,7 @@ def argmin( | |||||
| inp = inp.flatten() | inp = inp.flatten() | ||||
| axis = 0 | axis = 0 | ||||
| axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||||
| axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||||
| if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
| for ai in axis: | for ai in axis: | ||||
| @@ -528,7 +529,7 @@ def argmax( | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | assert not keepdims, "can not set axis=None and keepdims=True" | ||||
| inp = inp.flatten() | inp = inp.flatten() | ||||
| axis = 0 | axis = 0 | ||||
| axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||||
| axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||||
| if isinstance(axis, collections.abc.Iterable): | if isinstance(axis, collections.abc.Iterable): | ||||
| @@ -807,8 +808,13 @@ def matmul( | |||||
| [28. 40.]] | [28. 40.]] | ||||
| """ | """ | ||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||||
| else: | |||||
| inp1, inp2 = convert_inputs(inp1, inp2) | |||||
| remove_row, remove_col = False, False | remove_row, remove_col = False, False | ||||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
| dim1, dim2 = inp1.ndim, inp2.ndim | dim1, dim2 = inp1.ndim, inp2.ndim | ||||
| # handle dim=1 cases, dot and matrix-vector multiplication | # handle dim=1 cases, dot and matrix-vector multiplication | ||||
| @@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||||
| """ | """ | ||||
| op = builtin.Dot() | op = builtin.Dot() | ||||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
| inp1, inp2 = convert_inputs(inp1, inp2) | |||||
| assert ( | assert ( | ||||
| inp1.ndim <= 1 and inp2.ndim <= 1 | inp1.ndim <= 1 and inp2.ndim <= 1 | ||||
| ), "Input tensors for dot must be 1-dimensional or scalar" | ), "Input tensors for dot must be 1-dimensional or scalar" | ||||
| (result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
| utils.setscalar(result) | |||||
| setscalar(result) | |||||
| return result | return result | ||||
| @@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import BatchNorm, Elemwise | from ..core.ops.builtin import BatchNorm, Elemwise | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import megbrain_graph, utils | |||||
| from ..core.tensor import amp, megbrain_graph | |||||
| from ..core.tensor.array_method import _elwise_apply | from ..core.tensor.array_method import _elwise_apply | ||||
| from ..core.tensor.utils import astensor1d, astype, setscalar | |||||
| from ..core.tensor.utils import ( | |||||
| astensor1d, | |||||
| astype, | |||||
| cast_tensors, | |||||
| convert_inputs, | |||||
| convert_single_value, | |||||
| setscalar, | |||||
| ) | |||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
| from ..random import uniform | from ..random import uniform | ||||
| @@ -91,7 +98,9 @@ def expand_hw(x): | |||||
| return int(h), int(w) | return int(h), int(w) | ||||
| def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | |||||
| def linear( | |||||
| inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default", | |||||
| ) -> Tensor: | |||||
| """ | """ | ||||
| Applies a linear transformation to the input tensor. | Applies a linear transformation to the input tensor. | ||||
| @@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor | |||||
| :param bias: bias with shape `(out_features,)`. | :param bias: bias with shape `(out_features,)`. | ||||
| Default: None | Default: None | ||||
| """ | """ | ||||
| ret = matmul(inp, weight, transpose_b=True) | |||||
| ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||||
| if bias is not None: | if bias is not None: | ||||
| if amp._enabled: | |||||
| bias = bias.astype("float16") | |||||
| ret += bias | ret += bias | ||||
| return ret | return ret | ||||
| @@ -153,6 +164,11 @@ def conv1d( | |||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | ||||
| assert inp.ndim == 3, "the input dimension of conv1d should be 3" | assert inp.ndim == 3, "the input dimension of conv1d should be 3" | ||||
| assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | ||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
| else: | |||||
| inp, weight = convert_inputs(inp, weight) | |||||
| inp = expand_dims(inp, 3) | inp = expand_dims(inp, 3) | ||||
| weight = expand_dims(weight, 3) | weight = expand_dims(weight, 3) | ||||
| @@ -177,7 +193,6 @@ def conv1d( | |||||
| compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
| sparse=sparse_type, | sparse=sparse_type, | ||||
| ) | ) | ||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| @@ -228,7 +243,11 @@ def conv2d( | |||||
| conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
| or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
| ) | ) | ||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
| else: | |||||
| inp, weight = convert_inputs(inp, weight) | |||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| @@ -247,7 +266,6 @@ def conv2d( | |||||
| compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
| sparse=sparse_type, | sparse=sparse_type, | ||||
| ) | ) | ||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| @@ -286,6 +304,7 @@ def conv3d( | |||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| assert conv_mode.lower() == "cross_correlation" | assert conv_mode.lower() == "cross_correlation" | ||||
| inp, weight = convert_inputs(inp, weight) | |||||
| D, H, W = 0, 1, 2 | D, H, W = 0, 1, 2 | ||||
| @@ -308,7 +327,6 @@ def conv3d( | |||||
| mode=conv_mode, | mode=conv_mode, | ||||
| sparse=sparse_type, | sparse=sparse_type, | ||||
| ) | ) | ||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| @@ -358,7 +376,11 @@ def conv_transpose2d( | |||||
| conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
| or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
| ) | ) | ||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
| else: | |||||
| inp, weight = convert_inputs(inp, weight) | |||||
| if groups != 1: | if groups != 1: | ||||
| raise NotImplementedError("group transposed conv2d is not supported yet.") | raise NotImplementedError("group transposed conv2d is not supported yet.") | ||||
| @@ -375,8 +397,8 @@ def conv_transpose2d( | |||||
| dilate_h=dilate_h, | dilate_h=dilate_h, | ||||
| dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
| strategy=get_execution_strategy(), | strategy=get_execution_strategy(), | ||||
| compute_mode=compute_mode, | |||||
| ) | ) | ||||
| weight, inp = utils.convert_inputs(weight, inp) | |||||
| (output,) = apply(op, weight, inp) | (output,) = apply(op, weight, inp) | ||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| @@ -428,7 +450,11 @@ def deformable_conv2d( | |||||
| conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
| or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
| ) | ) | ||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | |||||
| else: | |||||
| inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask) | |||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| @@ -447,7 +473,6 @@ def deformable_conv2d( | |||||
| compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
| sparse=sparse_type, | sparse=sparse_type, | ||||
| ) | ) | ||||
| inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask) | |||||
| (output,) = apply(op, inp, weight, offset, mask) | (output,) = apply(op, inp, weight, offset, mask) | ||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| @@ -468,6 +493,7 @@ def local_conv2d( | |||||
| conv_mode.lower() == "cross_correlation" | conv_mode.lower() == "cross_correlation" | ||||
| or conv_mode.name == "CROSS_CORRELATION" | or conv_mode.name == "CROSS_CORRELATION" | ||||
| ) | ) | ||||
| inp, weight = convert_inputs(inp, weight) | |||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| @@ -481,10 +507,8 @@ def local_conv2d( | |||||
| dilate_h=dilate_h, | dilate_h=dilate_h, | ||||
| dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
| mode=conv_mode, | mode=conv_mode, | ||||
| compute_mode="default", | |||||
| sparse="dense", | sparse="dense", | ||||
| ) | ) | ||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| @@ -515,8 +539,9 @@ def conv_transpose3d( | |||||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | :param dilation: dilation of the 3D convolution operation. Default: 1 | ||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| D, H, W = 0, 1, 2 | |||||
| inp, weight = convert_inputs(inp, weight) | |||||
| D, H, W = 0, 1, 2 | |||||
| pad = _triple(padding) | pad = _triple(padding) | ||||
| stride = _triple_nonzero(stride) | stride = _triple_nonzero(stride) | ||||
| dilate = _triple_nonzero(dilation) | dilate = _triple_nonzero(dilation) | ||||
| @@ -533,7 +558,6 @@ def conv_transpose3d( | |||||
| dilate_w=dilate[W], | dilate_w=dilate[W], | ||||
| strategy=get_execution_strategy(), | strategy=get_execution_strategy(), | ||||
| ) | ) | ||||
| weight, inp = utils.convert_inputs(weight, inp) | |||||
| (output,) = apply(op, weight, inp) | (output,) = apply(op, weight, inp) | ||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| @@ -994,7 +1018,8 @@ def batch_norm( | |||||
| training: bool = False, | training: bool = False, | ||||
| momentum: float = 0.9, | momentum: float = 0.9, | ||||
| eps: float = 1e-5, | eps: float = 1e-5, | ||||
| inplace: bool = True | |||||
| inplace: bool = True, | |||||
| compute_mode="default", | |||||
| ): | ): | ||||
| r""" | r""" | ||||
| Applies batch normalization to the input. | Applies batch normalization to the input. | ||||
| @@ -1027,15 +1052,11 @@ def batch_norm( | |||||
| def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
| if x is None: | if x is None: | ||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | ||||
| shape = utils.astensor1d( | |||||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | |||||
| ) | |||||
| shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
| return result | return result | ||||
| elif x.ndim == 1: | elif x.ndim == 1: | ||||
| shape = utils.astensor1d( | |||||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | |||||
| ) | |||||
| shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(builtin.Reshape(), x, shape) | (result,) = apply(builtin.Reshape(), x, shape) | ||||
| return result | return result | ||||
| return x | return x | ||||
| @@ -1052,10 +1073,15 @@ def batch_norm( | |||||
| if has_var and running_var.ndim != 4: | if has_var and running_var.ndim != 4: | ||||
| raise ValueError | raise ValueError | ||||
| inp, weight, bias, running_mean, running_var = utils.convert_inputs( | |||||
| inp, weight, bias, running_mean, running_var | |||||
| ) | |||||
| if amp._enabled: | |||||
| inp = inp.astype("float16") | |||||
| weight, bias, running_mean, running_var = cast_tensors( | |||||
| weight, bias, running_mean, running_var, promote=True | |||||
| ) | |||||
| elif compute_mode != "float32": | |||||
| inp, weight, bias, running_mean, running_var = convert_inputs( | |||||
| inp, weight, bias, running_mean, running_var | |||||
| ) | |||||
| weight = make_full_if_none(weight, 1) | weight = make_full_if_none(weight, 1) | ||||
| bias = make_full_if_none(bias, 0) | bias = make_full_if_none(bias, 0) | ||||
| @@ -1352,7 +1378,7 @@ def indexing_one_hot( | |||||
| """ | """ | ||||
| assert isinstance(src, Tensor), "src must be of Tensor type" | assert isinstance(src, Tensor), "src must be of Tensor type" | ||||
| op = builtin.IndexingOneHot(axis=axis) | op = builtin.IndexingOneHot(axis=axis) | ||||
| index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||||
| index = convert_single_value(index, dtype="int32", device=src.device) | |||||
| (result,) = apply(op, src, index) | (result,) = apply(op, src, index) | ||||
| if not keepdims: | if not keepdims: | ||||
| result = squeeze(result, axis) | result = squeeze(result, axis) | ||||
| @@ -13,7 +13,7 @@ import numpy as np | |||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core._imperative_rt.core2 import SymbolVar, apply | from ..core._imperative_rt.core2 import SymbolVar, apply | ||||
| from ..core._wrap import device as as_device | |||||
| from ..core._wrap import as_device | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| @@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import ( | |||||
| RemoteSend, | RemoteSend, | ||||
| ) | ) | ||||
| from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
| from ..core._wrap import device as as_device | |||||
| from ..core._wrap import as_device | |||||
| from ..core.ops.builtin import BatchNorm, OpDef | from ..core.ops.builtin import BatchNorm, OpDef | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| @@ -26,6 +26,7 @@ class _BatchNorm(Module): | |||||
| affine=True, | affine=True, | ||||
| track_running_stats=True, | track_running_stats=True, | ||||
| freeze=False, | freeze=False, | ||||
| compute_mode="default", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| super(_BatchNorm, self).__init__(**kwargs) | super(_BatchNorm, self).__init__(**kwargs) | ||||
| @@ -36,6 +37,7 @@ class _BatchNorm(Module): | |||||
| self.track_running_stats = track_running_stats | self.track_running_stats = track_running_stats | ||||
| self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
| self.freeze = freeze | self.freeze = freeze | ||||
| self.compute_mode = compute_mode | |||||
| if self.freeze: | if self.freeze: | ||||
| assert ( | assert ( | ||||
| self._track_running_stats_saved | self._track_running_stats_saved | ||||
| @@ -123,6 +125,7 @@ class _BatchNorm(Module): | |||||
| or ((self.running_mean is None) and (self.running_var is None)), | or ((self.running_mean is None) and (self.running_var is None)), | ||||
| momentum=exponential_average_factor, | momentum=exponential_average_factor, | ||||
| eps=self.eps, | eps=self.eps, | ||||
| compute_mode=self.compute_mode, | |||||
| ) | ) | ||||
| if _ndims != 4: | if _ndims != 4: | ||||
| @@ -51,7 +51,12 @@ class Linear(Module): | |||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||||
| self, | |||||
| in_features: int, | |||||
| out_features: int, | |||||
| bias: bool = True, | |||||
| compute_mode: str = "default", | |||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| self.out_features = out_features | self.out_features = out_features | ||||
| @@ -62,6 +67,7 @@ class Linear(Module): | |||||
| if bias: | if bias: | ||||
| b_shape = (out_features,) | b_shape = (out_features,) | ||||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | ||||
| self.compute_mode = compute_mode | |||||
| self.reset_parameters() | self.reset_parameters() | ||||
| def _get_fanin(self): | def _get_fanin(self): | ||||
| @@ -75,7 +81,7 @@ class Linear(Module): | |||||
| init.zeros_(self.bias) | init.zeros_(self.bias) | ||||
| def _calc_linear(self, x, weight, bias): | def _calc_linear(self, x, weight, bias): | ||||
| return linear(x, weight, bias) | |||||
| return linear(x, weight, bias, compute_mode=self.compute_mode) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| return self._calc_linear(x, self.weight, self.bias) | return self._calc_linear(x, self.weight, self.bias) | ||||
| @@ -5,8 +5,6 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import copy | |||||
| import math | |||||
| from functools import partial | from functools import partial | ||||
| from .. import functional as F | from .. import functional as F | ||||
| @@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode | |||||
| from .core._imperative_rt.core2 import Tensor as _Tensor | from .core._imperative_rt.core2 import Tensor as _Tensor | ||||
| from .core._imperative_rt.core2 import apply | from .core._imperative_rt.core2 import apply | ||||
| from .core._trace_option import use_symbolic_shape | from .core._trace_option import use_symbolic_shape | ||||
| from .core._wrap import device as as_device | |||||
| from .core._wrap import as_device | |||||
| from .core.ops.builtin import Copy, GetVarShape | from .core.ops.builtin import Copy, GetVarShape | ||||
| from .core.tensor.array_method import ArrayMethodMixin | from .core.tensor.array_method import ArrayMethodMixin | ||||
| from .device import _valid_device, get_default_device | from .device import _valid_device, get_default_device | ||||
| @@ -0,0 +1,34 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from megengine import amp | |||||
| from megengine.core.tensor import amp as origin_amp | |||||
| def test_grad_scaler(): | |||||
| def check(enabled, low, high): | |||||
| assert amp.enabled == enabled | |||||
| assert origin_amp._enabled == enabled | |||||
| assert amp.low_prec_dtype == low | |||||
| assert origin_amp._low_prec_dtype == low | |||||
| assert amp.high_prec_dtype == high | |||||
| assert origin_amp._high_prec_dtype == high | |||||
| origin_enabled = amp.enabled | |||||
| origin_high = amp.high_prec_dtype | |||||
| origin_low = amp.low_prec_dtype | |||||
| with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): | |||||
| check(True, "low", "high") | |||||
| check(origin_enabled, origin_low, origin_high) | |||||
| amp.enabled = True | |||||
| amp.high_prec_dtype = "high" | |||||
| amp.low_prec_dtype = "low" | |||||
| check(True, "low", "high") | |||||
| amp.enabled = origin_enabled | |||||
| amp.high_prec_dtype = origin_high | |||||
| amp.low_prec_dtype = origin_low | |||||
| check(origin_enabled, origin_low, origin_high) | |||||
| @@ -14,6 +14,7 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| from utils import opr_test | from utils import opr_test | ||||
| import megengine.amp as amp | |||||
| import megengine.core.ops.builtin as builtin | import megengine.core.ops.builtin as builtin | ||||
| import megengine.core.tensor.dtype as dtype | import megengine.core.tensor.dtype as dtype | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| @@ -767,6 +768,27 @@ def test_batch_conv_bias(): | |||||
| run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | ||||
| def test_conv2d_io16c32(): | |||||
| amp.enabled = True | |||||
| inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||||
| weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) | |||||
| out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | |||||
| amp.enabled = False | |||||
| expected = F.conv2d( | |||||
| inp.astype("float16"), | |||||
| weight.astype("float16"), | |||||
| None, | |||||
| (2, 2), | |||||
| (3, 3), | |||||
| (1, 1), | |||||
| 1, | |||||
| compute_mode="float32", | |||||
| ) | |||||
| assert out.dtype == np.float16 | |||||
| assert expected.dtype == np.float16 | |||||
| np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||||
| def test_conv2d_zero_stride_numpy_array(): | def test_conv2d_zero_stride_numpy_array(): | ||||
| inp = np.random.randn(3, 224, 224).astype(np.float32) | inp = np.random.randn(3, 224, 224).astype(np.float32) | ||||
| inp = inp[np.newaxis, :] | inp = inp[np.newaxis, :] | ||||
| @@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array(): | |||||
| def test_conv1d(): | def test_conv1d(): | ||||
| inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | |||||
| weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | |||||
| inp = tensor(np.ones((2, 2, 4), dtype=np.float32)) | |||||
| weight = tensor(np.ones((3, 2, 2), dtype=np.float32)) | |||||
| out = F.conv1d(inp, weight, None, 2, 0, 1, 1) | out = F.conv1d(inp, weight, None, 2, 0, 1, 1) | ||||
| np.testing.assert_equal( | np.testing.assert_equal( | ||||
| out.numpy(), | out.numpy(), | ||||
| @@ -798,9 +820,31 @@ def test_conv1d(): | |||||
| ) | ) | ||||
| def test_batchnorm2d_io16c32(): | |||||
| amp.enabled = True | |||||
| inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||||
| weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) | |||||
| bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) | |||||
| out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) | |||||
| amp.enabled = False | |||||
| expected = F.batch_norm( | |||||
| inp.astype("float16"), | |||||
| weight=weight, | |||||
| bias=bias, | |||||
| training=True, | |||||
| inplace=False, | |||||
| compute_mode="float32", | |||||
| ) | |||||
| assert out.dtype == np.float16 | |||||
| assert expected.dtype == np.float16 | |||||
| np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||||
| def test_conv3d(): | def test_conv3d(): | ||||
| inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) | |||||
| weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) | |||||
| inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32)) | |||||
| weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32)) | |||||
| out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | ||||
| print(out.numpy().shape) | print(out.numpy().shape) | ||||
| np.testing.assert_equal( | np.testing.assert_equal( | ||||
| @@ -473,39 +473,6 @@ def test_pickle_module(): | |||||
| np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | ||||
| def test_load_quantized(): | |||||
| from megengine.core.tensor import dtype | |||||
| data_shape = (2, 28) | |||||
| data = tensor(np.random.random(data_shape), dtype="float32") | |||||
| data = data.astype(dtype.qint8(0.1)) | |||||
| mlp = MLP() | |||||
| quantize_qat(mlp) | |||||
| quantize(mlp) | |||||
| mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) | |||||
| mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) | |||||
| mlp.eval() | |||||
| pred0 = mlp(data) | |||||
| with BytesIO() as fout: | |||||
| mge.save(mlp.state_dict(), fout) | |||||
| fout.seek(0) | |||||
| checkpoint = mge.load(fout) | |||||
| # change mlp weight. | |||||
| mlp.dense0.weight = Parameter( | |||||
| mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() | |||||
| ) | |||||
| mlp.dense1.weight = Parameter( | |||||
| mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() | |||||
| ) | |||||
| mlp.load_state_dict(checkpoint) | |||||
| pred1 = mlp(data) | |||||
| np.testing.assert_allclose( | |||||
| pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6 | |||||
| ) | |||||
| def test_repr_basic(): | def test_repr_basic(): | ||||
| # test whether __repr__ can output correct information | # test whether __repr__ can output correct information | ||||
| class ConvModel(Module): | class ConvModel(Module): | ||||