Browse Source

Fix sampler error messages

tags/v1.2.0-rc1
hesham 4 years ago
parent
commit
aa4298721a
4 changed files with 31 additions and 8 deletions
  1. +4
    -1
      mindspore/dataset/core/validator_helpers.py
  2. +6
    -6
      mindspore/dataset/engine/samplers.py
  3. +1
    -1
      tests/ut/python/dataset/test_datasets_celeba.py
  4. +20
    -0
      tests/ut/python/dataset/test_sampler.py

+ 4
- 1
mindspore/dataset/core/validator_helpers.py View File

@@ -23,6 +23,7 @@ import numpy as np

import mindspore._c_dataengine as cde
from ..engine import samplers

# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
UINT8_MAX = 255
@@ -289,7 +290,6 @@ def check_sampler_shuffle_shard_options(param_dict):
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
num_samples = param_dict.get('num_samples')
check_sampler(sampler)

if sampler is not None:
if shuffle is not None:
@@ -348,6 +348,7 @@ def check_num_samples(value):
raise ValueError(
"num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX))


def validate_dataset_param_value(param_list, param_dict, param_type):
for param_name in param_list:
if param_dict.get(param_name) is not None:
@@ -387,6 +388,7 @@ def check_tensor_op(param, param_name):
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))


def check_sampler(sampler):
"""
Check if the sampler is of valid input.
@@ -419,5 +421,6 @@ def check_sampler(sampler):
if not (builtin or base_sampler or list_num):
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")


def replace_none(value, default):
return value if value is not None else default

+ 6
- 6
mindspore/dataset/engine/samplers.py View File

@@ -73,11 +73,11 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if _is_iterable(input_sampler):
if not isinstance(input_sampler, str) and _is_iterable(input_sampler):
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return [input_sampler]
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
@@ -644,9 +644,9 @@ class SubsetSampler(BuiltinSampler):
indices = [indices]

for i, item in enumerate(indices):
if not isinstance(item, numbers.Number):
raise TypeError("type of indices element must be number, "
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
if not isinstance(item, int):
raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item)))

if num_samples is not None:
if not isinstance(num_samples, int):


+ 1
- 1
tests/ut/python/dataset/test_datasets_celeba.py View File

@@ -179,7 +179,7 @@ def test_celeba_sampler_exception():
pass
assert False
except TypeError as e:
assert "Argument" in str(e)
assert "Unsupported sampler object of type (<class 'str'>)" in str(e)


if __name__ == '__main__':


+ 20
- 0
tests/ut/python/dataset/test_sampler.py View File

@@ -274,6 +274,26 @@ def test_sampler_list():

dataset_equal(data1, data21 + data22 + data23, 0)

data3 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=1)
dataset_equal(data3, data21, 0)

def bad_pipeline(sampler, msg):
with pytest.raises(Exception) as info:
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=sampler)
for _ in data1:
pass
assert msg in str(info.value)

bad_pipeline(sampler=[1.5, 7],
msg="Type of indices element must be int, but got list[0]: 1.5, type: <class 'float'>")

bad_pipeline(sampler=["a", "b"],
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler=np.array([1, 2]),
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.")


if __name__ == '__main__':
test_sequential_sampler(True)


Loading…
Cancel
Save