GitOrigin-RevId: 6fbffc4845
tags/v1.5.0
| @@ -117,6 +117,7 @@ def _atexit(handler): | |||
| # subpackages | |||
| import megengine.amp | |||
| import megengine.autodiff | |||
| import megengine.data | |||
| import megengine.distributed | |||
| @@ -0,0 +1,14 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import mprop | |||
| from ..core.tensor.amp import * | |||
| from .autocast import autocast | |||
| mprop.init() | |||
| @@ -0,0 +1,79 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| from ..core.tensor import amp | |||
| class autocast: | |||
| r""" | |||
| A class to control autocast mode for amp as a context manager or a decorator. | |||
| :param enabled: Whether autocast mode is enabled. | |||
| :low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the | |||
| target dtype in tensor casting for better speed and memory. Default: float16. | |||
| :high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the | |||
| target dtype in tensor casting for better precision. Default: float32. | |||
| Examples: | |||
| ..code-block:: | |||
| # used as decorator | |||
| @autocast() | |||
| def train_step(image, label): | |||
| with gm: | |||
| logits = model(image) | |||
| loss = F.nn.cross_entropy(logits, label) | |||
| gm.backward(loss) | |||
| opt.step().clear_grad() | |||
| return loss | |||
| # used as context manager | |||
| def train_step(image, label): | |||
| with autocast(): | |||
| with gm: | |||
| logits = model(image) | |||
| loss = F.nn.cross_entropy(logits, label) | |||
| gm.backward(loss) | |||
| opt.step().clear_grad() | |||
| return loss | |||
| """ | |||
| def __init__( | |||
| self, | |||
| enabled: bool = True, | |||
| low_prec_dtype: str = "float16", | |||
| high_prec_dtype: str = "float32", | |||
| ): | |||
| self.enabled = enabled | |||
| self.high_prec_dtype = high_prec_dtype | |||
| self.low_prec_dtype = low_prec_dtype | |||
| self._origin_enabled = None | |||
| self._origin_high = None | |||
| self._origin_low = None | |||
| def __enter__(self): | |||
| self._origin_enabled, amp._enabled = amp._enabled, self.enabled | |||
| self._origin_high = amp._high_prec_dtype | |||
| amp._high_prec_dtype = self.high_prec_dtype | |||
| self._origin_low = amp._low_prec_dtype | |||
| amp._low_prec_dtype = self.low_prec_dtype | |||
| def __exit__(self, *args): | |||
| amp._enabled = self._origin_enabled | |||
| amp._high_prec_dtype = self._origin_high | |||
| amp._low_prec_dtype = self._origin_low | |||
| def __call__(self, func): | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| with self: | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||
| @@ -49,7 +49,7 @@ class Device: | |||
| return self._cn == rhs._cn | |||
| def device(obj): | |||
| def as_device(obj): | |||
| if isinstance(obj, Device): | |||
| return obj | |||
| return Device(obj) | |||
| @@ -0,0 +1,78 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| _enabled = False | |||
| _high_prec_dtype = "float32" | |||
| _low_prec_dtype = "float16" | |||
| @property | |||
| def enabled(mod): | |||
| r""" | |||
| Get or set amp autocast mode enabled or not. | |||
| Examples: | |||
| ..code-block:: | |||
| import megengine as mge | |||
| mge.amp.enabled = True | |||
| """ | |||
| return _enabled | |||
| @enabled.setter | |||
| def enabled(mod, enabled: bool): | |||
| global _enabled | |||
| _enabled = enabled | |||
| @property | |||
| def high_prec_dtype(mod): | |||
| r""" | |||
| Get or set amp autocast mode's higher precision dtype. It will change the | |||
| target dtype in tensor casting for better precision. Default: float32. | |||
| Examples: | |||
| ..code-block:: | |||
| import megengine as mge | |||
| mge.amp.high_prec_dtype = "float32" | |||
| """ | |||
| return _high_prec_dtype | |||
| @high_prec_dtype.setter | |||
| def high_prec_dtype(mod, dtype: str): | |||
| global _high_prec_dtype | |||
| _high_prec_dtype = dtype | |||
| @property | |||
| def low_prec_dtype(mod): | |||
| r""" | |||
| Get or set amp autocast mode's lower precision dtype. It will change the | |||
| target dtype in tensor casting for better speed and memory. Default: float16. | |||
| Examples: | |||
| ..code-block:: | |||
| import megengine as mge | |||
| mge.amp.low_prec_dtype = "float16" | |||
| """ | |||
| return _low_prec_dtype | |||
| @low_prec_dtype.setter | |||
| def low_prec_dtype(mod, dtype: str): | |||
| global _low_prec_dtype | |||
| _low_prec_dtype = dtype | |||
| @@ -15,15 +15,20 @@ import numpy as np | |||
| from .._imperative_rt.common import CompNode | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
| from ..ops import builtin | |||
| from ..ops.builtin import Elemwise, GetVarShape | |||
| from . import utils | |||
| from .indexing import getitem as _getitem | |||
| from .indexing import setitem as _setitem | |||
| from .utils import isscalar | |||
| from .utils import make_shape_tuple as _make_shape_tuple | |||
| from .utils import setscalar | |||
| _ElwMod = Elemwise.Mode | |||
| from . import amp | |||
| from .indexing import getitem, setitem | |||
| from .utils import ( | |||
| _normalize_axis, | |||
| astensor1d, | |||
| astype, | |||
| cast_tensors, | |||
| convert_inputs, | |||
| isscalar, | |||
| make_shape_tuple, | |||
| setscalar, | |||
| ) | |||
| _ElwMod = builtin.Elemwise.Mode | |||
| def _elwise_apply(args, mode): | |||
| @@ -40,47 +45,59 @@ def _elwise_apply(args, mode): | |||
| def _elwise(*args, mode): | |||
| args = convert_inputs(*args) | |||
| if mode in ( | |||
| _ElwMod.TRUE_DIV, | |||
| _ElwMod.EXP, | |||
| _ElwMod.POW, | |||
| _ElwMod.CEIL, | |||
| _ElwMod.FLOOR, | |||
| _ElwMod.ROUND, | |||
| _ElwMod.LOG, | |||
| _ElwMod.EXPM1, | |||
| _ElwMod.LOG1P, | |||
| _ElwMod.TANH, | |||
| _ElwMod.ACOS, | |||
| _ElwMod.ASIN, | |||
| _ElwMod.ATAN2, | |||
| _ElwMod.COS, | |||
| _ElwMod.H_SWISH, | |||
| _ElwMod.SIGMOID, | |||
| _ElwMod.SIN, | |||
| ) and ( | |||
| amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||
| ): | |||
| if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( | |||
| args[0].dtype, np.integer | |||
| ): | |||
| return args[0] | |||
| args = tuple( | |||
| map( | |||
| lambda x: x.astype("float32") | |||
| if hasattr(x, "dtype") and x.dtype != np.float32 | |||
| else x, | |||
| args, | |||
| ) | |||
| ) | |||
| args = utils.convert_inputs(*args) | |||
| # autocast to FP32 to maintain precision | |||
| # or to avoid op's not supporting all int args | |||
| args = cast_tensors(*args, promote=True) | |||
| if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( | |||
| args[0].dtype, np.integer | |||
| ): | |||
| return args[0] | |||
| return _elwise_apply(args, mode) | |||
| def _matmul(inp1, inp2): | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||
| else: | |||
| compute_mode = "default" | |||
| inp1, inp2 = convert_inputs(inp1, inp2) | |||
| op = builtin.MatrixMul( | |||
| transposeA=False, transposeB=False, compute_mode="default", format="default" | |||
| transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" | |||
| ) | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| (result,) = apply(op, inp1, inp2) | |||
| return result | |||
| def _transpose(data, axes): | |||
| op = builtin.Dimshuffle(axes) | |||
| (data,) = utils.convert_inputs(data) | |||
| (data,) = convert_inputs(data) | |||
| (result,) = apply(op, data) | |||
| return result | |||
| def _broadcast(inp, shape): | |||
| shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
| shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Broadcast(), inp, shape) | |||
| return result | |||
| @@ -88,7 +105,7 @@ def _broadcast(inp, shape): | |||
| def _reshape(x, shape): | |||
| unspec_axis = None | |||
| try: | |||
| shape_tuple = _make_shape_tuple(shape) | |||
| shape_tuple = make_shape_tuple(shape) | |||
| except ValueError: | |||
| pass | |||
| else: | |||
| @@ -102,7 +119,7 @@ def _reshape(x, shape): | |||
| "multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||
| ) | |||
| unspec_axis = i | |||
| shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||
| shape = astensor1d(shape, x, dtype="int32", device=x.device) | |||
| if unspec_axis is None: | |||
| op = builtin.Reshape() | |||
| else: | |||
| @@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| return list(map(int, axis)) | |||
| axis = get_axes() | |||
| axis = utils._normalize_axis(inp.ndim, axis) | |||
| axis = _normalize_axis(inp.ndim, axis) | |||
| axis = [a - i for i, a in enumerate(axis)] | |||
| op = builtin.RemoveAxis(axis=axis) | |||
| @@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| def _reduce(mode): | |||
| def f(self, axis=None, keepdims: bool = False): | |||
| data = self | |||
| (data,) = utils.convert_inputs(data) | |||
| (data,) = convert_inputs(data) | |||
| if mode == "mean": | |||
| data = data.astype("float32") | |||
| elif self.dtype == np.bool_: | |||
| @@ -196,7 +213,7 @@ def _reduce(mode): | |||
| op = builtin.Reduce(mode=mode, axis=0) | |||
| (result,) = apply(op, data) | |||
| elif isinstance(axis, collections.abc.Iterable): | |||
| axis = utils._normalize_axis(self.ndim, axis, reverse=True) | |||
| axis = _normalize_axis(self.ndim, axis, reverse=True) | |||
| for ai in axis: | |||
| op = builtin.Reduce(mode=mode, axis=ai) | |||
| (data,) = apply(op, data) | |||
| @@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC): | |||
| yield self[i] | |||
| def __getitem__(self, index): | |||
| return _getitem(self, index) | |||
| return getitem(self, index) | |||
| def __setitem__(self, index, value): | |||
| if index is not Ellipsis: | |||
| value = _setitem(self, index, value) | |||
| value = setitem(self, index, value) | |||
| self._reset(value) | |||
| __contains__ = _todo | |||
| @@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC): | |||
| Returns a :class:`Tensor` with the same data and number of elements | |||
| with the specified :attr:`~.Tensor.dtype`. | |||
| """ | |||
| return utils.astype(self, dtype) | |||
| return astype(self, dtype) | |||
| def reshape(self, *args): | |||
| r""" | |||
| @@ -18,7 +18,7 @@ import numpy as np | |||
| from .. import _imperative_rt | |||
| from .._imperative_rt import GraphOptimizeOptions | |||
| from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||
| from .._wrap import device as as_device | |||
| from .._wrap import as_device | |||
| from ..ops.builtin import OpDef | |||
| from .core import TensorBase | |||
| @@ -13,9 +13,10 @@ import numpy as np | |||
| from .._imperative_rt import make_const | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | |||
| from .._wrap import device as as_device | |||
| from .._wrap import as_device | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .amp import _high_prec_dtype, _low_prec_dtype | |||
| from .dtype import is_dtype_equal, is_quantize | |||
| _enable_convert_inputs = True | |||
| @@ -98,6 +99,14 @@ def convert_inputs(*args, device=None): | |||
| return tuple(map(convert, args)) | |||
| def cast_tensors(*args, promote=False): | |||
| if promote: | |||
| dtype = _high_prec_dtype | |||
| else: | |||
| dtype = _low_prec_dtype | |||
| return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | |||
| def result_type(*args): | |||
| dtypes = [] | |||
| for i in args: | |||
| @@ -12,10 +12,8 @@ import numpy as np | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.array_method import _elwise_apply | |||
| from ..core.tensor.utils import astype | |||
| from ..device import get_default_device | |||
| from ..core.tensor.array_method import _elwise | |||
| from ..core.tensor.utils import astype, convert_inputs | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_func | |||
| @@ -69,46 +67,9 @@ __all__ = [ | |||
| ] | |||
| def _elwise(*args, mode): | |||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||
| if len(tensor_args) == 0: | |||
| dtype = utils.dtype_promotion(args) | |||
| first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
| args = utils.convert_inputs(first_arg, *args[1:]) | |||
| else: | |||
| args = utils.convert_inputs(*args) | |||
| if mode in ( | |||
| Elemwise.Mode.TRUE_DIV, | |||
| Elemwise.Mode.EXP, | |||
| Elemwise.Mode.POW, | |||
| Elemwise.Mode.LOG, | |||
| Elemwise.Mode.EXPM1, | |||
| Elemwise.Mode.LOG1P, | |||
| Elemwise.Mode.TANH, | |||
| Elemwise.Mode.ACOS, | |||
| Elemwise.Mode.ASIN, | |||
| Elemwise.Mode.ATAN2, | |||
| Elemwise.Mode.CEIL, | |||
| Elemwise.Mode.COS, | |||
| Elemwise.Mode.FLOOR, | |||
| Elemwise.Mode.H_SWISH, | |||
| Elemwise.Mode.ROUND, | |||
| Elemwise.Mode.SIGMOID, | |||
| Elemwise.Mode.SIN, | |||
| ): | |||
| if mode in ( | |||
| Elemwise.Mode.CEIL, | |||
| Elemwise.Mode.FLOOR, | |||
| Elemwise.Mode.ROUND, | |||
| ) and np.issubdtype(args[0].dtype, np.integer): | |||
| return args[0] | |||
| args = tuple(map(lambda x: astype(x, "float32"), args)) | |||
| return _elwise_apply(args, mode) | |||
| def _elemwise_multi_type(*args, mode, **kwargs): | |||
| op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | |||
| args = utils.convert_inputs(*args) | |||
| args = convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| return result | |||
| @@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply | |||
| from ..core._trace_option import use_symbolic_shape | |||
| from ..core.ops import builtin | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import utils | |||
| from ..core.tensor import amp | |||
| from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar | |||
| from ..tensor import Tensor | |||
| from .debug_param import get_execution_strategy | |||
| from .elemwise import clip, exp, log, log1p | |||
| @@ -471,7 +472,7 @@ def argmin( | |||
| inp = inp.flatten() | |||
| axis = 0 | |||
| axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||
| axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||
| if isinstance(axis, collections.abc.Iterable): | |||
| for ai in axis: | |||
| @@ -528,7 +529,7 @@ def argmax( | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| inp = inp.flatten() | |||
| axis = 0 | |||
| axis = utils._normalize_axis(inp.ndim, axis, reverse=True) | |||
| axis = _normalize_axis(inp.ndim, axis, reverse=True) | |||
| if isinstance(axis, collections.abc.Iterable): | |||
| @@ -807,8 +808,13 @@ def matmul( | |||
| [28. 40.]] | |||
| """ | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||
| else: | |||
| inp1, inp2 = convert_inputs(inp1, inp2) | |||
| remove_row, remove_col = False, False | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||
| # handle dim=1 cases, dot and matrix-vector multiplication | |||
| @@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
| """ | |||
| op = builtin.Dot() | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| inp1, inp2 = convert_inputs(inp1, inp2) | |||
| assert ( | |||
| inp1.ndim <= 1 and inp2.ndim <= 1 | |||
| ), "Input tensors for dot must be 1-dimensional or scalar" | |||
| (result,) = apply(op, inp1, inp2) | |||
| utils.setscalar(result) | |||
| setscalar(result) | |||
| return result | |||
| @@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import BatchNorm, Elemwise | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph, utils | |||
| from ..core.tensor import amp, megbrain_graph | |||
| from ..core.tensor.array_method import _elwise_apply | |||
| from ..core.tensor.utils import astensor1d, astype, setscalar | |||
| from ..core.tensor.utils import ( | |||
| astensor1d, | |||
| astype, | |||
| cast_tensors, | |||
| convert_inputs, | |||
| convert_single_value, | |||
| setscalar, | |||
| ) | |||
| from ..device import get_default_device | |||
| from ..distributed import WORLD, is_distributed | |||
| from ..random import uniform | |||
| @@ -91,7 +98,9 @@ def expand_hw(x): | |||
| return int(h), int(w) | |||
| def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | |||
| def linear( | |||
| inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default", | |||
| ) -> Tensor: | |||
| """ | |||
| Applies a linear transformation to the input tensor. | |||
| @@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor | |||
| :param bias: bias with shape `(out_features,)`. | |||
| Default: None | |||
| """ | |||
| ret = matmul(inp, weight, transpose_b=True) | |||
| ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||
| if bias is not None: | |||
| if amp._enabled: | |||
| bias = bias.astype("float16") | |||
| ret += bias | |||
| return ret | |||
| @@ -153,6 +164,11 @@ def conv1d( | |||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
| assert inp.ndim == 3, "the input dimension of conv1d should be 3" | |||
| assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||
| else: | |||
| inp, weight = convert_inputs(inp, weight) | |||
| inp = expand_dims(inp, 3) | |||
| weight = expand_dims(weight, 3) | |||
| @@ -177,7 +193,6 @@ def conv1d( | |||
| compute_mode=compute_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| inp, weight = utils.convert_inputs(inp, weight) | |||
| (output,) = apply(op, inp, weight) | |||
| if bias is not None: | |||
| output += bias | |||
| @@ -228,7 +243,11 @@ def conv2d( | |||
| conv_mode.lower() == "cross_correlation" | |||
| or conv_mode.name == "CROSS_CORRELATION" | |||
| ) | |||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||
| else: | |||
| inp, weight = convert_inputs(inp, weight) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| pad_h, pad_w = expand_hw(padding) | |||
| @@ -247,7 +266,6 @@ def conv2d( | |||
| compute_mode=compute_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| inp, weight = utils.convert_inputs(inp, weight) | |||
| (output,) = apply(op, inp, weight) | |||
| if bias is not None: | |||
| output += bias | |||
| @@ -286,6 +304,7 @@ def conv3d( | |||
| :return: output tensor. | |||
| """ | |||
| assert conv_mode.lower() == "cross_correlation" | |||
| inp, weight = convert_inputs(inp, weight) | |||
| D, H, W = 0, 1, 2 | |||
| @@ -308,7 +327,6 @@ def conv3d( | |||
| mode=conv_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| inp, weight = utils.convert_inputs(inp, weight) | |||
| (output,) = apply(op, inp, weight) | |||
| if bias is not None: | |||
| output += bias | |||
| @@ -358,7 +376,11 @@ def conv_transpose2d( | |||
| conv_mode.lower() == "cross_correlation" | |||
| or conv_mode.name == "CROSS_CORRELATION" | |||
| ) | |||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||
| else: | |||
| inp, weight = convert_inputs(inp, weight) | |||
| if groups != 1: | |||
| raise NotImplementedError("group transposed conv2d is not supported yet.") | |||
| @@ -375,8 +397,8 @@ def conv_transpose2d( | |||
| dilate_h=dilate_h, | |||
| dilate_w=dilate_w, | |||
| strategy=get_execution_strategy(), | |||
| compute_mode=compute_mode, | |||
| ) | |||
| weight, inp = utils.convert_inputs(weight, inp) | |||
| (output,) = apply(op, weight, inp) | |||
| if bias is not None: | |||
| output += bias | |||
| @@ -428,7 +450,11 @@ def deformable_conv2d( | |||
| conv_mode.lower() == "cross_correlation" | |||
| or conv_mode.name == "CROSS_CORRELATION" | |||
| ) | |||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | |||
| else: | |||
| inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| pad_h, pad_w = expand_hw(padding) | |||
| @@ -447,7 +473,6 @@ def deformable_conv2d( | |||
| compute_mode=compute_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask) | |||
| (output,) = apply(op, inp, weight, offset, mask) | |||
| if bias is not None: | |||
| output += bias | |||
| @@ -468,6 +493,7 @@ def local_conv2d( | |||
| conv_mode.lower() == "cross_correlation" | |||
| or conv_mode.name == "CROSS_CORRELATION" | |||
| ) | |||
| inp, weight = convert_inputs(inp, weight) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| pad_h, pad_w = expand_hw(padding) | |||
| @@ -481,10 +507,8 @@ def local_conv2d( | |||
| dilate_h=dilate_h, | |||
| dilate_w=dilate_w, | |||
| mode=conv_mode, | |||
| compute_mode="default", | |||
| sparse="dense", | |||
| ) | |||
| inp, weight = utils.convert_inputs(inp, weight) | |||
| (output,) = apply(op, inp, weight) | |||
| if bias is not None: | |||
| output += bias | |||
| @@ -515,8 +539,9 @@ def conv_transpose3d( | |||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | |||
| :return: output tensor. | |||
| """ | |||
| D, H, W = 0, 1, 2 | |||
| inp, weight = convert_inputs(inp, weight) | |||
| D, H, W = 0, 1, 2 | |||
| pad = _triple(padding) | |||
| stride = _triple_nonzero(stride) | |||
| dilate = _triple_nonzero(dilation) | |||
| @@ -533,7 +558,6 @@ def conv_transpose3d( | |||
| dilate_w=dilate[W], | |||
| strategy=get_execution_strategy(), | |||
| ) | |||
| weight, inp = utils.convert_inputs(weight, inp) | |||
| (output,) = apply(op, weight, inp) | |||
| if bias is not None: | |||
| output += bias | |||
| @@ -994,7 +1018,8 @@ def batch_norm( | |||
| training: bool = False, | |||
| momentum: float = 0.9, | |||
| eps: float = 1e-5, | |||
| inplace: bool = True | |||
| inplace: bool = True, | |||
| compute_mode="default", | |||
| ): | |||
| r""" | |||
| Applies batch normalization to the input. | |||
| @@ -1027,15 +1052,11 @@ def batch_norm( | |||
| def make_full_if_none(x, value): | |||
| if x is None: | |||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||
| shape = utils.astensor1d( | |||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | |||
| ) | |||
| shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Broadcast(), x, shape) | |||
| return result | |||
| elif x.ndim == 1: | |||
| shape = utils.astensor1d( | |||
| (1, C, 1, 1), inp, dtype="int32", device=inp.device | |||
| ) | |||
| shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Reshape(), x, shape) | |||
| return result | |||
| return x | |||
| @@ -1052,10 +1073,15 @@ def batch_norm( | |||
| if has_var and running_var.ndim != 4: | |||
| raise ValueError | |||
| inp, weight, bias, running_mean, running_var = utils.convert_inputs( | |||
| inp, weight, bias, running_mean, running_var | |||
| ) | |||
| if amp._enabled: | |||
| inp = inp.astype("float16") | |||
| weight, bias, running_mean, running_var = cast_tensors( | |||
| weight, bias, running_mean, running_var, promote=True | |||
| ) | |||
| elif compute_mode != "float32": | |||
| inp, weight, bias, running_mean, running_var = convert_inputs( | |||
| inp, weight, bias, running_mean, running_var | |||
| ) | |||
| weight = make_full_if_none(weight, 1) | |||
| bias = make_full_if_none(bias, 0) | |||
| @@ -1352,7 +1378,7 @@ def indexing_one_hot( | |||
| """ | |||
| assert isinstance(src, Tensor), "src must be of Tensor type" | |||
| op = builtin.IndexingOneHot(axis=axis) | |||
| index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||
| index = convert_single_value(index, dtype="int32", device=src.device) | |||
| (result,) = apply(op, src, index) | |||
| if not keepdims: | |||
| result = squeeze(result, axis) | |||
| @@ -13,7 +13,7 @@ import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core._wrap import device as as_device | |||
| from ..core._wrap import as_device | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Copy, Identity | |||
| from ..core.ops.special import Const | |||
| @@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import ( | |||
| RemoteSend, | |||
| ) | |||
| from ..core._trace_option import set_symbolic_shape | |||
| from ..core._wrap import device as as_device | |||
| from ..core._wrap import as_device | |||
| from ..core.ops.builtin import BatchNorm, OpDef | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| @@ -26,6 +26,7 @@ class _BatchNorm(Module): | |||
| affine=True, | |||
| track_running_stats=True, | |||
| freeze=False, | |||
| compute_mode="default", | |||
| **kwargs | |||
| ): | |||
| super(_BatchNorm, self).__init__(**kwargs) | |||
| @@ -36,6 +37,7 @@ class _BatchNorm(Module): | |||
| self.track_running_stats = track_running_stats | |||
| self._track_running_stats_saved = track_running_stats | |||
| self.freeze = freeze | |||
| self.compute_mode = compute_mode | |||
| if self.freeze: | |||
| assert ( | |||
| self._track_running_stats_saved | |||
| @@ -123,6 +125,7 @@ class _BatchNorm(Module): | |||
| or ((self.running_mean is None) and (self.running_var is None)), | |||
| momentum=exponential_average_factor, | |||
| eps=self.eps, | |||
| compute_mode=self.compute_mode, | |||
| ) | |||
| if _ndims != 4: | |||
| @@ -51,7 +51,12 @@ class Linear(Module): | |||
| """ | |||
| def __init__( | |||
| self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||
| self, | |||
| in_features: int, | |||
| out_features: int, | |||
| bias: bool = True, | |||
| compute_mode: str = "default", | |||
| **kwargs | |||
| ): | |||
| super().__init__(**kwargs) | |||
| self.out_features = out_features | |||
| @@ -62,6 +67,7 @@ class Linear(Module): | |||
| if bias: | |||
| b_shape = (out_features,) | |||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
| self.compute_mode = compute_mode | |||
| self.reset_parameters() | |||
| def _get_fanin(self): | |||
| @@ -75,7 +81,7 @@ class Linear(Module): | |||
| init.zeros_(self.bias) | |||
| def _calc_linear(self, x, weight, bias): | |||
| return linear(x, weight, bias) | |||
| return linear(x, weight, bias, compute_mode=self.compute_mode) | |||
| def forward(self, x): | |||
| return self._calc_linear(x, self.weight, self.bias) | |||
| @@ -5,8 +5,6 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import copy | |||
| import math | |||
| from functools import partial | |||
| from .. import functional as F | |||
| @@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode | |||
| from .core._imperative_rt.core2 import Tensor as _Tensor | |||
| from .core._imperative_rt.core2 import apply | |||
| from .core._trace_option import use_symbolic_shape | |||
| from .core._wrap import device as as_device | |||
| from .core._wrap import as_device | |||
| from .core.ops.builtin import Copy, GetVarShape | |||
| from .core.tensor.array_method import ArrayMethodMixin | |||
| from .device import _valid_device, get_default_device | |||
| @@ -0,0 +1,34 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from megengine import amp | |||
| from megengine.core.tensor import amp as origin_amp | |||
| def test_grad_scaler(): | |||
| def check(enabled, low, high): | |||
| assert amp.enabled == enabled | |||
| assert origin_amp._enabled == enabled | |||
| assert amp.low_prec_dtype == low | |||
| assert origin_amp._low_prec_dtype == low | |||
| assert amp.high_prec_dtype == high | |||
| assert origin_amp._high_prec_dtype == high | |||
| origin_enabled = amp.enabled | |||
| origin_high = amp.high_prec_dtype | |||
| origin_low = amp.low_prec_dtype | |||
| with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): | |||
| check(True, "low", "high") | |||
| check(origin_enabled, origin_low, origin_high) | |||
| amp.enabled = True | |||
| amp.high_prec_dtype = "high" | |||
| amp.low_prec_dtype = "low" | |||
| check(True, "low", "high") | |||
| amp.enabled = origin_enabled | |||
| amp.high_prec_dtype = origin_high | |||
| amp.low_prec_dtype = origin_low | |||
| check(origin_enabled, origin_low, origin_high) | |||
| @@ -14,6 +14,7 @@ import numpy as np | |||
| import pytest | |||
| from utils import opr_test | |||
| import megengine.amp as amp | |||
| import megengine.core.ops.builtin as builtin | |||
| import megengine.core.tensor.dtype as dtype | |||
| import megengine.functional as F | |||
| @@ -767,6 +768,27 @@ def test_batch_conv_bias(): | |||
| run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||
| def test_conv2d_io16c32(): | |||
| amp.enabled = True | |||
| inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||
| weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) | |||
| out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) | |||
| amp.enabled = False | |||
| expected = F.conv2d( | |||
| inp.astype("float16"), | |||
| weight.astype("float16"), | |||
| None, | |||
| (2, 2), | |||
| (3, 3), | |||
| (1, 1), | |||
| 1, | |||
| compute_mode="float32", | |||
| ) | |||
| assert out.dtype == np.float16 | |||
| assert expected.dtype == np.float16 | |||
| np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||
| def test_conv2d_zero_stride_numpy_array(): | |||
| inp = np.random.randn(3, 224, 224).astype(np.float32) | |||
| inp = inp[np.newaxis, :] | |||
| @@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array(): | |||
| def test_conv1d(): | |||
| inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4)) | |||
| weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2)) | |||
| inp = tensor(np.ones((2, 2, 4), dtype=np.float32)) | |||
| weight = tensor(np.ones((3, 2, 2), dtype=np.float32)) | |||
| out = F.conv1d(inp, weight, None, 2, 0, 1, 1) | |||
| np.testing.assert_equal( | |||
| out.numpy(), | |||
| @@ -798,9 +820,31 @@ def test_conv1d(): | |||
| ) | |||
| def test_batchnorm2d_io16c32(): | |||
| amp.enabled = True | |||
| inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||
| weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) | |||
| bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) | |||
| out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) | |||
| amp.enabled = False | |||
| expected = F.batch_norm( | |||
| inp.astype("float16"), | |||
| weight=weight, | |||
| bias=bias, | |||
| training=True, | |||
| inplace=False, | |||
| compute_mode="float32", | |||
| ) | |||
| assert out.dtype == np.float16 | |||
| assert expected.dtype == np.float16 | |||
| np.testing.assert_allclose(out.numpy(), expected.numpy()) | |||
| def test_conv3d(): | |||
| inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4)) | |||
| weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2)) | |||
| inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32)) | |||
| weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32)) | |||
| out = F.conv3d(inp, weight, None, 2, 0, 1, 1) | |||
| print(out.numpy().shape) | |||
| np.testing.assert_equal( | |||
| @@ -473,39 +473,6 @@ def test_pickle_module(): | |||
| np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | |||
| def test_load_quantized(): | |||
| from megengine.core.tensor import dtype | |||
| data_shape = (2, 28) | |||
| data = tensor(np.random.random(data_shape), dtype="float32") | |||
| data = data.astype(dtype.qint8(0.1)) | |||
| mlp = MLP() | |||
| quantize_qat(mlp) | |||
| quantize(mlp) | |||
| mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) | |||
| mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) | |||
| mlp.eval() | |||
| pred0 = mlp(data) | |||
| with BytesIO() as fout: | |||
| mge.save(mlp.state_dict(), fout) | |||
| fout.seek(0) | |||
| checkpoint = mge.load(fout) | |||
| # change mlp weight. | |||
| mlp.dense0.weight = Parameter( | |||
| mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() | |||
| ) | |||
| mlp.dense1.weight = Parameter( | |||
| mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() | |||
| ) | |||
| mlp.load_state_dict(checkpoint) | |||
| pred1 = mlp(data) | |||
| np.testing.assert_allclose( | |||
| pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6 | |||
| ) | |||
| def test_repr_basic(): | |||
| # test whether __repr__ can output correct information | |||
| class ConvModel(Module): | |||