GitOrigin-RevId: 086e2871e8
tags/v1.0.0-rc1
| @@ -0,0 +1,52 @@ | |||||
| from contextlib import contextmanager | |||||
| from ..core.autodiff.grad import Grad | |||||
| from ..tensor import tensor | |||||
| class GradManager: | |||||
| def __init__(self): | |||||
| self._call_back_pair = [] | |||||
| self._recording = False | |||||
| self._grad = None | |||||
| def register(self, params, callback=None): | |||||
| self._call_back_pair.append([params, callback]) | |||||
| def backward(self, ys, dys=None): | |||||
| if not self._recording: | |||||
| raise RuntimeError( | |||||
| "no computation history. " | |||||
| "did you forget record() or " | |||||
| "call a method that clears the history?" | |||||
| ) | |||||
| assert self._grad is not None | |||||
| if not isinstance(ys, (tuple, list)): | |||||
| ys = [ys] | |||||
| if dys is None: | |||||
| dys = [tensor(1).broadcast(y.shape) for y in ys] | |||||
| if not isinstance(dys, (tuple, list)): | |||||
| dys = [dys] | |||||
| try: | |||||
| self._grad(ys, dys) | |||||
| finally: | |||||
| self._grad = None | |||||
| def record(self): | |||||
| @contextmanager | |||||
| def recorder(): | |||||
| grad = Grad() | |||||
| if self._recording: | |||||
| raise RuntimeError("already recording!") | |||||
| try: | |||||
| self._recording = True | |||||
| self._grad = grad | |||||
| for params, callbacks in self._call_back_pair: | |||||
| grad.wrt(*params, callback=callbacks) | |||||
| with grad: | |||||
| yield | |||||
| finally: | |||||
| self._recording = False | |||||
| self._grad = None | |||||
| return recorder() | |||||
| @@ -260,9 +260,13 @@ class Grad: | |||||
| cache[v] = g | cache[v] = g | ||||
| if last_written_to[v] == (seqno, i): | if last_written_to[v] == (seqno, i): | ||||
| if v.callback: | if v.callback: | ||||
| v.callback( | |||||
| grad = v.callback( | |||||
| v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | ||||
| ) | ) | ||||
| if getattr(v.owner(), "grad", None) is None: | |||||
| v.owner().grad = grad | |||||
| else: | |||||
| v.owner().grad += grad | |||||
| if v.opnode is None: | if v.opnode is None: | ||||
| # won't read by backward, mark consumed | # won't read by backward, mark consumed | ||||
| cache[v] = None | cache[v] = None | ||||
| @@ -19,7 +19,7 @@ from .group import ( | |||||
| is_distributed, | is_distributed, | ||||
| new_group, | new_group, | ||||
| ) | ) | ||||
| from .helper import synchronized | |||||
| from .helper import bcast_params_, make_allreduce_cb, synchronized | |||||
| from .launcher import launcher | from .launcher import launcher | ||||
| from .server import Client, Server | from .server import Client, Server | ||||
| from .util import get_free_ports | from .util import get_free_ports | ||||
| @@ -12,7 +12,8 @@ from typing import Callable | |||||
| from megengine.device import get_device_count | from megengine.device import get_device_count | ||||
| from .group import group_barrier, is_distributed | |||||
| from .functional import all_reduce_sum, broadcast | |||||
| from .group import WORLD, group_barrier, is_distributed | |||||
| def synchronized(func: Callable): | def synchronized(func: Callable): | ||||
| @@ -42,3 +43,23 @@ def get_device_count_by_fork(device_type: str): | |||||
| p.start() | p.start() | ||||
| p.join() | p.join() | ||||
| return q.get() | return q.get() | ||||
| def bcast_params_(params, group): | |||||
| for p in params: | |||||
| p._reset(broadcast(p, group)) | |||||
| class AllreduceCallback: | |||||
| def __init__(self, reduce_method, group=WORLD): | |||||
| self._reduce_method = reduce_method | |||||
| self._group = group | |||||
| def __call__(self, param, grad): | |||||
| ret = all_reduce_sum(grad, self._group) | |||||
| if self._reduce_method == "MEAN": | |||||
| ret = ret / self._group.size | |||||
| return ret | |||||
| make_allreduce_cb = AllreduceCallback | |||||