GitOrigin-RevId: a890c206a5
tags/v1.1.0
| @@ -127,7 +127,7 @@ class GradManager: | |||||
| self._after_backward_callback.append(callback) | self._after_backward_callback.append(callback) | ||||
| return self | return self | ||||
| def backward(self, ys, dys=None): | |||||
| def backward(self, ys=None, dys=None): | |||||
| r""" | r""" | ||||
| Performs back-propagation and computes gradients. | Performs back-propagation and computes gradients. | ||||
| @@ -146,6 +146,8 @@ class GradManager: | |||||
| "call a method that clears the history?" | "call a method that clears the history?" | ||||
| ) | ) | ||||
| assert self._grad is not None | assert self._grad is not None | ||||
| if ys is None: | |||||
| ys = [] | |||||
| if not isinstance(ys, (tuple, list)): | if not isinstance(ys, (tuple, list)): | ||||
| ys = [ys] | ys = [ys] | ||||
| if dys is None: | if dys is None: | ||||
| @@ -14,6 +14,8 @@ import weakref | |||||
| import numpy as np | import numpy as np | ||||
| import megengine as mge | |||||
| from ..ops.builtin import Elemwise, OpDef | from ..ops.builtin import Elemwise, OpDef | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | from ..tensor.core import TensorBase, TensorWrapperBase, apply | ||||
| @@ -167,6 +169,8 @@ class Grad: | |||||
| for i in dys: | for i in dys: | ||||
| if isinstance(i, TensorWrapperBase): | if isinstance(i, TensorWrapperBase): | ||||
| return type(i) | return type(i) | ||||
| # use Tensor as defualt wrapper | |||||
| return mge.Tensor | |||||
| Wrapper = check_wrapper() | Wrapper = check_wrapper() | ||||
| @@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||||
| def backward(*args): | def backward(*args): | ||||
| return [ | return [ | ||||
| remote_recv( | remote_recv( | ||||
| op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||||
| op.rank_to, | |||||
| inputs[0].shape, | |||||
| inputs[0].dtype, | |||||
| device=str(inputs[0].device), | |||||
| inp=inputs[0], | |||||
| ) | ) | ||||
| ] | ] | ||||
| @@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
| def remote_recv( | def remote_recv( | ||||
| src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||||
| src_rank: int, | |||||
| shape: Tuple[int], | |||||
| dtype: type, | |||||
| device: Optional[str] = None, | |||||
| inp=None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Receive a Tensor from a remote process. | Receive a Tensor from a remote process. | ||||
| @@ -284,13 +292,15 @@ def remote_recv( | |||||
| :param shape: the shape of the tensor to receive. | :param shape: the shape of the tensor to receive. | ||||
| :param dtype: the data type of the tensor to receive. | :param dtype: the data type of the tensor to receive. | ||||
| :param device: the device to place the received tensor. | :param device: the device to place the received tensor. | ||||
| :param inp: dummy input to determine recved tensor type | |||||
| """ | """ | ||||
| key = "{}->{}".format(src_rank, get_rank()) | key = "{}->{}".format(src_rank, get_rank()) | ||||
| if device is None: | if device is None: | ||||
| device = get_default_device() | device = get_default_device() | ||||
| # dummpy input | |||||
| inp = tensor([0]) | |||||
| # dummy input | |||||
| if inp == None: | |||||
| inp = tensor([0]) | |||||
| tracer_set = get_client().check_remote_tracer(key) | tracer_set = get_client().check_remote_tracer(key) | ||||
| for grad_manager in get_grad_managers(): | for grad_manager in get_grad_managers(): | ||||
| if grad_manager.name in tracer_set: | if grad_manager.name in tracer_set: | ||||
| @@ -5,12 +5,19 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import platform | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.distributed as dist | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | |||||
| import megengine.optimizer as optim | |||||
| from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
| from megengine.core._imperative_rt.imperative import sync | |||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| def test_basic(): | def test_basic(): | ||||
| @@ -48,3 +55,47 @@ def test_attach_in_with_block(): | |||||
| c = b + 1 | c = b + 1 | ||||
| gm.backward(c) | gm.backward(c) | ||||
| assert int(b.grad.numpy()) == 1 | assert int(b.grad.numpy()) == 1 | ||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||||
| ) | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_remote_grad(): | |||||
| @dist.launcher | |||||
| def worker(): | |||||
| rank = dist.get_rank() | |||||
| size = dist.get_world_size() | |||||
| x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32) | |||||
| m = M.Linear(rank * 2 + 2, rank * 2 + 4) | |||||
| gm = GradManager().attach(m.parameters()) | |||||
| opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) | |||||
| def train_func(x): | |||||
| if rank != 0: | |||||
| x = dist.functional.remote_recv( | |||||
| rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 | |||||
| ) | |||||
| print(rank, "x", x) | |||||
| y = m(x) | |||||
| print(rank, "y", y) | |||||
| if rank != size - 1: | |||||
| y = dist.functional.remote_send(y, dest_rank=rank + 1) | |||||
| return y | |||||
| with gm: | |||||
| y = train_func(x) | |||||
| if rank == size - 1: | |||||
| y = y.mean() | |||||
| gm.backward(y) | |||||
| else: | |||||
| gm.backward() | |||||
| opt.step().clear_grad() | |||||
| # sync because send is the last job | |||||
| sync() | |||||
| worker() | |||||