|
|
|
@@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset): |
|
|
|
>>> new_sampler = ds.DistributedSampler(10, 2) |
|
|
|
>>> data.use_sampler(new_sampler) |
|
|
|
""" |
|
|
|
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): |
|
|
|
raise TypeError("new_sampler is not an instance of a sampler.") |
|
|
|
if new_sampler is None: |
|
|
|
raise TypeError("Input sampler could not be None.") |
|
|
|
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): |
|
|
|
raise TypeError("Input sampler is not an instance of a sampler.") |
|
|
|
|
|
|
|
self.sampler = self.sampler.child_sampler |
|
|
|
self.add_sampler(new_sampler) |
|
|
|
@@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset): |
|
|
|
Return: |
|
|
|
Number, number of batches. |
|
|
|
""" |
|
|
|
if self.num_samples is None: |
|
|
|
num_samples = 0 |
|
|
|
else: |
|
|
|
num_samples = self.num_samples |
|
|
|
|
|
|
|
if self.class_indexing is None: |
|
|
|
class_indexing = dict() |
|
|
|
else: |
|
|
|
class_indexing = self.class_indexing |
|
|
|
|
|
|
|
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) |
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
return self.num_samples |
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
return min(rows_from_sampler, self.num_samples) |
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
def get_class_indexing(self): |
|
|
|
""" |
|
|
|
|