|
|
|
@@ -18,11 +18,13 @@ MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with |
|
|
|
high performance and parses data precisely. Some of the operations that are |
|
|
|
provided to users to preprocess data include shuffle, batch, repeat, map, and zip. |
|
|
|
""" |
|
|
|
import atexit |
|
|
|
import glob |
|
|
|
import json |
|
|
|
import math |
|
|
|
import os |
|
|
|
import signal |
|
|
|
import time |
|
|
|
import uuid |
|
|
|
import multiprocessing |
|
|
|
import queue |
|
|
|
@@ -1965,6 +1967,7 @@ class BatchDataset(Dataset): |
|
|
|
# Wrap per_batch_map into _PythonCallable |
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) |
|
|
|
self.hook = _ExceptHookHandler() |
|
|
|
atexit.register(_mp_pool_exit_preprocess) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None: |
|
|
|
@@ -2213,7 +2216,7 @@ class _PythonCallable: |
|
|
|
self.idx = idx |
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
|
if self.pool is not None and self.pool._state == 0: # pylint: disable=W0212 |
|
|
|
if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 |
|
|
|
# This call will send the tensors along with Python callable index to the process pool. |
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned. |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) |
|
|
|
@@ -2233,13 +2236,22 @@ class _PythonCallable: |
|
|
|
return self.py_callable(*args) |
|
|
|
|
|
|
|
|
|
|
|
def _mp_pool_exit_preprocess(): |
|
|
|
if check_iterator_cleanup() is False: |
|
|
|
logger.info("Execution preprocessing process before map exit.") |
|
|
|
# Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async |
|
|
|
# applied to the multiprocessing task to prevent multiprocessing from hang when exiting |
|
|
|
_set_iterator_cleanup() |
|
|
|
time.sleep(3) |
|
|
|
|
|
|
|
|
|
|
|
class _ExceptHookHandler: |
|
|
|
def __init__(self): |
|
|
|
sys.excepthook = self.__handler_exception |
|
|
|
|
|
|
|
def __handler_exception(self, type, value, tb): |
|
|
|
logger.error("Uncaught exception: ", exc_info=(type, value, tb)) |
|
|
|
_set_iterator_cleanup() |
|
|
|
_mp_pool_exit_preprocess() |
|
|
|
|
|
|
|
|
|
|
|
class MapDataset(Dataset): |
|
|
|
@@ -2400,11 +2412,13 @@ class MapDataset(Dataset): |
|
|
|
iter_specific_operations.append(op) |
|
|
|
self.operations = iter_specific_operations |
|
|
|
self.hook = _ExceptHookHandler() |
|
|
|
atexit.register(_mp_pool_exit_preprocess) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None: |
|
|
|
logger.info("Map process pool is being terminated.") |
|
|
|
self.process_pool.close() |
|
|
|
self.process_pool.join() |
|
|
|
|
|
|
|
|
|
|
|
class FilterDataset(Dataset): |
|
|
|
|