|
|
|
@@ -22,10 +22,12 @@ import glob |
|
|
|
import json |
|
|
|
import math |
|
|
|
import os |
|
|
|
import signal |
|
|
|
import uuid |
|
|
|
import multiprocessing |
|
|
|
import queue |
|
|
|
from enum import Enum |
|
|
|
from functools import partial |
|
|
|
from importlib import import_module |
|
|
|
import sys |
|
|
|
import threading |
|
|
|
@@ -3447,6 +3449,7 @@ class SamplerFn: |
|
|
|
self.workers = [] |
|
|
|
self.num_worker = num_worker |
|
|
|
self.multi_process = multi_process |
|
|
|
self.joined = False |
|
|
|
# Event for end of epoch |
|
|
|
if multi_process is True: |
|
|
|
self.eof = multiprocessing.Event() |
|
|
|
@@ -3485,29 +3488,47 @@ class SamplerFn: |
|
|
|
|
|
|
|
# Fetch results |
|
|
|
for i in range(len(indices)): |
|
|
|
if self.eof.is_set(): |
|
|
|
self._stop_subprocess() |
|
|
|
return |
|
|
|
# Fetch result and put index |
|
|
|
try: |
|
|
|
result = self.workers[i % self.num_worker].get() |
|
|
|
except queue.Empty: |
|
|
|
self._stop_subprocess() |
|
|
|
raise Exception("Generator worker process timeout.") |
|
|
|
except KeyboardInterrupt: |
|
|
|
self.eof.set() |
|
|
|
for w in self.workers: |
|
|
|
w.terminate() |
|
|
|
w.join() |
|
|
|
self._stop_subprocess() |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt.") |
|
|
|
if self.eof.is_set(): |
|
|
|
self._stop_subprocess() |
|
|
|
return |
|
|
|
if idx_cursor < len(indices): |
|
|
|
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) |
|
|
|
yield tuple([np.array(x, copy=False) for x in result]) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
def _stop_subprocess(self): |
|
|
|
self.eof.set() |
|
|
|
if self.joined is False: |
|
|
|
for w in self.workers: |
|
|
|
w.join() |
|
|
|
self.joined = True |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
self._stop_subprocess() |
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eof): |
|
|
|
def _subprocess_handle(eof, signum, frame): |
|
|
|
logger.info("The subprocess receives a termination signal.") |
|
|
|
eof.set() |
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing): |
|
|
|
""" |
|
|
|
Multithread or multiprocess generator worker process loop. |
|
|
|
""" |
|
|
|
if is_multiprocessing: |
|
|
|
signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof)) |
|
|
|
while True: |
|
|
|
# Fetch index, block |
|
|
|
try: |
|
|
|
@@ -3516,6 +3537,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof): |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt.") |
|
|
|
except queue.Empty: |
|
|
|
if eof.is_set(): |
|
|
|
if is_multiprocessing: |
|
|
|
idx_queue.cancel_join_thread() |
|
|
|
result_queue.cancel_join_thread() |
|
|
|
return |
|
|
|
# If end-of-file (eof) is not set, continue to get data from idx_queue |
|
|
|
continue |
|
|
|
@@ -3525,6 +3549,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof): |
|
|
|
assert eof.is_set(), "" |
|
|
|
return |
|
|
|
if eof.is_set(): |
|
|
|
if is_multiprocessing: |
|
|
|
idx_queue.cancel_join_thread() |
|
|
|
result_queue.cancel_join_thread() |
|
|
|
return |
|
|
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process |
|
|
|
result = dataset[idx] |
|
|
|
@@ -3536,6 +3563,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof): |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt.") |
|
|
|
except queue.Full: |
|
|
|
if eof.is_set(): |
|
|
|
if is_multiprocessing: |
|
|
|
idx_queue.cancel_join_thread() |
|
|
|
result_queue.cancel_join_thread() |
|
|
|
return |
|
|
|
# If eof is not set, continue to put data to result_queue |
|
|
|
continue |
|
|
|
@@ -3551,7 +3581,7 @@ class _GeneratorWorkerMt(threading.Thread): |
|
|
|
def __init__(self, dataset, eof): |
|
|
|
self.idx_queue = queue.Queue(16) |
|
|
|
self.res_queue = queue.Queue(16) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False)) |
|
|
|
|
|
|
|
def put(self, item): |
|
|
|
""" |
|
|
|
@@ -3567,10 +3597,10 @@ class _GeneratorWorkerMt(threading.Thread): |
|
|
|
|
|
|
|
def queue_empty(self): |
|
|
|
if not self.idx_queue.empty(): |
|
|
|
logger.error("idx_queue is not empty") |
|
|
|
logger.warning("idx_queue is not empty") |
|
|
|
return False |
|
|
|
if not self.res_queue.empty(): |
|
|
|
logger.error("res_queue is not empty") |
|
|
|
logger.warning("res_queue is not empty") |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
@@ -3583,7 +3613,7 @@ class _GeneratorWorkerMp(multiprocessing.Process): |
|
|
|
def __init__(self, dataset, eof): |
|
|
|
self.idx_queue = multiprocessing.Queue(16) |
|
|
|
self.res_queue = multiprocessing.Queue(16) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True)) |
|
|
|
|
|
|
|
def put(self, item): |
|
|
|
""" |
|
|
|
@@ -3601,21 +3631,13 @@ class _GeneratorWorkerMp(multiprocessing.Process): |
|
|
|
|
|
|
|
def queue_empty(self): |
|
|
|
if not self.idx_queue.empty(): |
|
|
|
logger.error("idx_queue is not empty.") |
|
|
|
logger.warning("idx_queue is not empty.") |
|
|
|
return False |
|
|
|
if not self.res_queue.empty(): |
|
|
|
logger.error("res_queue is not empty.") |
|
|
|
logger.warning("res_queue is not empty.") |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
# Try to destruct here, sometimes the class itself will be destructed in advance, |
|
|
|
# so "self" will be a NoneType |
|
|
|
try: |
|
|
|
self.terminate() |
|
|
|
except AttributeError: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class GeneratorDataset(MappableDataset): |
|
|
|
""" |
|
|
|
|