Browse Source

protect pyfunc timeout

tags/v1.0.0
yanghaitao1 5 years ago
parent
commit
8b4591c482
3 changed files with 40 additions and 17 deletions
  1. +2
    -1
      mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc
  2. +22
    -15
      mindspore/dataset/engine/datasets.py
  3. +16
    -1
      mindspore/dataset/engine/iterators.py

+ 2
- 1
mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc View File

@@ -61,6 +61,7 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
py::object ret_py_ele = ret_py_tuple[i];
// Object is none if pyfunc timeout
if (ret_py_ele.is_none()) {
MS_LOG(INFO) << "PyFunc execute time out";
goto TimeoutError;
}
if (!py::isinstance<py::array>(ret_py_ele)) {
@@ -92,7 +93,7 @@ ShapeMisMatch:
goto ComputeReturn;

TimeoutError:
ret = Status(StatusCode::kTimeOut, "PyFunc timeout");
ret = Status(StatusCode::kTimeOut, "PyFunc execute time out");
goto ComputeReturn;
}



+ 22
- 15
mindspore/dataset/engine/datasets.py View File

@@ -38,7 +38,7 @@ from mindspore._c_expression import typing

from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, check_device_send, \
check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \
@@ -55,7 +55,6 @@ try:
except ModuleNotFoundError:
context = None


class Shuffle(str, Enum):
GLOBAL: str = "global"
FILES: str = "file"
@@ -2012,18 +2011,19 @@ class _PythonCallable:

def __call__(self, *args):
if self.pool is not None:
try:
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
return result.get(60)
except multiprocessing.TimeoutError:
# Ensure c++ pyfunc threads exit normally if python sub-process is killed unnormally.
return (None,)
except KeyboardInterrupt:
self.pool.terminate()
self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
while check_iterator_cleanup() is False:
try:
return result.get(30)
except multiprocessing.TimeoutError:
continue
except KeyboardInterrupt:
self.pool.terminate()
self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
return (None,)
# Invoke original Python callable in master process in case the pool is gone.
return self.py_callable(*args)

@@ -3395,7 +3395,14 @@ class _GeneratorWorkerMp(multiprocessing.Process):
"""
Get function for worker result queue. Block with timeout.
"""
return self.res_queue.get(timeout=10)
while check_iterator_cleanup() is False:
try:
return self.res_queue.get(timeout=10)
except multiprocessing.TimeoutError:
continue

raise Exception("Generator worker process timeout")


def __del__(self):
self.terminate()


+ 16
- 1
mindspore/dataset/engine/iterators.py View File

@@ -27,16 +27,30 @@ from mindspore import log as logger
from . import datasets as de


_ITERATOR_CLEANUP = False

def _set_iterator_cleanup():
global _ITERATOR_CLEANUP
_ITERATOR_CLEANUP = True

def _unset_iterator_cleanup():
global _ITERATOR_CLEANUP
_ITERATOR_CLEANUP = False

def check_iterator_cleanup():
global _ITERATOR_CLEANUP
return _ITERATOR_CLEANUP

ITERATORS_LIST = list()

def _cleanup():
"""Release all the Iterator."""
_set_iterator_cleanup()
for itr_ref in ITERATORS_LIST:
itr = itr_ref()
if itr is not None:
itr.release()


def alter_tree(node):
"""Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
if not node.children:
@@ -71,6 +85,7 @@ class Iterator:
self.num_epochs = num_epochs
self.output_numpy = output_numpy
ITERATORS_LIST.append(weakref.ref(self))
_unset_iterator_cleanup()
# create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset)
self.ori_dataset = dataset


Loading…
Cancel
Save