From b78894e02bb5af62897643a620fc267b8b8d3f5f Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Tue, 19 May 2020 17:09:54 -0400 Subject: [PATCH] Cleanup dataset UT: unskip and enhance TFRecord sharding tests --- tests/ut/python/dataset/test_concat.py | 9 +- .../python/dataset/test_datasets_sharding.py | 85 +++++++++++++++++++ tests/ut/python/dataset/test_five_crop.py | 2 +- tests/ut/python/dataset/test_tfreader_op.py | 13 ++- 4 files changed, 99 insertions(+), 10 deletions(-) diff --git a/tests/ut/python/dataset/test_concat.py b/tests/ut/python/dataset/test_concat.py index 465554869e..19d81edb10 100644 --- a/tests/ut/python/dataset/test_concat.py +++ b/tests/ut/python/dataset/test_concat.py @@ -21,19 +21,18 @@ import mindspore.dataset.transforms.vision.py_transforms as F from mindspore import log as logger -# In generator dataset: Number of rows is 3, its value is 0, 1, 2 +# In generator dataset: Number of rows is 3; its values are 0, 1, 2 def generator(): for i in range(3): yield np.array([i]), -# In generator_10 dataset: Number of rows is 7, its value is 3, 4, 5 ... 10 +# In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9 def generator_10(): for i in range(3, 10): yield np.array([i]), - -# In generator_20 dataset: Number of rows is 10, its value is 10, 11, 12 ... 20 +# In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19 def generator_20(): for i in range(10, 20): yield np.array([i]), @@ -135,7 +134,7 @@ def test_concat_05(): def test_concat_06(): """ - Test concat: test concat muti datasets in one time + Test concat: test concat multi datasets in one time """ logger.info("test_concat_06") data1 = ds.GeneratorDataset(generator, ["col1"]) diff --git a/tests/ut/python/dataset/test_datasets_sharding.py b/tests/ut/python/dataset/test_datasets_sharding.py index 825ceb661a..59c8c2c7bf 100644 --- a/tests/ut/python/dataset/test_datasets_sharding.py +++ b/tests/ut/python/dataset/test_datasets_sharding.py @@ -35,6 +35,9 @@ def test_imagefolder_shardings(print_res=False): assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows + assert (sharding_config(1, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] ) # 44 rows + assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3] ) # 22 rows + assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3] ) # 22 rows # total 22 in dataset rows because of class indexing which takes only 2 folders assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6) assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3) @@ -44,6 +47,86 @@ def test_imagefolder_shardings(print_res=False): assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20) +def test_tfrecord_shardings1(print_res=False): + """ Test TFRecordDataset sharding with num_parallel_workers=1 """ + + # total 40 rows in dataset + tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", + "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] + + def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1): + data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, + shuffle=ds.Shuffle.FILES, num_parallel_workers=1) + data1 = data1.repeat(repeat_cnt) + res = [] + for item in data1.create_dict_iterator(): # each data is a dictionary + res.append(item["scalars"][0]) + if print_res: + logger.info("scalars of dataset: {}".format(res)) + return res + + assert sharding_config(2, 0, None, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows + assert sharding_config(2, 1, None, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows + assert sharding_config(2, 0, 3, 1) == [11, 12, 13] # 3 rows + assert sharding_config(2, 1, 3, 1) == [1, 2, 3] # 3 rows + assert sharding_config(2, 0, 40, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows + assert sharding_config(2, 1, 40, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows + assert sharding_config(2, 0, 55, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows + assert sharding_config(2, 1, 55, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows + assert sharding_config(3, 0, 8, 1) == [11, 12, 13, 14, 15, 16, 17, 18] # 8 rows + assert sharding_config(3, 1, 8, 1) == [1, 2, 3, 4, 5, 6, 7, 8] # 8 rows + assert sharding_config(3, 2, 8, 1) == [21, 22, 23, 24, 25, 26, 27, 28] # 8 rows + assert sharding_config(4, 0, 2, 1) == [11, 12] # 2 rows + assert sharding_config(4, 1, 2, 1) == [1, 2] # 2 rows + assert sharding_config(4, 2, 2, 1) == [21, 22] # 2 rows + assert sharding_config(4, 3, 2, 1) == [31, 32] # 2 rows + assert sharding_config(3, 0, 4, 2) == [11, 12, 13, 14, 21, 22, 23, 24] # 8 rows + assert sharding_config(3, 1, 4, 2) == [1, 2, 3, 4, 11, 12, 13, 14] # 8 rows + assert sharding_config(3, 2, 4, 2) == [21, 22, 23, 24, 31, 32, 33, 34] # 8 rows + + +def test_tfrecord_shardings4(print_res=False): + """ Test TFRecordDataset sharding with num_parallel_workers=4 """ + + # total 40 rows in dataset + tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", + "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] + + def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1): + data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, + shuffle=ds.Shuffle.FILES, num_parallel_workers=4) + data1 = data1.repeat(repeat_cnt) + res = [] + for item in data1.create_dict_iterator(): # each data is a dictionary + res.append(item["scalars"][0]) + if print_res: + logger.info("scalars of dataset: {}".format(res)) + return res + + def check_result(result_list, expect_length, expect_set): + assert len(result_list) == expect_length + assert set(result_list) == expect_set + + check_result(sharding_config(2, 0, None, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) + check_result(sharding_config(2, 1, None, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) + check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21}) + check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31}) + check_result(sharding_config(2, 0, 40, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) + check_result(sharding_config(2, 1, 40, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) + check_result(sharding_config(2, 0, 55, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) + check_result(sharding_config(2, 1, 55, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) + check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31}) + check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8}) + check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28}) + check_result(sharding_config(4, 0, 2, 1), 2, {11, 12}) + check_result(sharding_config(4, 1, 2, 1), 2, {1, 2}) + check_result(sharding_config(4, 2, 2, 1), 2, {21, 22}) + check_result(sharding_config(4, 3, 2, 1), 2, {31, 32}) + check_result(sharding_config(3, 0, 4, 2), 8, {32, 1, 2, 11, 12, 21, 22, 31}) + check_result(sharding_config(3, 1, 4, 2), 8, {1, 2, 3, 4, 11, 12, 13, 14}) + check_result(sharding_config(3, 2, 4, 2), 8, {32, 33, 34, 21, 22, 23, 24, 31}) + + def test_manifest_shardings(print_res=False): manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" @@ -157,6 +240,8 @@ def test_mnist_shardings(print_res=False): if __name__ == '__main__': test_imagefolder_shardings(True) + test_tfrecord_shardings1(True) + test_tfrecord_shardings4(True) test_manifest_shardings(True) test_voc_shardings(True) test_cifar10_shardings(True) diff --git a/tests/ut/python/dataset/test_five_crop.py b/tests/ut/python/dataset/test_five_crop.py index c5cba721c7..3a08cc17c4 100644 --- a/tests/ut/python/dataset/test_five_crop.py +++ b/tests/ut/python/dataset/test_five_crop.py @@ -43,7 +43,7 @@ def visualize(image_1, image_2): plt.show() -def skip_test_five_crop_op(): +def test_five_crop_op(): """ Test FiveCrop """ diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py index e4c991eef2..09e4dc1fd3 100644 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ b/tests/ut/python/dataset/test_tfreader_op.py @@ -153,7 +153,7 @@ def test_tf_record_shuffle(): assert np.array_equal(t1, t2) -def skip_test_tf_record_shard(): +def test_tf_record_shard(): tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] @@ -171,12 +171,14 @@ def skip_test_tf_record_shard(): # 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4) worker1_res = get_res(0, 16) worker2_res = get_res(1, 16) + # Confirm each worker gets 3x16=48 rows + assert len(worker1_res) == 48 + assert len(worker1_res) == len(worker2_res) # check criteria 1 for i in range(len(worker1_res)): assert (worker1_res[i] != worker2_res[i]) # check criteria 2 assert (set(worker2_res) == set(worker1_res)) - assert (len(set(worker2_res)) == 12) def test_tf_shard_equal_rows(): @@ -198,7 +200,10 @@ def test_tf_shard_equal_rows(): for i in range(len(worker1_res)): assert (worker1_res[i] != worker2_res[i]) assert (worker2_res[i] != worker3_res[i]) - assert (len(worker1_res) == 28) + # Confirm each worker gets same number of rows + assert len(worker1_res) == 28 + assert len(worker1_res) == len(worker2_res) + assert len(worker2_res) == len(worker3_res) worker4_res = get_res(1, 0, 1) assert (len(worker4_res) == 40) @@ -272,7 +277,7 @@ if __name__ == '__main__': test_tf_files() test_tf_record_schema() test_tf_record_shuffle() - # test_tf_record_shard() + test_tf_record_shard() test_tf_shard_equal_rows() test_case_tf_file_no_schema_columns_list() test_tf_record_schema_columns_list()