Browse Source

!12765 Fix numpy input to samplers

From: @hfarahat
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
1c191a65fb
3 changed files with 43 additions and 62 deletions
  1. +0
    -35
      mindspore/dataset/core/validator_helpers.py
  2. +24
    -23
      mindspore/dataset/engine/samplers.py
  3. +19
    -4
      tests/ut/python/dataset/test_sampler.py

+ 0
- 35
mindspore/dataset/core/validator_helpers.py View File

@@ -16,13 +16,11 @@
General Validators. General Validators.
""" """
import inspect import inspect
import numbers
from multiprocessing import cpu_count from multiprocessing import cpu_count
import os import os
import numpy as np import numpy as np


import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from ..engine import samplers


# POS_INT_MIN is used to limit values from starting from 0 # POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1 POS_INT_MIN = 1
@@ -389,38 +387,5 @@ def check_tensor_op(param, param_name):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))




def check_sampler(sampler):
"""
Check if the sampler is of valid input.

Args:
param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler

Returns:
Exception: TypeError if error
"""
builtin = False
base_sampler = False
list_num = False
if sampler is not None:
if isinstance(sampler, samplers.BuiltinSampler):
builtin = True
elif isinstance(sampler, samplers.Sampler):
base_sampler = True
else:
# check for list of numbers
list_num = True
# subset sampler check
subset_sampler = sampler
if not isinstance(sampler, list):
subset_sampler = [sampler]

for _, item in enumerate(subset_sampler):
if not isinstance(item, numbers.Number):
list_num = False
if not (builtin or base_sampler or list_num):
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")


def replace_none(value, default): def replace_none(value, default):
return value if value is not None else default return value if value is not None else default

+ 24
- 23
mindspore/dataset/engine/samplers.py View File

@@ -41,22 +41,6 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
Sampler, sampler selected based on user input. Sampler, sampler selected based on user input.
""" """


def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True

def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)

if isinstance(sampler, list):
return sampler[:number_of_samples]

return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]

if input_sampler is not None: if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because # If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler. # we are being asked specifically to use the given sampler.
@@ -73,11 +57,8 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler): if isinstance(input_sampler, BuiltinSampler):
return input_sampler return input_sampler
if not isinstance(input_sampler, str) and _is_iterable(input_sampler):
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return SubsetSampler([input_sampler])
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler)))
return SubsetSampler(input_sampler, num_samples)

if shuffle is None: if shuffle is None:
if num_shards is not None: if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler # If shuffle is not specified, sharding enabled, use distributed random sampler
@@ -640,11 +621,31 @@ class SubsetSampler(BuiltinSampler):
""" """


def __init__(self, indices, num_samples=None): def __init__(self, indices, num_samples=None):
if not isinstance(indices, list):
def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True

def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)

if isinstance(sampler, list):
return sampler[:number_of_samples]

return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]

if not isinstance(indices, str) and _is_iterable(indices):
indices = _get_sample_ids_as_list(indices, num_samples)
elif isinstance(indices, int):
indices = [indices] indices = [indices]
else:
raise TypeError('Unsupported sampler object of type ({})'.format(type(indices)))


for i, item in enumerate(indices): for i, item in enumerate(indices):
if not isinstance(item, int):
if not isinstance(item, (int, np.integer)):
raise TypeError("SubsetSampler: Type of indices element must be int, " raise TypeError("SubsetSampler: Type of indices element must be int, "
"but got list[{}]: {}, type: {}.".format(i, item, type(item))) "but got list[{}]: {}, type: {}.".format(i, item, type(item)))




+ 19
- 4
tests/ut/python/dataset/test_sampler.py View File

@@ -177,13 +177,23 @@ def test_subset_sampler():
def pipeline(): def pipeline():
sampler = ds.SubsetSampler(indices, num_samples) sampler = ds.SubsetSampler(indices, num_samples)
data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler) data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler)
data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples)
dataset_size = data.get_dataset_size() dataset_size = data.get_dataset_size()
return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
dataset_size2 = data.get_dataset_size()
res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2
return res1, res2


if exception_msg is None: if exception_msg is None:
res, size = pipeline()
res, res2 = pipeline()
res, size = res
res2, size2 = res2
if not isinstance(indices, list):
indices = list(indices)
assert indices[:num_samples] == res assert indices[:num_samples] == res
assert len(indices[:num_samples]) == size assert len(indices[:num_samples]) == size
assert indices[:num_samples] == res2
assert len(indices[:num_samples]) == size2
else: else:
with pytest.raises(Exception) as error_info: with pytest.raises(Exception) as error_info:
pipeline() pipeline()
@@ -205,6 +215,8 @@ def test_subset_sampler():
test_config([0, 9, 3, 2], num_samples=2) test_config([0, 9, 3, 2], num_samples=2)
test_config([0, 9, 3, 2], num_samples=5) test_config([0, 9, 3, 2], num_samples=5)


test_config(np.array([1, 2, 3]))

test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]") test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]")
test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]") test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]")
test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
@@ -212,6 +224,9 @@ def test_subset_sampler():
# test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
test_config([0, 9, 3, 2], num_samples=-1, test_config([0, 9, 3, 2], num_samples=-1,
exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)") exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)")
test_config(np.array([[1], [5]]), num_samples=10,
exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1],"
" type: <class 'numpy.ndarray'>.")




def test_sampler_chain(): def test_sampler_chain():
@@ -291,8 +306,8 @@ def test_sampler_list():
msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.") msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)") bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)") bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
bad_pipeline(sampler=np.array([1, 2]),
msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.")
bad_pipeline(sampler=np.array([[1, 2]]),
msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save