GitOrigin-RevId: 5e8c27ac81
tags/v1.1.0
| @@ -46,7 +46,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
| ssprintf("%s:%d", recv.addr.data(), recv.port)); | ssprintf("%s:%d", recv.addr.data(), recv.port)); | ||||
| auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
| return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | ||||
| recv.key, *graph, group_client, OperatorNodeConfig{recv.cn}, | |||||
| recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, | |||||
| recv.shape, recv.dtype)); | recv.shape, recv.dtype)); | ||||
| } | } | ||||
| @@ -60,6 +60,43 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| #endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
| bool RemoteSend::is_same_st(const Hashable& another) const{ | |||||
| return as_tuple() == another.cast_final<RemoteSend>().as_tuple(); | |||||
| } | |||||
| size_t RemoteSend::hash() const{ | |||||
| XXHash xxhash; | |||||
| auto append = [&xxhash](auto field){ | |||||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||||
| }; | |||||
| append(key); | |||||
| append(addr); | |||||
| append(port); | |||||
| append(rank_to); | |||||
| return xxhash.digest(); | |||||
| } | |||||
| bool RemoteRecv::is_same_st(const Hashable& another) const{ | |||||
| return as_tuple() == another.cast_final<RemoteRecv>().as_tuple(); | |||||
| } | |||||
| size_t RemoteRecv::hash() const{ | |||||
| XXHash xxhash; | |||||
| auto append = [&xxhash](auto field){ | |||||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||||
| }; | |||||
| append(key); | |||||
| append(addr); | |||||
| append(port); | |||||
| append(rank_from); | |||||
| append(cn.to_string()); | |||||
| append(dtype.handle()); | |||||
| append(shape.to_string()); | |||||
| return xxhash.digest(); | |||||
| } | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | ||||
| @@ -31,6 +31,13 @@ public: | |||||
| std::string addr; | std::string addr; | ||||
| uint32_t port; | uint32_t port; | ||||
| uint32_t rank_to; | uint32_t rank_to; | ||||
| size_t hash() const override; | |||||
| bool is_same_st(const Hashable& another) const override; | |||||
| auto as_tuple() const{ | |||||
| return std::tuple(key, addr, port, rank_to); | |||||
| } | |||||
| }; | }; | ||||
| class RemoteRecv : public OpDefImplBase<RemoteRecv> { | class RemoteRecv : public OpDefImplBase<RemoteRecv> { | ||||
| @@ -55,6 +62,13 @@ public: | |||||
| CompNode cn; | CompNode cn; | ||||
| TensorShape shape; | TensorShape shape; | ||||
| DType dtype; | DType dtype; | ||||
| size_t hash() const override; | |||||
| bool is_same_st(const Hashable& another) const override; | |||||
| auto as_tuple() const{ | |||||
| return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); | |||||
| } | |||||
| }; | }; | ||||
| } // namespace imperative | } // namespace imperative | ||||
| @@ -151,6 +151,23 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||||
| add_equivalence_component<ScalarHash<void*>>(this); | add_equivalence_component<ScalarHash<void*>>(this); | ||||
| } | } | ||||
| RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||||
| std::shared_ptr<GroupClient> group_client, | |||||
| const OperatorNodeConfig& config, | |||||
| const TensorShape& shape, DType dtype) : | |||||
| Super(&graph, config, "remote_recv", {}), | |||||
| m_shape(shape), m_dtype(dtype) { | |||||
| m_key = key; | |||||
| m_group_client = group_client; | |||||
| add_input({var}); | |||||
| add_output(None) | |||||
| ->dtype(dtype) | |||||
| .add_flag(VarNode::Flag::NO_MEM_RECLAIM) | |||||
| .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC); | |||||
| add_equivalence_component<ScalarHash<void*>>(this); | |||||
| } | |||||
| SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | ||||
| std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
| const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
| @@ -160,6 +177,15 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||||
| return opr->output(0); | return opr->output(0); | ||||
| } | } | ||||
| SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||||
| std::shared_ptr<GroupClient> group_client, | |||||
| const OperatorNodeConfig& config, | |||||
| const TensorShape& shape, DType dtype) { | |||||
| auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||||
| key, var.node(), graph, group_client, config, shape, dtype)); | |||||
| return opr->output(0); | |||||
| } | |||||
| void RemoteRecv::scn_do_execute() { | void RemoteRecv::scn_do_execute() { | ||||
| if (!m_init) { | if (!m_init) { | ||||
| auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
| @@ -77,12 +77,23 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||||
| const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
| DType dtype); | DType dtype); | ||||
| RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||||
| std::shared_ptr<GroupClient> group_client, | |||||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||||
| DType dtype); | |||||
| static SymbolVar make( | static SymbolVar make( | ||||
| const std::string& key, cg::ComputingGraph& graph, | const std::string& key, cg::ComputingGraph& graph, | ||||
| std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
| const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
| DType dtype); | DType dtype); | ||||
| static SymbolVar make( | |||||
| const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||||
| std::shared_ptr<GroupClient> group_client, | |||||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||||
| DType dtype); | |||||
| private: | private: | ||||
| const TensorShape m_shape; | const TensorShape m_shape; | ||||
| const DType m_dtype; | const DType m_dtype; | ||||