GitOrigin-RevId: b7a7bbd0da
tags/v0.6.0
| @@ -72,10 +72,10 @@ SymbolVar _Opr::remote_send( | |||
| const std::string& key, SymbolVar var, | |||
| const bool is_grad, | |||
| const OperatorNodeConfig& config) { | |||
| return RemoteSend::make({key, RemoteIOBase::Type::SEND, is_grad}, var, | |||
| return RemoteSend::make(key, var, | |||
| std::make_shared<GroupClientProxy>(ssprintf( | |||
| "%s:%d", server_addr.c_str(), port)), | |||
| config); | |||
| is_grad, config); | |||
| } | |||
| SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
| @@ -85,8 +85,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
| const TensorShape ishape = npy::vec2shape(shape); | |||
| const DType idtype = npy::dtype_np2mgb(dtype); | |||
| return RemoteRecv::make({key, RemoteIOBase::Type::RECV, false}, | |||
| graph.get(), | |||
| return RemoteRecv::make(key, graph.get(), | |||
| std::make_shared<GroupClientProxy>( | |||
| ssprintf("%s:%d", server_addr.c_str(), port)), | |||
| config, ishape, idtype); | |||
| @@ -26,27 +26,28 @@ cudaStream_t get_stream(VarNode* var) { | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
| RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var, | |||
| RemoteSend::RemoteSend(const std::string& key, VarNode* var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config) : | |||
| Super(var->owner_graph(), config, "remote_send", {var}) { | |||
| m_peer = peer; | |||
| bool is_grad, const OperatorNodeConfig& config) : | |||
| Super(var->owner_graph(), config, "remote_send", {var}), | |||
| m_is_grad(is_grad) { | |||
| m_key = key; | |||
| m_group_client = group_client; | |||
| add_input({var}); | |||
| auto ovar = add_output(None); | |||
| if (!peer.is_grad) { | |||
| if (!m_is_grad) { | |||
| ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||
| .add_flag(VarNode::Flag::VOLATILE_CONTENT); | |||
| } | |||
| add_equivalence_component<ScalarHash<void*>>(this); | |||
| } | |||
| SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var, | |||
| SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config) { | |||
| return var.insert_single_output_opr<RemoteSend>(peer, var.node(), | |||
| group_client, config); | |||
| bool is_grad, const OperatorNodeConfig& config) { | |||
| return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client, | |||
| is_grad, config); | |||
| } | |||
| void RemoteSend::scn_do_execute() { | |||
| @@ -54,11 +55,11 @@ void RemoteSend::scn_do_execute() { | |||
| auto&& comp_node = output(0)->comp_node(); | |||
| // rank 0 for RemoteSend | |||
| auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, | |||
| auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, | |||
| comp_node.get_uid()); | |||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
| reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
| reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); | |||
| m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||
| @@ -76,7 +77,7 @@ void RemoteSend::scn_do_execute() { | |||
| auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); | |||
| mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); | |||
| if (m_peer.is_grad) { | |||
| if (m_is_grad) { | |||
| auto&& dest = output(0)->dev_tensor(); | |||
| if (m_output_val.empty()) { | |||
| m_output_val.comp_node(dest.comp_node()) | |||
| @@ -92,7 +93,7 @@ void RemoteSend::init_output_static_infer_desc() { | |||
| using namespace cg::static_infer; | |||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||
| auto do_infer = [this](TensorShape& dest, const InpVal&) { | |||
| if (peer_desc().is_grad) { | |||
| if (m_is_grad) { | |||
| dest = {1}; | |||
| } else { | |||
| dest = {0}; | |||
| @@ -109,9 +110,8 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { | |||
| } | |||
| MGB_IMPL_OPR_GRAD(RemoteSend) { | |||
| mgb_assert(opr.peer_desc().is_grad); | |||
| return RemoteRecv::make({opr.peer_desc().key + ":grad", | |||
| RemoteIOBase::Type::RECV, false}, | |||
| mgb_assert(opr.is_grad()); | |||
| return RemoteRecv::make(opr.key() + ":grad", | |||
| *opr.owner_graph(), opr.group_client(), | |||
| OperatorNodeConfig{opr.comp_node()}.name( | |||
| opr.name() + ":grad_recv"), | |||
| @@ -123,13 +123,13 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
| RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
| RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) : | |||
| Super(&graph, config, "remote_recv", {}), | |||
| m_shape(shape), m_dtype(dtype) { | |||
| m_peer = peer; | |||
| m_key = key; | |||
| m_group_client = group_client; | |||
| add_output(None) | |||
| @@ -139,12 +139,12 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
| add_equivalence_component<ScalarHash<void*>>(this); | |||
| } | |||
| SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
| SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) { | |||
| auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
| peer, graph, group_client, config, shape, dtype)); | |||
| key, graph, group_client, config, shape, dtype)); | |||
| return opr->output(0); | |||
| } | |||
| @@ -154,11 +154,11 @@ void RemoteRecv::scn_do_execute() { | |||
| // rank 1 for RemoteRecv | |||
| auto reg_info = m_group_client->opr_register( | |||
| m_peer.key, 2, false, 1, | |||
| m_key, 2, false, 1, | |||
| comp_node.get_uid()); | |||
| m_megray_comm = MegRayCommBuilder::get_megray_comm( | |||
| reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
| reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); | |||
| m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); | |||
| @@ -206,8 +206,8 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( | |||
| const OperatorNodeConfig& config) { | |||
| mgb_assert(inputs.size() == 1); | |||
| auto&& opr = opr_.cast_final_safe<RemoteSend>(); | |||
| return RemoteSend::make(opr.peer_desc(), inputs[0], opr.group_client(), | |||
| config) | |||
| return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), | |||
| opr.is_grad(), config) | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| @@ -218,7 +218,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( | |||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| auto&& opr = opr_.cast_final_safe<RemoteRecv>(); | |||
| return RemoteRecv::make(opr.peer_desc(), *opr.owner_graph(), | |||
| return RemoteRecv::make(opr.key(), *opr.owner_graph(), | |||
| opr.group_client(), config, inputs[0]->shape(), | |||
| inputs[0]->dtype()) | |||
| .node() | |||
| @@ -25,25 +25,14 @@ namespace opr { | |||
| */ | |||
| MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { | |||
| public: | |||
| enum Type { | |||
| SEND, | |||
| RECV | |||
| }; | |||
| struct PeerDesc { | |||
| std::string key; | |||
| Type type; | |||
| bool is_grad; | |||
| }; | |||
| const PeerDesc& peer_desc() const { return m_peer; } | |||
| const std::string& key() const { return m_key; } | |||
| std::shared_ptr<GroupClient> group_client() const { | |||
| return m_group_client; | |||
| } | |||
| protected: | |||
| PeerDesc m_peer; | |||
| std::string m_key; | |||
| std::shared_ptr<GroupClient> m_group_client; | |||
| std::shared_ptr<MegRay::Communicator> m_megray_comm; | |||
| std::shared_ptr<MegRay::Context> m_megray_ctx; | |||
| @@ -53,21 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // { | |||
| /*! | |||
| * \brief send a variable to remote address; a virtual output is produced | |||
| * for expressing dependency | |||
| * for expressing dependency | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||
| public: | |||
| RemoteSend(const PeerDesc& peer, VarNode* var, | |||
| RemoteSend(const std::string& key, VarNode* var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config); | |||
| bool is_grad, const OperatorNodeConfig& config); | |||
| static SymbolVar make( | |||
| const PeerDesc& peer, SymbolVar var, | |||
| const std::string& key, SymbolVar var, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config = {}); | |||
| bool is_grad, const OperatorNodeConfig& config = {}); | |||
| bool is_grad() const { return m_is_grad; } | |||
| private: | |||
| HostTensorND m_output_val; | |||
| bool m_is_grad; | |||
| void scn_do_execute() override; | |||
| void init_output_static_infer_desc() override; | |||
| @@ -75,19 +67,18 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { | |||
| }; | |||
| /*! | |||
| * \brief receive from multiple remote addresses and write to a var | |||
| * | |||
| * Target computing node of the var must be specified in config | |||
| * \brief receive a variable from remote address; target computing node | |||
| * of the var must be specified in config | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||
| public: | |||
| RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph, | |||
| RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| static SymbolVar make( | |||
| const PeerDesc& peer, cg::ComputingGraph& graph, | |||
| const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| @@ -20,9 +20,6 @@ | |||
| using namespace mgb; | |||
| const auto send_tag = opr::RemoteIOBase::Type::SEND; | |||
| const auto recv_tag = opr::RemoteIOBase::Type::RECV; | |||
| TEST(TestOprIORemote, Identity) { | |||
| REQUIRE_GPU(2); | |||
| auto cn0 = CompNode::load("gpu0"); | |||
| @@ -36,8 +33,8 @@ TEST(TestOprIORemote, Identity) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); | |||
| auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||
| auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
| auto xr = opr::RemoteSend::make("x", x, client, false); | |||
| auto y = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cn1}, host_x->shape(), | |||
| host_x->dtype()); | |||
| @@ -59,7 +56,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
| auto graph = ComputingGraph::make(); | |||
| sys::set_thread_name("sender"); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
| xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||
| xr = opr::RemoteSend::make("x", x, client, false); | |||
| auto func = graph->compile({{xr, {}}}); | |||
| func->execute(); | |||
| }; | |||
| @@ -67,7 +64,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { | |||
| auto receiver = [&]() { | |||
| sys::set_thread_name("receiver"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
| auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cns[0]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| auto func = graph->compile({make_callback_copy(x, host_x_get)}); | |||
| @@ -92,7 +89,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
| sys::set_thread_name("sender"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, | |||
| xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); | |||
| xr = opr::RemoteSend::make("x", x, client, false); | |||
| auto func = graph->compile({{xr, {}}}); | |||
| func->execute(); | |||
| }; | |||
| @@ -100,7 +97,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { | |||
| auto receiver = [&]() { | |||
| sys::set_thread_name("receiver"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
| auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cns[0]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| auto func = | |||
| @@ -124,14 +121,14 @@ TEST(TestOprIORemote, APlusB) { | |||
| auto sender = [&]() { | |||
| auto graph = ComputingGraph::make(); | |||
| auto z = opr::RemoteRecv::make({"z", recv_tag, false}, *graph.get(), | |||
| auto z = opr::RemoteRecv::make("z", *graph.get(), | |||
| client, {cns[0]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), | |||
| y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), | |||
| xr = opr::RemoteSend::make({"x", send_tag, false}, x, client) | |||
| xr = opr::RemoteSend::make("x", x, client, false) | |||
| .rename("xr"), | |||
| yr = opr::RemoteSend::make({"y", send_tag, false}, y, client) | |||
| yr = opr::RemoteSend::make("y", y, client, false) | |||
| .rename("yr"); | |||
| auto func = graph->compile( | |||
| {{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); | |||
| @@ -142,14 +139,14 @@ TEST(TestOprIORemote, APlusB) { | |||
| auto receiver = [&]() { | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), | |||
| auto x = opr::RemoteRecv::make("x", *graph.get(), | |||
| client, {cns[1]}, host_x->shape(), | |||
| host_x->dtype()), | |||
| y = opr::RemoteRecv::make({"y", recv_tag, false}, *graph.get(), | |||
| y = opr::RemoteRecv::make("y", *graph.get(), | |||
| client, {cns[1]}, host_y->shape(), | |||
| host_y->dtype()), | |||
| z = x + y, | |||
| zr = opr::RemoteSend::make({"z", send_tag, false}, z, client); | |||
| zr = opr::RemoteSend::make("z", z, client, false); | |||
| auto func = graph->compile({{zr, {}}}); | |||
| func->execute(); | |||
| }; | |||
| @@ -177,10 +174,10 @@ TEST(TestOprIORemote, SendGrad) { | |||
| sys::set_thread_name("sender"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
| loss = opr::RemoteSend::make({"loss", send_tag, false}, x, client); | |||
| loss = opr::RemoteSend::make("loss", x, client, false); | |||
| ASSERT_TRUE(!loss.shape().ndim && | |||
| loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | |||
| loss = opr::RemoteSend::make({"loss", send_tag, true}, x, client); | |||
| loss = opr::RemoteSend::make("loss", x, client, true); | |||
| auto gx = cg::grad(loss, x); | |||
| set_priority(loss, 0); | |||
| set_priority(gx, 1); | |||
| @@ -197,10 +194,10 @@ TEST(TestOprIORemote, SendGrad) { | |||
| auto receiver = [&]() { | |||
| sys::set_thread_name("receiver"); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::RemoteRecv::make({"loss", recv_tag, false}, *graph.get(), | |||
| auto x = opr::RemoteRecv::make("loss", *graph.get(), | |||
| client, {cns[1]}, host_x->shape(), | |||
| host_x->dtype()); | |||
| auto y = opr::RemoteSend::make({"loss:grad", send_tag, false}, x + 1, client); | |||
| auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); | |||
| auto func = graph->compile({{y, {}}}); | |||
| func->execute(); | |||
| }; | |||