|
- from __future__ import absolute_import
-
- from ._base import _LIB, check_call
- import ctypes
- from . import ndarray
-
-
- class DLStream(ctypes.Structure):
- _fields_ = [("device_id", ctypes.c_int),
- ("handle", ctypes.c_void_p)]
-
-
- DLStreamHandle = ctypes.POINTER(DLStream)
-
-
- class Stream(ctypes.Structure):
- __slots__ = ["handle"]
-
- def __init__(self, handle):
- self.handle = handle
-
- def __del__(self):
- check_call(_LIB.DLStreamDestroy(self.handle))
-
- def sync(self):
- check_call(_LIB.DLStreamSync(self.handle))
-
-
- def create_stream_handle(ctx):
- assert ndarray.is_gpu_ctx(ctx)
- handle = DLStreamHandle()
- check_call(_LIB.DLStreamCreate(ctx.device_id, ctypes.byref(handle)))
- return Stream(handle)
-
-
- class DLEvent(ctypes.Structure):
- _fields_ = [("device_id", ctypes.c_int),
- ("handle", ctypes.c_void_p)]
-
-
- DLEventHandle = ctypes.POINTER(DLEvent)
-
-
- class Event(ctypes.Structure):
- __slots__ = ["handle"]
-
- def __init__(self, handle):
- self.handle = handle
-
- def __del__(self):
- check_call(_LIB.DLEventDestroy(self.handle))
-
- def sync(self):
- check_call(_LIB.DLEventSync(self.handle))
-
- def record(self, stream_handle):
- check_call(_LIB.DLEventRecord(stream_handle.handle, self.handle))
-
-
- def create_event_handle(ctx):
- assert ndarray.is_gpu_ctx(ctx)
- handle = DLEventHandle()
- check_call(_LIB.DLEventCreate(ctx.device_id, ctypes.byref(handle)))
- return Event(handle)
-
-
- class PSEvent(object):
- __slots__ = ["comm", "nid", "need_wait"]
-
- def __init__(self, comm, nid):
- self.comm = comm
- self.nid = nid
- self.need_wait = False
-
- def update(self):
- self.need_wait = True
-
- def sync(self):
- if self.need_wait:
- self.comm.Wait(self.nid)
- self.need_wait = False
-
-
- class CSEvent(PSEvent):
- __slots__ = ["tss"]
-
- def __init__(self, comm, nid):
- super().__init__(comm, nid)
- self.tss = []
-
- def update_ts(self, ts):
- self.tss.append(ts)
-
- def sync(self):
- super().sync()
- if self.tss != []:
- for ts in self.tss:
- ts.wait()
- self.tss = []
|