GitOrigin-RevId: eb3d712704
tags/v1.11.0
| @@ -1,6 +1,7 @@ | |||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| from mprop import mproperty | from mprop import mproperty | ||||
| from ..core._imperative_rt.core2 import group_end, group_start | |||||
| from . import group | from . import group | ||||
| from .group import ( | from .group import ( | ||||
| WORLD, | WORLD, | ||||
| @@ -1,12 +1,11 @@ | |||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| from typing import Optional, Tuple | |||||
| from typing import Optional | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
| from ..core.autodiff.grad import Function, _grad_manager_dict | from ..core.autodiff.grad import Function, _grad_manager_dict | ||||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||||
| from ..core.tensor.utils import isscalar | |||||
| from ..core.ops.builtin import CollectiveComm, RemoteRecv, RemoteSend | |||||
| from ..device import get_default_device, what_is_xpu | from ..device import get_default_device, what_is_xpu | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from . import group | from . import group | ||||
| @@ -843,16 +842,13 @@ def remote_send(inp: Tensor, dest_rank: int): | |||||
| """ | """ | ||||
| group = _SendRecvGroup(get_rank(), dest_rank) | group = _SendRecvGroup(get_rank(), dest_rank) | ||||
| _bcast_shape_dtype(group, inp) | _bcast_shape_dtype(group, inp) | ||||
| _bcast_tracer_state(group, inp) | _bcast_tracer_state(group, inp) | ||||
| op = RemoteSend() | op = RemoteSend() | ||||
| op.key = group.key | op.key = group.key | ||||
| op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
| op.rank_to = dest_rank | op.rank_to = dest_rank | ||||
| op.backend = _backend() | op.backend = _backend() | ||||
| out = _RemoteSend(op)(inp) | out = _RemoteSend(op)(inp) | ||||
| _save_output_for_autodiff(inp, out) | _save_output_for_autodiff(inp, out) | ||||
| @@ -900,6 +896,34 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||||
| op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
| op.rank_from = src_rank | op.rank_from = src_rank | ||||
| op.backend = _backend() | op.backend = _backend() | ||||
| ret = _RemoteRecv(op)(inp) | ret = _RemoteRecv(op)(inp) | ||||
| return ret | return ret | ||||
| def _remote_send_nobackward(inp: Tensor, dest_rank: int): | |||||
| op = RemoteSend() | |||||
| op.key = "b{}->{}".format(get_rank(), dest_rank) | |||||
| op.addr, op.port = get_mm_server_addr() | |||||
| op.rank_to = dest_rank | |||||
| op.backend = _backend() | |||||
| apply(op, inp) | |||||
| def _remote_recv_nobackward( | |||||
| src_rank: int, device: Optional[str] = None, inp=None, shape=None, dtype=None, | |||||
| ): | |||||
| op = RemoteRecv() | |||||
| op.key = "b{}->{}".format(src_rank, get_rank()) | |||||
| if device is None: | |||||
| device = get_default_device() | |||||
| op.cn = device | |||||
| if inp is None: | |||||
| inp = Tensor(0, device=device) | |||||
| assert shape is not None and dtype is not None | |||||
| op.shape = shape | |||||
| op.dtype = dtype | |||||
| op.addr, op.port = get_mm_server_addr() | |||||
| op.rank_from = src_rank | |||||
| op.backend = _backend() | |||||
| ret = apply(op, inp)[0] | |||||
| return ret | |||||
| @@ -160,6 +160,13 @@ def init_process_group( | |||||
| set_default_device("{}{}".format(device_type, device)) | set_default_device("{}{}".format(device_type, device)) | ||||
| seed(int(time.time()) + rank) | seed(int(time.time()) + rank) | ||||
| if backend == "nccl": | |||||
| # init nccl env | |||||
| from ..core._imperative_rt.common import init_nccl_env | |||||
| group_barrier() | |||||
| init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0) | |||||
| def _set_machine_ranks(ranks) -> None: | def _set_machine_ranks(ranks) -> None: | ||||
| global _sd | global _sd | ||||
| @@ -8,6 +8,9 @@ | |||||
| #include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
| #include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
| #include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
| #if MGB_ENABLE_OPR_MM | |||||
| #include "megbrain/opr/mm_handler.h" | |||||
| #endif | |||||
| #if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
| #include "cuda_sm_gen.h" | #include "cuda_sm_gen.h" | ||||
| @@ -46,6 +49,18 @@ void set_default_device(const std::string& device) { | |||||
| default_device = device; | default_device = device; | ||||
| } | } | ||||
| void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root) { | |||||
| #if MGB_ENABLE_OPR_MM | |||||
| auto&& help = mgb::opr::BatchSendRecvHelper::getInstance(); | |||||
| bool res = help->init(nranks, rank, ip, port, root); | |||||
| auto p = help->get(std::string("init_all_cards")); | |||||
| #else | |||||
| mgb_throw( | |||||
| MegBrainError, | |||||
| "MegEngine compiled without MM opr, doesn't support init_nccl_env"); | |||||
| #endif | |||||
| } | |||||
| std::string get_default_device() { | std::string get_default_device() { | ||||
| return default_device; | return default_device; | ||||
| } | } | ||||
| @@ -252,6 +267,8 @@ void init_common(py::module m) { | |||||
| m.def("what_is_xpu", | m.def("what_is_xpu", | ||||
| [] { return CompNode::Locator::parse("xpux").to_physical().type; }); | [] { return CompNode::Locator::parse("xpux").to_physical().type; }); | ||||
| m.def("init_nccl_env", &init_nccl_env); | |||||
| init_npy_num_bfloat16(m); | init_npy_num_bfloat16(m); | ||||
| init_npy_num_intbx(m); | init_npy_num_intbx(m); | ||||
| init_dtypes(m); | init_dtypes(m); | ||||
| @@ -8,3 +8,4 @@ void set_default_device(const std::string& device); | |||||
| std::string get_default_device(); | std::string get_default_device(); | ||||
| extern pybind11::handle py_comp_node_type; | extern pybind11::handle py_comp_node_type; | ||||
| void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root); | |||||
| @@ -9,6 +9,7 @@ | |||||
| #include "megbrain/imperative/transformations/dtype_promote.h" | #include "megbrain/imperative/transformations/dtype_promote.h" | ||||
| #include "megbrain/imperative/transformations/eval.h" | #include "megbrain/imperative/transformations/eval.h" | ||||
| #include "megbrain/imperative/transformations/format.h" | #include "megbrain/imperative/transformations/format.h" | ||||
| #include "megbrain/imperative/transformations/group_comm.h" | |||||
| #include "megbrain/imperative/transformations/lazy.h" | #include "megbrain/imperative/transformations/lazy.h" | ||||
| #include "megbrain/imperative/transformations/scalar.h" | #include "megbrain/imperative/transformations/scalar.h" | ||||
| #include "megbrain/imperative/transformations/symbol.h" | #include "megbrain/imperative/transformations/symbol.h" | ||||
| @@ -947,6 +948,13 @@ void init_tensor(py::module m) { | |||||
| m.def("enable_cupti", &cupti::enable); | m.def("enable_cupti", &cupti::enable); | ||||
| m.def("disable_cupti", &cupti::disable); | m.def("disable_cupti", &cupti::disable); | ||||
| m.def("cupti_available", &cupti::available); | m.def("cupti_available", &cupti::available); | ||||
| static std::unique_ptr<CleanupGuard<>> group_comm_guard; | |||||
| m.def("group_start", []() { | |||||
| auto commtrans = std::make_shared<GroupCommTransformation>(); | |||||
| group_comm_guard = transformations.register_at<Segment::GroupComm>(commtrans); | |||||
| }); | |||||
| m.def("group_end", []() { group_comm_guard.reset(); }); | |||||
| m.def("sync", [channel]() { | m.def("sync", [channel]() { | ||||
| if (channel->check_available()) { | if (channel->check_available()) { | ||||
| channel->sync(); | channel->sync(); | ||||
| @@ -16,6 +16,7 @@ struct TransformationManager { | |||||
| public: | public: | ||||
| enum Segment { | enum Segment { | ||||
| ModuleTrace, | ModuleTrace, | ||||
| GroupComm, | |||||
| DTypePromote, | DTypePromote, | ||||
| DimExpansion, | DimExpansion, | ||||
| Format, | Format, | ||||
| @@ -26,7 +27,7 @@ public: | |||||
| Eval, | Eval, | ||||
| }; | }; | ||||
| std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments; | |||||
| std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments; | |||||
| private: | private: | ||||
| template <Segment segment> | template <Segment segment> | ||||
| @@ -237,3 +237,32 @@ def test_get_cuda_compute_capability(): | |||||
| assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0 | assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0 | ||||
| worker() | worker() | ||||
| @pytest.mark.require_ngpu(3) | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_batch_send_recv(): | |||||
| import megengine.distributed.functional as DF | |||||
| @dist.launcher(n_gpus=3) | |||||
| def worker(): | |||||
| rank = dist.get_rank() | |||||
| dist.group_start() | |||||
| for i in range(3): | |||||
| tensor = mge.tensor(np.ones(10000)) * rank | |||||
| if i == 2: | |||||
| tensor *= i | |||||
| DF._remote_send_nobackward(tensor, (rank + 1) % 3) | |||||
| DF._remote_recv_nobackward( | |||||
| src_rank=(rank + 1) % 3, dtype="float32", shape=(10000,) | |||||
| ) | |||||
| DF._remote_send_nobackward(tensor, (rank - 1) % 3) | |||||
| recv = DF._remote_recv_nobackward( | |||||
| src_rank=(rank - 1) % 3, dtype="float32", shape=(10000,) | |||||
| ) | |||||
| if i == 2: | |||||
| recv2 = recv | |||||
| dist.group_end() | |||||
| np.testing.assert_equal(recv2.numpy(), (rank - 1) % 3 * 2 * np.ones(10000)) | |||||
| worker() | |||||
| @@ -1,14 +1,19 @@ | |||||
| #include "megbrain/imperative/ops/io_remote.h" | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include <numeric> | |||||
| #include "../blob_manager_impl.h" | |||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| #include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
| #include "megbrain/opr/io_remote.h" | #include "megbrain/opr/io_remote.h" | ||||
| #include "megbrain/opr/megray_helper.h" | |||||
| #include "megbrain/opr/mm_handler.h" | #include "megbrain/opr/mm_handler.h" | ||||
| #endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -46,15 +51,164 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
| recv.backend)); | recv.backend)); | ||||
| } | } | ||||
| TensorPtr megray_recv_tensor( | |||||
| std::shared_ptr<MegRay::Communicator> megray_comm, TensorLayout& layout, | |||||
| CompNode cn, uint32_t rank_from) { | |||||
| DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, layout); | |||||
| auto megray_ctx = mgb::opr::get_megray_context(cn); | |||||
| size_t data_size = layout.total_nr_elems(); | |||||
| auto status = megray_comm->recv( | |||||
| out.raw_ptr(), data_size, mgb::opr::get_megray_dtype(layout.dtype), | |||||
| rank_from, megray_ctx); | |||||
| mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | |||||
| return Tensor::make(out); | |||||
| } | |||||
| void megray_send_tensor( | |||||
| std::shared_ptr<MegRay::Communicator> megray_comm, const TensorPtr& src, | |||||
| uint32_t rank_to) { | |||||
| auto&& tensor = src->dev_tensor(); | |||||
| auto&& ishp = src->shape(); | |||||
| size_t data_size = ishp.total_nr_elems(); | |||||
| auto megray_ctx = mgb::opr::get_megray_context(src->comp_node()); | |||||
| auto status = megray_comm->send( | |||||
| src->dev_tensor().raw_ptr(), data_size, | |||||
| mgb::opr::get_megray_dtype(src->layout().dtype), rank_to, megray_ctx); | |||||
| mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | |||||
| } | |||||
| TensorLayout create_layout(const std::vector<int32_t>& shape, DType dtype) { | |||||
| TensorShape tshape; | |||||
| tshape.ndim = shape.size(); | |||||
| mgb_assert(tshape.ndim <= TensorLayout::MAX_NDIM); | |||||
| std::copy(shape.begin(), shape.end(), tshape.shape); | |||||
| return TensorLayout(tshape, dtype); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_send( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
| auto&& dtype = input_descs[0].layout.dtype; | |||||
| auto&& cn = input_descs[0].comp_node; | |||||
| return {{{TensorLayout({0}, dtype), cn}}, true}; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor_remote_send( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& op = def.cast_final_safe<RemoteSend>(); | |||||
| auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||||
| std::string("init_all_cards")); | |||||
| if (!megray_comm) { | |||||
| return proxy_graph_detail::apply_on_physical_tensor( | |||||
| def, inputs, output_descs, validated); | |||||
| } | |||||
| mgb_assert(megray_comm != nullptr); | |||||
| megray_send_tensor(megray_comm, inputs[0], op.rank_to); | |||||
| TensorLayout layout({0}, inputs[0]->dtype()); | |||||
| DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( | |||||
| inputs[0]->comp_node(), layout); | |||||
| return {Tensor::make(out)}; | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_recv( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
| auto& op = def.cast_final_safe<RemoteRecv>(); | |||||
| return {{{create_layout(op.shape, op.dtype), op.cn}}, true}; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor_remote_recv( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& op = def.cast_final_safe<RemoteRecv>(); | |||||
| auto layout = create_layout(op.shape, op.dtype); | |||||
| auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||||
| std::string("init_all_cards")); | |||||
| if (!megray_comm) { | |||||
| return proxy_graph_detail::apply_on_physical_tensor( | |||||
| def, inputs, output_descs, validated); | |||||
| } | |||||
| auto&& out = megray_recv_tensor(megray_comm, layout, op.cn, op.rank_from); | |||||
| return {out}; | |||||
| } | |||||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
| for (size_t i; i < inputs.size(); i++) { | |||||
| layout_checker[i] = [](const TensorLayout& layout) { | |||||
| return layout.is_contiguous(); | |||||
| }; | |||||
| } | |||||
| return layout_checker; | |||||
| } | |||||
| OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | ||||
| .apply_on_var_node(apply_on_var_node_remote_send) | .apply_on_var_node(apply_on_var_node_remote_send) | ||||
| .apply_on_physical_tensor(apply_on_physical_tensor_remote_send) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible_remote_send) | |||||
| .get_input_layout_constraint(get_input_layout_constraint) | |||||
| .fallback(); | .fallback(); | ||||
| OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | ||||
| .apply_on_var_node(apply_on_var_node_remote_recv) | .apply_on_var_node(apply_on_var_node_remote_recv) | ||||
| .apply_on_physical_tensor(apply_on_physical_tensor_remote_recv) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible_remote_recv) | |||||
| .get_input_layout_constraint(get_input_layout_constraint) | |||||
| .fallback(); | .fallback(); | ||||
| } // anonymous namespace | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor_batch_send_recv( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& op = def.cast_final_safe<BatchSendRecvOp>(); | |||||
| auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||||
| std::string("init_all_cards")); | |||||
| mgb_assert(megray_comm != nullptr); | |||||
| megray_comm->group_start(); | |||||
| SmallVector<TensorPtr> outputs; | |||||
| size_t ind = 0; | |||||
| for (auto&& op_ : op.op_list) { | |||||
| if (op_->same_type<RemoteSend>()) { | |||||
| auto&& send_op = op_->cast_final_safe<RemoteSend>(); | |||||
| auto&& tensor = inputs[ind]; | |||||
| megray_send_tensor(megray_comm, tensor, send_op.rank_to); | |||||
| ind++; | |||||
| } else { | |||||
| mgb_assert(op_->same_type<RemoteRecv>()); | |||||
| auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||||
| auto layout = create_layout(recv_op.shape, recv_op.dtype); | |||||
| auto&& out = megray_recv_tensor( | |||||
| megray_comm, layout, recv_op.cn, recv_op.rank_from); | |||||
| outputs.push_back(out); | |||||
| } | |||||
| } | |||||
| megray_comm->group_end(); | |||||
| return outputs; | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> | |||||
| infer_output_attrs_fallible_batch_send_recv( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||||
| auto& op = def.cast_final_safe<BatchSendRecvOp>(); | |||||
| SmallVector<LogicalTensorDesc> output_descs; | |||||
| for (auto&& op_ : op.op_list) { | |||||
| if (op_->same_type<RemoteRecv>()) { | |||||
| auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||||
| output_descs.push_back( | |||||
| {create_layout(recv_op.shape, recv_op.dtype), recv_op.cn}); | |||||
| } | |||||
| } | |||||
| return {output_descs, true}; | |||||
| } | |||||
| OP_TRAIT_REG(BatchSendRecvOp, BatchSendRecvOp) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor_batch_send_recv) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible_batch_send_recv) | |||||
| .get_input_layout_constraint(get_input_layout_constraint) | |||||
| .fallback(); | |||||
| } // namespace | |||||
| #endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchSendRecvOp); | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -0,0 +1,67 @@ | |||||
| #include "megbrain/imperative/transformations/group_comm.h" | |||||
| #include "megbrain/imperative/blob_manager.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/imperative/ops/io_remote.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| ValueRefList GroupCommTransformation::apply_transformation( | |||||
| const Operator& op, Span<ValueRef> inputs) { | |||||
| for (auto inp : inputs) { | |||||
| mgb_assert( | |||||
| !inp.is(m_value_type), "Can not use PlaceholderValue as apply input"); | |||||
| } | |||||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||||
| if (apply_op->op().same_type<RemoteSend>()) { | |||||
| auto&& send_op = apply_op->op().cast_final_safe<RemoteSend>(); | |||||
| if (send_op.key[0] == 'b') { | |||||
| send_inputs.push_back(inputs[0]); | |||||
| record_ops.push_back(send_op.shared_from_this()); | |||||
| return {}; | |||||
| } | |||||
| } | |||||
| if (apply_op->op().same_type<RemoteRecv>()) { | |||||
| auto&& recv_op = apply_op->op().cast_final_safe<RemoteRecv>(); | |||||
| if (recv_op.key[0] == 'b') { | |||||
| record_ops.push_back(recv_op.shared_from_this()); | |||||
| auto rst = m_value_type.make(); | |||||
| recv_tensors.push_back(rst); | |||||
| auto outputs = ValueRefList(1); | |||||
| outputs[0] = rst; | |||||
| return outputs; | |||||
| } | |||||
| } | |||||
| return imperative::apply(op, inputs); | |||||
| } else { | |||||
| return imperative::apply(op, inputs); | |||||
| } | |||||
| } | |||||
| ValueRefList GroupCommTransformation::execute_batch_op() { | |||||
| auto batch_op = BatchSendRecvOp::make(record_ops); | |||||
| auto outputs = imperative::apply(*batch_op, send_inputs); | |||||
| return outputs; | |||||
| } | |||||
| void GroupCommTransformation::on_unregister() noexcept { | |||||
| auto rst = execute_batch_op(); | |||||
| mgb_assert(rst.size() == recv_tensors.size()); | |||||
| for (size_t i = 0; i < rst.size(); i++) { | |||||
| auto v = recv_tensors[i].lock(); | |||||
| if (v != ValueRef::nil) { | |||||
| v.reset(rst[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| GroupCommTransformation::~GroupCommTransformation() { | |||||
| for (auto&& recv : recv_tensors) { | |||||
| mgb_assert( | |||||
| recv.lock() == ValueRef::nil, | |||||
| "Some PlaceholderValues are not reset after GroupCommTransformation " | |||||
| "destroyed!"); | |||||
| }; | |||||
| } | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| @@ -0,0 +1,11 @@ | |||||
| #pragma once | |||||
| #include "megbrain/imperative/op_def.h" | |||||
| namespace mgb::imperative { | |||||
| struct BatchSendRecvOp final : OpDefImplBase<BatchSendRecvOp> { | |||||
| SmallVector<std::shared_ptr<OpDef>> op_list; | |||||
| BatchSendRecvOp() = default; | |||||
| BatchSendRecvOp(SmallVector<std::shared_ptr<OpDef>> op_list) : op_list{op_list} {} | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/scalar.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/imperative/basic_operators.h" | |||||
| #include "megbrain/imperative/dispatch.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| namespace mgb::imperative { | |||||
| class PlaceholderValue final : public ObjectValue<PlaceholderValue> { | |||||
| public: | |||||
| std::string to_string() const override { return ssprintf("PlaceholderValue"); } | |||||
| void clear() override {} | |||||
| }; | |||||
| class GroupCommTransformation final : public Transformation { | |||||
| private: | |||||
| SmallVector<ValueRef> send_inputs; | |||||
| std::vector<PlaceholderValue::weak_ref_t> recv_tensors; | |||||
| SmallVector<std::shared_ptr<OpDef>> record_ops; | |||||
| ObjectType<PlaceholderValue> m_value_type{"PlaceholderValue"}; | |||||
| public: | |||||
| GroupCommTransformation() = default; | |||||
| ValueRefList apply_transformation( | |||||
| const Operator& op, Span<ValueRef> inputs) override; | |||||
| ValueRefList execute_batch_op(); | |||||
| ValueRef unwrap(ValueRef value) override { return value; } | |||||
| std::string name() const override { return "GroupCommTransformation"; } | |||||
| void on_unregister() noexcept override; | |||||
| ~GroupCommTransformation(); | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -1,4 +1,5 @@ | |||||
| #include "./helper.h" | #include "./helper.h" | ||||
| #include "megbrain/comp_node_env.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/opr/mm_handler.h" | #include "megbrain/opr/mm_handler.h" | ||||
| @@ -47,7 +48,4 @@ TEST(TestImperative, IORemote) { | |||||
| t0.join(); | t0.join(); | ||||
| t1.join(); | t1.join(); | ||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| // ./imperative_test --gtest_filter TestIORemote | |||||
| @@ -151,6 +151,28 @@ void GroupManager::bcast_addr( | |||||
| } | } | ||||
| } | } | ||||
| void GroupManager::bcast_nccluniqueid( | |||||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
| uint32_t root) { | |||||
| std::unique_lock<std::mutex> lk{m_key2nccl_id_mtx}; | |||||
| if (rank == root) { | |||||
| m_key2nccl_id[key] = id; | |||||
| } | |||||
| m_key2nccl_id_size[key]++; | |||||
| if (m_key2nccl_id_size[key] == size) { | |||||
| m_key2nccl_id_flag[key] = true; | |||||
| m_bcast_cv.notify_all(); | |||||
| } else { | |||||
| m_bcast_cv.wait(lk, [&] { return m_key2nccl_id_flag.count(key) > 0; }); | |||||
| } | |||||
| id = m_key2nccl_id[key]; | |||||
| m_key2nccl_id_size[key]--; | |||||
| if (m_key2nccl_id_size[key] == 0) { | |||||
| m_key2nccl_id.erase(key); | |||||
| m_key2nccl_id_flag.erase(key); | |||||
| } | |||||
| } | |||||
| void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) { | void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) { | ||||
| auto&& group = get_group(key); | auto&& group = get_group(key); | ||||
| group.set_output_shape(key, shape); | group.set_output_shape(key, shape); | ||||
| @@ -67,6 +67,15 @@ void MegRayCommBuilder::emplace( | |||||
| m_megray_comms.emplace(hash, comm); | m_megray_comms.emplace(hash, comm); | ||||
| } | } | ||||
| void MegRayCommBuilder::remove( | |||||
| uint64_t hash, std::shared_ptr<MegRay::Communicator> comm) { | |||||
| std::unique_lock<std::mutex> lk(m_map_mtx); | |||||
| auto it = m_megray_comms.find(hash); | |||||
| if (it != m_megray_comms.end()) { | |||||
| m_megray_comms.erase(hash); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | ||||
| uint64_t hash, std::string key, uint32_t size, uint32_t rank, | uint64_t hash, std::string key, uint32_t size, uint32_t rank, | ||||
| MegRay::Backend backend, std::shared_ptr<mgb::opr::GroupClient> group_client) { | MegRay::Backend backend, std::shared_ptr<mgb::opr::GroupClient> group_client) { | ||||
| @@ -104,5 +113,3 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||||
| MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | ||||
| std::mutex MegRayCommBuilder::sm_instance_mtx; | std::mutex MegRayCommBuilder::sm_instance_mtx; | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -45,6 +45,7 @@ public: | |||||
| RUNSERVER(get_output_shape); | RUNSERVER(get_output_shape); | ||||
| RUNSERVER(bcast_addr); | RUNSERVER(bcast_addr); | ||||
| RUNSERVER(group_barrier); | RUNSERVER(group_barrier); | ||||
| RUNSERVER(bcast_nccluniqueid); | |||||
| mgb_assert(false, "invalid rpc request"); | mgb_assert(false, "invalid rpc request"); | ||||
| } | } | ||||
| @@ -53,6 +54,7 @@ private: | |||||
| void set_output_shape(void* input_ptr, size_t input_len, std::string* output); | void set_output_shape(void* input_ptr, size_t input_len, std::string* output); | ||||
| void get_output_shape(void* input_ptr, size_t input_len, std::string* output); | void get_output_shape(void* input_ptr, size_t input_len, std::string* output); | ||||
| void bcast_addr(void* input_ptr, size_t input_len, std::string* output); | void bcast_addr(void* input_ptr, size_t input_len, std::string* output); | ||||
| void bcast_nccluniqueid(void* input_ptr, size_t input_len, std::string* output); | |||||
| void group_barrier(void* input_ptr, size_t input_len, std::string* output); | void group_barrier(void* input_ptr, size_t input_len, std::string* output); | ||||
| private: | private: | ||||
| @@ -116,6 +118,15 @@ void GroupServerProxy::bcast_addr( | |||||
| rsp.SerializeToString(output); | rsp.SerializeToString(output); | ||||
| } | } | ||||
| void GroupServerProxy::bcast_nccluniqueid( | |||||
| void* input_ptr, size_t input_len, std::string* output) { | |||||
| INFO_INIT(mm_handler, BcastNcclUniqueId); | |||||
| std::string id = req.id(); | |||||
| m_mgr.bcast_nccluniqueid(req.key(), id, req.size(), req.rank(), req.root()); | |||||
| rsp.set_id(id); | |||||
| rsp.SerializeToString(output); | |||||
| } | |||||
| void GroupServerProxy::group_barrier( | void GroupServerProxy::group_barrier( | ||||
| void* input_ptr, size_t input_len, std::string* output) { | void* input_ptr, size_t input_len, std::string* output) { | ||||
| INFO_INIT(mm_handler, GroupBarrier); | INFO_INIT(mm_handler, GroupBarrier); | ||||
| @@ -201,6 +212,19 @@ void GroupClientProxy::bcast_addr( | |||||
| port = rsp.port(); | port = rsp.port(); | ||||
| } | } | ||||
| void GroupClientProxy::bcast_nccluniqueid( | |||||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
| uint32_t root) { | |||||
| INFO_INIT(mm_handler, bcast_nccluniqueid, BcastNcclUniqueId); | |||||
| req.set_id(id.data(), id.size()); | |||||
| req.set_key(key.data(), key.size()); | |||||
| req.set_size(size); | |||||
| req.set_rank(rank); | |||||
| req.set_root(root); | |||||
| SOLVE_REQUEST(func_name, req, rsp); | |||||
| id = rsp.id(); | |||||
| } | |||||
| uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | ||||
| INFO_INIT(mm_handler, group_barrier, GroupBarrier); | INFO_INIT(mm_handler, group_barrier, GroupBarrier); | ||||
| req.set_size(size); | req.set_size(size); | ||||
| @@ -209,6 +233,40 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||||
| return rsp.size(); | return rsp.size(); | ||||
| } | } | ||||
| std::shared_ptr<MegRay::Communicator> BatchSendRecvHelper::get(std::string&& key) { | |||||
| auto ptr = megray_comm_cache.find(key); | |||||
| if (ptr != megray_comm_cache.end()) { | |||||
| return megray_comm_cache[key]; | |||||
| } else { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||||
| BatchSendRecvHelper::megray_comm_cache{}; | |||||
| bool BatchSendRecvHelper::init( | |||||
| int nranks, int rank, std::string ip, int port, int root) { | |||||
| auto megray_comm = | |||||
| MegRay::get_communicator(nranks, rank, MegRay::Backend::MEGRAY_NCCL); | |||||
| auto group_client = | |||||
| std::make_shared<opr::GroupClientProxy>(ssprintf("%s:%d", ip.data(), port)); | |||||
| auto cb = [=](char* nccl_buffer, size_t len) { | |||||
| std::string id; | |||||
| id.resize(128); | |||||
| if (rank == root) { | |||||
| memcpy(id.data(), nccl_buffer, len); | |||||
| } | |||||
| group_client->bcast_nccluniqueid("init_all_cards", id, nranks, rank, root); | |||||
| if (rank != root) { | |||||
| memcpy(nccl_buffer, id.data(), len); | |||||
| } | |||||
| }; | |||||
| megray_comm->init(cb); | |||||
| return megray_comm_cache.insert({std::string("init_all_cards"), megray_comm}) | |||||
| .second; | |||||
| } | |||||
| #undef INFO_INIT | #undef INFO_INIT | ||||
| #undef SOLVE_REQUEST | #undef SOLVE_REQUEST | ||||
| @@ -77,6 +77,11 @@ public: | |||||
| std::string& master_ip, int& port, const std::string& key, uint32_t size, | std::string& master_ip, int& port, const std::string& key, uint32_t size, | ||||
| uint32_t rank, uint32_t root); | uint32_t rank, uint32_t root); | ||||
| //! bcast uid | |||||
| void bcast_nccluniqueid( | |||||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
| uint32_t root); | |||||
| //! Set output shape of this key | //! Set output shape of this key | ||||
| void set_output_shape(const std::string& key, const TensorShape& shape); | void set_output_shape(const std::string& key, const TensorShape& shape); | ||||
| @@ -101,6 +106,12 @@ private: | |||||
| std::mutex m_key2addr_mtx; | std::mutex m_key2addr_mtx; | ||||
| std::condition_variable m_bcast_cv; | std::condition_variable m_bcast_cv; | ||||
| //! key -> ncclid | |||||
| std::unordered_map<std::string, std::string> m_key2nccl_id; | |||||
| std::unordered_map<std::string, uint32_t> m_key2nccl_id_size; | |||||
| std::unordered_map<std::string, bool> m_key2nccl_id_flag; | |||||
| std::mutex m_key2nccl_id_mtx; | |||||
| //! barrier | //! barrier | ||||
| uint32_t m_barrier_size; | uint32_t m_barrier_size; | ||||
| std::set<uint32_t> m_barrier_set; | std::set<uint32_t> m_barrier_set; | ||||
| @@ -128,6 +139,10 @@ public: | |||||
| std::string& master_ip, int& port, const std::string& key, uint32_t size, | std::string& master_ip, int& port, const std::string& key, uint32_t size, | ||||
| uint32_t rank, uint32_t root) = 0; | uint32_t rank, uint32_t root) = 0; | ||||
| virtual void bcast_nccluniqueid( | |||||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
| uint32_t root) = 0; | |||||
| virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; | virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; | ||||
| virtual TensorShape get_output_shape(const std::string& key) = 0; | virtual TensorShape get_output_shape(const std::string& key) = 0; | ||||
| @@ -23,6 +23,7 @@ class MegRayCommBuilder { | |||||
| private: | private: | ||||
| bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | ||||
| void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | ||||
| void remove(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||||
| std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | ||||
| std::mutex m_map_mtx; | std::mutex m_map_mtx; | ||||
| @@ -39,6 +39,10 @@ public: | |||||
| std::string& master_ip, int& port, const std::string& key, uint32_t size, | std::string& master_ip, int& port, const std::string& key, uint32_t size, | ||||
| uint32_t rank, uint32_t root) override; | uint32_t rank, uint32_t root) override; | ||||
| void bcast_nccluniqueid( | |||||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
| uint32_t root) override; | |||||
| void set_output_shape(const std::string& key, const TensorShape& shape) override; | void set_output_shape(const std::string& key, const TensorShape& shape) override; | ||||
| TensorShape get_output_shape(const std::string& key) override; | TensorShape get_output_shape(const std::string& key) override; | ||||
| @@ -52,6 +56,34 @@ private: | |||||
| void* m_stub; | void* m_stub; | ||||
| }; | }; | ||||
| template <typename T> | |||||
| class ProcessGlobal { // thread safe | |||||
| public: | |||||
| template <class... Args> | |||||
| static std::shared_ptr<T>& getInstance(Args&&... args) { | |||||
| static auto instance = std::make_shared<T>(std::forward<Args>(args)...); | |||||
| return instance; | |||||
| } | |||||
| protected: | |||||
| template <class... Args> | |||||
| ProcessGlobal(Args&&... args); | |||||
| ProcessGlobal() = default; | |||||
| public: | |||||
| ProcessGlobal(ProcessGlobal const&) = delete; | |||||
| void operator=(ProcessGlobal const&) = delete; | |||||
| }; | |||||
| class BatchSendRecvHelper : public ProcessGlobal<BatchSendRecvHelper> { | |||||
| static std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||||
| megray_comm_cache; | |||||
| public: | |||||
| std::shared_ptr<MegRay::Communicator> get(std::string&&); | |||||
| bool init(int nranks, int rank, std::string ip, int port, int root); | |||||
| }; | |||||
| /* ======================== ZmqRpcServerMgr ========================== */ | /* ======================== ZmqRpcServerMgr ========================== */ | ||||
| int create_zmqrpc_server(const std::string& server_addr, int port); | int create_zmqrpc_server(const std::string& server_addr, int port); | ||||
| @@ -60,5 +92,3 @@ int create_zmqrpc_server(const std::string& server_addr, int port); | |||||
| } // namespace mgb | } // namespace mgb | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -30,6 +30,18 @@ message BcastAddrResponse { | |||||
| int32 port = 2; | int32 port = 2; | ||||
| } | } | ||||
| message BcastNcclUniqueIdRequest{ | |||||
| string key = 1; | |||||
| bytes id = 2; | |||||
| uint32 size =3 ; | |||||
| uint32 rank = 4; | |||||
| uint32 root =5; | |||||
| } | |||||
| message BcastNcclUniqueIdResponse{ | |||||
| bytes id = 1; | |||||
| } | |||||
| message SetOutputShapeRequest { | message SetOutputShapeRequest { | ||||
| string key = 1; | string key = 1; | ||||
| TensorShape shape = 2; | TensorShape shape = 2; | ||||
| @@ -26,6 +26,12 @@ public: | |||||
| return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | ||||
| } | } | ||||
| void bcast_nccluniqueid( | |||||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||||
| uint32_t root) override { | |||||
| return m_mgr.bcast_nccluniqueid(key, id, size, rank, root); | |||||
| } | |||||
| void set_output_shape(const std::string& key, const TensorShape& shape) override { | void set_output_shape(const std::string& key, const TensorShape& shape) override { | ||||
| m_mgr.set_output_shape(key, shape); | m_mgr.set_output_shape(key, shape); | ||||
| } | } | ||||