You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stream.py 2.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from __future__ import absolute_import
  2. from ._base import _LIB, check_call
  3. import ctypes
  4. from . import ndarray
  5. class DLStream(ctypes.Structure):
  6. _fields_ = [("device_id", ctypes.c_int),
  7. ("handle", ctypes.c_void_p)]
  8. DLStreamHandle = ctypes.POINTER(DLStream)
  9. class Stream(ctypes.Structure):
  10. __slots__ = ["handle"]
  11. def __init__(self, handle):
  12. self.handle = handle
  13. def __del__(self):
  14. check_call(_LIB.DLStreamDestroy(self.handle))
  15. def sync(self):
  16. check_call(_LIB.DLStreamSync(self.handle))
  17. def create_stream_handle(ctx):
  18. assert ndarray.is_gpu_ctx(ctx)
  19. handle = DLStreamHandle()
  20. check_call(_LIB.DLStreamCreate(ctx.device_id, ctypes.byref(handle)))
  21. return Stream(handle)
  22. class DLEvent(ctypes.Structure):
  23. _fields_ = [("device_id", ctypes.c_int),
  24. ("handle", ctypes.c_void_p)]
  25. DLEventHandle = ctypes.POINTER(DLEvent)
  26. class Event(ctypes.Structure):
  27. __slots__ = ["handle"]
  28. def __init__(self, handle):
  29. self.handle = handle
  30. def __del__(self):
  31. check_call(_LIB.DLEventDestroy(self.handle))
  32. def sync(self):
  33. check_call(_LIB.DLEventSync(self.handle))
  34. def record(self, stream_handle):
  35. check_call(_LIB.DLEventRecord(stream_handle.handle, self.handle))
  36. def create_event_handle(ctx):
  37. assert ndarray.is_gpu_ctx(ctx)
  38. handle = DLEventHandle()
  39. check_call(_LIB.DLEventCreate(ctx.device_id, ctypes.byref(handle)))
  40. return Event(handle)
  41. class PSEvent(object):
  42. __slots__ = ["comm", "nid", "need_wait"]
  43. def __init__(self, comm, nid):
  44. self.comm = comm
  45. self.nid = nid
  46. self.need_wait = False
  47. def update(self):
  48. self.need_wait = True
  49. def sync(self):
  50. if self.need_wait:
  51. self.comm.Wait(self.nid)
  52. self.need_wait = False
  53. class CSEvent(PSEvent):
  54. __slots__ = ["tss"]
  55. def __init__(self, comm, nid):
  56. super().__init__(comm, nid)
  57. self.tss = []
  58. def update_ts(self, ts):
  59. self.tss.append(ts)
  60. def sync(self):
  61. super().sync()
  62. if self.tss != []:
  63. for ts in self.tss:
  64. ts.wait()
  65. self.tss = []