GitOrigin-RevId: eb3d712704
tags/v1.11.0
| @@ -1,6 +1,7 @@ | |||
| # -*- coding: utf-8 -*- | |||
| from mprop import mproperty | |||
| from ..core._imperative_rt.core2 import group_end, group_start | |||
| from . import group | |||
| from .group import ( | |||
| WORLD, | |||
| @@ -1,12 +1,11 @@ | |||
| # -*- coding: utf-8 -*- | |||
| from typing import Optional, Tuple | |||
| from typing import Optional | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.autodiff.grad import Function, _grad_manager_dict | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.tensor.utils import isscalar | |||
| from ..core.ops.builtin import CollectiveComm, RemoteRecv, RemoteSend | |||
| from ..device import get_default_device, what_is_xpu | |||
| from ..tensor import Tensor | |||
| from . import group | |||
| @@ -843,16 +842,13 @@ def remote_send(inp: Tensor, dest_rank: int): | |||
| """ | |||
| group = _SendRecvGroup(get_rank(), dest_rank) | |||
| _bcast_shape_dtype(group, inp) | |||
| _bcast_tracer_state(group, inp) | |||
| op = RemoteSend() | |||
| op.key = group.key | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_to = dest_rank | |||
| op.backend = _backend() | |||
| out = _RemoteSend(op)(inp) | |||
| _save_output_for_autodiff(inp, out) | |||
| @@ -900,6 +896,34 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_from = src_rank | |||
| op.backend = _backend() | |||
| ret = _RemoteRecv(op)(inp) | |||
| return ret | |||
| def _remote_send_nobackward(inp: Tensor, dest_rank: int): | |||
| op = RemoteSend() | |||
| op.key = "b{}->{}".format(get_rank(), dest_rank) | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_to = dest_rank | |||
| op.backend = _backend() | |||
| apply(op, inp) | |||
| def _remote_recv_nobackward( | |||
| src_rank: int, device: Optional[str] = None, inp=None, shape=None, dtype=None, | |||
| ): | |||
| op = RemoteRecv() | |||
| op.key = "b{}->{}".format(src_rank, get_rank()) | |||
| if device is None: | |||
| device = get_default_device() | |||
| op.cn = device | |||
| if inp is None: | |||
| inp = Tensor(0, device=device) | |||
| assert shape is not None and dtype is not None | |||
| op.shape = shape | |||
| op.dtype = dtype | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_from = src_rank | |||
| op.backend = _backend() | |||
| ret = apply(op, inp)[0] | |||
| return ret | |||
| @@ -160,6 +160,13 @@ def init_process_group( | |||
| set_default_device("{}{}".format(device_type, device)) | |||
| seed(int(time.time()) + rank) | |||
| if backend == "nccl": | |||
| # init nccl env | |||
| from ..core._imperative_rt.common import init_nccl_env | |||
| group_barrier() | |||
| init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0) | |||
| def _set_machine_ranks(ranks) -> None: | |||
| global _sd | |||
| @@ -8,6 +8,9 @@ | |||
| #include "megbrain/comp_node.h" | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/imperative/physical_tensor.h" | |||
| #if MGB_ENABLE_OPR_MM | |||
| #include "megbrain/opr/mm_handler.h" | |||
| #endif | |||
| #if MEGDNN_WITH_CUDA | |||
| #include "cuda_sm_gen.h" | |||
| @@ -46,6 +49,18 @@ void set_default_device(const std::string& device) { | |||
| default_device = device; | |||
| } | |||
| void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root) { | |||
| #if MGB_ENABLE_OPR_MM | |||
| auto&& help = mgb::opr::BatchSendRecvHelper::getInstance(); | |||
| bool res = help->init(nranks, rank, ip, port, root); | |||
| auto p = help->get(std::string("init_all_cards")); | |||
| #else | |||
| mgb_throw( | |||
| MegBrainError, | |||
| "MegEngine compiled without MM opr, doesn't support init_nccl_env"); | |||
| #endif | |||
| } | |||
| std::string get_default_device() { | |||
| return default_device; | |||
| } | |||
| @@ -252,6 +267,8 @@ void init_common(py::module m) { | |||
| m.def("what_is_xpu", | |||
| [] { return CompNode::Locator::parse("xpux").to_physical().type; }); | |||
| m.def("init_nccl_env", &init_nccl_env); | |||
| init_npy_num_bfloat16(m); | |||
| init_npy_num_intbx(m); | |||
| init_dtypes(m); | |||
| @@ -8,3 +8,4 @@ void set_default_device(const std::string& device); | |||
| std::string get_default_device(); | |||
| extern pybind11::handle py_comp_node_type; | |||
| void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root); | |||
| @@ -9,6 +9,7 @@ | |||
| #include "megbrain/imperative/transformations/dtype_promote.h" | |||
| #include "megbrain/imperative/transformations/eval.h" | |||
| #include "megbrain/imperative/transformations/format.h" | |||
| #include "megbrain/imperative/transformations/group_comm.h" | |||
| #include "megbrain/imperative/transformations/lazy.h" | |||
| #include "megbrain/imperative/transformations/scalar.h" | |||
| #include "megbrain/imperative/transformations/symbol.h" | |||
| @@ -947,6 +948,13 @@ void init_tensor(py::module m) { | |||
| m.def("enable_cupti", &cupti::enable); | |||
| m.def("disable_cupti", &cupti::disable); | |||
| m.def("cupti_available", &cupti::available); | |||
| static std::unique_ptr<CleanupGuard<>> group_comm_guard; | |||
| m.def("group_start", []() { | |||
| auto commtrans = std::make_shared<GroupCommTransformation>(); | |||
| group_comm_guard = transformations.register_at<Segment::GroupComm>(commtrans); | |||
| }); | |||
| m.def("group_end", []() { group_comm_guard.reset(); }); | |||
| m.def("sync", [channel]() { | |||
| if (channel->check_available()) { | |||
| channel->sync(); | |||
| @@ -16,6 +16,7 @@ struct TransformationManager { | |||
| public: | |||
| enum Segment { | |||
| ModuleTrace, | |||
| GroupComm, | |||
| DTypePromote, | |||
| DimExpansion, | |||
| Format, | |||
| @@ -26,7 +27,7 @@ public: | |||
| Eval, | |||
| }; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments; | |||
| private: | |||
| template <Segment segment> | |||
| @@ -237,3 +237,32 @@ def test_get_cuda_compute_capability(): | |||
| assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0 | |||
| worker() | |||
| @pytest.mark.require_ngpu(3) | |||
| @pytest.mark.isolated_distributed | |||
| def test_batch_send_recv(): | |||
| import megengine.distributed.functional as DF | |||
| @dist.launcher(n_gpus=3) | |||
| def worker(): | |||
| rank = dist.get_rank() | |||
| dist.group_start() | |||
| for i in range(3): | |||
| tensor = mge.tensor(np.ones(10000)) * rank | |||
| if i == 2: | |||
| tensor *= i | |||
| DF._remote_send_nobackward(tensor, (rank + 1) % 3) | |||
| DF._remote_recv_nobackward( | |||
| src_rank=(rank + 1) % 3, dtype="float32", shape=(10000,) | |||
| ) | |||
| DF._remote_send_nobackward(tensor, (rank - 1) % 3) | |||
| recv = DF._remote_recv_nobackward( | |||
| src_rank=(rank - 1) % 3, dtype="float32", shape=(10000,) | |||
| ) | |||
| if i == 2: | |||
| recv2 = recv | |||
| dist.group_end() | |||
| np.testing.assert_equal(recv2.numpy(), (rank - 1) % 3 * 2 * np.ones(10000)) | |||
| worker() | |||
| @@ -1,14 +1,19 @@ | |||
| #include "megbrain/imperative/ops/io_remote.h" | |||
| #include "megbrain_build_config.h" | |||
| #if MGB_ENABLE_OPR_MM | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include "../blob_manager_impl.h" | |||
| #include "../op_trait.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/opr/io_remote.h" | |||
| #include "megbrain/opr/megray_helper.h" | |||
| #include "megbrain/opr/mm_handler.h" | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -46,15 +51,164 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
| recv.backend)); | |||
| } | |||
| TensorPtr megray_recv_tensor( | |||
| std::shared_ptr<MegRay::Communicator> megray_comm, TensorLayout& layout, | |||
| CompNode cn, uint32_t rank_from) { | |||
| DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, layout); | |||
| auto megray_ctx = mgb::opr::get_megray_context(cn); | |||
| size_t data_size = layout.total_nr_elems(); | |||
| auto status = megray_comm->recv( | |||
| out.raw_ptr(), data_size, mgb::opr::get_megray_dtype(layout.dtype), | |||
| rank_from, megray_ctx); | |||
| mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); | |||
| return Tensor::make(out); | |||
| } | |||
| void megray_send_tensor( | |||
| std::shared_ptr<MegRay::Communicator> megray_comm, const TensorPtr& src, | |||
| uint32_t rank_to) { | |||
| auto&& tensor = src->dev_tensor(); | |||
| auto&& ishp = src->shape(); | |||
| size_t data_size = ishp.total_nr_elems(); | |||
| auto megray_ctx = mgb::opr::get_megray_context(src->comp_node()); | |||
| auto status = megray_comm->send( | |||
| src->dev_tensor().raw_ptr(), data_size, | |||
| mgb::opr::get_megray_dtype(src->layout().dtype), rank_to, megray_ctx); | |||
| mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | |||
| } | |||
| TensorLayout create_layout(const std::vector<int32_t>& shape, DType dtype) { | |||
| TensorShape tshape; | |||
| tshape.ndim = shape.size(); | |||
| mgb_assert(tshape.ndim <= TensorLayout::MAX_NDIM); | |||
| std::copy(shape.begin(), shape.end(), tshape.shape); | |||
| return TensorLayout(tshape, dtype); | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_send( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
| auto&& dtype = input_descs[0].layout.dtype; | |||
| auto&& cn = input_descs[0].comp_node; | |||
| return {{{TensorLayout({0}, dtype), cn}}, true}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor_remote_send( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op = def.cast_final_safe<RemoteSend>(); | |||
| auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||
| std::string("init_all_cards")); | |||
| if (!megray_comm) { | |||
| return proxy_graph_detail::apply_on_physical_tensor( | |||
| def, inputs, output_descs, validated); | |||
| } | |||
| mgb_assert(megray_comm != nullptr); | |||
| megray_send_tensor(megray_comm, inputs[0], op.rank_to); | |||
| TensorLayout layout({0}, inputs[0]->dtype()); | |||
| DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( | |||
| inputs[0]->comp_node(), layout); | |||
| return {Tensor::make(out)}; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_recv( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
| auto& op = def.cast_final_safe<RemoteRecv>(); | |||
| return {{{create_layout(op.shape, op.dtype), op.cn}}, true}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor_remote_recv( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op = def.cast_final_safe<RemoteRecv>(); | |||
| auto layout = create_layout(op.shape, op.dtype); | |||
| auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||
| std::string("init_all_cards")); | |||
| if (!megray_comm) { | |||
| return proxy_graph_detail::apply_on_physical_tensor( | |||
| def, inputs, output_descs, validated); | |||
| } | |||
| auto&& out = megray_recv_tensor(megray_comm, layout, op.cn, op.rank_from); | |||
| return {out}; | |||
| } | |||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
| for (size_t i; i < inputs.size(); i++) { | |||
| layout_checker[i] = [](const TensorLayout& layout) { | |||
| return layout.is_contiguous(); | |||
| }; | |||
| } | |||
| return layout_checker; | |||
| } | |||
| OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | |||
| .apply_on_var_node(apply_on_var_node_remote_send) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor_remote_send) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible_remote_send) | |||
| .get_input_layout_constraint(get_input_layout_constraint) | |||
| .fallback(); | |||
| OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||
| .apply_on_var_node(apply_on_var_node_remote_recv) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor_remote_recv) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible_remote_recv) | |||
| .get_input_layout_constraint(get_input_layout_constraint) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| SmallVector<TensorPtr> apply_on_physical_tensor_batch_send_recv( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op = def.cast_final_safe<BatchSendRecvOp>(); | |||
| auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( | |||
| std::string("init_all_cards")); | |||
| mgb_assert(megray_comm != nullptr); | |||
| megray_comm->group_start(); | |||
| SmallVector<TensorPtr> outputs; | |||
| size_t ind = 0; | |||
| for (auto&& op_ : op.op_list) { | |||
| if (op_->same_type<RemoteSend>()) { | |||
| auto&& send_op = op_->cast_final_safe<RemoteSend>(); | |||
| auto&& tensor = inputs[ind]; | |||
| megray_send_tensor(megray_comm, tensor, send_op.rank_to); | |||
| ind++; | |||
| } else { | |||
| mgb_assert(op_->same_type<RemoteRecv>()); | |||
| auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||
| auto layout = create_layout(recv_op.shape, recv_op.dtype); | |||
| auto&& out = megray_recv_tensor( | |||
| megray_comm, layout, recv_op.cn, recv_op.rank_from); | |||
| outputs.push_back(out); | |||
| } | |||
| } | |||
| megray_comm->group_end(); | |||
| return outputs; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> | |||
| infer_output_attrs_fallible_batch_send_recv( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
| auto& op = def.cast_final_safe<BatchSendRecvOp>(); | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| for (auto&& op_ : op.op_list) { | |||
| if (op_->same_type<RemoteRecv>()) { | |||
| auto&& recv_op = op_->cast_final_safe<RemoteRecv>(); | |||
| output_descs.push_back( | |||
| {create_layout(recv_op.shape, recv_op.dtype), recv_op.cn}); | |||
| } | |||
| } | |||
| return {output_descs, true}; | |||
| } | |||
| OP_TRAIT_REG(BatchSendRecvOp, BatchSendRecvOp) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor_batch_send_recv) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible_batch_send_recv) | |||
| .get_input_layout_constraint(get_input_layout_constraint) | |||
| .fallback(); | |||
| } // namespace | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchSendRecvOp); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -0,0 +1,67 @@ | |||
| #include "megbrain/imperative/transformations/group_comm.h" | |||
| #include "megbrain/imperative/blob_manager.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/io_remote.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| ValueRefList GroupCommTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| for (auto inp : inputs) { | |||
| mgb_assert( | |||
| !inp.is(m_value_type), "Can not use PlaceholderValue as apply input"); | |||
| } | |||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||
| if (apply_op->op().same_type<RemoteSend>()) { | |||
| auto&& send_op = apply_op->op().cast_final_safe<RemoteSend>(); | |||
| if (send_op.key[0] == 'b') { | |||
| send_inputs.push_back(inputs[0]); | |||
| record_ops.push_back(send_op.shared_from_this()); | |||
| return {}; | |||
| } | |||
| } | |||
| if (apply_op->op().same_type<RemoteRecv>()) { | |||
| auto&& recv_op = apply_op->op().cast_final_safe<RemoteRecv>(); | |||
| if (recv_op.key[0] == 'b') { | |||
| record_ops.push_back(recv_op.shared_from_this()); | |||
| auto rst = m_value_type.make(); | |||
| recv_tensors.push_back(rst); | |||
| auto outputs = ValueRefList(1); | |||
| outputs[0] = rst; | |||
| return outputs; | |||
| } | |||
| } | |||
| return imperative::apply(op, inputs); | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } | |||
| ValueRefList GroupCommTransformation::execute_batch_op() { | |||
| auto batch_op = BatchSendRecvOp::make(record_ops); | |||
| auto outputs = imperative::apply(*batch_op, send_inputs); | |||
| return outputs; | |||
| } | |||
| void GroupCommTransformation::on_unregister() noexcept { | |||
| auto rst = execute_batch_op(); | |||
| mgb_assert(rst.size() == recv_tensors.size()); | |||
| for (size_t i = 0; i < rst.size(); i++) { | |||
| auto v = recv_tensors[i].lock(); | |||
| if (v != ValueRef::nil) { | |||
| v.reset(rst[i]); | |||
| } | |||
| } | |||
| } | |||
| GroupCommTransformation::~GroupCommTransformation() { | |||
| for (auto&& recv : recv_tensors) { | |||
| mgb_assert( | |||
| recv.lock() == ValueRef::nil, | |||
| "Some PlaceholderValues are not reset after GroupCommTransformation " | |||
| "destroyed!"); | |||
| }; | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -0,0 +1,11 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| namespace mgb::imperative { | |||
| struct BatchSendRecvOp final : OpDefImplBase<BatchSendRecvOp> { | |||
| SmallVector<std::shared_ptr<OpDef>> op_list; | |||
| BatchSendRecvOp() = default; | |||
| BatchSendRecvOp(SmallVector<std::shared_ptr<OpDef>> op_list) : op_list{op_list} {} | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/scalar.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/basic_operators.h" | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace mgb::imperative { | |||
| class PlaceholderValue final : public ObjectValue<PlaceholderValue> { | |||
| public: | |||
| std::string to_string() const override { return ssprintf("PlaceholderValue"); } | |||
| void clear() override {} | |||
| }; | |||
| class GroupCommTransformation final : public Transformation { | |||
| private: | |||
| SmallVector<ValueRef> send_inputs; | |||
| std::vector<PlaceholderValue::weak_ref_t> recv_tensors; | |||
| SmallVector<std::shared_ptr<OpDef>> record_ops; | |||
| ObjectType<PlaceholderValue> m_value_type{"PlaceholderValue"}; | |||
| public: | |||
| GroupCommTransformation() = default; | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRefList execute_batch_op(); | |||
| ValueRef unwrap(ValueRef value) override { return value; } | |||
| std::string name() const override { return "GroupCommTransformation"; } | |||
| void on_unregister() noexcept override; | |||
| ~GroupCommTransformation(); | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -1,4 +1,5 @@ | |||
| #include "./helper.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/mm_handler.h" | |||
| @@ -47,7 +48,4 @@ TEST(TestImperative, IORemote) { | |||
| t0.join(); | |||
| t1.join(); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| // ./imperative_test --gtest_filter TestIORemote | |||
| @@ -151,6 +151,28 @@ void GroupManager::bcast_addr( | |||
| } | |||
| } | |||
| void GroupManager::bcast_nccluniqueid( | |||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
| uint32_t root) { | |||
| std::unique_lock<std::mutex> lk{m_key2nccl_id_mtx}; | |||
| if (rank == root) { | |||
| m_key2nccl_id[key] = id; | |||
| } | |||
| m_key2nccl_id_size[key]++; | |||
| if (m_key2nccl_id_size[key] == size) { | |||
| m_key2nccl_id_flag[key] = true; | |||
| m_bcast_cv.notify_all(); | |||
| } else { | |||
| m_bcast_cv.wait(lk, [&] { return m_key2nccl_id_flag.count(key) > 0; }); | |||
| } | |||
| id = m_key2nccl_id[key]; | |||
| m_key2nccl_id_size[key]--; | |||
| if (m_key2nccl_id_size[key] == 0) { | |||
| m_key2nccl_id.erase(key); | |||
| m_key2nccl_id_flag.erase(key); | |||
| } | |||
| } | |||
| void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) { | |||
| auto&& group = get_group(key); | |||
| group.set_output_shape(key, shape); | |||
| @@ -67,6 +67,15 @@ void MegRayCommBuilder::emplace( | |||
| m_megray_comms.emplace(hash, comm); | |||
| } | |||
| void MegRayCommBuilder::remove( | |||
| uint64_t hash, std::shared_ptr<MegRay::Communicator> comm) { | |||
| std::unique_lock<std::mutex> lk(m_map_mtx); | |||
| auto it = m_megray_comms.find(hash); | |||
| if (it != m_megray_comms.end()) { | |||
| m_megray_comms.erase(hash); | |||
| } | |||
| } | |||
| std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
| uint64_t hash, std::string key, uint32_t size, uint32_t rank, | |||
| MegRay::Backend backend, std::shared_ptr<mgb::opr::GroupClient> group_client) { | |||
| @@ -104,5 +113,3 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
| MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; | |||
| std::mutex MegRayCommBuilder::sm_instance_mtx; | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -45,6 +45,7 @@ public: | |||
| RUNSERVER(get_output_shape); | |||
| RUNSERVER(bcast_addr); | |||
| RUNSERVER(group_barrier); | |||
| RUNSERVER(bcast_nccluniqueid); | |||
| mgb_assert(false, "invalid rpc request"); | |||
| } | |||
| @@ -53,6 +54,7 @@ private: | |||
| void set_output_shape(void* input_ptr, size_t input_len, std::string* output); | |||
| void get_output_shape(void* input_ptr, size_t input_len, std::string* output); | |||
| void bcast_addr(void* input_ptr, size_t input_len, std::string* output); | |||
| void bcast_nccluniqueid(void* input_ptr, size_t input_len, std::string* output); | |||
| void group_barrier(void* input_ptr, size_t input_len, std::string* output); | |||
| private: | |||
| @@ -116,6 +118,15 @@ void GroupServerProxy::bcast_addr( | |||
| rsp.SerializeToString(output); | |||
| } | |||
| void GroupServerProxy::bcast_nccluniqueid( | |||
| void* input_ptr, size_t input_len, std::string* output) { | |||
| INFO_INIT(mm_handler, BcastNcclUniqueId); | |||
| std::string id = req.id(); | |||
| m_mgr.bcast_nccluniqueid(req.key(), id, req.size(), req.rank(), req.root()); | |||
| rsp.set_id(id); | |||
| rsp.SerializeToString(output); | |||
| } | |||
| void GroupServerProxy::group_barrier( | |||
| void* input_ptr, size_t input_len, std::string* output) { | |||
| INFO_INIT(mm_handler, GroupBarrier); | |||
| @@ -201,6 +212,19 @@ void GroupClientProxy::bcast_addr( | |||
| port = rsp.port(); | |||
| } | |||
| void GroupClientProxy::bcast_nccluniqueid( | |||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
| uint32_t root) { | |||
| INFO_INIT(mm_handler, bcast_nccluniqueid, BcastNcclUniqueId); | |||
| req.set_id(id.data(), id.size()); | |||
| req.set_key(key.data(), key.size()); | |||
| req.set_size(size); | |||
| req.set_rank(rank); | |||
| req.set_root(root); | |||
| SOLVE_REQUEST(func_name, req, rsp); | |||
| id = rsp.id(); | |||
| } | |||
| uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||
| INFO_INIT(mm_handler, group_barrier, GroupBarrier); | |||
| req.set_size(size); | |||
| @@ -209,6 +233,40 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||
| return rsp.size(); | |||
| } | |||
| std::shared_ptr<MegRay::Communicator> BatchSendRecvHelper::get(std::string&& key) { | |||
| auto ptr = megray_comm_cache.find(key); | |||
| if (ptr != megray_comm_cache.end()) { | |||
| return megray_comm_cache[key]; | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||
| BatchSendRecvHelper::megray_comm_cache{}; | |||
| bool BatchSendRecvHelper::init( | |||
| int nranks, int rank, std::string ip, int port, int root) { | |||
| auto megray_comm = | |||
| MegRay::get_communicator(nranks, rank, MegRay::Backend::MEGRAY_NCCL); | |||
| auto group_client = | |||
| std::make_shared<opr::GroupClientProxy>(ssprintf("%s:%d", ip.data(), port)); | |||
| auto cb = [=](char* nccl_buffer, size_t len) { | |||
| std::string id; | |||
| id.resize(128); | |||
| if (rank == root) { | |||
| memcpy(id.data(), nccl_buffer, len); | |||
| } | |||
| group_client->bcast_nccluniqueid("init_all_cards", id, nranks, rank, root); | |||
| if (rank != root) { | |||
| memcpy(nccl_buffer, id.data(), len); | |||
| } | |||
| }; | |||
| megray_comm->init(cb); | |||
| return megray_comm_cache.insert({std::string("init_all_cards"), megray_comm}) | |||
| .second; | |||
| } | |||
| #undef INFO_INIT | |||
| #undef SOLVE_REQUEST | |||
| @@ -77,6 +77,11 @@ public: | |||
| std::string& master_ip, int& port, const std::string& key, uint32_t size, | |||
| uint32_t rank, uint32_t root); | |||
| //! bcast uid | |||
| void bcast_nccluniqueid( | |||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
| uint32_t root); | |||
| //! Set output shape of this key | |||
| void set_output_shape(const std::string& key, const TensorShape& shape); | |||
| @@ -101,6 +106,12 @@ private: | |||
| std::mutex m_key2addr_mtx; | |||
| std::condition_variable m_bcast_cv; | |||
| //! key -> ncclid | |||
| std::unordered_map<std::string, std::string> m_key2nccl_id; | |||
| std::unordered_map<std::string, uint32_t> m_key2nccl_id_size; | |||
| std::unordered_map<std::string, bool> m_key2nccl_id_flag; | |||
| std::mutex m_key2nccl_id_mtx; | |||
| //! barrier | |||
| uint32_t m_barrier_size; | |||
| std::set<uint32_t> m_barrier_set; | |||
| @@ -128,6 +139,10 @@ public: | |||
| std::string& master_ip, int& port, const std::string& key, uint32_t size, | |||
| uint32_t rank, uint32_t root) = 0; | |||
| virtual void bcast_nccluniqueid( | |||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
| uint32_t root) = 0; | |||
| virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; | |||
| virtual TensorShape get_output_shape(const std::string& key) = 0; | |||
| @@ -23,6 +23,7 @@ class MegRayCommBuilder { | |||
| private: | |||
| bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm); | |||
| void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||
| void remove(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm); | |||
| std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; | |||
| std::mutex m_map_mtx; | |||
| @@ -39,6 +39,10 @@ public: | |||
| std::string& master_ip, int& port, const std::string& key, uint32_t size, | |||
| uint32_t rank, uint32_t root) override; | |||
| void bcast_nccluniqueid( | |||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
| uint32_t root) override; | |||
| void set_output_shape(const std::string& key, const TensorShape& shape) override; | |||
| TensorShape get_output_shape(const std::string& key) override; | |||
| @@ -52,6 +56,34 @@ private: | |||
| void* m_stub; | |||
| }; | |||
| template <typename T> | |||
| class ProcessGlobal { // thread safe | |||
| public: | |||
| template <class... Args> | |||
| static std::shared_ptr<T>& getInstance(Args&&... args) { | |||
| static auto instance = std::make_shared<T>(std::forward<Args>(args)...); | |||
| return instance; | |||
| } | |||
| protected: | |||
| template <class... Args> | |||
| ProcessGlobal(Args&&... args); | |||
| ProcessGlobal() = default; | |||
| public: | |||
| ProcessGlobal(ProcessGlobal const&) = delete; | |||
| void operator=(ProcessGlobal const&) = delete; | |||
| }; | |||
| class BatchSendRecvHelper : public ProcessGlobal<BatchSendRecvHelper> { | |||
| static std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>> | |||
| megray_comm_cache; | |||
| public: | |||
| std::shared_ptr<MegRay::Communicator> get(std::string&&); | |||
| bool init(int nranks, int rank, std::string ip, int port, int root); | |||
| }; | |||
| /* ======================== ZmqRpcServerMgr ========================== */ | |||
| int create_zmqrpc_server(const std::string& server_addr, int port); | |||
| @@ -60,5 +92,3 @@ int create_zmqrpc_server(const std::string& server_addr, int port); | |||
| } // namespace mgb | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -30,6 +30,18 @@ message BcastAddrResponse { | |||
| int32 port = 2; | |||
| } | |||
| message BcastNcclUniqueIdRequest{ | |||
| string key = 1; | |||
| bytes id = 2; | |||
| uint32 size =3 ; | |||
| uint32 rank = 4; | |||
| uint32 root =5; | |||
| } | |||
| message BcastNcclUniqueIdResponse{ | |||
| bytes id = 1; | |||
| } | |||
| message SetOutputShapeRequest { | |||
| string key = 1; | |||
| TensorShape shape = 2; | |||
| @@ -26,6 +26,12 @@ public: | |||
| return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | |||
| } | |||
| void bcast_nccluniqueid( | |||
| const std::string& key, std::string& id, uint32_t size, uint32_t rank, | |||
| uint32_t root) override { | |||
| return m_mgr.bcast_nccluniqueid(key, id, size, rank, root); | |||
| } | |||
| void set_output_shape(const std::string& key, const TensorShape& shape) override { | |||
| m_mgr.set_output_shape(key, shape); | |||
| } | |||