diff --git a/tests/ut/python/dataset/test_datasets_sharding.py b/tests/ut/python/dataset/test_datasets_sharding.py index 59c8c2c7bf..0b8c1900c9 100644 --- a/tests/ut/python/dataset/test_datasets_sharding.py +++ b/tests/ut/python/dataset/test_datasets_sharding.py @@ -35,9 +35,9 @@ def test_imagefolder_shardings(print_res=False): assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows - assert (sharding_config(1, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] ) # 44 rows - assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3] ) # 22 rows - assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3] ) # 22 rows + assert (sharding_config(1, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]) # 44 rows + assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) # 22 rows + assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]) # 22 rows # total 22 in dataset rows because of class indexing which takes only 2 folders assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6) assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3) diff --git a/tests/ut/python/dataset/test_exceptions.py b/tests/ut/python/dataset/test_exceptions.py index 8dcca8160e..cb79d456d4 100644 --- a/tests/ut/python/dataset/test_exceptions.py +++ b/tests/ut/python/dataset/test_exceptions.py @@ -27,7 +27,6 @@ def test_exception_01(): Test single exception with invalid input """ logger.info("test_exception_01") - ds.config.set_num_parallel_workers(1) data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) with pytest.raises(ValueError) as info: data = data.map(input_columns=["image"], operations=vision.Resize(100, 100))