From 5cef74a77e23d32b5047b4da6febee234bd2e418 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 29 Jun 2021 18:17:44 +0800 Subject: [PATCH] feat(mge/amp): add GradScaler support GitOrigin-RevId: 0ab4910360757d783f132be7041297c846cf513f --- imperative/python/megengine/amp/__init__.py | 2 +- .../python/megengine/amp/grad_scaler.py | 185 ++++++++++++++++++ .../python/megengine/autodiff/grad_manager.py | 22 ++- imperative/python/megengine/functional/nn.py | 2 +- .../python/test/unit/amp/test_grad_scaler.py | 30 +++ .../test/unit/autodiff/test_grad_manger.py | 45 +++++ 6 files changed, 282 insertions(+), 4 deletions(-) create mode 100644 imperative/python/megengine/amp/grad_scaler.py create mode 100644 imperative/python/test/unit/amp/test_grad_scaler.py diff --git a/imperative/python/megengine/amp/__init__.py b/imperative/python/megengine/amp/__init__.py index 29be0ddd..c78f9d86 100644 --- a/imperative/python/megengine/amp/__init__.py +++ b/imperative/python/megengine/amp/__init__.py @@ -5,10 +5,10 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - import mprop from ..core.tensor.amp import * from .autocast import autocast +from .grad_scaler import GradScaler mprop.init() diff --git a/imperative/python/megengine/amp/grad_scaler.py b/imperative/python/megengine/amp/grad_scaler.py new file mode 100644 index 00000000..2af68a3f --- /dev/null +++ b/imperative/python/megengine/amp/grad_scaler.py @@ -0,0 +1,185 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Iterable, List, Union + +import numpy as np + +from ..autodiff import GradManager +from ..functional import full_like +from ..functional.math import _has_inf +from ..tensor import Tensor + + +class GradScaler: + r""" + A helper class that performs grad scaling to prevent from data overflow in + :class:`~.autocast` mode. + + :param init_scale: Initial scale factor. + :param growth_factor: Factor that the scale is multiplied by in actual + :meth:`update` stage. If growth_factor is 0, scale_factor will not update. + :param backoff_factor: Factor that the scale is multiplied by when encountering + overflow grad. + :param growth_interval: The interval between two scale update stages. + + Example:: + + gm = GradManager() + opt = ... + scaler = GradScaler() + + gm.attach(model.parameters()) + + @autocast() + def train_step(image, label): + with gm: + logits = model(image) + loss = F.nn.cross_entropy(logits, label) + scaler.backward(gm, loss) + opt.step().clear_grad() + return loss + + If need more flexible usage, could split ``scaler.backward`` into three lines: + + .. code-block:: + + @autocast() + def train_step(image, label): + with gm: + logits = model(image) + loss = F.nn.cross_entropy(logits, label) + gm.backward(loss, dy=megengine.tensor(scaler.scale_factor)) + scaler.unscale(gm.attached_tensors()) + scaler.update() + opt.step().clear_grad() + return loss + + This is useful when need to accumulate grads for multi batches. + """ + + def __init__( + self, + init_scale: float = 2.0 ** 4, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + ): + self.scale_factor = float(init_scale) + self.growth_factor = float(growth_factor) + self.backoff_factor = float(backoff_factor) + self.growth_interval = growth_interval + + self._growth_tracker = 0 + self._found_inf = False + + def backward( + self, + gm: GradManager, + y: Union[Tensor, List[Tensor]] = None, + dy: Union[Tensor, List[Tensor]] = None, + *, + unscale_grad: bool = True, + update_scale: bool = "if_unscale_grad" + ): + r""" + A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale + ``y``'s grad and unscale parameters' grads. + + :param gm: The to be wrapped GradManager. + :param y: Same as GradManager backward's ``y``. + :param dy: Same as GradManager backward's ``dy``. Will be multiplied + by ``scale_factor``. + :param unscale_grad: Whether do :meth:`unscale` at the same time. Could be + ``False`` if needs to accumulate grads. + :param update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored + if ``unscale_grad`` is ``False``. + """ + # These checks should be consistent with GradManager's + if y is None: + ys = [] + elif isinstance(y, (tuple, list)): + ys = y + else: + ys = [y] + if dy is None: + dys = [full_like(y, self.scale_factor) for y in ys] + elif isinstance(dy, (tuple, list)): + dys = [dy_ * self.scale_factor for dy_ in dy] + else: + dys = [dy * self.scale_factor] + + gm.backward(y=ys, dy=dys) + + if unscale_grad: + self.unscale(gm.attached_tensors()) + if update_scale: + self.update() + + def unscale(self, grad_tensors: Iterable[Tensor]): + r""" + Unscale all ``grad_tensors``'s grad. + + :param grad_tensors: Tensors needed to unscale grads. Should be all tensors + that are affected by ``target`` tensor in GradManager's backward. + """ + # use float64 for better precision + inv_scale = Tensor(1.0 / self.scale_factor) + for tensor in grad_tensors: + if tensor is None or getattr(tensor, "grad", None) is None: + continue + # to support tracing, _check_gradients should be applied to every grad. + if self._check_gradients(tensor.grad): + self._found_inf = True + tensor.grad *= inv_scale + + if self._found_inf: + for tensor in grad_tensors: + if tensor is None or getattr(tensor, "grad", None) is None: + continue + tensor.grad = None + return self + + def _check_gradients(self, grad): + if self.growth_interval == 0: + return False + return _has_inf(grad) + + def update(self, new_scale: float = None): + r"""Update the scale factor according to whether encountered overflow grad. + If ``new_scale`` is provided, internal update mechanism will be ignored.""" + if self.growth_interval == 0: + return + + if new_scale is not None: + self.scale_factor = float(new_scale) + else: + if self._found_inf: + self.scale_factor *= self.backoff_factor + self._growth_tracker = 0 + else: + self._growth_tracker += 1 + if self._growth_tracker >= self.growth_interval: + self.scale_factor *= self.growth_factor + self._growth_tracker = 0 + self._found_inf = False + + def state_dict(self): + return { + "scale_factor": self.scale_factor, + "growth_factor": self.growth_factor, + "backoff_factor": self.backoff_factor, + "growth_interval": self.growth_interval, + "_growth_tracker": self._growth_tracker, + } + + def load_state_dict(self, state): + self.scale_factor = state["scale_factor"] + self.growth_factor = state["growth_factor"] + self.backoff_factor = state["backoff_factor"] + self.growth_interval = state["growth_interval"] + self._growth_tracker = state["_growth_tracker"] diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 19652fa3..923f1f40 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,5 +1,13 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import weakref -from typing import Callable, Iterable +from collections import OrderedDict +from typing import Callable, Iterable, List, Union from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core.autodiff.grad import Grad @@ -123,6 +131,10 @@ class GradManager: self._gradients = {} self._priority = None + def attached_tensors(self): + r"""Return attached tensor list from :meth:`attach`.""" + return [spec.tensor() for spec in self._attach_specs.values()] + def attach(self, tensors: Iterable[Tensor], callbacks=None): r""" Instruct GradManager to track operations on tensors, so that gradients with respect @@ -210,13 +222,18 @@ class GradManager: spec.callbacks.extend(callbacks) if new_attach and self._recording: self._do_record(spec) + return self def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self - def backward(self, y=None, dy=None): + def backward( + self, + y: Union[Tensor, List[Tensor]] = None, + dy: Union[Tensor, List[Tensor]] = None, + ): r""" Compute gradients (or vector-Jacobian product) for all attached tensors, accumulate to corresponding .grad attribute, and release resources along the way. @@ -257,6 +274,7 @@ class GradManager: "call a method that clears the history?" ) assert self._grad is not None + # These checks should be consistent with GradScaler's if y is None: ys = [] elif isinstance(y, (tuple, list)): diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 5dab1d87..c6c00aad 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1019,7 +1019,7 @@ def batch_norm( momentum: float = 0.9, eps: float = 1e-5, inplace: bool = True, - compute_mode="default", + compute_mode="default" ): r""" Applies batch normalization to the input. diff --git a/imperative/python/test/unit/amp/test_grad_scaler.py b/imperative/python/test/unit/amp/test_grad_scaler.py new file mode 100644 index 00000000..9303b516 --- /dev/null +++ b/imperative/python/test/unit/amp/test_grad_scaler.py @@ -0,0 +1,30 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine as mge +from megengine.amp import GradScaler +from megengine.autodiff import GradManager + + +def test_grad_scaler(): + gm = GradManager() + scaler = GradScaler() + + x = mge.tensor(1.0) + for _ in range(3): + with gm: + y = x + 1 + gm.attach(y) + loss = y + 1 + scaler.backward(gm, loss, unscale_grad=False) + np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor) + scaler.unscale(gm.attached_tensors()) + np.testing.assert_equal(y.grad.numpy(), 1) + # test handle None elements + scaler.unscale(gm.attached_tensors()) diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 112c37ba..8511567d 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -49,6 +49,32 @@ def test_basic(): np.testing.assert_equal(b.grad.numpy(), [1]) +def test_dy(): + x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3) + w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1) + b = mge.tensor(-1.0) + + gm = GradManager().attach([w, b]) + + def get_grad(grad, dy, idx): + if isinstance(dy, (list, tuple)): + return np.array(grad) * dy[idx] + else: + return np.array(grad) * dy + + # dy's shape should be the same as y's + dy = mge.tensor(2.5).reshape(1, 1) + w.grad = None + b.grad = None + with gm: + p = F.matmul(x, w) + y = p + b + gm.backward(y, dy=dy) + + np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]] * dy.numpy()) + np.testing.assert_equal(b.grad.numpy(), [1] * dy.numpy()) + + def test_attach_in_with_block(): a = mge.Parameter([1.0]) gm = GradManager() @@ -93,6 +119,25 @@ def test_attach_temporary(): # gm.backward(y) +def test_attached_tensors(): + w1 = mge.Parameter(2.0) + w2 = mge.Parameter(2.0) + gm = GradManager() + + def check(expected): + actual = gm.attached_tensors() + assert len(expected) == len(actual) + for exp, act in zip(expected, actual): + assert exp is act + + gm.attach(w1) + check([w1]) + gm.attach(w2) + check([w1, w2]) + gm.attach(w1) + check([w1, w2]) + + def test_no_dependency(): x = mge.tensor(3)