Browse Source

!10844 Fix MapDataset hang when exiting

From: @heleiwang
Reviewed-by: @pandoublefeng,@liucunwei
Signed-off-by: @liucunwei
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
3e52d5dcd4
1 changed files with 16 additions and 2 deletions
  1. +16
    -2
      mindspore/dataset/engine/datasets.py

+ 16
- 2
mindspore/dataset/engine/datasets.py View File

@@ -18,11 +18,13 @@ MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
high performance and parses data precisely. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
"""
import atexit
import glob
import json
import math
import os
import signal
import time
import uuid
import multiprocessing
import queue
@@ -1965,6 +1967,7 @@ class BatchDataset(Dataset):
# Wrap per_batch_map into _PythonCallable
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
self.hook = _ExceptHookHandler()
atexit.register(_mp_pool_exit_preprocess)

def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None:
@@ -2213,7 +2216,7 @@ class _PythonCallable:
self.idx = idx

def __call__(self, *args):
if self.pool is not None and self.pool._state == 0: # pylint: disable=W0212
if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
@@ -2233,13 +2236,22 @@ class _PythonCallable:
return self.py_callable(*args)


def _mp_pool_exit_preprocess():
if check_iterator_cleanup() is False:
logger.info("Execution preprocessing process before map exit.")
# Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async
# applied to the multiprocessing task to prevent multiprocessing from hang when exiting
_set_iterator_cleanup()
time.sleep(3)


class _ExceptHookHandler:
def __init__(self):
sys.excepthook = self.__handler_exception

def __handler_exception(self, type, value, tb):
logger.error("Uncaught exception: ", exc_info=(type, value, tb))
_set_iterator_cleanup()
_mp_pool_exit_preprocess()


class MapDataset(Dataset):
@@ -2400,11 +2412,13 @@ class MapDataset(Dataset):
iter_specific_operations.append(op)
self.operations = iter_specific_operations
self.hook = _ExceptHookHandler()
atexit.register(_mp_pool_exit_preprocess)

def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None:
logger.info("Map process pool is being terminated.")
self.process_pool.close()
self.process_pool.join()


class FilterDataset(Dataset):


Loading…
Cancel
Save