GitOrigin-RevId: 5e8c27ac81
tags/v1.1.0
| @@ -46,7 +46,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
| ssprintf("%s:%d", recv.addr.data(), recv.port)); | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
| recv.key, *graph, group_client, OperatorNodeConfig{recv.cn}, | |||
| recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, | |||
| recv.shape, recv.dtype)); | |||
| } | |||
| @@ -60,6 +60,43 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||
| } // anonymous namespace | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| bool RemoteSend::is_same_st(const Hashable& another) const{ | |||
| return as_tuple() == another.cast_final<RemoteSend>().as_tuple(); | |||
| } | |||
| size_t RemoteSend::hash() const{ | |||
| XXHash xxhash; | |||
| auto append = [&xxhash](auto field){ | |||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
| }; | |||
| append(key); | |||
| append(addr); | |||
| append(port); | |||
| append(rank_to); | |||
| return xxhash.digest(); | |||
| } | |||
| bool RemoteRecv::is_same_st(const Hashable& another) const{ | |||
| return as_tuple() == another.cast_final<RemoteRecv>().as_tuple(); | |||
| } | |||
| size_t RemoteRecv::hash() const{ | |||
| XXHash xxhash; | |||
| auto append = [&xxhash](auto field){ | |||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
| }; | |||
| append(key); | |||
| append(addr); | |||
| append(port); | |||
| append(rank_from); | |||
| append(cn.to_string()); | |||
| append(dtype.handle()); | |||
| append(shape.to_string()); | |||
| return xxhash.digest(); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
| @@ -31,6 +31,13 @@ public: | |||
| std::string addr; | |||
| uint32_t port; | |||
| uint32_t rank_to; | |||
| size_t hash() const override; | |||
| bool is_same_st(const Hashable& another) const override; | |||
| auto as_tuple() const{ | |||
| return std::tuple(key, addr, port, rank_to); | |||
| } | |||
| }; | |||
| class RemoteRecv : public OpDefImplBase<RemoteRecv> { | |||
| @@ -55,6 +62,13 @@ public: | |||
| CompNode cn; | |||
| TensorShape shape; | |||
| DType dtype; | |||
| size_t hash() const override; | |||
| bool is_same_st(const Hashable& another) const override; | |||
| auto as_tuple() const{ | |||
| return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); | |||
| } | |||
| }; | |||
| } // namespace imperative | |||
| @@ -151,6 +151,23 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
| add_equivalence_component<ScalarHash<void*>>(this); | |||
| } | |||
| RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) : | |||
| Super(&graph, config, "remote_recv", {}), | |||
| m_shape(shape), m_dtype(dtype) { | |||
| m_key = key; | |||
| m_group_client = group_client; | |||
| add_input({var}); | |||
| add_output(None) | |||
| ->dtype(dtype) | |||
| .add_flag(VarNode::Flag::NO_MEM_RECLAIM) | |||
| .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC); | |||
| add_equivalence_component<ScalarHash<void*>>(this); | |||
| } | |||
| SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| @@ -160,6 +177,15 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
| return opr->output(0); | |||
| } | |||
| SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, | |||
| const TensorShape& shape, DType dtype) { | |||
| auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
| key, var.node(), graph, group_client, config, shape, dtype)); | |||
| return opr->output(0); | |||
| } | |||
| void RemoteRecv::scn_do_execute() { | |||
| if (!m_init) { | |||
| auto&& comp_node = output(0)->comp_node(); | |||
| @@ -77,12 +77,23 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| static SymbolVar make( | |||
| const std::string& key, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| static SymbolVar make( | |||
| const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const OperatorNodeConfig& config, const TensorShape& shape, | |||
| DType dtype); | |||
| private: | |||
| const TensorShape m_shape; | |||
| const DType m_dtype; | |||