| @@ -71,7 +71,7 @@ if sys.platform == "win32": | |||
| kernel32.SetErrorMode(old_error_mode) | |||
| from .core._imperative_rt.core2 import sync, release_trace_apply_func | |||
| from .core._imperative_rt.core2 import release_trace_apply_func, sync | |||
| from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | |||
| from .device import * | |||
| from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
| @@ -46,9 +46,31 @@ def get_grad_managers(): | |||
| return [_grad_manager_dict[key] for key in _grad_manager_dict] | |||
| class GradKey(core2.GradKey): | |||
| def __init__(self, name=None): | |||
| if name: | |||
| self.name = name | |||
| def backward(self, ys, dys): | |||
| return core2.backward(self, ys, dys) | |||
| class Grad: | |||
| def __init__(self): | |||
| self._impl = core2.GradKey() | |||
| def __init__(self, name=None): | |||
| global _grad_count | |||
| if name is None: | |||
| name = "grad_%d" % _grad_count | |||
| _grad_count += 1 | |||
| self._refkeeper = [] | |||
| self._impl = GradKey(name) | |||
| _grad_manager_dict[self._name] = self | |||
| @property | |||
| def _name(self): | |||
| return self._impl.name | |||
| def _is_attached_to(self, tensor): | |||
| return self._impl.is_attached_to(tensor) | |||
| def wrt(self, *tensors, callback=None): | |||
| for x in tensors: | |||
| @@ -62,12 +84,16 @@ class Grad: | |||
| ys = [ys] | |||
| if not isinstance(dys, Sequence): | |||
| dys = [dys] | |||
| core2.backward(self._impl, ys, dys) | |||
| self._impl.backward(ys, dys) | |||
| self._refkeeper = None | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, _1, _2, _3): | |||
| self._refkeeper = None | |||
| del self._impl | |||
| @@ -9,8 +9,8 @@ | |||
| from typing import Optional, Tuple | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.autodiff.grad import get_grad_managers | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.autodiff.grad import _grad_manager_dict | |||
| from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend | |||
| from ..device import get_default_device | |||
| from ..tensor import Tensor | |||
| from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
| @@ -193,6 +193,48 @@ def all_to_all( | |||
| return collective_comm(inp, mode, group, device) | |||
| class _RemoteSend(PyOpBase): | |||
| def __init__(self, op: RemoteSend): | |||
| self.op = op | |||
| def _default_rule(self, data): | |||
| return apply(self.op, data) | |||
| def _grad_rule(self, data): | |||
| self.dtype = data.dtype | |||
| self.shape = data.shape | |||
| self.device = data.device | |||
| (self.dummy,) = self._default_rule(data) | |||
| return self.dummy, self.backward | |||
| def backward(self, grad): | |||
| assert grad is None | |||
| if get_client().check_is_grad(self.op.key): | |||
| return remote_recv( | |||
| self.op.rank_to, | |||
| self.shape, | |||
| self.dtype, | |||
| device=str(self.device), | |||
| inp=self.dummy, | |||
| ) | |||
| class _RemoteRecv(PyOpBase): | |||
| def __init__(self, op: RemoteRecv): | |||
| self.op = op | |||
| def _default_rule(self, dummy): | |||
| return apply(self.op, dummy) | |||
| def _grad_rule(self, dummy): | |||
| return self._default_rule(dummy), self.backward | |||
| def backward(self, grad): | |||
| get_client().set_is_grad(self.op.key, grad is not None) | |||
| if grad is not None: | |||
| remote_send(grad, self.op.rank_from) | |||
| def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
| """ | |||
| Send a Tensor to a remote process. | |||
| @@ -200,11 +242,21 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
| :param inp: tensor to send. | |||
| :param dest_rank: destination process rank. | |||
| """ | |||
| key = "{}->{}".format(get_rank(), dest_rank) | |||
| grad_keys = {} | |||
| for n, g in _grad_manager_dict.items(): | |||
| if g._is_attached_to(inp): | |||
| grad_keys[n] = g | |||
| get_client().set_remote_tracer(key, grad_keys) | |||
| op = RemoteSend() | |||
| op.key = "{}->{}".format(get_rank(), dest_rank) | |||
| op.key = key | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_to = dest_rank | |||
| return apply(op, inp)[0] | |||
| (dummy,) = apply(_RemoteSend(op), inp) | |||
| for g in grad_keys.values(): | |||
| g._refkeeper.append(dummy) | |||
| def remote_recv( | |||
| @@ -228,12 +280,14 @@ def remote_recv( | |||
| if device is None: | |||
| device = get_default_device() | |||
| # dummy input | |||
| if inp == None: | |||
| if inp is None: | |||
| inp = Tensor([0], device=device) | |||
| tracer_set = get_client().check_remote_tracer(key) | |||
| for grad_manager in get_grad_managers(): | |||
| if grad_manager.name in tracer_set: | |||
| grad_manager.wrt(inp) | |||
| for n in tracer_set: | |||
| g = _grad_manager_dict.get(n) | |||
| if g is not None: | |||
| g.wrt(inp) | |||
| g._refkeeper.append(inp) | |||
| op = RemoteRecv() | |||
| op.key = key | |||
| @@ -243,4 +297,5 @@ def remote_recv( | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_from = src_rank | |||
| return apply(op, inp)[0] | |||
| (ret,) = apply(_RemoteRecv(op), inp) | |||
| return ret | |||
| @@ -193,11 +193,15 @@ struct PythonBackward { | |||
| args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||
| } | |||
| auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||
| if (!input_grads) throw py::error_already_set(); | |||
| if (input_grads.is_none()) return; | |||
| if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||
| if (input_size != 1) { | |||
| throw py::value_error("custom grad rule returned wrong number of grads"); | |||
| } | |||
| if (!ctx.pytype) { | |||
| ctx.pytype = Py_TYPE(input_grads.ptr()); | |||
| } | |||
| receiver(0, tw->m_tensor); | |||
| return; | |||
| } | |||
| @@ -210,6 +214,9 @@ struct PythonBackward { | |||
| if (!tw) { | |||
| throw py::type_error("custom grad rule returned non-tensor"); | |||
| } | |||
| if (!ctx.pytype) { | |||
| ctx.pytype = Py_TYPE(g.ptr()); | |||
| } | |||
| receiver(i, tw->m_tensor); | |||
| } | |||
| } | |||
| @@ -321,6 +328,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
| } | |||
| auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||
| auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); | |||
| if (!pyret) throw py::error_already_set(); | |||
| auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||
| ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||
| if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||
| @@ -507,8 +515,12 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
| ~CleanupGuard() {owner->cleanup();} | |||
| } _cleanup_guard(this); | |||
| if (tape.empty() || grads.empty()) return; | |||
| PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr()); | |||
| if (tape.empty()) return; | |||
| BackwardContext bctx; | |||
| if (!grads.empty()) { | |||
| bctx.pytype = Py_TYPE(grads[0]->self().ptr()); | |||
| } | |||
| for (size_t i = 0; i < tensors.size(); ++i) { | |||
| auto& grad_info = tensors[i]->m_tensor->m_grad_info; | |||
| @@ -517,7 +529,6 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
| } | |||
| } | |||
| BackwardContext bctx{pytype}; | |||
| std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||
| ref_keeper.reserve(tape.size()); | |||
| // back-propagation in reverse order | |||
| @@ -548,7 +559,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
| } | |||
| if (!dst.producer_record.next && dst->callback && dst->grad) { | |||
| // I'm the last grad producer, invoke callback | |||
| dst->callback(TensorWrapper::make(pytype, dst->grad)); | |||
| dst->callback(bctx.wrap_tensor(dst->grad)); | |||
| } | |||
| } | |||
| grad_fn->clear(); | |||
| @@ -568,6 +579,31 @@ void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<T | |||
| m_key->backward(std::move(tensors), std::move(grads)); | |||
| } | |||
| PyObject* GradKeyWrapper::get_name() { | |||
| return py::cast(m_key->name).release().ptr(); | |||
| } | |||
| void GradKeyWrapper::set_name(py::handle name) { | |||
| m_key->name = py::cast<std::string>(name); | |||
| } | |||
| PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { | |||
| if (nargs != 1) { | |||
| PyErr_SetString(PyExc_TypeError, "expect 1 argument"); | |||
| return nullptr; | |||
| } | |||
| auto* tw = TensorWrapper::try_cast(args[0]); | |||
| if (!tw) { | |||
| PyErr_SetString(PyExc_TypeError, "expect Tensor"); | |||
| return nullptr; | |||
| } | |||
| auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn; | |||
| if (grad_fn && grad_fn->key.lock() == m_key) { | |||
| Py_RETURN_TRUE; | |||
| } | |||
| Py_RETURN_FALSE; | |||
| } | |||
| GradKey::~GradKey() { | |||
| cleanup(); | |||
| } | |||
| @@ -41,8 +41,11 @@ struct GradKeyWrapper { | |||
| inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} | |||
| PyObject* get_name(); | |||
| void set_name(pybind11::handle name); | |||
| void attach(PyObject*const* args, size_t nargs); | |||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
| PyObject* is_attached_to(PyObject*const* args, size_t nargs); | |||
| }; | |||
| struct BackwardContext { | |||
| @@ -733,15 +733,18 @@ void init_tensor(py::module m) { | |||
| py_task_q.wait_all_task_finish(); | |||
| }, | |||
| py::call_guard<py::gil_scoped_release>()); | |||
| m.def("release_trace_apply_func", &release_trace_apply_func); | |||
| py::handle grad_key_type = GradKeyWrapper::wrap_t::type() | |||
| .def<&GradKeyWrapper::attach>("attach") | |||
| .def<&GradKeyWrapper::is_attached_to>("is_attached_to") | |||
| .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name") | |||
| .finalize(); | |||
| if (!grad_key_type) throw py::error_already_set(); | |||
| py::setattr(m, "GradKey", grad_key_type); | |||
| py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); | |||
| m.def("backward", &GradKeyWrapper::backward); | |||
| m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); | |||
| m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); | |||
| m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode); | |||
| @@ -141,6 +141,7 @@ def test_regression_1762(): | |||
| ) | |||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
| @pytest.mark.isolated_distributed | |||
| @pytest.mark.skip(reason="FIXME: remote_send/recv") | |||
| def test_remote_grad(): | |||
| @dist.launcher | |||
| def worker(): | |||
| @@ -16,9 +16,8 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| from megengine.core._imperative_rt import TensorAttr, core2, imperative | |||
| from megengine.core._imperative_rt.core2 import TensorWeakRef, apply | |||
| from megengine.core._imperative_rt.imperative import sync | |||
| from megengine.core._imperative_rt import CompNode, TensorAttr, core2, imperative | |||
| from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | |||
| from megengine.core.autodiff.grad import Grad | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| @@ -73,7 +72,7 @@ def test_dist_grad(): | |||
| x = as_tensor(x_np) | |||
| grad.wrt(x, callback=save_to(x)) | |||
| # need a placeholder to trace operator | |||
| send_x = remote_send(x, 1) | |||
| remote_send(x, 1) | |||
| recv_x = remote_recv(1, x_np.shape, x_np.dtype) | |||
| y = recv_x * recv_x | |||
| @@ -83,13 +82,12 @@ def test_dist_grad(): | |||
| grad = Grad() | |||
| recv_x = remote_recv(0, x_np.shape, x_np.dtype) | |||
| send_x = remote_send(recv_x, 0) | |||
| remote_send(recv_x, 0) | |||
| grad([], []) | |||
| worker() | |||
| def test_grad(): | |||
| x_np = np.random.rand(10).astype("float32") | |||
| x = as_tensor(x_np) | |||
| @@ -14,6 +14,7 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| from megengine import Parameter, Tensor, tensor | |||
| from megengine.core._imperative_rt.core2 import sync | |||
| from megengine.device import get_default_device, set_default_device | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.functional.distributed import ( | |||
| @@ -333,8 +334,8 @@ def test_io_remote(): | |||
| rank = dist.get_rank() | |||
| if rank == 0: # remote send | |||
| x = Tensor(val, device="gpu0") | |||
| y = remote_send(x, 1) | |||
| assert y.numpy()[0] == 0 | |||
| remote_send(x, 1) | |||
| sync() | |||
| else: # remote recv | |||
| y = remote_recv(0, val.shape, val.dtype) | |||
| assert y.device == "gpu1" | |||