Browse Source

!5073 Add checks and exception handling DS callback

Merge pull request !5073 from h.farahat/map_callback_end
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f3fd7a5578
3 changed files with 40 additions and 10 deletions
  1. +5
    -1
      mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc
  2. +18
    -9
      mindspore/dataset/callback/ds_callback.py
  3. +17
    -0
      tests/ut/python/dataset/test_callbacks.py

+ 5
- 1
mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc View File

@@ -53,7 +53,11 @@ Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param
if (Py_IsInitialized() == 0) { if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
} }
f(cb_param);
try {
f(cb_param);
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
} }
return Status::OK(); return Status::OK();
} }


+ 18
- 9
mindspore/dataset/callback/ds_callback.py View File

@@ -144,6 +144,8 @@ class WaitedDSCallback(Callback, DSCallback):
self.epoch_event = threading.Event() self.epoch_event = threading.Event()
self.epoch_run_context = None self.epoch_run_context = None


self.training_ended = False

def sync_epoch_begin(self, train_run_context, ds_run_context): def sync_epoch_begin(self, train_run_context, ds_run_context):
""" """
Called before a new dataset epoch is started and after the previous training epoch is ended. Called before a new dataset epoch is started and after the previous training epoch is ended.
@@ -180,10 +182,11 @@ class WaitedDSCallback(Callback, DSCallback):
ds_run_context: Include some information of the pipeline. ds_run_context: Include some information of the pipeline.
""" """
if ds_run_context.cur_epoch_num > 1: if ds_run_context.cur_epoch_num > 1:
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
self.epoch_event.clear()
if not success:
raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)")
if not self.training_ended:
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
self.epoch_event.clear()
if not success:
raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)")
# by the time this thread wakes up, self.epoch_run_context is already available # by the time this thread wakes up, self.epoch_run_context is already available
self.sync_epoch_begin(self.epoch_run_context, ds_run_context) self.sync_epoch_begin(self.epoch_run_context, ds_run_context)


@@ -205,11 +208,12 @@ class WaitedDSCallback(Callback, DSCallback):
ds_run_context: Include some information of the pipeline. ds_run_context: Include some information of the pipeline.
""" """
if ds_run_context.cur_step_num > self.step_size: if ds_run_context.cur_step_num > self.step_size:
success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
self.step_event.clear()
if not success:
raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)")
# by the time this thread wakes up, self.epoch_run_context is already available
if not self.training_ended:
success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
self.step_event.clear()
if not success:
raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)")
# by the time this thread wakes up, self.epoch_run_context is already available
self.sync_step_begin(self.step_run_context, ds_run_context) self.sync_step_begin(self.step_run_context, ds_run_context)


def create_runtime_obj(self): def create_runtime_obj(self):
@@ -233,3 +237,8 @@ class WaitedDSCallback(Callback, DSCallback):
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.") raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")


return c_cb return c_cb

def end(self, run_context):
self.epoch_end(run_context)
self.step_end(run_context)
self.training_ended = True

+ 17
- 0
tests/ut/python/dataset/test_callbacks.py View File

@@ -410,6 +410,22 @@ def test_callbacks_exceptions():
assert "RuntimeError: Bad begin" in str(err.value) assert "RuntimeError: Bad begin" in str(err.value)




def test_callbacks_train_end():
logger.info("test_callback_sink_simulation")
# No asserts are needed, just test there is no deadlock or exceptions
events = []
epochs = 2

my_cb = MyWaitedCallback(events, 1)
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=[my_cb])
data = data.to_device()
data.send(num_epochs=epochs)
time.sleep(0.5)
my_cb.end(run_context={})
time.sleep(0.5)


def test_callbacks_one_cb(): def test_callbacks_one_cb():
logger.info("test_callbacks_one_cb") logger.info("test_callbacks_one_cb")


@@ -458,3 +474,4 @@ if __name__ == '__main__':
test_callbacks_non_sink() test_callbacks_non_sink()
test_callbacks_one_cb() test_callbacks_one_cb()
test_callbacks_non_sink_mismatch_size() test_callbacks_non_sink_mismatch_size()
test_callbacks_train_end()

Loading…
Cancel
Save