|
|
|
@@ -270,6 +270,21 @@ def test_simple_sync_wait_empty_condition_name(): |
|
|
|
dataset.sync_update(condition_name="", data=data) |
|
|
|
|
|
|
|
|
|
|
|
def test_sync_exception_06(): |
|
|
|
""" |
|
|
|
Test sync: with string batch size |
|
|
|
""" |
|
|
|
logger.info("test_sync_exception_03") |
|
|
|
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"]) |
|
|
|
|
|
|
|
aug = Augment(0) |
|
|
|
# try to create dataset with batch_size < 0 |
|
|
|
with pytest.raises(TypeError) as e: |
|
|
|
dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update) |
|
|
|
assert "is not of type (<class 'int'>" in str(e.value) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_simple_sync_wait() |
|
|
|
test_simple_shuffle_sync() |
|
|
|
@@ -279,6 +294,7 @@ if __name__ == "__main__": |
|
|
|
test_sync_exception_03() |
|
|
|
test_sync_exception_04() |
|
|
|
test_sync_exception_05() |
|
|
|
test_sync_exception_06() |
|
|
|
test_sync_epoch() |
|
|
|
test_multiple_iterators() |
|
|
|
test_simple_sync_wait_empty_condition_name() |