|
|
|
@@ -20,6 +20,18 @@ from util import save_and_check_md5 |
|
|
|
|
|
|
|
GENERATE_GOLDEN = False |
|
|
|
|
|
|
|
IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" |
|
|
|
IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", |
|
|
|
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", |
|
|
|
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", |
|
|
|
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] |
|
|
|
MNIST_DATA_DIR = "../data/dataset/testMnistData" |
|
|
|
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" |
|
|
|
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data" |
|
|
|
COCO_DATA_DIR = "../data/dataset/testCOCO/train/" |
|
|
|
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" |
|
|
|
VOC_DATA_DIR = "../data/dataset/testVOC2012" |
|
|
|
|
|
|
|
|
|
|
|
def test_numpyslices_sampler_no_chain(): |
|
|
|
""" |
|
|
|
@@ -107,6 +119,166 @@ def test_numpyslices_sampler_chain2(): |
|
|
|
logger.info("dataset: {}".format(res)) |
|
|
|
|
|
|
|
|
|
|
|
def test_imagefolder_sampler_chain(): |
|
|
|
""" |
|
|
|
Test ImageFolderDataset sampler chain |
|
|
|
""" |
|
|
|
logger.info("test_imagefolder_sampler_chain") |
|
|
|
|
|
|
|
sampler = ds.SequentialSampler(start_index=1, num_samples=3) |
|
|
|
child_sampler = ds.PKSampler(2) |
|
|
|
sampler.add_child(child_sampler) |
|
|
|
data1 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, sampler=sampler) |
|
|
|
# Verify dataset size |
|
|
|
data1_size = data1.get_dataset_size() |
|
|
|
logger.info("dataset size is: {}".format(data1_size)) |
|
|
|
assert data1_size == 3 |
|
|
|
# Verify number of rows |
|
|
|
assert sum([1 for _ in data1]) == 3 |
|
|
|
|
|
|
|
# Verify dataset contents |
|
|
|
res = [] |
|
|
|
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): |
|
|
|
logger.info("item: {}".format(item)) |
|
|
|
res.append(item) |
|
|
|
logger.info("dataset: {}".format(res)) |
|
|
|
|
|
|
|
|
|
|
|
def test_mnist_sampler_chain(): |
|
|
|
""" |
|
|
|
Test Mnist sampler chain |
|
|
|
""" |
|
|
|
logger.info("test_mnist_sampler_chain") |
|
|
|
|
|
|
|
sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) |
|
|
|
child_sampler = ds.RandomSampler(replacement=True, num_samples=4) |
|
|
|
sampler.add_child(child_sampler) |
|
|
|
data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler) |
|
|
|
|
|
|
|
# Verify dataset size |
|
|
|
data1_size = data1.get_dataset_size() |
|
|
|
logger.info("dataset size is: {}".format(data1_size)) |
|
|
|
assert data1_size == 3 |
|
|
|
# Verify number of rows |
|
|
|
assert sum([1 for _ in data1]) == 3 |
|
|
|
|
|
|
|
# Verify dataset contents |
|
|
|
res = [] |
|
|
|
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): |
|
|
|
logger.info("item: {}".format(item)) |
|
|
|
res.append(item) |
|
|
|
logger.info("dataset: {}".format(res)) |
|
|
|
|
|
|
|
|
|
|
|
def test_manifest_sampler_chain(): |
|
|
|
""" |
|
|
|
Test Manifest sampler chain |
|
|
|
""" |
|
|
|
logger.info("test_manifest_sampler_chain") |
|
|
|
|
|
|
|
sampler = ds.RandomSampler(replacement=True, num_samples=2) |
|
|
|
child_sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) |
|
|
|
sampler.add_child(child_sampler) |
|
|
|
data1 = ds.ManifestDataset(MANIFEST_DATA_FILE, sampler=sampler) |
|
|
|
|
|
|
|
# Verify dataset size |
|
|
|
data1_size = data1.get_dataset_size() |
|
|
|
logger.info("dataset size is: {}".format(data1_size)) |
|
|
|
assert data1_size == 2 |
|
|
|
# Verify number of rows |
|
|
|
assert sum([1 for _ in data1]) == 2 |
|
|
|
|
|
|
|
# Verify dataset contents |
|
|
|
res = [] |
|
|
|
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): |
|
|
|
logger.info("item: {}".format(item)) |
|
|
|
res.append(item) |
|
|
|
logger.info("dataset: {}".format(res)) |
|
|
|
|
|
|
|
|
|
|
|
def test_coco_sampler_chain(): |
|
|
|
""" |
|
|
|
Test Coco sampler chain |
|
|
|
""" |
|
|
|
logger.info("test_coco_sampler_chain") |
|
|
|
|
|
|
|
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) |
|
|
|
child_sampler = ds.RandomSampler(replacement=True, num_samples=2) |
|
|
|
sampler.add_child(child_sampler) |
|
|
|
data1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True, |
|
|
|
sampler=sampler) |
|
|
|
|
|
|
|
# Verify dataset size |
|
|
|
data1_size = data1.get_dataset_size() |
|
|
|
logger.info("dataset size is: {}".format(data1_size)) |
|
|
|
assert data1_size == 1 |
|
|
|
|
|
|
|
# Verify number of rows |
|
|
|
assert sum([1 for _ in data1]) == 1 |
|
|
|
|
|
|
|
# Verify dataset contents |
|
|
|
res = [] |
|
|
|
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): |
|
|
|
logger.info("item: {}".format(item)) |
|
|
|
res.append(item) |
|
|
|
logger.info("dataset: {}".format(res)) |
|
|
|
|
|
|
|
|
|
|
|
def test_cifar_sampler_chain(): |
|
|
|
""" |
|
|
|
Test Cifar sampler chain |
|
|
|
""" |
|
|
|
logger.info("test_cifar_sampler_chain") |
|
|
|
|
|
|
|
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) |
|
|
|
child_sampler = ds.RandomSampler(replacement=True, num_samples=4) |
|
|
|
child_sampler2 = ds.SequentialSampler(start_index=0, num_samples=2) |
|
|
|
child_sampler.add_child(child_sampler2) |
|
|
|
sampler.add_child(child_sampler) |
|
|
|
data1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, sampler=sampler) |
|
|
|
# Verify dataset size |
|
|
|
data1_size = data1.get_dataset_size() |
|
|
|
logger.info("dataset size is: {}".format(data1_size)) |
|
|
|
assert data1_size == 1 |
|
|
|
|
|
|
|
# Verify number of rows |
|
|
|
assert sum([1 for _ in data1]) == 1 |
|
|
|
|
|
|
|
# Verify dataset contents |
|
|
|
res = [] |
|
|
|
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): |
|
|
|
logger.info("item: {}".format(item)) |
|
|
|
res.append(item) |
|
|
|
logger.info("dataset: {}".format(res)) |
|
|
|
|
|
|
|
|
|
|
|
def test_voc_sampler_chain(): |
|
|
|
""" |
|
|
|
Test VOC sampler chain |
|
|
|
""" |
|
|
|
logger.info("test_voc_sampler_chain") |
|
|
|
|
|
|
|
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) |
|
|
|
child_sampler = ds.SequentialSampler(start_index=0) |
|
|
|
sampler.add_child(child_sampler) |
|
|
|
data1 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", sampler=sampler) |
|
|
|
|
|
|
|
# Verify dataset size |
|
|
|
data1_size = data1.get_dataset_size() |
|
|
|
logger.info("dataset size is: {}".format(data1_size)) |
|
|
|
assert data1_size == 5 |
|
|
|
|
|
|
|
# Verify number of rows |
|
|
|
assert sum([1 for _ in data1]) == 5 |
|
|
|
|
|
|
|
# Verify dataset contents |
|
|
|
res = [] |
|
|
|
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): |
|
|
|
logger.info("item: {}".format(item)) |
|
|
|
res.append(item) |
|
|
|
logger.info("dataset: {}".format(res)) |
|
|
|
|
|
|
|
|
|
|
|
def test_numpyslices_sampler_chain_batch(): |
|
|
|
""" |
|
|
|
Test NumpySlicesDataset sampler chaining, with batch |
|
|
|
@@ -241,6 +413,12 @@ if __name__ == '__main__': |
|
|
|
test_numpyslices_sampler_no_chain() |
|
|
|
test_numpyslices_sampler_chain() |
|
|
|
test_numpyslices_sampler_chain2() |
|
|
|
test_imagefolder_sampler_chain() |
|
|
|
test_mnist_sampler_chain() |
|
|
|
test_manifest_sampler_chain() |
|
|
|
test_coco_sampler_chain() |
|
|
|
test_cifar_sampler_chain() |
|
|
|
test_voc_sampler_chain() |
|
|
|
test_numpyslices_sampler_chain_batch() |
|
|
|
test_sampler_chain_errors() |
|
|
|
test_manifest_sampler_chain_repeat() |
|
|
|
|