|
|
|
@@ -30,6 +30,14 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" |
|
|
|
|
|
|
|
|
|
|
|
def test_basic(): |
|
|
|
""" |
|
|
|
Test basic configuration functions |
|
|
|
""" |
|
|
|
# Save original configuration values |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
prefetch_size_original = ds.config.get_prefetch_size() |
|
|
|
seed_original = ds.config.get_seed() |
|
|
|
|
|
|
|
ds.config.load('../data/dataset/declient.cfg') |
|
|
|
|
|
|
|
# assert ds.config.get_rows_per_buffer() == 32 |
|
|
|
@@ -50,6 +58,11 @@ def test_basic(): |
|
|
|
assert ds.config.get_prefetch_size() == 4 |
|
|
|
assert ds.config.get_seed() == 5 |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_num_parallel_workers(num_parallel_workers_original) |
|
|
|
ds.config.set_prefetch_size(prefetch_size_original) |
|
|
|
ds.config.set_seed(seed_original) |
|
|
|
|
|
|
|
|
|
|
|
def test_get_seed(): |
|
|
|
""" |
|
|
|
@@ -62,6 +75,9 @@ def test_pipeline(): |
|
|
|
""" |
|
|
|
Test that our configuration pipeline works when we set parameters at different locations in dataset code |
|
|
|
""" |
|
|
|
# Save original configuration values |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) |
|
|
|
ds.config.set_num_parallel_workers(2) |
|
|
|
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) |
|
|
|
@@ -85,6 +101,9 @@ def test_pipeline(): |
|
|
|
except IOError: |
|
|
|
logger.info("Error while deleting: {}".format(f)) |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_num_parallel_workers(num_parallel_workers_original) |
|
|
|
|
|
|
|
|
|
|
|
def test_deterministic_run_fail(): |
|
|
|
""" |
|
|
|
@@ -92,6 +111,10 @@ def test_deterministic_run_fail(): |
|
|
|
""" |
|
|
|
logger.info("test_deterministic_run_fail") |
|
|
|
|
|
|
|
# Save original configuration values |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
seed_original = ds.config.get_seed() |
|
|
|
|
|
|
|
# when we set the seed all operations within our dataset should be deterministic |
|
|
|
ds.config.set_seed(0) |
|
|
|
ds.config.set_num_parallel_workers(1) |
|
|
|
@@ -120,12 +143,21 @@ def test_deterministic_run_fail(): |
|
|
|
logger.info("Got an exception in DE: {}".format(str(e))) |
|
|
|
assert "Array" in str(e) |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_num_parallel_workers(num_parallel_workers_original) |
|
|
|
ds.config.set_seed(seed_original) |
|
|
|
|
|
|
|
|
|
|
|
def test_deterministic_run_pass(): |
|
|
|
""" |
|
|
|
Test deterministic run with with setting the seed |
|
|
|
""" |
|
|
|
logger.info("test_deterministic_run_pass") |
|
|
|
|
|
|
|
# Save original configuration values |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
seed_original = ds.config.get_seed() |
|
|
|
|
|
|
|
ds.config.set_seed(0) |
|
|
|
ds.config.set_num_parallel_workers(1) |
|
|
|
|
|
|
|
@@ -152,13 +184,23 @@ def test_deterministic_run_pass(): |
|
|
|
logger.info("Got an exception in DE: {}".format(str(e))) |
|
|
|
assert "Array" in str(e) |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_num_parallel_workers(num_parallel_workers_original) |
|
|
|
ds.config.set_seed(seed_original) |
|
|
|
|
|
|
|
|
|
|
|
def test_seed_undeterministic(): |
|
|
|
""" |
|
|
|
Test seed with num parallel workers in c, this test is expected to fail some of the time |
|
|
|
""" |
|
|
|
logger.info("test_seed_undeterministic") |
|
|
|
|
|
|
|
# Save original configuration values |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
seed_original = ds.config.get_seed() |
|
|
|
|
|
|
|
ds.config.set_seed(0) |
|
|
|
ds.config.set_num_parallel_workers(1) |
|
|
|
|
|
|
|
# First dataset |
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) |
|
|
|
@@ -178,6 +220,10 @@ def test_seed_undeterministic(): |
|
|
|
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): |
|
|
|
np.testing.assert_equal(item1["image"], item2["image"]) |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_num_parallel_workers(num_parallel_workers_original) |
|
|
|
ds.config.set_seed(seed_original) |
|
|
|
|
|
|
|
|
|
|
|
def test_deterministic_run_distribution(): |
|
|
|
""" |
|
|
|
@@ -185,6 +231,10 @@ def test_deterministic_run_distribution(): |
|
|
|
""" |
|
|
|
logger.info("test_deterministic_run_distribution") |
|
|
|
|
|
|
|
# Save original configuration values |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
seed_original = ds.config.get_seed() |
|
|
|
|
|
|
|
# when we set the seed all operations within our dataset should be deterministic |
|
|
|
ds.config.set_seed(0) |
|
|
|
ds.config.set_num_parallel_workers(1) |
|
|
|
@@ -206,12 +256,21 @@ def test_deterministic_run_distribution(): |
|
|
|
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): |
|
|
|
np.testing.assert_equal(item1["image"], item2["image"]) |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_num_parallel_workers(num_parallel_workers_original) |
|
|
|
ds.config.set_seed(seed_original) |
|
|
|
|
|
|
|
|
|
|
|
def test_deterministic_python_seed(): |
|
|
|
""" |
|
|
|
Test deterministic execution with seed in python |
|
|
|
""" |
|
|
|
logger.info("deterministic_random_crop_op_python_2") |
|
|
|
|
|
|
|
# Save original configuration values |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
seed_original = ds.config.get_seed() |
|
|
|
|
|
|
|
ds.config.set_seed(0) |
|
|
|
ds.config.set_num_parallel_workers(1) |
|
|
|
|
|
|
|
@@ -242,12 +301,20 @@ def test_deterministic_python_seed(): |
|
|
|
|
|
|
|
np.testing.assert_equal(data1_output, data2_output) |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_num_parallel_workers(num_parallel_workers_original) |
|
|
|
ds.config.set_seed(seed_original) |
|
|
|
|
|
|
|
|
|
|
|
def test_deterministic_python_seed_multi_thread(): |
|
|
|
""" |
|
|
|
Test deterministic execution with seed in python, this fails with multi-thread pyfunc run |
|
|
|
""" |
|
|
|
logger.info("deterministic_random_crop_op_python_2") |
|
|
|
|
|
|
|
# Save original configuration values |
|
|
|
seed_original = ds.config.get_seed() |
|
|
|
|
|
|
|
ds.config.set_seed(0) |
|
|
|
# when we set the seed all operations within our dataset should be deterministic |
|
|
|
# First dataset |
|
|
|
@@ -282,6 +349,9 @@ def test_deterministic_python_seed_multi_thread(): |
|
|
|
logger.info("Got an exception in DE: {}".format(str(e))) |
|
|
|
assert "Array" in str(e) |
|
|
|
|
|
|
|
# Restore original configuration values |
|
|
|
ds.config.set_seed(seed_original) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_basic() |
|
|
|
|