GitOrigin-RevId: 673e11c5b6
tags/v1.0.0-rc1
| @@ -2,8 +2,8 @@ from collections import defaultdict | |||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||
| from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
| from ..distributed.util import Future | |||||
| from ..tensor import tensor | from ..tensor import tensor | ||||
| from ..utils.future import Future | |||||
| backwarding_grad_manager = None | backwarding_grad_manager = None | ||||
| @@ -26,6 +26,7 @@ class GradManager: | |||||
| self._param_dict[id(p)] = p | self._param_dict[id(p)] = p | ||||
| for cb in callbacks: | for cb in callbacks: | ||||
| self._call_back_dict[id(p)].append(cb) | self._call_back_dict[id(p)].append(cb) | ||||
| return self | |||||
| def register_after_backward_callback(self, callback): | def register_after_backward_callback(self, callback): | ||||
| self._after_backward_callback.append(callback) | self._after_backward_callback.append(callback) | ||||
| @@ -45,7 +46,7 @@ class GradManager: | |||||
| if not isinstance(ys, (tuple, list)): | if not isinstance(ys, (tuple, list)): | ||||
| ys = [ys] | ys = [ys] | ||||
| if dys is None: | if dys is None: | ||||
| dys = [tensor(1.0) for y in ys] | |||||
| dys = [tensor(1.0).broadcast(y.shape) for y in ys] | |||||
| if not isinstance(dys, (tuple, list)): | if not isinstance(dys, (tuple, list)): | ||||
| dys = [dys] | dys = [dys] | ||||
| try: | try: | ||||
| @@ -178,8 +178,6 @@ class Grad: | |||||
| assert len(ys) == len(dys) | assert len(ys) == len(dys) | ||||
| ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | ||||
| if len(ids) == 0: | |||||
| return | |||||
| ys = [y for i, y in enumerate(ys) if i in ids] | ys = [y for i, y in enumerate(ys) if i in ids] | ||||
| dys = [dy for i, dy in enumerate(dys) if i in ids] | dys = [dy for i, dy in enumerate(dys) if i in ids] | ||||
| @@ -18,9 +18,9 @@ from megengine.device import get_default_device, get_device_count | |||||
| from ..functional.param_pack import get_offsets, pack_allreduce_split | from ..functional.param_pack import get_offsets, pack_allreduce_split | ||||
| from ..functional.utils import copy | from ..functional.utils import copy | ||||
| from ..utils.future import Future | |||||
| from .functional import all_reduce_sum, broadcast | from .functional import all_reduce_sum, broadcast | ||||
| from .group import WORLD, group_barrier, is_distributed | from .group import WORLD, group_barrier, is_distributed | ||||
| from .util import Future | |||||
| class FakeTensor(Future): | class FakeTensor(Future): | ||||
| @@ -77,7 +77,7 @@ class AllreduceCallback: | |||||
| assert reduce_method in ["sum", "mean"] | assert reduce_method in ["sum", "mean"] | ||||
| self._reduce_method = reduce_method | self._reduce_method = reduce_method | ||||
| self._group = group | self._group = group | ||||
| self._gm_set = set() | |||||
| self._marked_gm = set() | |||||
| self._param_pack_thd = 10 * 1024 * 1024 | self._param_pack_thd = 10 * 1024 * 1024 | ||||
| self._reset() | self._reset() | ||||
| @@ -87,6 +87,7 @@ class AllreduceCallback: | |||||
| self._futures_dict = dict() | self._futures_dict = dict() | ||||
| self._packing_list = defaultdict(list) | self._packing_list = defaultdict(list) | ||||
| self._packing_size = defaultdict(int) | self._packing_size = defaultdict(int) | ||||
| self._grad_origin_device = dict() | |||||
| def _pack(self, dtype): | def _pack(self, dtype): | ||||
| grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | ||||
| @@ -102,27 +103,28 @@ class AllreduceCallback: | |||||
| def __call__(self, param, grad): | def __call__(self, param, grad): | ||||
| gm = get_backwarding_grad_manager() | gm = get_backwarding_grad_manager() | ||||
| assert isinstance(gm, GradManager) | assert isinstance(gm, GradManager) | ||||
| if gm not in self._gm_set: | |||||
| if gm not in self._marked_gm: | |||||
| gm.register_after_backward_callback(self._flush) | gm.register_after_backward_callback(self._flush) | ||||
| self._gm_set.add(gm) | |||||
| self._marked_gm.add(gm) | |||||
| self._params.append(param) | self._params.append(param) | ||||
| self._futures_dict[param] = FakeTensor(ack=False) | self._futures_dict[param] = FakeTensor(ack=False) | ||||
| self._gradients_dict[param] = grad | self._gradients_dict[param] = grad | ||||
| self._packing_list[param.dtype].append(param) | |||||
| self._packing_size[param.dtype] += ( | |||||
| int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize | |||||
| ) | |||||
| if self._packing_size[param.dtype] > self._param_pack_thd: | |||||
| self._pack(param.dtype) | |||||
| self._grad_origin_device[param] = str(grad.device) | |||||
| dtype_str = str(np.dtype(param.dtype)) | |||||
| dtype_size = np.dtype(param.dtype).itemsize | |||||
| self._packing_list[dtype_str].append(param) | |||||
| self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size | |||||
| if self._packing_size[dtype_str] > self._param_pack_thd: | |||||
| self._pack(dtype_str) | |||||
| return self._futures_dict[param] | return self._futures_dict[param] | ||||
| def _flush(self): | def _flush(self): | ||||
| for dtype in self._packing_list.keys(): | |||||
| for dtype in sorted(self._packing_list.keys()): | |||||
| self._pack(dtype) | self._pack(dtype) | ||||
| for param in self._params: | for param in self._params: | ||||
| grad = self._gradients_dict[param] | grad = self._gradients_dict[param] | ||||
| grad = copy(grad, get_default_device()) | |||||
| grad = copy(grad, self._grad_origin_device[param]) | |||||
| self._futures_dict[param].set(grad) | self._futures_dict[param].set(grad) | ||||
| self._reset() | self._reset() | ||||
| @@ -16,7 +16,8 @@ from xmlrpc.client import ServerProxy | |||||
| from xmlrpc.server import SimpleXMLRPCServer | from xmlrpc.server import SimpleXMLRPCServer | ||||
| from ..core._imperative_rt.utils import create_mm_server | from ..core._imperative_rt.utils import create_mm_server | ||||
| from .util import Future, get_free_ports | |||||
| from ..utils.future import Future | |||||
| from .util import get_free_ports | |||||
| class Methods: | class Methods: | ||||
| @@ -8,28 +8,9 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import functools | import functools | ||||
| import socket | import socket | ||||
| import threading | |||||
| from typing import List | from typing import List | ||||
| class Future: | |||||
| def __init__(self, ack=True): | |||||
| self.ready = threading.Event() | |||||
| self.ack = threading.Event() if ack else None | |||||
| def set(self, value): | |||||
| self.value = value | |||||
| self.ready.set() | |||||
| if self.ack: | |||||
| self.ack.wait() | |||||
| def get(self): | |||||
| self.ready.wait() | |||||
| if self.ack: | |||||
| self.ack.set() | |||||
| return self.value | |||||
| def get_free_ports(num: int) -> List[int]: | def get_free_ports(num: int) -> List[int]: | ||||
| """Get one or more free ports. | """Get one or more free ports. | ||||
| """ | """ | ||||
| @@ -8,8 +8,8 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import numpy as np | import numpy as np | ||||
| from ..functional.distributed import all_reduce_sum | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .distributed import all_reduce_sum | |||||
| from .tensor import param_pack_concat, param_pack_split | from .tensor import param_pack_concat, param_pack_split | ||||
| @@ -29,6 +29,6 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||||
| packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | ||||
| packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | ||||
| if reduce_method == "mean": | if reduce_method == "mean": | ||||
| packed_grads /= dist_group.size | |||||
| packed_grads /= group.size | |||||
| grads = param_pack_split(packed_grads, offsets_val, shapes) | grads = param_pack_split(packed_grads, offsets_val, shapes) | ||||
| return grads | return grads | ||||
| @@ -0,0 +1,26 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import threading | |||||
| class Future: | |||||
| def __init__(self, ack=True): | |||||
| self.ready = threading.Event() | |||||
| self.ack = threading.Event() if ack else None | |||||
| def set(self, value): | |||||
| self.value = value | |||||
| self.ready.set() | |||||
| if self.ack: | |||||
| self.ack.wait() | |||||
| def get(self): | |||||
| self.ready.wait() | |||||
| if self.ack: | |||||
| self.ack.set() | |||||
| return self.value | |||||