| @@ -0,0 +1,9 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .grad_manager import GradManager | |||||
| @@ -10,8 +10,8 @@ class GradManager: | |||||
| self._recording = False | self._recording = False | ||||
| self._grad = None | self._grad = None | ||||
| def register(self, params, callback=None): | |||||
| self._call_back_pair.append([params, callback]) | |||||
| def register(self, params, callbacks=None): | |||||
| self._call_back_pair.append([list(params), callbacks or []]) | |||||
| def backward(self, ys, dys=None): | def backward(self, ys, dys=None): | ||||
| if not self._recording: | if not self._recording: | ||||
| @@ -24,7 +24,7 @@ class GradManager: | |||||
| if not isinstance(ys, (tuple, list)): | if not isinstance(ys, (tuple, list)): | ||||
| ys = [ys] | ys = [ys] | ||||
| if dys is None: | if dys is None: | ||||
| dys = [tensor(1).broadcast(y.shape) for y in ys] | |||||
| dys = [tensor(1.0) for y in ys] | |||||
| if not isinstance(dys, (tuple, list)): | if not isinstance(dys, (tuple, list)): | ||||
| dys = [dys] | dys = [dys] | ||||
| try: | try: | ||||
| @@ -42,7 +42,14 @@ class GradManager: | |||||
| self._recording = True | self._recording = True | ||||
| self._grad = grad | self._grad = grad | ||||
| for params, callbacks in self._call_back_pair: | for params, callbacks in self._call_back_pair: | ||||
| grad.wrt(*params, callback=callbacks) | |||||
| def callback(param, grad, callbacks=callbacks): | |||||
| ret = grad | |||||
| for cb in callbacks: | |||||
| ret = cb(param, ret) | |||||
| param.grad = ret | |||||
| grad.wrt(*params, callback=callback) | |||||
| with grad: | with grad: | ||||
| yield | yield | ||||
| finally: | finally: | ||||
| @@ -260,13 +260,9 @@ class Grad: | |||||
| cache[v] = g | cache[v] = g | ||||
| if last_written_to[v] == (seqno, i): | if last_written_to[v] == (seqno, i): | ||||
| if v.callback: | if v.callback: | ||||
| grad = v.callback( | |||||
| v.callback( | |||||
| v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | ||||
| ) | ) | ||||
| if getattr(v.owner(), "grad", None) is None: | |||||
| v.owner().grad = grad | |||||
| else: | |||||
| v.owner().grad += grad | |||||
| if v.opnode is None: | if v.opnode is None: | ||||
| # won't read by backward, mark consumed | # won't read by backward, mark consumed | ||||
| cache[v] = None | cache[v] = None | ||||
| @@ -9,8 +9,8 @@ | |||||
| from bisect import bisect_right | from bisect import bisect_right | ||||
| from typing import Iterable as Iter | from typing import Iterable as Iter | ||||
| from .optimizer import Optimizer | |||||
| from .lr_scheduler import LRScheduler | from .lr_scheduler import LRScheduler | ||||
| from .optimizer import Optimizer | |||||
| class MultiStepLR(LRScheduler): | class MultiStepLR(LRScheduler): | ||||
| @@ -53,10 +53,6 @@ class SGD(Optimizer): | |||||
| for param in param_group["params"]: | for param in param_group["params"]: | ||||
| if param.__wrapped__ in self._grad_skip: | |||||
| self._grad_skip.remove(param.__wrapped__) | |||||
| continue | |||||
| if not isinstance(param.grad, Buffer): | if not isinstance(param.grad, Buffer): | ||||
| raise TypeError( | raise TypeError( | ||||
| "grad must be a Buffer, maybe you forget to call backward()?" | "grad must be a Buffer, maybe you forget to call backward()?" | ||||
| @@ -76,5 +72,3 @@ class SGD(Optimizer): | |||||
| self._state[param]["momentum_buffer"]._reset(v) | self._state[param]["momentum_buffer"]._reset(v) | ||||
| else: | else: | ||||
| param -= lr * grad | param -= lr * grad | ||||
| assert len(self._grad_skip) == 0 | |||||