GitOrigin-RevId: a890c206a5
tags/v1.1.0
| @@ -127,7 +127,7 @@ class GradManager: | |||
| self._after_backward_callback.append(callback) | |||
| return self | |||
| def backward(self, ys, dys=None): | |||
| def backward(self, ys=None, dys=None): | |||
| r""" | |||
| Performs back-propagation and computes gradients. | |||
| @@ -146,6 +146,8 @@ class GradManager: | |||
| "call a method that clears the history?" | |||
| ) | |||
| assert self._grad is not None | |||
| if ys is None: | |||
| ys = [] | |||
| if not isinstance(ys, (tuple, list)): | |||
| ys = [ys] | |||
| if dys is None: | |||
| @@ -14,6 +14,8 @@ import weakref | |||
| import numpy as np | |||
| import megengine as mge | |||
| from ..ops.builtin import Elemwise, OpDef | |||
| from ..ops.special import Const | |||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
| @@ -167,6 +169,8 @@ class Grad: | |||
| for i in dys: | |||
| if isinstance(i, TensorWrapperBase): | |||
| return type(i) | |||
| # use Tensor as defualt wrapper | |||
| return mge.Tensor | |||
| Wrapper = check_wrapper() | |||
| @@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||
| def backward(*args): | |||
| return [ | |||
| remote_recv( | |||
| op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||
| op.rank_to, | |||
| inputs[0].shape, | |||
| inputs[0].dtype, | |||
| device=str(inputs[0].device), | |||
| inp=inputs[0], | |||
| ) | |||
| ] | |||
| @@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
| def remote_recv( | |||
| src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||
| src_rank: int, | |||
| shape: Tuple[int], | |||
| dtype: type, | |||
| device: Optional[str] = None, | |||
| inp=None, | |||
| ) -> Tensor: | |||
| """ | |||
| Receive a Tensor from a remote process. | |||
| @@ -284,13 +292,15 @@ def remote_recv( | |||
| :param shape: the shape of the tensor to receive. | |||
| :param dtype: the data type of the tensor to receive. | |||
| :param device: the device to place the received tensor. | |||
| :param inp: dummy input to determine recved tensor type | |||
| """ | |||
| key = "{}->{}".format(src_rank, get_rank()) | |||
| if device is None: | |||
| device = get_default_device() | |||
| # dummpy input | |||
| inp = tensor([0]) | |||
| # dummy input | |||
| if inp == None: | |||
| inp = tensor([0]) | |||
| tracer_set = get_client().check_remote_tracer(key) | |||
| for grad_manager in get_grad_managers(): | |||
| if grad_manager.name in tracer_set: | |||
| @@ -5,12 +5,19 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import platform | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.optimizer as optim | |||
| from megengine.autodiff import GradManager | |||
| from megengine.core._imperative_rt.imperative import sync | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| def test_basic(): | |||
| @@ -48,3 +55,47 @@ def test_attach_in_with_block(): | |||
| c = b + 1 | |||
| gm.backward(c) | |||
| assert int(b.grad.numpy()) == 1 | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||
| ) | |||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
| @pytest.mark.isolated_distributed | |||
| def test_remote_grad(): | |||
| @dist.launcher | |||
| def worker(): | |||
| rank = dist.get_rank() | |||
| size = dist.get_world_size() | |||
| x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32) | |||
| m = M.Linear(rank * 2 + 2, rank * 2 + 4) | |||
| gm = GradManager().attach(m.parameters()) | |||
| opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) | |||
| def train_func(x): | |||
| if rank != 0: | |||
| x = dist.functional.remote_recv( | |||
| rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 | |||
| ) | |||
| print(rank, "x", x) | |||
| y = m(x) | |||
| print(rank, "y", y) | |||
| if rank != size - 1: | |||
| y = dist.functional.remote_send(y, dest_rank=rank + 1) | |||
| return y | |||
| with gm: | |||
| y = train_func(x) | |||
| if rank == size - 1: | |||
| y = y.mean() | |||
| gm.backward(y) | |||
| else: | |||
| gm.backward() | |||
| opt.step().clear_grad() | |||
| # sync because send is the last job | |||
| sync() | |||
| worker() | |||