Browse Source

protect pyfunc timeout

tags/v1.0.0
yanghaitao1 5 years ago
parent
commit
8b4591c482
3 changed files with 40 additions and 17 deletions
  1. +2
    -1
      mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc
  2. +22
    -15
      mindspore/dataset/engine/datasets.py
  3. +16
    -1
      mindspore/dataset/engine/iterators.py

+ 2
- 1
mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc View File

@@ -61,6 +61,7 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
py::object ret_py_ele = ret_py_tuple[i]; py::object ret_py_ele = ret_py_tuple[i];
// Object is none if pyfunc timeout // Object is none if pyfunc timeout
if (ret_py_ele.is_none()) { if (ret_py_ele.is_none()) {
MS_LOG(INFO) << "PyFunc execute time out";
goto TimeoutError; goto TimeoutError;
} }
if (!py::isinstance<py::array>(ret_py_ele)) { if (!py::isinstance<py::array>(ret_py_ele)) {
@@ -92,7 +93,7 @@ ShapeMisMatch:
goto ComputeReturn; goto ComputeReturn;


TimeoutError: TimeoutError:
ret = Status(StatusCode::kTimeOut, "PyFunc timeout");
ret = Status(StatusCode::kTimeOut, "PyFunc execute time out");
goto ComputeReturn; goto ComputeReturn;
} }




+ 22
- 15
mindspore/dataset/engine/datasets.py View File

@@ -38,7 +38,7 @@ from mindspore._c_expression import typing


from mindspore import log as logger from mindspore import log as logger
from . import samplers from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, check_device_send, \ check_rename, check_numpyslicesdataset, check_device_send, \
check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \
@@ -55,7 +55,6 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
context = None context = None



class Shuffle(str, Enum): class Shuffle(str, Enum):
GLOBAL: str = "global" GLOBAL: str = "global"
FILES: str = "file" FILES: str = "file"
@@ -2012,18 +2011,19 @@ class _PythonCallable:


def __call__(self, *args): def __call__(self, *args):
if self.pool is not None: if self.pool is not None:
try:
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
return result.get(60)
except multiprocessing.TimeoutError:
# Ensure c++ pyfunc threads exit normally if python sub-process is killed unnormally.
return (None,)
except KeyboardInterrupt:
self.pool.terminate()
self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
while check_iterator_cleanup() is False:
try:
return result.get(30)
except multiprocessing.TimeoutError:
continue
except KeyboardInterrupt:
self.pool.terminate()
self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
return (None,)
# Invoke original Python callable in master process in case the pool is gone. # Invoke original Python callable in master process in case the pool is gone.
return self.py_callable(*args) return self.py_callable(*args)


@@ -3395,7 +3395,14 @@ class _GeneratorWorkerMp(multiprocessing.Process):
""" """
Get function for worker result queue. Block with timeout. Get function for worker result queue. Block with timeout.
""" """
return self.res_queue.get(timeout=10)
while check_iterator_cleanup() is False:
try:
return self.res_queue.get(timeout=10)
except multiprocessing.TimeoutError:
continue

raise Exception("Generator worker process timeout")



def __del__(self): def __del__(self):
self.terminate() self.terminate()


+ 16
- 1
mindspore/dataset/engine/iterators.py View File

@@ -27,16 +27,30 @@ from mindspore import log as logger
from . import datasets as de from . import datasets as de




_ITERATOR_CLEANUP = False

def _set_iterator_cleanup():
global _ITERATOR_CLEANUP
_ITERATOR_CLEANUP = True

def _unset_iterator_cleanup():
global _ITERATOR_CLEANUP
_ITERATOR_CLEANUP = False

def check_iterator_cleanup():
global _ITERATOR_CLEANUP
return _ITERATOR_CLEANUP

ITERATORS_LIST = list() ITERATORS_LIST = list()


def _cleanup(): def _cleanup():
"""Release all the Iterator.""" """Release all the Iterator."""
_set_iterator_cleanup()
for itr_ref in ITERATORS_LIST: for itr_ref in ITERATORS_LIST:
itr = itr_ref() itr = itr_ref()
if itr is not None: if itr is not None:
itr.release() itr.release()



def alter_tree(node): def alter_tree(node):
"""Traversing the Python dataset tree/graph to perform some alteration to some specific nodes.""" """Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
if not node.children: if not node.children:
@@ -71,6 +85,7 @@ class Iterator:
self.num_epochs = num_epochs self.num_epochs = num_epochs
self.output_numpy = output_numpy self.output_numpy = output_numpy
ITERATORS_LIST.append(weakref.ref(self)) ITERATORS_LIST.append(weakref.ref(self))
_unset_iterator_cleanup()
# create a copy of tree and work on it. # create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset) self.dataset = copy.deepcopy(dataset)
self.ori_dataset = dataset self.ori_dataset = dataset


Loading…
Cancel
Save