From 41d629b486cb64b4be1befb1f63a8307f8223358 Mon Sep 17 00:00:00 2001 From: heleiwang Date: Wed, 30 Dec 2020 18:09:30 +0800 Subject: [PATCH] fix map exit hang --- mindspore/dataset/engine/datasets.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7bbbf39b39..54332d0627 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -18,11 +18,13 @@ MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with high performance and parses data precisely. Some of the operations that are provided to users to preprocess data include shuffle, batch, repeat, map, and zip. """ +import atexit import glob import json import math import os import signal +import time import uuid import multiprocessing import queue @@ -1965,6 +1967,7 @@ class BatchDataset(Dataset): # Wrap per_batch_map into _PythonCallable self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) self.hook = _ExceptHookHandler() + atexit.register(_mp_pool_exit_preprocess) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: @@ -2213,7 +2216,7 @@ class _PythonCallable: self.idx = idx def __call__(self, *args): - if self.pool is not None and self.pool._state == 0: # pylint: disable=W0212 + if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 # This call will send the tensors along with Python callable index to the process pool. # Block, yield GIL. Current thread will reacquire GIL once result is returned. result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) @@ -2233,13 +2236,22 @@ class _PythonCallable: return self.py_callable(*args) +def _mp_pool_exit_preprocess(): + if check_iterator_cleanup() is False: + logger.info("Execution preprocessing process before map exit.") + # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async + # applied to the multiprocessing task to prevent multiprocessing from hang when exiting + _set_iterator_cleanup() + time.sleep(3) + + class _ExceptHookHandler: def __init__(self): sys.excepthook = self.__handler_exception def __handler_exception(self, type, value, tb): logger.error("Uncaught exception: ", exc_info=(type, value, tb)) - _set_iterator_cleanup() + _mp_pool_exit_preprocess() class MapDataset(Dataset): @@ -2400,11 +2412,13 @@ class MapDataset(Dataset): iter_specific_operations.append(op) self.operations = iter_specific_operations self.hook = _ExceptHookHandler() + atexit.register(_mp_pool_exit_preprocess) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: logger.info("Map process pool is being terminated.") self.process_pool.close() + self.process_pool.join() class FilterDataset(Dataset):