|
|
|
@@ -605,7 +605,7 @@ class SubsetSampler(BuiltinSampler): |
|
|
|
Samples the elements from a sequence of indices. |
|
|
|
|
|
|
|
Args: |
|
|
|
indices (list[int]): A sequence of indices. |
|
|
|
indices (Any iterable python object but string): A sequence of indices. |
|
|
|
num_samples (int, optional): Number of elements to sample (default=None, all elements). |
|
|
|
|
|
|
|
Examples: |
|
|
|
@@ -633,6 +633,13 @@ class SubsetSampler(BuiltinSampler): |
|
|
|
|
|
|
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] |
|
|
|
|
|
|
|
if num_samples is not None: |
|
|
|
if not isinstance(num_samples, int): |
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) |
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX: |
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" |
|
|
|
.format(0, validator.INT64_MAX)) |
|
|
|
|
|
|
|
if not isinstance(indices, str) and validator.is_iterable(indices): |
|
|
|
indices = _get_sample_ids_as_list(indices, num_samples) |
|
|
|
elif isinstance(indices, int): |
|
|
|
@@ -645,13 +652,6 @@ class SubsetSampler(BuiltinSampler): |
|
|
|
raise TypeError("SubsetSampler: Type of indices element must be int, " |
|
|
|
"but got list[{}]: {}, type: {}.".format(i, item, type(item))) |
|
|
|
|
|
|
|
if num_samples is not None: |
|
|
|
if not isinstance(num_samples, int): |
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples)) |
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX: |
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!" |
|
|
|
.format(0, validator.INT64_MAX)) |
|
|
|
|
|
|
|
self.indices = indices |
|
|
|
super().__init__(num_samples) |
|
|
|
|
|
|
|
|