|
|
|
@@ -35,6 +35,9 @@ def test_imagefolder_shardings(print_res=False): |
|
|
|
assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows |
|
|
|
assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows |
|
|
|
assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows |
|
|
|
assert (sharding_config(1, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] ) # 44 rows |
|
|
|
assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3] ) # 22 rows |
|
|
|
assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3] ) # 22 rows |
|
|
|
# total 22 in dataset rows because of class indexing which takes only 2 folders |
|
|
|
assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6) |
|
|
|
assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3) |
|
|
|
@@ -44,6 +47,86 @@ def test_imagefolder_shardings(print_res=False): |
|
|
|
assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20) |
|
|
|
|
|
|
|
|
|
|
|
def test_tfrecord_shardings1(print_res=False): |
|
|
|
""" Test TFRecordDataset sharding with num_parallel_workers=1 """ |
|
|
|
|
|
|
|
# total 40 rows in dataset |
|
|
|
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", |
|
|
|
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] |
|
|
|
|
|
|
|
def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1): |
|
|
|
data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, |
|
|
|
shuffle=ds.Shuffle.FILES, num_parallel_workers=1) |
|
|
|
data1 = data1.repeat(repeat_cnt) |
|
|
|
res = [] |
|
|
|
for item in data1.create_dict_iterator(): # each data is a dictionary |
|
|
|
res.append(item["scalars"][0]) |
|
|
|
if print_res: |
|
|
|
logger.info("scalars of dataset: {}".format(res)) |
|
|
|
return res |
|
|
|
|
|
|
|
assert sharding_config(2, 0, None, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows |
|
|
|
assert sharding_config(2, 1, None, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows |
|
|
|
assert sharding_config(2, 0, 3, 1) == [11, 12, 13] # 3 rows |
|
|
|
assert sharding_config(2, 1, 3, 1) == [1, 2, 3] # 3 rows |
|
|
|
assert sharding_config(2, 0, 40, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows |
|
|
|
assert sharding_config(2, 1, 40, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows |
|
|
|
assert sharding_config(2, 0, 55, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows |
|
|
|
assert sharding_config(2, 1, 55, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows |
|
|
|
assert sharding_config(3, 0, 8, 1) == [11, 12, 13, 14, 15, 16, 17, 18] # 8 rows |
|
|
|
assert sharding_config(3, 1, 8, 1) == [1, 2, 3, 4, 5, 6, 7, 8] # 8 rows |
|
|
|
assert sharding_config(3, 2, 8, 1) == [21, 22, 23, 24, 25, 26, 27, 28] # 8 rows |
|
|
|
assert sharding_config(4, 0, 2, 1) == [11, 12] # 2 rows |
|
|
|
assert sharding_config(4, 1, 2, 1) == [1, 2] # 2 rows |
|
|
|
assert sharding_config(4, 2, 2, 1) == [21, 22] # 2 rows |
|
|
|
assert sharding_config(4, 3, 2, 1) == [31, 32] # 2 rows |
|
|
|
assert sharding_config(3, 0, 4, 2) == [11, 12, 13, 14, 21, 22, 23, 24] # 8 rows |
|
|
|
assert sharding_config(3, 1, 4, 2) == [1, 2, 3, 4, 11, 12, 13, 14] # 8 rows |
|
|
|
assert sharding_config(3, 2, 4, 2) == [21, 22, 23, 24, 31, 32, 33, 34] # 8 rows |
|
|
|
|
|
|
|
|
|
|
|
def test_tfrecord_shardings4(print_res=False): |
|
|
|
""" Test TFRecordDataset sharding with num_parallel_workers=4 """ |
|
|
|
|
|
|
|
# total 40 rows in dataset |
|
|
|
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", |
|
|
|
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] |
|
|
|
|
|
|
|
def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1): |
|
|
|
data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, |
|
|
|
shuffle=ds.Shuffle.FILES, num_parallel_workers=4) |
|
|
|
data1 = data1.repeat(repeat_cnt) |
|
|
|
res = [] |
|
|
|
for item in data1.create_dict_iterator(): # each data is a dictionary |
|
|
|
res.append(item["scalars"][0]) |
|
|
|
if print_res: |
|
|
|
logger.info("scalars of dataset: {}".format(res)) |
|
|
|
return res |
|
|
|
|
|
|
|
def check_result(result_list, expect_length, expect_set): |
|
|
|
assert len(result_list) == expect_length |
|
|
|
assert set(result_list) == expect_set |
|
|
|
|
|
|
|
check_result(sharding_config(2, 0, None, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) |
|
|
|
check_result(sharding_config(2, 1, None, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) |
|
|
|
check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21}) |
|
|
|
check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31}) |
|
|
|
check_result(sharding_config(2, 0, 40, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) |
|
|
|
check_result(sharding_config(2, 1, 40, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) |
|
|
|
check_result(sharding_config(2, 0, 55, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) |
|
|
|
check_result(sharding_config(2, 1, 55, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) |
|
|
|
check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31}) |
|
|
|
check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8}) |
|
|
|
check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28}) |
|
|
|
check_result(sharding_config(4, 0, 2, 1), 2, {11, 12}) |
|
|
|
check_result(sharding_config(4, 1, 2, 1), 2, {1, 2}) |
|
|
|
check_result(sharding_config(4, 2, 2, 1), 2, {21, 22}) |
|
|
|
check_result(sharding_config(4, 3, 2, 1), 2, {31, 32}) |
|
|
|
check_result(sharding_config(3, 0, 4, 2), 8, {32, 1, 2, 11, 12, 21, 22, 31}) |
|
|
|
check_result(sharding_config(3, 1, 4, 2), 8, {1, 2, 3, 4, 11, 12, 13, 14}) |
|
|
|
check_result(sharding_config(3, 2, 4, 2), 8, {32, 33, 34, 21, 22, 23, 24, 31}) |
|
|
|
|
|
|
|
|
|
|
|
def test_manifest_shardings(print_res=False): |
|
|
|
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" |
|
|
|
|
|
|
|
@@ -157,6 +240,8 @@ def test_mnist_shardings(print_res=False): |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_imagefolder_shardings(True) |
|
|
|
test_tfrecord_shardings1(True) |
|
|
|
test_tfrecord_shardings4(True) |
|
|
|
test_manifest_shardings(True) |
|
|
|
test_voc_shardings(True) |
|
|
|
test_cifar10_shardings(True) |
|
|
|
|