Browse Source

fix: MindDataset parameter shard_id & num_shards check

tags/v0.3.0-alpha
jonyguo 5 years ago
parent
commit
be2e7531ca
2 changed files with 47 additions and 1 deletions
  1. +3
    -0
      mindspore/dataset/engine/validators.py
  2. +44
    -1
      tests/ut/python/dataset/test_minddataset_exception.py

+ 3
- 0
mindspore/dataset/engine/validators.py View File

@@ -534,6 +534,7 @@ def check_minddataset(method):
check_dataset_file(f) check_dataset_file(f)
else: else:
check_dataset_file(dataset_file) check_dataset_file(dataset_file)

check_param_type(nreq_param_int, param_dict, int) check_param_type(nreq_param_int, param_dict, int)


check_param_type(nreq_param_list, param_dict, list) check_param_type(nreq_param_list, param_dict, list)
@@ -544,6 +545,8 @@ def check_minddataset(method):
if (num_shards is not None and shard_id is None) or (num_shards is None and shard_id is not None): if (num_shards is not None and shard_id is None) or (num_shards is None and shard_id is not None):
raise ValueError("num_shards and shard_id need to be set or not set at the same time") raise ValueError("num_shards and shard_id need to be set or not set at the same time")


check_sampler_shuffle_shard_options(param_dict)

return method(*args, **kwargs) return method(*args, **kwargs)


return new_method return new_method


+ 44
- 1
tests/ut/python/dataset/test_minddataset_exception.py View File

@@ -128,7 +128,7 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
sampler = ds.PKSampler(2) sampler = ds.PKSampler(2)
with pytest.raises(Exception, match="shuffle not allowed when use sampler"):
with pytest.raises(Exception, match="sampler and shuffle cannot be specified at the same time."):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
sampler=sampler, shuffle=False) sampler=sampler, shuffle=False)
num_iter = 0 num_iter = 0
@@ -168,3 +168,46 @@ def test_cv_minddataset_reader_different_page_size():
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
os.remove(CV1_FILE_NAME) os.remove(CV1_FILE_NAME)
os.remove("{}.db".format(CV1_FILE_NAME)) os.remove("{}.db".format(CV1_FILE_NAME))

def test_minddataset_invalidate_num_shards():
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

def test_minddataset_invalidate_shard_id():
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

def test_minddataset_shard_id_bigger_than_num_shard():
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1

with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1

os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

Loading…
Cancel
Save