Browse Source

add paramter check for numpyslices and num_shards

tags/v0.6.0-beta
ms_yan 5 years ago
parent
commit
7fa0d9e7e4
3 changed files with 10 additions and 3 deletions
  1. +7
    -2
      mindspore/dataset/engine/datasets.py
  2. +2
    -0
      mindspore/dataset/engine/validators.py
  3. +1
    -1
      tests/ut/python/dataset/test_minddataset_exception.py

+ 7
- 2
mindspore/dataset/engine/datasets.py View File

@@ -3069,7 +3069,7 @@ class GeneratorDataset(MappableDataset):
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.

@@ -4878,6 +4878,11 @@ class _NumpySlicesDataset:
else:
self.data = (np.array(data),)

# check whether the data length in each column is equal
data_len = [len(data_item) for data_item in self.data]
if data_len[1:] != data_len[:-1]:
raise ValueError("Data length in each column is not equal.")

# Init column_name
if column_list is not None:
self.column_list = column_list
@@ -4966,7 +4971,7 @@ class NumpySlicesDataset(GeneratorDataset):
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.



+ 2
- 0
mindspore/dataset/engine/validators.py View File

@@ -153,6 +153,7 @@ def check_sampler_shuffle_shard_options(param_dict):
raise RuntimeError("sampler and sharding cannot be specified at the same time.")

if num_shards is not None:
check_positive_int32(num_shards, "num_shards")
if shard_id is None:
raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
if shard_id < 0 or shard_id >= num_shards:
@@ -529,6 +530,7 @@ def check_generatordataset(method):
# These two parameters appear together.
raise ValueError("num_shards and shard_id need to be passed in together")
if num_shards is not None:
check_positive_int32(num_shards, "num_shards")
if shard_id >= num_shards:
raise ValueError("shard_id should be less than num_shards")



+ 1
- 1
tests/ut/python/dataset/test_minddataset_exception.py View File

@@ -185,7 +185,7 @@ def test_minddataset_invalidate_num_shards():
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1)
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1


Loading…
Cancel
Save