|
|
|
@@ -501,6 +501,35 @@ def test_cifar_exception_file_path(): |
|
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) |
|
|
|
|
|
|
|
|
|
|
|
def test_cifar10_pk_sampler_get_dataset_size(): |
|
|
|
""" |
|
|
|
Test Cifar10Dataset with PKSampler and get_dataset_size |
|
|
|
""" |
|
|
|
sampler = ds.PKSampler(3) |
|
|
|
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) |
|
|
|
num_iter = 0 |
|
|
|
ds_sz = data.get_dataset_size() |
|
|
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): |
|
|
|
num_iter += 1 |
|
|
|
|
|
|
|
assert ds_sz == num_iter == 30 |
|
|
|
|
|
|
|
|
|
|
|
def test_cifar10_with_chained_sampler_get_dataset_size(): |
|
|
|
""" |
|
|
|
Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size |
|
|
|
""" |
|
|
|
sampler = ds.SequentialSampler(start_index=0, num_samples=5) |
|
|
|
child_sampler = ds.PKSampler(4) |
|
|
|
sampler.add_child(child_sampler) |
|
|
|
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) |
|
|
|
num_iter = 0 |
|
|
|
ds_sz = data.get_dataset_size() |
|
|
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): |
|
|
|
num_iter += 1 |
|
|
|
assert ds_sz == num_iter == 5 |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_cifar10_content_check() |
|
|
|
test_cifar10_basic() |
|
|
|
@@ -517,3 +546,6 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
test_cifar_usage() |
|
|
|
test_cifar_exception_file_path() |
|
|
|
|
|
|
|
test_cifar10_with_chained_sampler_get_dataset_size() |
|
|
|
test_cifar10_pk_sampler_get_dataset_size() |