|
|
|
@@ -57,6 +57,24 @@ def test_imagefolder_numsamples(): |
|
|
|
logger.info("Number of data in data1: {}".format(num_iter)) |
|
|
|
assert num_iter == 10 |
|
|
|
|
|
|
|
random_sampler = ds.RandomSampler(num_samples=3, replacement=True) |
|
|
|
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler) |
|
|
|
|
|
|
|
num_iter = 0 |
|
|
|
for item in data1.create_dict_iterator(): |
|
|
|
num_iter += 1 |
|
|
|
|
|
|
|
assert num_iter == 3 |
|
|
|
|
|
|
|
random_sampler = ds.RandomSampler(num_samples=3, replacement=False) |
|
|
|
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler) |
|
|
|
|
|
|
|
num_iter = 0 |
|
|
|
for item in data1.create_dict_iterator(): |
|
|
|
num_iter += 1 |
|
|
|
|
|
|
|
assert num_iter == 3 |
|
|
|
|
|
|
|
|
|
|
|
def test_imagefolder_numshards(): |
|
|
|
logger.info("Test Case numShards") |
|
|
|
|