|
|
|
@@ -1,19 +1,18 @@ |
|
|
|
import pyarrow as pa |
|
|
|
|
|
|
|
# To install pyarrow.cuda, run `conda install pyarrow "arrow-cpp-proc=*=cuda" -c conda-forge` |
|
|
|
import pyarrow.cuda as cuda |
|
|
|
|
|
|
|
# Make sure to install torch with cuda |
|
|
|
import torch |
|
|
|
from numba.cuda import to_device |
|
|
|
|
|
|
|
# Make sure to install numba with cuda |
|
|
|
from numba.cuda.cudadrv.devicearray import DeviceNDArray |
|
|
|
from numba.cuda import to_device |
|
|
|
|
|
|
|
# To install pyarrow.cuda, run `conda install pyarrow "arrow-cpp-proc=*=cuda" -c conda-forge` |
|
|
|
from pyarrow import cuda |
|
|
|
|
|
|
|
|
|
|
|
def torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: |
|
|
|
""" |
|
|
|
Converts a Pytorch tensor into a pyarrow buffer containing the IPC handle and its metadata. |
|
|
|
"""Converts a Pytorch tensor into a pyarrow buffer containing the IPC handle and its metadata. |
|
|
|
|
|
|
|
Example Use: |
|
|
|
```python |
|
|
|
@@ -34,8 +33,7 @@ def torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: |
|
|
|
|
|
|
|
|
|
|
|
def ipc_buffer_to_ipc_handle(handle_buffer: pa.array) -> cuda.IpcMemHandle: |
|
|
|
""" |
|
|
|
Converts a buffer containing a serialized handler into cuda IPC MemHandle. |
|
|
|
"""Converts a buffer containing a serialized handler into cuda IPC MemHandle. |
|
|
|
|
|
|
|
example use: |
|
|
|
```python |
|
|
|
@@ -57,8 +55,7 @@ def ipc_buffer_to_ipc_handle(handle_buffer: pa.array) -> cuda.IpcMemHandle: |
|
|
|
|
|
|
|
|
|
|
|
def cudabuffer_to_numba(buffer: cuda.CudaBuffer, metadata: dict) -> DeviceNDArray: |
|
|
|
""" |
|
|
|
Converts a pyarrow CUDA buffer to numba. |
|
|
|
"""Converts a pyarrow CUDA buffer to numba. |
|
|
|
|
|
|
|
example use: |
|
|
|
```python |
|
|
|
@@ -74,7 +71,6 @@ def cudabuffer_to_numba(buffer: cuda.CudaBuffer, metadata: dict) -> DeviceNDArra |
|
|
|
numba_tensor = cudabuffer_to_numbda(cudabuffer, event["metadata"]) |
|
|
|
``` |
|
|
|
""" |
|
|
|
|
|
|
|
shape = metadata["shape"] |
|
|
|
strides = metadata["strides"] |
|
|
|
dtype = metadata["dtype"] |
|
|
|
@@ -83,8 +79,7 @@ def cudabuffer_to_numba(buffer: cuda.CudaBuffer, metadata: dict) -> DeviceNDArra |
|
|
|
|
|
|
|
|
|
|
|
def cudabuffer_to_torch(buffer: cuda.CudaBuffer, metadata: dict) -> torch.Tensor: |
|
|
|
""" |
|
|
|
Converts a pyarrow CUDA buffer to a torch tensor. |
|
|
|
"""Converts a pyarrow CUDA buffer to a torch tensor. |
|
|
|
|
|
|
|
example use: |
|
|
|
```python |
|
|
|
@@ -100,7 +95,6 @@ def cudabuffer_to_torch(buffer: cuda.CudaBuffer, metadata: dict) -> torch.Tensor |
|
|
|
torch_tensor = cudabuffer_to_torch(cudabuffer, event["metadata"]) # on cuda |
|
|
|
``` |
|
|
|
""" |
|
|
|
|
|
|
|
device_arr = cudabuffer_to_numba(buffer, metadata) |
|
|
|
torch_tensor = torch.as_tensor(device_arr, device="cuda") |
|
|
|
return torch_tensor |