|
|
|
@@ -204,7 +204,6 @@ def test_sync_exception_03(): |
|
|
|
Test sync: with wrong batch size |
|
|
|
""" |
|
|
|
logger.info("test_sync_exception_03") |
|
|
|
batch_size = 6 |
|
|
|
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"]) |
|
|
|
|
|
|
|
@@ -223,7 +222,6 @@ def test_sync_exception_04(): |
|
|
|
Test sync: with negative batch size in update |
|
|
|
""" |
|
|
|
logger.info("test_sync_exception_04") |
|
|
|
batch_size = 6 |
|
|
|
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"]) |
|
|
|
|
|
|
|
@@ -233,7 +231,7 @@ def test_sync_exception_04(): |
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) |
|
|
|
count = 0 |
|
|
|
try: |
|
|
|
for item in dataset.create_dict_iterator(): |
|
|
|
for _ in dataset.create_dict_iterator(): |
|
|
|
count += 1 |
|
|
|
data = {"loss": count} |
|
|
|
# dataset.disable_sync() |
|
|
|
@@ -246,7 +244,6 @@ def test_sync_exception_05(): |
|
|
|
Test sync: with wrong batch size in update |
|
|
|
""" |
|
|
|
logger.info("test_sync_exception_05") |
|
|
|
batch_size = 6 |
|
|
|
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"]) |
|
|
|
count = 0 |
|
|
|
@@ -255,7 +252,7 @@ def test_sync_exception_05(): |
|
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) |
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) |
|
|
|
try: |
|
|
|
for item in dataset.create_dict_iterator(): |
|
|
|
for _ in dataset.create_dict_iterator(): |
|
|
|
dataset.disable_sync() |
|
|
|
count += 1 |
|
|
|
data = {"loss": count} |
|
|
|
|