GitOrigin-RevId: 1b41e1042c
tags/v1.10.0
| @@ -2,6 +2,7 @@ import mprop | |||
| from ..core.tensor.amp import * | |||
| from .autocast import autocast | |||
| from .convert_format import convert_module_format, convert_tensor_format | |||
| from .grad_scaler import GradScaler | |||
| mprop.init() | |||
| @@ -1,5 +1,6 @@ | |||
| import functools | |||
| from ..core import _config | |||
| from ..core.tensor import amp | |||
| @@ -50,24 +51,37 @@ class autocast: | |||
| self._origin_high = None | |||
| self._origin_low = None | |||
| self._origin_configs = None | |||
| def __enter__(self): | |||
| self._origin_enabled = amp._enabled | |||
| self._origin_high = amp._get_amp_high_prec_dtype() | |||
| self._origin_low = amp._get_amp_low_prec_dtype() | |||
| amp._enabled = self.enabled | |||
| amp._set_amp_dtype_autocast(self.enabled) | |||
| if not self.enabled: | |||
| return | |||
| self._origin_high = amp._get_amp_high_prec_dtype() | |||
| self._origin_low = amp._get_amp_low_prec_dtype() | |||
| amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
| amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
| self._origin_configs = _config._reset_execution_config(compute_mode="float32") | |||
| def __exit__(self, *args): | |||
| amp._enabled = self._origin_enabled | |||
| amp._set_amp_dtype_autocast(self._origin_enabled) | |||
| if not self.enabled: | |||
| return | |||
| amp._set_amp_high_prec_dtype(self._origin_high) | |||
| amp._set_amp_low_prec_dtype(self._origin_low) | |||
| _config._reset_execution_config(*self._origin_configs) | |||
| def __call__(self, func): | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| if not self.enabled: | |||
| return func(*args, **kwargs) | |||
| with self: | |||
| return func(*args, **kwargs) | |||
| @@ -0,0 +1,45 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from copy import deepcopy | |||
| from .. import functional as F | |||
| from ..module import Module | |||
| from ..tensor import Tensor | |||
| def _is_nchw_format(param: Tensor): | |||
| # TODO: use better condition | |||
| return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" | |||
| def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
| """Convert NCHW Tensor to NHWC Tensor.""" | |||
| if x.ndim == 4: | |||
| pattern = (0, 2, 3, 1) | |||
| elif x.ndim == 5: | |||
| pattern = (0, 1, 3, 4, 2) | |||
| else: | |||
| raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | |||
| # TODO: use initialization from tensor after fixing format setting | |||
| if inplace: | |||
| x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
| else: | |||
| x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
| return x | |||
| def convert_module_format(module: Module, inplace: bool = True): | |||
| """Convert NCHW Module to NHWC Module.""" | |||
| if not inplace: | |||
| module = deepcopy(module) | |||
| for name, param in module.named_tensors(): | |||
| if _is_nchw_format(param): | |||
| # hostvalue should still be valid, so no d2h cost. | |||
| convert_tensor_format(param, inplace=True) | |||
| return module | |||
| @@ -1,7 +1,13 @@ | |||
| import weakref | |||
| from typing import Callable, Iterable, List, Union | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | |||
| from ..core._imperative_rt.core2 import ( | |||
| get_auto_format_convert, | |||
| pop_scope, | |||
| push_scope, | |||
| set_auto_format_convert, | |||
| set_option, | |||
| ) | |||
| from ..core.autodiff.grad import Grad | |||
| from ..core.tensor.dtype import is_differentible_dtype | |||
| from ..logger import get_logger | |||
| @@ -253,6 +259,8 @@ class GradManager: | |||
| """ | |||
| push_scope("backward") | |||
| set_option("record_computing_path", 0) | |||
| _origin_auto_format = get_auto_format_convert() | |||
| set_auto_format_convert(False) | |||
| from ..functional import ones_like | |||
| global backwarding_grad_manager | |||
| @@ -296,6 +304,7 @@ class GradManager: | |||
| self.release() | |||
| backwarding_grad_manager = cache | |||
| set_option("record_computing_path", 1) | |||
| set_auto_format_convert(_origin_auto_format) | |||
| pop_scope("backward") | |||
| def record(self): | |||
| @@ -10,8 +10,10 @@ from ._imperative_rt.core2 import ( | |||
| set_option, | |||
| ) | |||
| # use "default" to distinguish it from None in _reset_execution_config | |||
| __compute_mode = "default" | |||
| __conv_format = "default" | |||
| __bn_format = "default" | |||
| _benchmark_kernel = False | |||
| _deterministic_kernel = False | |||
| @@ -22,6 +24,8 @@ __all__ = [ | |||
| "disable_memory_forwarding", | |||
| "_compute_mode", | |||
| "_conv_format", | |||
| "_bn_format", | |||
| "_auto_format_convert", | |||
| "_override", | |||
| ] | |||
| @@ -32,6 +36,7 @@ def benchmark_kernel(mod): | |||
| which means use heuristic to choose the fastest algorithm. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -55,6 +60,7 @@ def deterministic_kernel(mod): | |||
| which means the algorithm is not reproducible. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -75,6 +81,7 @@ def async_level(mod) -> int: | |||
| which means both device and user side errors are async. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -110,16 +117,17 @@ def disable_memory_forwarding(mod, disable: bool): | |||
| @property | |||
| def _compute_mode(mod): | |||
| r"""Get or set the precision of intermediate results. The default option is "default", | |||
| which means that no special requirements will be placed on. When set to 'float32', it | |||
| would be used for accumulator and intermediate result, but only effective when input and | |||
| r"""Get or set the precision of intermediate results for conv, matmul. The default | |||
| option is None and will fallback to "default". When set to "float32", it will | |||
| trigger mixed precision computation on TensorCore, but only effective when input and | |||
| output are of float16 dtype. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| mge.config._compute_mode = "default" | |||
| mge.config._compute_mode = "float32" | |||
| """ | |||
| return __compute_mode | |||
| @@ -132,7 +140,7 @@ def _compute_mode(mod, _compute_mode: str): | |||
| @property | |||
| def _conv_format(mod): | |||
| r"""Get or set convolution data/filter/output layout format. The default option is "default", | |||
| r"""Get or set convolution data/filter/output layout format. The default option is None, | |||
| which means that no special format will be placed on. There are all layout definitions | |||
| ``NCHW`` layout: ``{N, C, H, W}`` | |||
| @@ -145,6 +153,7 @@ def _conv_format(mod): | |||
| ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -159,12 +168,35 @@ def _conv_format(mod, format: str): | |||
| __conv_format = format | |||
| @property | |||
| def _bn_format(mod): | |||
| r"""Get or set batchnorm param layout format. The default option is None and will | |||
| fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c", | |||
| param format of batchnorm will be changed to NHWC. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| mge.config._bn_format = "dim_111c" | |||
| """ | |||
| return __bn_format | |||
| @_bn_format.setter | |||
| def _bn_format(mod, format: str): | |||
| global __bn_format | |||
| __bn_format = format | |||
| @property | |||
| def _auto_format_convert(mod): | |||
| r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. | |||
| The default value is False, which means no convert. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -184,15 +216,17 @@ def _reset_execution_config( | |||
| async_level=None, | |||
| compute_mode=None, | |||
| conv_format=None, | |||
| bn_format=None, | |||
| auto_format_convert=None, | |||
| ): | |||
| global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format | |||
| global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format | |||
| orig_flags = ( | |||
| _benchmark_kernel, | |||
| _deterministic_kernel, | |||
| get_option("async_level"), | |||
| __compute_mode, | |||
| __conv_format, | |||
| __bn_format, | |||
| get_auto_format_convert(), | |||
| ) | |||
| if benchmark_kernel is not None: | |||
| @@ -205,6 +239,8 @@ def _reset_execution_config( | |||
| __compute_mode = compute_mode | |||
| if conv_format is not None: | |||
| __conv_format = conv_format | |||
| if bn_format is not None: | |||
| __bn_format = bn_format | |||
| if auto_format_convert is not None: | |||
| set_auto_format_convert(auto_format_convert) | |||
| @@ -218,12 +254,14 @@ def _override( | |||
| async_level=None, | |||
| compute_mode=None, | |||
| conv_format=None, | |||
| bn_format=None, | |||
| auto_format_convert=None, | |||
| ): | |||
| r"""A context manager that users can opt in by attaching the decorator to set | |||
| the config of the global variable. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -234,6 +272,7 @@ def _override( | |||
| async_level=2, | |||
| compute_mode="float32", | |||
| conv_format="NHWC", | |||
| bn_format="dim_111c", | |||
| auto_format_convert=True, | |||
| ) | |||
| def train(): | |||
| @@ -244,6 +283,7 @@ def _override( | |||
| async_level, | |||
| compute_mode, | |||
| conv_format, | |||
| bn_format, | |||
| auto_format_convert, | |||
| ) | |||
| try: | |||
| @@ -254,4 +294,4 @@ def _override( | |||
| def _get_actual_op_param(function_param, config_param): | |||
| return function_param if config_param == "default" else config_param | |||
| return function_param if config_param is "default" else config_param | |||
| @@ -10,13 +10,19 @@ from .._imperative_rt.core2 import ( | |||
| _enabled = False | |||
| _set_amp_dtype_autocast(_enabled) | |||
| __all__ = [ | |||
| "enabled", | |||
| "high_prec_dtype", | |||
| "low_prec_dtype", | |||
| ] | |||
| @property | |||
| def enabled(mod): | |||
| r"""Get or set amp autocast mode enabled or not. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -36,9 +42,9 @@ def enabled(mod, enabled: bool): | |||
| def high_prec_dtype(mod): | |||
| r"""Get or set amp autocast mode's higher precision dtype. It will change the | |||
| target dtype in tensor casting for better precision. Default: float32. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -56,9 +62,9 @@ def high_prec_dtype(mod, dtype: str): | |||
| def low_prec_dtype(mod): | |||
| r"""Get or set amp autocast mode's lower precision dtype. It will change the | |||
| target dtype in tensor casting for better speed and memory. Default: float16. | |||
| Examples: | |||
| .. code-block:: | |||
| import megengine as mge | |||
| @@ -63,6 +63,7 @@ def _matmul( | |||
| assert dim1 > 0 and dim2 > 0 | |||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
| (result,) = apply(builtin.Dot(), inp1, inp2) | |||
| return result | |||
| @@ -441,7 +441,6 @@ def deformable_conv2d( | |||
| or conv_mode.name == "CROSS_CORRELATION" | |||
| ) | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) | |||
| else: | |||
| offset = offset.astype("float32") | |||
| @@ -1182,7 +1181,6 @@ def batch_norm( | |||
| momentum: float = 0.9, | |||
| eps: float = 1e-5, | |||
| inplace: bool = True, | |||
| compute_mode="default", | |||
| param_dim="dim_1c11" | |||
| ): | |||
| r"""Applies batch normalization to the input. | |||
| @@ -19,7 +19,6 @@ class _BatchNorm(Module): | |||
| affine=True, | |||
| track_running_stats=True, | |||
| freeze=False, | |||
| compute_mode="default", | |||
| param_dim="dim_1c11", | |||
| **kwargs | |||
| ): | |||
| @@ -31,7 +30,6 @@ class _BatchNorm(Module): | |||
| self.track_running_stats = track_running_stats | |||
| self._track_running_stats_saved = track_running_stats | |||
| self.freeze = freeze | |||
| self.compute_mode = compute_mode | |||
| self.param_dim = param_dim | |||
| if self.freeze: | |||
| assert ( | |||
| @@ -106,7 +104,6 @@ class _BatchNorm(Module): | |||
| or ((self.running_mean is None) and (self.running_var is None)), | |||
| momentum=exponential_average_factor, | |||
| eps=self.eps, | |||
| compute_mode=self.compute_mode, | |||
| param_dim=self.param_dim, | |||
| ) | |||
| @@ -8,7 +8,13 @@ from typing import Union | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | |||
| from ..core._imperative_rt.core2 import ( | |||
| get_auto_format_convert, | |||
| pop_scope, | |||
| push_scope, | |||
| set_auto_format_convert, | |||
| set_option, | |||
| ) | |||
| from ..core.tensor.utils import set_convert_inputs | |||
| from ..tensor import Parameter, Tensor | |||
| from ..utils.deprecation import deprecated | |||
| @@ -90,7 +96,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| "optimizer can only optimize Parameters, but one of the params is " | |||
| + str(type(param)) | |||
| ) | |||
| param._reset(Tensor(param.numpy(), no_cache=True)) | |||
| param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||
| for name, default in self._defaults.items(): | |||
| if default is required and name not in param_group: | |||
| @@ -139,6 +145,8 @@ class Optimizer(metaclass=ABCMeta): | |||
| # set the globle state `_enable_convert_inputs` to `False` to disable | |||
| # the `convert_inputs` for param updates | |||
| set_option("record_computing_path", 0) | |||
| _origin_auto_format = get_auto_format_convert() | |||
| set_auto_format_convert(False) | |||
| if self._disable_type_convert: | |||
| backup = set_convert_inputs(False) | |||
| for group in self.param_groups: | |||
| @@ -155,6 +163,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| # restore the globle state `_enable_convert_inputs` | |||
| set_convert_inputs(backup) | |||
| set_option("record_computing_path", 1) | |||
| set_auto_format_convert(_origin_auto_format) | |||
| return self | |||
| @deprecated(version="1.0", reason="use clear_grad instead") | |||
| @@ -0,0 +1,44 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import pytest | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine import Parameter, Tensor, amp, tensor | |||
| class MyModule(M.Module): | |||
| class InnerModule(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bn = M.BatchNorm2d(4) | |||
| def forward(self, x): | |||
| return self.bn(x) | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.i = self.InnerModule() | |||
| self.conv = M.Conv2d(4, 4, 4, groups=2) | |||
| self.bn = M.BatchNorm2d(4) | |||
| self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32)) | |||
| self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32)) | |||
| def forward(self, x): | |||
| x = self.i(x) | |||
| x = self.bn(x) | |||
| return x | |||
| @pytest.mark.parametrize("is_inplace", [False, True]) | |||
| def test_convert_module(is_inplace): | |||
| m = MyModule() | |||
| m = amp.convert_module_format(m, is_inplace) | |||
| for name, param in m.named_tensors(): | |||
| assert param.format == "nhwc" | |||
| @@ -8,14 +8,27 @@ from megengine.autodiff import GradManager | |||
| def test_basic(): | |||
| a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc") | |||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||
| # init from numpy | |||
| a = tensor(data, format="nhwc") | |||
| assert a.format == "nhwc" | |||
| # init from tensor | |||
| b = tensor(a) | |||
| assert b.format == "nhwc" | |||
| # TODO: fix Tensor init bug for another Tensor | |||
| # TODO: init from tensor with new format | |||
| # c = tensor(a, format="nchw") | |||
| # assert c.format == "nchw" | |||
| # TODO: reset from numpy | |||
| # b[...] = data | |||
| # assert b.format == "nhwc" | |||
| # reset from tensor | |||
| b[...] = tensor(data, format="nchw") | |||
| assert b.format == "nchw" | |||
| def _compare_nchw_nhwc(data, func): | |||
| x1 = tensor(data, format="nchw") | |||
| @@ -23,7 +36,7 @@ def _compare_nchw_nhwc(data, func): | |||
| out1 = func(x1) | |||
| with mge.config._override(auto_format_convert=True): | |||
| out2 = func(x2) | |||
| np.testing.assert_equal(out1, out2) | |||
| np.testing.assert_almost_equal(out1, out2, decimal=5) | |||
| def test_dimshuffle(): | |||
| @@ -296,8 +309,10 @@ def test_backward(): | |||
| with gm: | |||
| with mge.config._override(auto_format_convert=True, conv_format="NHWC"): | |||
| x = F.conv2d(x, w, b) | |||
| # TODO: fix manually convert to NHWC, usually used in detection head | |||
| # x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) | |||
| gm.backward(x) | |||
| # TODO: backward grad has no format yet | |||
| # backward grad has no format | |||
| np.testing.assert_equal( | |||
| w.grad.numpy(), | |||
| np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||
| @@ -921,12 +921,7 @@ def test_batchnorm2d_autocast(): | |||
| amp.enabled = False | |||
| expected = F.batch_norm( | |||
| inp.astype("float16"), | |||
| weight=weight, | |||
| bias=bias, | |||
| training=True, | |||
| inplace=False, | |||
| compute_mode="float32", | |||
| inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False, | |||
| ) | |||
| assert out.dtype == np.float16 | |||
| assert expected.dtype == np.float16 | |||