|
|
|
@@ -11,7 +11,7 @@ from numba.cuda.cudadrv.devicearray import DeviceNDArray |
|
|
|
from numba.cuda import to_device |
|
|
|
|
|
|
|
|
|
|
|
def torch_to_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: |
|
|
|
def torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: |
|
|
|
"""Converts a Pytorch tensor into a pyarrow buffer containing the IPC handle and its metadata.""" |
|
|
|
device_arr = to_device(tensor) |
|
|
|
cuda_buf = pa.cuda.CudaBuffer.from_numba(device_arr.gpu_data) |
|
|
|
@@ -24,7 +24,7 @@ def torch_to_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: |
|
|
|
return pa.array(handle_buffer, type=pa.uint8()), metadata |
|
|
|
|
|
|
|
|
|
|
|
def buffer_to_ipc_handle(handle_buffer: pa.array) -> cuda.IpcMemHandle: |
|
|
|
def ipc_buffer_to_ipc_handle(handle_buffer: pa.array) -> cuda.IpcMemHandle: |
|
|
|
"""Converts a buffer containing a serialized handler into cuda IPC MemHandle.""" |
|
|
|
handle_buffer = handle_buffer.buffers()[1] |
|
|
|
ipc_handle = pa.cuda.IpcMemHandle.from_buffer(handle_buffer) |
|
|
|
|