Merge pull request !64 from anzhengqi/fix-check-num-samplestags/v0.2.0-alpha
| @@ -243,6 +243,8 @@ def check_param_type(param_list, param_dict, param_type): | |||||
| if param_dict.get(param_name) is not None: | if param_dict.get(param_name) is not None: | ||||
| if param_name == 'num_parallel_workers': | if param_name == 'num_parallel_workers': | ||||
| check_num_parallel_workers(param_dict.get(param_name)) | check_num_parallel_workers(param_dict.get(param_name)) | ||||
| if param_name == 'num_samples': | |||||
| check_num_samples(param_dict.get(param_name)) | |||||
| else: | else: | ||||
| check_type(param_dict.get(param_name), param_name, param_type) | check_type(param_dict.get(param_name), param_name, param_type) | ||||
| @@ -262,6 +264,12 @@ def check_num_parallel_workers(value): | |||||
| raise ValueError("num_parallel_workers exceeds the boundary between 0 and {}!".format(cpu_count())) | raise ValueError("num_parallel_workers exceeds the boundary between 0 and {}!".format(cpu_count())) | ||||
| def check_num_samples(value): | |||||
| check_type(value, 'num_samples', int) | |||||
| if value <= 0: | |||||
| raise ValueError("num_samples must be greater than 0!") | |||||
| def check_dataset_dir(dataset_dir): | def check_dataset_dir(dataset_dir): | ||||
| if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): | if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): | ||||
| raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) | raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) | ||||
| @@ -33,14 +33,14 @@ def test_imagefolder_shardings(print_res=False): | |||||
| # total 44 rows in dataset | # total 44 rows in dataset | ||||
| assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows | assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows | ||||
| assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows | assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows | ||||
| assert (sharding_config(4, 3, 0, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows | |||||
| assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows | |||||
| # total 22 in dataset rows because of class indexing which takes only 2 folders | # total 22 in dataset rows because of class indexing which takes only 2 folders | ||||
| assert (len(sharding_config(4, 0, 0, True, {"class1": 111, "class2": 999})) == 6) | |||||
| assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6) | |||||
| assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3) | assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3) | ||||
| # test with repeat | # test with repeat | ||||
| assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3) | assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3) | ||||
| assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5) | assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5) | ||||
| assert (len(sharding_config(5, 1, 0, True, {"class1": 111, "class2": 999}, 4)) == 20) | |||||
| assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20) | |||||
| def test_manifest_shardings(print_res=False): | def test_manifest_shardings(print_res=False): | ||||
| @@ -18,6 +18,7 @@ import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | ||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| def skip_test_exception(): | def skip_test_exception(): | ||||
| @@ -29,5 +30,23 @@ def skip_test_exception(): | |||||
| assert "The shape size 1 of input tensor is invalid" in str(info.value) | assert "The shape size 1 of input tensor is invalid" in str(info.value) | ||||
| def test_sample_exception(): | |||||
| num_samples = 0 | |||||
| with pytest.raises(ValueError) as info: | |||||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||||
| assert "num_samples must be greater than 0" in str(info.value) | |||||
| num_samples = -1 | |||||
| with pytest.raises(ValueError) as info: | |||||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||||
| assert "num_samples must be greater than 0" in str(info.value) | |||||
| num_samples = 1 | |||||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||||
| data = data.map(input_columns=["image"], operations=vision.Decode()) | |||||
| data = data.map(input_columns=["image"], operations=vision.Resize((100, 100))) | |||||
| num_iters = 0 | |||||
| for item in data.create_dict_iterator(): | |||||
| num_iters += 1 | |||||
| assert num_iters == 1 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_exception() | test_exception() | ||||