| @@ -5,10 +5,10 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import mprop | import mprop | ||||
| from ..core.tensor.amp import * | from ..core.tensor.amp import * | ||||
| from .autocast import autocast | from .autocast import autocast | ||||
| from .grad_scaler import GradScaler | |||||
| mprop.init() | mprop.init() | ||||
| @@ -0,0 +1,185 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Iterable, List, Union | |||||
| import numpy as np | |||||
| from ..autodiff import GradManager | |||||
| from ..functional import full_like | |||||
| from ..functional.math import _has_inf | |||||
| from ..tensor import Tensor | |||||
| class GradScaler: | |||||
| r""" | |||||
| A helper class that performs grad scaling to prevent from data overflow in | |||||
| :class:`~.autocast` mode. | |||||
| :param init_scale: Initial scale factor. | |||||
| :param growth_factor: Factor that the scale is multiplied by in actual | |||||
| :meth:`update` stage. If growth_factor is 0, scale_factor will not update. | |||||
| :param backoff_factor: Factor that the scale is multiplied by when encountering | |||||
| overflow grad. | |||||
| :param growth_interval: The interval between two scale update stages. | |||||
| Example:: | |||||
| gm = GradManager() | |||||
| opt = ... | |||||
| scaler = GradScaler() | |||||
| gm.attach(model.parameters()) | |||||
| @autocast() | |||||
| def train_step(image, label): | |||||
| with gm: | |||||
| logits = model(image) | |||||
| loss = F.nn.cross_entropy(logits, label) | |||||
| scaler.backward(gm, loss) | |||||
| opt.step().clear_grad() | |||||
| return loss | |||||
| If need more flexible usage, could split ``scaler.backward`` into three lines: | |||||
| .. code-block:: | |||||
| @autocast() | |||||
| def train_step(image, label): | |||||
| with gm: | |||||
| logits = model(image) | |||||
| loss = F.nn.cross_entropy(logits, label) | |||||
| gm.backward(loss, dy=megengine.tensor(scaler.scale_factor)) | |||||
| scaler.unscale(gm.attached_tensors()) | |||||
| scaler.update() | |||||
| opt.step().clear_grad() | |||||
| return loss | |||||
| This is useful when need to accumulate grads for multi batches. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| init_scale: float = 2.0 ** 4, | |||||
| growth_factor: float = 2.0, | |||||
| backoff_factor: float = 0.5, | |||||
| growth_interval: int = 2000, | |||||
| ): | |||||
| self.scale_factor = float(init_scale) | |||||
| self.growth_factor = float(growth_factor) | |||||
| self.backoff_factor = float(backoff_factor) | |||||
| self.growth_interval = growth_interval | |||||
| self._growth_tracker = 0 | |||||
| self._found_inf = False | |||||
| def backward( | |||||
| self, | |||||
| gm: GradManager, | |||||
| y: Union[Tensor, List[Tensor]] = None, | |||||
| dy: Union[Tensor, List[Tensor]] = None, | |||||
| *, | |||||
| unscale_grad: bool = True, | |||||
| update_scale: bool = "if_unscale_grad" | |||||
| ): | |||||
| r""" | |||||
| A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale | |||||
| ``y``'s grad and unscale parameters' grads. | |||||
| :param gm: The to be wrapped GradManager. | |||||
| :param y: Same as GradManager backward's ``y``. | |||||
| :param dy: Same as GradManager backward's ``dy``. Will be multiplied | |||||
| by ``scale_factor``. | |||||
| :param unscale_grad: Whether do :meth:`unscale` at the same time. Could be | |||||
| ``False`` if needs to accumulate grads. | |||||
| :param update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored | |||||
| if ``unscale_grad`` is ``False``. | |||||
| """ | |||||
| # These checks should be consistent with GradManager's | |||||
| if y is None: | |||||
| ys = [] | |||||
| elif isinstance(y, (tuple, list)): | |||||
| ys = y | |||||
| else: | |||||
| ys = [y] | |||||
| if dy is None: | |||||
| dys = [full_like(y, self.scale_factor) for y in ys] | |||||
| elif isinstance(dy, (tuple, list)): | |||||
| dys = [dy_ * self.scale_factor for dy_ in dy] | |||||
| else: | |||||
| dys = [dy * self.scale_factor] | |||||
| gm.backward(y=ys, dy=dys) | |||||
| if unscale_grad: | |||||
| self.unscale(gm.attached_tensors()) | |||||
| if update_scale: | |||||
| self.update() | |||||
| def unscale(self, grad_tensors: Iterable[Tensor]): | |||||
| r""" | |||||
| Unscale all ``grad_tensors``'s grad. | |||||
| :param grad_tensors: Tensors needed to unscale grads. Should be all tensors | |||||
| that are affected by ``target`` tensor in GradManager's backward. | |||||
| """ | |||||
| # use float64 for better precision | |||||
| inv_scale = Tensor(1.0 / self.scale_factor) | |||||
| for tensor in grad_tensors: | |||||
| if tensor is None or getattr(tensor, "grad", None) is None: | |||||
| continue | |||||
| # to support tracing, _check_gradients should be applied to every grad. | |||||
| if self._check_gradients(tensor.grad): | |||||
| self._found_inf = True | |||||
| tensor.grad *= inv_scale | |||||
| if self._found_inf: | |||||
| for tensor in grad_tensors: | |||||
| if tensor is None or getattr(tensor, "grad", None) is None: | |||||
| continue | |||||
| tensor.grad = None | |||||
| return self | |||||
| def _check_gradients(self, grad): | |||||
| if self.growth_interval == 0: | |||||
| return False | |||||
| return _has_inf(grad) | |||||
| def update(self, new_scale: float = None): | |||||
| r"""Update the scale factor according to whether encountered overflow grad. | |||||
| If ``new_scale`` is provided, internal update mechanism will be ignored.""" | |||||
| if self.growth_interval == 0: | |||||
| return | |||||
| if new_scale is not None: | |||||
| self.scale_factor = float(new_scale) | |||||
| else: | |||||
| if self._found_inf: | |||||
| self.scale_factor *= self.backoff_factor | |||||
| self._growth_tracker = 0 | |||||
| else: | |||||
| self._growth_tracker += 1 | |||||
| if self._growth_tracker >= self.growth_interval: | |||||
| self.scale_factor *= self.growth_factor | |||||
| self._growth_tracker = 0 | |||||
| self._found_inf = False | |||||
| def state_dict(self): | |||||
| return { | |||||
| "scale_factor": self.scale_factor, | |||||
| "growth_factor": self.growth_factor, | |||||
| "backoff_factor": self.backoff_factor, | |||||
| "growth_interval": self.growth_interval, | |||||
| "_growth_tracker": self._growth_tracker, | |||||
| } | |||||
| def load_state_dict(self, state): | |||||
| self.scale_factor = state["scale_factor"] | |||||
| self.growth_factor = state["growth_factor"] | |||||
| self.backoff_factor = state["backoff_factor"] | |||||
| self.growth_interval = state["growth_interval"] | |||||
| self._growth_tracker = state["_growth_tracker"] | |||||
| @@ -1,5 +1,13 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import weakref | import weakref | ||||
| from typing import Callable, Iterable | |||||
| from collections import OrderedDict | |||||
| from typing import Callable, Iterable, List, Union | |||||
| from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | ||||
| from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
| @@ -123,6 +131,10 @@ class GradManager: | |||||
| self._gradients = {} | self._gradients = {} | ||||
| self._priority = None | self._priority = None | ||||
| def attached_tensors(self): | |||||
| r"""Return attached tensor list from :meth:`attach`.""" | |||||
| return [spec.tensor() for spec in self._attach_specs.values()] | |||||
| def attach(self, tensors: Iterable[Tensor], callbacks=None): | def attach(self, tensors: Iterable[Tensor], callbacks=None): | ||||
| r""" | r""" | ||||
| Instruct GradManager to track operations on tensors, so that gradients with respect | Instruct GradManager to track operations on tensors, so that gradients with respect | ||||
| @@ -210,13 +222,18 @@ class GradManager: | |||||
| spec.callbacks.extend(callbacks) | spec.callbacks.extend(callbacks) | ||||
| if new_attach and self._recording: | if new_attach and self._recording: | ||||
| self._do_record(spec) | self._do_record(spec) | ||||
| return self | return self | ||||
| def _register_after_backward_callback(self, callback): | def _register_after_backward_callback(self, callback): | ||||
| self._after_backward_callback.append(callback) | self._after_backward_callback.append(callback) | ||||
| return self | return self | ||||
| def backward(self, y=None, dy=None): | |||||
| def backward( | |||||
| self, | |||||
| y: Union[Tensor, List[Tensor]] = None, | |||||
| dy: Union[Tensor, List[Tensor]] = None, | |||||
| ): | |||||
| r""" | r""" | ||||
| Compute gradients (or vector-Jacobian product) for all attached tensors, accumulate to | Compute gradients (or vector-Jacobian product) for all attached tensors, accumulate to | ||||
| corresponding .grad attribute, and release resources along the way. | corresponding .grad attribute, and release resources along the way. | ||||
| @@ -257,6 +274,7 @@ class GradManager: | |||||
| "call a method that clears the history?" | "call a method that clears the history?" | ||||
| ) | ) | ||||
| assert self._grad is not None | assert self._grad is not None | ||||
| # These checks should be consistent with GradScaler's | |||||
| if y is None: | if y is None: | ||||
| ys = [] | ys = [] | ||||
| elif isinstance(y, (tuple, list)): | elif isinstance(y, (tuple, list)): | ||||
| @@ -1019,7 +1019,7 @@ def batch_norm( | |||||
| momentum: float = 0.9, | momentum: float = 0.9, | ||||
| eps: float = 1e-5, | eps: float = 1e-5, | ||||
| inplace: bool = True, | inplace: bool = True, | ||||
| compute_mode="default", | |||||
| compute_mode="default" | |||||
| ): | ): | ||||
| r""" | r""" | ||||
| Applies batch normalization to the input. | Applies batch normalization to the input. | ||||
| @@ -0,0 +1,30 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| from megengine.amp import GradScaler | |||||
| from megengine.autodiff import GradManager | |||||
| def test_grad_scaler(): | |||||
| gm = GradManager() | |||||
| scaler = GradScaler() | |||||
| x = mge.tensor(1.0) | |||||
| for _ in range(3): | |||||
| with gm: | |||||
| y = x + 1 | |||||
| gm.attach(y) | |||||
| loss = y + 1 | |||||
| scaler.backward(gm, loss, unscale_grad=False) | |||||
| np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor) | |||||
| scaler.unscale(gm.attached_tensors()) | |||||
| np.testing.assert_equal(y.grad.numpy(), 1) | |||||
| # test handle None elements | |||||
| scaler.unscale(gm.attached_tensors()) | |||||
| @@ -49,6 +49,32 @@ def test_basic(): | |||||
| np.testing.assert_equal(b.grad.numpy(), [1]) | np.testing.assert_equal(b.grad.numpy(), [1]) | ||||
| def test_dy(): | |||||
| x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3) | |||||
| w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1) | |||||
| b = mge.tensor(-1.0) | |||||
| gm = GradManager().attach([w, b]) | |||||
| def get_grad(grad, dy, idx): | |||||
| if isinstance(dy, (list, tuple)): | |||||
| return np.array(grad) * dy[idx] | |||||
| else: | |||||
| return np.array(grad) * dy | |||||
| # dy's shape should be the same as y's | |||||
| dy = mge.tensor(2.5).reshape(1, 1) | |||||
| w.grad = None | |||||
| b.grad = None | |||||
| with gm: | |||||
| p = F.matmul(x, w) | |||||
| y = p + b | |||||
| gm.backward(y, dy=dy) | |||||
| np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]] * dy.numpy()) | |||||
| np.testing.assert_equal(b.grad.numpy(), [1] * dy.numpy()) | |||||
| def test_attach_in_with_block(): | def test_attach_in_with_block(): | ||||
| a = mge.Parameter([1.0]) | a = mge.Parameter([1.0]) | ||||
| gm = GradManager() | gm = GradManager() | ||||
| @@ -93,6 +119,25 @@ def test_attach_temporary(): | |||||
| # gm.backward(y) | # gm.backward(y) | ||||
| def test_attached_tensors(): | |||||
| w1 = mge.Parameter(2.0) | |||||
| w2 = mge.Parameter(2.0) | |||||
| gm = GradManager() | |||||
| def check(expected): | |||||
| actual = gm.attached_tensors() | |||||
| assert len(expected) == len(actual) | |||||
| for exp, act in zip(expected, actual): | |||||
| assert exp is act | |||||
| gm.attach(w1) | |||||
| check([w1]) | |||||
| gm.attach(w2) | |||||
| check([w1, w2]) | |||||
| gm.attach(w1) | |||||
| check([w1, w2]) | |||||
| def test_no_dependency(): | def test_no_dependency(): | ||||
| x = mge.tensor(3) | x = mge.tensor(3) | ||||