|
|
|
@@ -1857,6 +1857,40 @@ def test_cache_map_cifar3(): |
|
|
|
logger.info("test_cache_map_cifar3 Ended.\n") |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") |
|
|
|
def test_cache_map_cifar4(): |
|
|
|
""" |
|
|
|
Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op |
|
|
|
|
|
|
|
shuffle |
|
|
|
| |
|
|
|
cache |
|
|
|
| |
|
|
|
Cifar10 |
|
|
|
""" |
|
|
|
|
|
|
|
logger.info("Test cache map cifar4") |
|
|
|
if "SESSION_ID" in os.environ: |
|
|
|
session_id = int(os.environ['SESSION_ID']) |
|
|
|
else: |
|
|
|
raise RuntimeError("Testcase requires SESSION_ID environment variable") |
|
|
|
|
|
|
|
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) |
|
|
|
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) |
|
|
|
ds1 = ds1.shuffle(10) |
|
|
|
|
|
|
|
num_epoch = 1 |
|
|
|
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) |
|
|
|
|
|
|
|
epoch_count = 0 |
|
|
|
for _ in range(num_epoch): |
|
|
|
assert sum([1 for _ in iter1]) == 10 |
|
|
|
epoch_count += 1 |
|
|
|
assert epoch_count == num_epoch |
|
|
|
|
|
|
|
logger.info("test_cache_map_cifar4 Ended.\n") |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") |
|
|
|
def test_cache_map_voc1(): |
|
|
|
""" |
|
|
|
|