You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

receiver.py 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import time
  5. import pyarrow as pa
  6. from tqdm import tqdm
  7. from dora import Node
  8. from dora.cuda import ipc_buffer_to_ipc_handle, cudabuffer_to_torch
  9. from helper import record_results
  10. import torch
  11. torch.tensor([], device="cuda")
  12. pa.array([])
  13. pbar = tqdm(total=100)
  14. context = pa.cuda.Context()
  15. node = Node("node_2")
  16. current_size = 8
  17. n = 0
  18. i = 0
  19. latencies = []
  20. DEVICE = os.getenv("DEVICE", "cuda")
  21. NAME = f"dora torch {DEVICE}"
  22. ctx = pa.cuda.Context()
  23. while True:
  24. event = node.next()
  25. if event["type"] == "INPUT":
  26. t_send = event["metadata"]["time"]
  27. if event["metadata"]["device"] != "cuda":
  28. # BEFORE
  29. handle = event["value"].to_numpy()
  30. torch_tensor = torch.tensor(handle, device="cuda")
  31. else:
  32. # AFTER
  33. # storage needs to be spawned in the same file as where it's used. Don't ask me why.
  34. ipc_handle = ipc_buffer_to_ipc_handle(event["value"])
  35. cudabuffer = ctx.open_ipc_buffer(ipc_handle)
  36. torch_tensor = cudabuffer_to_torch(cudabuffer, event["metadata"]) # on cuda
  37. else:
  38. break
  39. t_received = time.perf_counter_ns()
  40. length = len(torch_tensor) * 8
  41. if length != current_size:
  42. if n > 0:
  43. pbar.close()
  44. pbar = tqdm(total=100)
  45. record_results(NAME, current_size, latencies)
  46. current_size = length
  47. n = 0
  48. start = time.perf_counter_ns()
  49. latencies = []
  50. pbar.update(1)
  51. latencies.append((t_received - t_send) / 1000)
  52. node.send_output("next", pa.array([]))
  53. n += 1
  54. i += 1
  55. record_results(NAME, current_size, latencies)