GitOrigin-RevId: 841a0e45ab
tags/v1.5.0
| @@ -265,6 +265,7 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
| op.key = key | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_to = dest_rank | |||
| op.backend = get_backend() | |||
| (dummy,) = apply(_RemoteSend(op), inp) | |||
| for g in grad_keys.values(): | |||
| @@ -313,6 +314,7 @@ def remote_recv( | |||
| op.dtype = dtype | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_from = src_rank | |||
| op.backend = get_backend() | |||
| (ret,) = apply(_RemoteRecv(op), inp) | |||
| if _isscalar: | |||
| @@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
| OperatorNodeConfig config{send.make_name()}; | |||
| cg::OperatorNodeBase* opr = | |||
| graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | |||
| send.key, inputs[0], group_client, true, config)); | |||
| send.key, inputs[0], group_client, true, send.backend, config)); | |||
| return opr; | |||
| } | |||
| @@ -49,7 +49,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
| recv.key, inputs[0], *graph, group_client, config, | |||
| recv.shape, recv.dtype)); | |||
| recv.shape, recv.dtype, recv.backend)); | |||
| } | |||
| OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) | |||
| @@ -34,7 +34,7 @@ TEST(TestImperative, IORemote) { | |||
| auto run_send = [&](std::shared_ptr<HostTensorND> hnd) { | |||
| auto def = imperative::RemoteSend::make( | |||
| "io_remote_test", server_addr, port, 1); | |||
| "io_remote_test", server_addr, port, 1, "nccl"); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| }; | |||
| @@ -43,7 +43,7 @@ TEST(TestImperative, IORemote) { | |||
| auto def = imperative::RemoteRecv::make( | |||
| "io_remote_test", server_addr, port, 0, | |||
| CompNode::load("gpu1"), TensorShape{vector_size}, | |||
| dtype::Float32()); | |||
| dtype::Float32(), "nccl"); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| HostTensorND host_v; | |||
| @@ -169,7 +169,8 @@ def RemoteSend : MgbHashableOp<"RemoteSend"> { | |||
| MgbStringAttr:$key, | |||
| MgbStringAttr:$addr, | |||
| MgbUI32Attr:$port, | |||
| MgbUI32Attr:$rank_to | |||
| MgbUI32Attr:$rank_to, | |||
| MgbStringAttr:$backend | |||
| ); | |||
| } | |||
| @@ -181,7 +182,8 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { | |||
| MgbUI32Attr:$rank_from, | |||
| MgbCompNodeAttr:$cn, | |||
| MgbTensorShapeAttr:$shape, | |||
| MgbDTypeAttr:$dtype | |||
| MgbDTypeAttr:$dtype, | |||
| MgbStringAttr:$backend | |||
| ); | |||
| } | |||
| @@ -24,8 +24,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
| RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| bool is_grad, const OperatorNodeConfig& config) : | |||
| bool is_grad, std::string backend, const OperatorNodeConfig& config) : | |||
| Super(var->owner_graph(), config, "remote_send", {var}), | |||
| m_backend(backend), | |||
| m_is_grad(is_grad) { | |||
| m_key = key; | |||
| m_group_client = group_client; | |||
| @@ -41,9 +42,9 @@ RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||
| SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| bool is_grad, const OperatorNodeConfig& config) { | |||
| bool is_grad, std::string backend, const OperatorNodeConfig& config) { | |||
| return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client, | |||
| is_grad, config); | |||
| is_grad, backend, config); | |||
| } | |||
| void RemoteSend::scn_do_execute() { | |||
| @@ -64,7 +65,7 @@ void RemoteSend::scn_do_execute() { | |||
| } | |||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
| reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); | |||
| reg_info.hash, m_key, 2, 0, get_megray_backend(m_backend), m_group_client); | |||
| m_megray_ctx = get_megray_context(output(0)->comp_node()); | |||
| @@ -122,7 +123,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { | |||
| *opr.owner_graph(), opr.group_client(), | |||
| OperatorNodeConfig{opr.comp_node()}.name( | |||
| opr.name() + ":grad_recv"), | |||
| opr.input(0)->shape(), opr.input(0)->dtype()) | |||
| opr.input(0)->shape(), opr.input(0)->dtype(), opr.backend()) | |||
| .node(); | |||
| } | |||
| #endif | |||
| @@ -134,9 +135,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
| RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) : | |||
| const TensorShape& shape, DType dtype, std::string backend) : | |||
| Super(&graph, config, "remote_recv", {}), | |||
| m_shape(shape), m_dtype(dtype) { | |||
| m_shape(shape), m_dtype(dtype), m_backend(backend) { | |||
| m_key = key; | |||
| m_group_client = group_client; | |||
| @@ -150,9 +151,9 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
| RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) : | |||
| const TensorShape& shape, DType dtype, std::string backend) : | |||
| Super(&graph, config, "remote_recv", {}), | |||
| m_shape(shape), m_dtype(dtype) { | |||
| m_shape(shape), m_dtype(dtype), m_backend(backend) { | |||
| m_key = key; | |||
| m_group_client = group_client; | |||
| @@ -167,18 +168,18 @@ RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& | |||
| SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) { | |||
| const TensorShape& shape, DType dtype, std::string backend) { | |||
| auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
| key, graph, group_client, config, shape, dtype)); | |||
| key, graph, group_client, config, shape, dtype, backend)); | |||
| return opr->output(0); | |||
| } | |||
| SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) { | |||
| const TensorShape& shape, DType dtype, std::string backend) { | |||
| auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
| key, var.node(), graph, group_client, config, shape, dtype)); | |||
| key, var.node(), graph, group_client, config, shape, dtype, backend)); | |||
| return opr->output(0); | |||
| } | |||
| @@ -201,7 +202,7 @@ void RemoteRecv::scn_do_execute() { | |||
| } | |||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
| reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); | |||
| reg_info.hash, m_key, 2, 1, get_megray_backend(m_backend), m_group_client); | |||
| m_megray_ctx = get_megray_context(output(0)->comp_node()); | |||
| @@ -251,7 +252,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( | |||
| mgb_assert(inputs.size() == 1); | |||
| auto&& opr = opr_.cast_final_safe<RemoteSend>(); | |||
| return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), | |||
| opr.is_grad(), config) | |||
| opr.is_grad(), opr.backend(), config) | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| @@ -265,14 +266,14 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( | |||
| if (inputs.size() == 1) { | |||
| return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), | |||
| opr.group_client(), config, opr.shape(), | |||
| opr.dtype()) | |||
| opr.dtype(), opr.backend()) | |||
| .node() | |||
| ->owner_opr(); | |||
| } else { | |||
| mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); | |||
| return RemoteRecv::make(opr.key(), *opr.owner_graph(), | |||
| opr.group_client(), config, opr.shape(), | |||
| opr.dtype()) | |||
| opr.dtype(), opr.backend()) | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| @@ -9,6 +9,8 @@ decl_raw_opr( | |||
| Doc('key', 'key to bind send-recv pair', 'str'), | |||
| Doc('var', 'variable to be sent', ':class:`.SymbolVar`'), | |||
| Doc('is_grad', 'whether the send', 'bool'), | |||
| Doc('backend', 'Backend for collective communication, nccl or ucx', | |||
| 'str', '\'nccl\''), | |||
| ] | |||
| ) | |||
| @@ -24,7 +26,9 @@ decl_raw_opr( | |||
| ':class:`.CompGraph`'), | |||
| Doc('shape', 'output var shape'), | |||
| Doc('dtype', 'data type of the output var; must match dtype at sender', | |||
| ':class:`numpy.dtype` compatible') | |||
| ':class:`numpy.dtype` compatible'), | |||
| Doc('backend', 'Backend for collective communication, nccl or ucx', | |||
| 'str', '\'nccl\''), | |||
| ] | |||
| ) | |||
| @@ -48,17 +48,19 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||
| public: | |||
| RemoteSend(const std::string& key, VarNode* var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| bool is_grad, const OperatorNodeConfig& config); | |||
| bool is_grad, std::string backend, const OperatorNodeConfig& config); | |||
| static SymbolVar make( | |||
| const std::string& key, SymbolVar var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| bool is_grad, const OperatorNodeConfig& config = {}); | |||
| bool is_grad, std::string backend, const OperatorNodeConfig& config = {}); | |||
| const std::string& backend() const { return m_backend; } | |||
| bool is_grad() const { return m_is_grad; } | |||
| private: | |||
| HostTensorND m_output_val; | |||
| std::string m_backend; | |||
| bool m_is_grad; | |||
| void scn_do_execute() override; | |||
| @@ -75,31 +77,33 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||
| RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| DType dtype, std::string backend); | |||
| RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| DType dtype, std::string backend); | |||
| static SymbolVar make( | |||
| const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| DType dtype, std::string backend); | |||
| static SymbolVar make( | |||
| const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| DType dtype, std::string backend); | |||
| const TensorShape& shape() const { return m_shape; } | |||
| const DType& dtype() const { return m_dtype; } | |||
| const std::string& backend() const { return m_backend; } | |||
| private: | |||
| const TensorShape m_shape; | |||
| const DType m_dtype; | |||
| const std::string m_backend; | |||
| const CompNode m_comp_node; | |||
| DeviceTensorND m_dev_buffer; | |||
| @@ -33,10 +33,10 @@ TEST(TestOprIORemote, Identity) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | |||
| auto xr = opr::RemoteSend::make("x", x, client, false); | |||
| auto xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||
| auto y = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cn1}, host_x->shape(), | |||
| host_x->dtype()); | |||
| host_x->dtype(), "nccl"); | |||
| auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)}); | |||
| @@ -57,7 +57,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
| auto graph = ComputingGraph::make(); | |||
| sys::set_thread_name("sender"); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
| xr = opr::RemoteSend::make("x", x, client, false); | |||
| xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||
| auto func = graph->compile({{xr, {}}}); | |||
| func->execute(); | |||
| }; | |||
| @@ -67,7 +67,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cns[0]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| host_x->dtype(), "nccl"); | |||
| auto func = graph->compile({make_callback_copy(x, host_x_get)}); | |||
| func->execute(); | |||
| }; | |||
| @@ -91,7 +91,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
| sys::set_thread_name("sender"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | |||
| xr = opr::RemoteSend::make("x", x, client, false); | |||
| xr = opr::RemoteSend::make("x", x, client, false, "nccl"); | |||
| auto func = graph->compile({{xr, {}}}); | |||
| func->execute(); | |||
| }; | |||
| @@ -101,7 +101,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cns[0]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| host_x->dtype(), "nccl"); | |||
| auto func = | |||
| graph->compile({make_callback_copy((x - 1) / 2, host_x_get)}); | |||
| func->execute(); | |||
| @@ -126,12 +126,12 @@ TEST(TestOprIORemote, APlusB) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto z = opr::RemoteRecv::make("z", *graph.get(), | |||
| client, {cns[0]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| host_x->dtype(), "nccl"); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | |||
| y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | |||
| xr = opr::RemoteSend::make("x", x, client, false) | |||
| xr = opr::RemoteSend::make("x", x, client, false, "nccl") | |||
| .rename("xr"), | |||
| yr = opr::RemoteSend::make("y", y, client, false) | |||
| yr = opr::RemoteSend::make("y", y, client, false, "nccl") | |||
| .rename("yr"); | |||
| auto func = graph->compile( | |||
| {{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | |||
| @@ -144,12 +144,12 @@ TEST(TestOprIORemote, APlusB) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cns[1]}, host_x->shape(), | |||
| host_x->dtype()), | |||
| host_x->dtype(), "nccl"), | |||
| y = opr::RemoteRecv::make("y", *graph.get(), | |||
| client, {cns[1]}, host_y->shape(), | |||
| host_y->dtype()), | |||
| host_y->dtype(), "nccl"), | |||
| z = x + y, | |||
| zr = opr::RemoteSend::make("z", z, client, false); | |||
| zr = opr::RemoteSend::make("z", z, client, false, "nccl"); | |||
| auto func = graph->compile({{zr, {}}}); | |||
| func->execute(); | |||
| }; | |||
| @@ -178,10 +178,10 @@ TEST(TestOprIORemote, SendGrad) { | |||
| sys::set_thread_name("sender"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
| loss = opr::RemoteSend::make("loss", x, client, false); | |||
| loss = opr::RemoteSend::make("loss", x, client, false, "nccl"); | |||
| ASSERT_TRUE(!loss.shape().ndim && | |||
| loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | |||
| loss = opr::RemoteSend::make("loss", x, client, true); | |||
| loss = opr::RemoteSend::make("loss", x, client, true, "nccl"); | |||
| auto gx = cg::grad(loss, x); | |||
| set_priority(loss, 0); | |||
| set_priority(gx, 1); | |||
| @@ -200,8 +200,8 @@ TEST(TestOprIORemote, SendGrad) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make("loss", *graph.get(), | |||
| client, {cns[1]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); | |||
| host_x->dtype(), "nccl"); | |||
| auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false, "nccl"); | |||
| auto func = graph->compile({{y, {}}}); | |||
| func->execute(); | |||
| }; | |||