|
|
|
@@ -10,8 +10,8 @@ class GradManager: |
|
|
|
self._recording = False |
|
|
|
self._grad = None |
|
|
|
|
|
|
|
def register(self, params, callback=None): |
|
|
|
self._call_back_pair.append([params, callback]) |
|
|
|
def register(self, params, callbacks=None): |
|
|
|
self._call_back_pair.append([list(params), callbacks or []]) |
|
|
|
|
|
|
|
def backward(self, ys, dys=None): |
|
|
|
if not self._recording: |
|
|
|
@@ -24,7 +24,7 @@ class GradManager: |
|
|
|
if not isinstance(ys, (tuple, list)): |
|
|
|
ys = [ys] |
|
|
|
if dys is None: |
|
|
|
dys = [tensor(1).broadcast(y.shape) for y in ys] |
|
|
|
dys = [tensor(1.0) for y in ys] |
|
|
|
if not isinstance(dys, (tuple, list)): |
|
|
|
dys = [dys] |
|
|
|
try: |
|
|
|
@@ -42,7 +42,14 @@ class GradManager: |
|
|
|
self._recording = True |
|
|
|
self._grad = grad |
|
|
|
for params, callbacks in self._call_back_pair: |
|
|
|
grad.wrt(*params, callback=callbacks) |
|
|
|
|
|
|
|
def callback(param, grad, callbacks=callbacks): |
|
|
|
ret = grad |
|
|
|
for cb in callbacks: |
|
|
|
ret = cb(param, ret) |
|
|
|
param.grad = ret |
|
|
|
|
|
|
|
grad.wrt(*params, callback=callback) |
|
|
|
with grad: |
|
|
|
yield |
|
|
|
finally: |
|
|
|
|