GitOrigin-RevId: e1dac3c919
tags/v1.8.0
| @@ -47,8 +47,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
| auto group_client = std::make_shared<opr::GroupClientProxy>( | auto group_client = std::make_shared<opr::GroupClientProxy>( | ||||
| ssprintf("%s:%d", recv.addr.data(), recv.port)); | ssprintf("%s:%d", recv.addr.data(), recv.port)); | ||||
| auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
| mgb_assert(!recv.shape.empty()); | |||||
| TensorShape shape; | |||||
| for (auto&& dim : recv.shape) { | |||||
| shape[shape.ndim++] = dim; | |||||
| } | |||||
| return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | ||||
| recv.key, inputs[0], *graph, group_client, config, recv.shape, recv.dtype, | |||||
| recv.key, inputs[0], *graph, group_client, config, shape, recv.dtype, | |||||
| recv.backend)); | recv.backend)); | ||||
| } | } | ||||
| @@ -42,7 +42,7 @@ TEST(TestImperative, IORemote) { | |||||
| auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | ||||
| auto def = imperative::RemoteRecv::make( | auto def = imperative::RemoteRecv::make( | ||||
| "io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), | "io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), | ||||
| TensorShape{vector_size}, dtype::Float32(), "nccl"); | |||||
| std::vector<int32_t>{(int32_t)vector_size}, dtype::Float32(), "nccl"); | |||||
| auto inp = Tensor::make(*hnd); | auto inp = Tensor::make(*hnd); | ||||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | ||||
| HostTensorND host_v; | HostTensorND host_v; | ||||
| @@ -284,7 +284,7 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { | |||||
| MgbUI32Attr:$port, | MgbUI32Attr:$port, | ||||
| MgbUI32Attr:$rank_from, | MgbUI32Attr:$rank_from, | ||||
| MgbCompNodeAttr:$cn, | MgbCompNodeAttr:$cn, | ||||
| MgbTensorShapeAttr:$shape, | |||||
| MgbArrayAttr<MgbI32Attr>:$shape, | |||||
| MgbDTypeAttr:$dtype, | MgbDTypeAttr:$dtype, | ||||
| MgbStringAttr:$backend | MgbStringAttr:$backend | ||||
| ); | ); | ||||