|
|
|
@@ -107,6 +107,7 @@ def test_two_sync(): |
|
|
|
if count % 2 == 0: |
|
|
|
dataset.sync_update(condition_name="every 2 batches") |
|
|
|
|
|
|
|
|
|
|
|
def test_sync_epoch(): |
|
|
|
""" |
|
|
|
Test sync wait with epochs: test sync with epochs in dataset pipeline |
|
|
|
@@ -130,6 +131,34 @@ def test_sync_epoch(): |
|
|
|
dataset.sync_update(condition_name="policy", data=data) |
|
|
|
|
|
|
|
|
|
|
|
def test_multiple_iterators(): |
|
|
|
""" |
|
|
|
Test sync wait with multiple iterators: will start multiple |
|
|
|
""" |
|
|
|
logger.info("test_sync_epoch") |
|
|
|
batch_size = 30 |
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"]) |
|
|
|
|
|
|
|
aug = Augment(0) |
|
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) |
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) |
|
|
|
dataset = dataset.batch(batch_size, drop_remainder=True) |
|
|
|
# 2nd dataset |
|
|
|
dataset2 = ds.GeneratorDataset(gen, column_names=["input"]) |
|
|
|
|
|
|
|
aug = Augment(0) |
|
|
|
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update) |
|
|
|
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess]) |
|
|
|
dataset2 = dataset2.batch(batch_size, drop_remainder=True) |
|
|
|
|
|
|
|
for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()): |
|
|
|
assert (item1["input"][0] == item2["input"][0]) |
|
|
|
data1 = {"loss": item1["input"][0]} |
|
|
|
data2 = {"loss": item2["input"][0]} |
|
|
|
dataset.sync_update(condition_name="policy", data=data1) |
|
|
|
dataset2.sync_update(condition_name="policy", data=data2) |
|
|
|
|
|
|
|
|
|
|
|
def test_sync_exception_01(): |
|
|
|
""" |
|
|
|
Test sync: with shuffle in sync mode |
|
|
|
@@ -179,4 +208,5 @@ if __name__ == "__main__": |
|
|
|
test_two_sync() |
|
|
|
test_sync_exception_01() |
|
|
|
test_sync_exception_02() |
|
|
|
test_sync_epoch() |
|
|
|
test_sync_epoch() |
|
|
|
test_multiple_iterators() |