|
|
|
@@ -1055,6 +1055,11 @@ class Dataset: |
|
|
|
return self.input[0].get_sync_notifiers() |
|
|
|
return {} |
|
|
|
|
|
|
|
def disable_sync(self): |
|
|
|
if self.input: |
|
|
|
return self.input[0].disable_sync() |
|
|
|
return {} |
|
|
|
|
|
|
|
def is_sync(self): |
|
|
|
if self.input: |
|
|
|
return self.input[0].is_sync() |
|
|
|
@@ -1062,16 +1067,23 @@ class Dataset: |
|
|
|
|
|
|
|
def sync_update(self, condition_name, num_batch=None, data=None): |
|
|
|
""" |
|
|
|
Release a blocking condition and triger callback with given data. |
|
|
|
Release a blocking condition and trigger callback with given data. |
|
|
|
|
|
|
|
Args: |
|
|
|
condition_name (str): The condition name that is used to toggle sending next row. |
|
|
|
num_batch (int or None): The number of batches(rows) that are released. |
|
|
|
When num_batch is None, it will default to the number specified by the sync_wait operator. |
|
|
|
data (dict or None): The data passed to the callback. |
|
|
|
""" |
|
|
|
When num_batch is None, it will default to the number specified by the |
|
|
|
sync_wait operator (default=None). |
|
|
|
data (dict or None): The data passed to the callback (default=None). |
|
|
|
""" |
|
|
|
if isinstance(num_batch, int) and num_batch <= 0: |
|
|
|
# throwing exception, disable all sync_wait in pipeline |
|
|
|
self.disable_sync() |
|
|
|
raise RuntimeError("Sync_update batch size can only be positive, got : {}".format(num_batch)) |
|
|
|
notifiers_dict = self.get_sync_notifiers() |
|
|
|
if condition_name not in notifiers_dict: |
|
|
|
# throwing exception, disable all sync_wait in pipeline |
|
|
|
self.disable_sync() |
|
|
|
raise RuntimeError("Condition name not found") |
|
|
|
if num_batch is not None: |
|
|
|
num_batch *= self.get_batch_size() |
|
|
|
@@ -1439,7 +1451,6 @@ class BatchDataset(DatasetOp): |
|
|
|
for input_dataset in dataset.input: |
|
|
|
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) |
|
|
|
|
|
|
|
|
|
|
|
class BatchInfo(CBatchInfo): |
|
|
|
""" |
|
|
|
The information object associates with the current batch of tensors. |
|
|
|
@@ -1472,10 +1483,13 @@ class BlockReleasePair: |
|
|
|
callback (function): The callback funciton that will be called when release is called. |
|
|
|
""" |
|
|
|
def __init__(self, init_release_rows, callback=None): |
|
|
|
if isinstance(init_release_rows, int) and init_release_rows <= 0: |
|
|
|
raise ValueError("release_rows need to be greater than 0.") |
|
|
|
self.row_count = -init_release_rows |
|
|
|
self.cv = threading.Condition() |
|
|
|
self.callback = callback |
|
|
|
self.default_rows = init_release_rows |
|
|
|
self.disable = False |
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
|
if id(self) in memodict: |
|
|
|
@@ -1491,13 +1505,18 @@ class BlockReleasePair: |
|
|
|
self.cv.notify_all() |
|
|
|
|
|
|
|
def update_batched_size(self, batch_size): |
|
|
|
# sanity check |
|
|
|
if isinstance(batch_size, int) and batch_size <= 0: |
|
|
|
raise ValueError("batch_size need to be greater than 0.") |
|
|
|
|
|
|
|
# should only use before the pipeline creates |
|
|
|
self.row_count *= batch_size |
|
|
|
self.default_rows *= batch_size |
|
|
|
|
|
|
|
def block_func(self): |
|
|
|
with self.cv: |
|
|
|
self.cv.wait_for(lambda: self.row_count < 0) |
|
|
|
# if disable is true, the always evaluate to true |
|
|
|
self.cv.wait_for(lambda: (self.row_count < 0 or self.disable)) |
|
|
|
self.row_count += 1 |
|
|
|
return True |
|
|
|
|
|
|
|
@@ -1510,6 +1529,12 @@ class BlockReleasePair: |
|
|
|
self.callback(data) |
|
|
|
self.cv.notify_all() |
|
|
|
|
|
|
|
def disable_lock(self): |
|
|
|
with self.cv: |
|
|
|
self.disable = True |
|
|
|
self.cv.notify_all() |
|
|
|
|
|
|
|
|
|
|
|
class SyncWaitDataset(DatasetOp): |
|
|
|
""" |
|
|
|
The result of adding a blocking condition to the input Dataset. |
|
|
|
@@ -1530,6 +1555,9 @@ class SyncWaitDataset(DatasetOp): |
|
|
|
input_dataset.output.append(self) |
|
|
|
# set to the default value, waiting for the batch to update it |
|
|
|
self._condition_name = condition_name |
|
|
|
if isinstance(num_batch, int) and num_batch <= 0: |
|
|
|
raise ValueError("num_batch need to be greater than 0.") |
|
|
|
|
|
|
|
self._pair = BlockReleasePair(num_batch, callback) |
|
|
|
if self._condition_name in self.input[0].get_sync_notifiers(): |
|
|
|
raise RuntimeError("Condition name is already in use") |
|
|
|
@@ -1549,8 +1577,14 @@ class SyncWaitDataset(DatasetOp): |
|
|
|
return args |
|
|
|
|
|
|
|
def update_sync_batch_size(self, batch_size): |
|
|
|
if isinstance(batch_size, int) and batch_size <= 0: |
|
|
|
raise ValueError("num_batch need to be greater than 0.") |
|
|
|
self._pair.update_batched_size(batch_size) |
|
|
|
|
|
|
|
def disable_sync(self): |
|
|
|
logger.info("Disabling Sync") |
|
|
|
self._pair.disable_lock() |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _is_ancestor_of_batch(dataset): |
|
|
|
""" |
|
|
|
|