|
|
|
@@ -144,6 +144,8 @@ class WaitedDSCallback(Callback, DSCallback): |
|
|
|
self.epoch_event = threading.Event() |
|
|
|
self.epoch_run_context = None |
|
|
|
|
|
|
|
self.training_ended = False |
|
|
|
|
|
|
|
def sync_epoch_begin(self, train_run_context, ds_run_context): |
|
|
|
""" |
|
|
|
Called before a new dataset epoch is started and after the previous training epoch is ended. |
|
|
|
@@ -180,10 +182,11 @@ class WaitedDSCallback(Callback, DSCallback): |
|
|
|
ds_run_context: Include some information of the pipeline. |
|
|
|
""" |
|
|
|
if ds_run_context.cur_epoch_num > 1: |
|
|
|
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) |
|
|
|
self.epoch_event.clear() |
|
|
|
if not success: |
|
|
|
raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)") |
|
|
|
if not self.training_ended: |
|
|
|
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) |
|
|
|
self.epoch_event.clear() |
|
|
|
if not success: |
|
|
|
raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)") |
|
|
|
# by the time this thread wakes up, self.epoch_run_context is already available |
|
|
|
self.sync_epoch_begin(self.epoch_run_context, ds_run_context) |
|
|
|
|
|
|
|
@@ -205,11 +208,12 @@ class WaitedDSCallback(Callback, DSCallback): |
|
|
|
ds_run_context: Include some information of the pipeline. |
|
|
|
""" |
|
|
|
if ds_run_context.cur_step_num > self.step_size: |
|
|
|
success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) |
|
|
|
self.step_event.clear() |
|
|
|
if not success: |
|
|
|
raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)") |
|
|
|
# by the time this thread wakes up, self.epoch_run_context is already available |
|
|
|
if not self.training_ended: |
|
|
|
success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) |
|
|
|
self.step_event.clear() |
|
|
|
if not success: |
|
|
|
raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)") |
|
|
|
# by the time this thread wakes up, self.epoch_run_context is already available |
|
|
|
self.sync_step_begin(self.step_run_context, ds_run_context) |
|
|
|
|
|
|
|
def create_runtime_obj(self): |
|
|
|
@@ -233,3 +237,8 @@ class WaitedDSCallback(Callback, DSCallback): |
|
|
|
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.") |
|
|
|
|
|
|
|
return c_cb |
|
|
|
|
|
|
|
def end(self, run_context): |
|
|
|
self.epoch_end(run_context) |
|
|
|
self.step_end(run_context) |
|
|
|
self.training_ended = True |