|
|
|
@@ -3433,6 +3433,7 @@ class GeneratorDataset(MappableDataset): |
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, |
|
|
|
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) |
|
|
|
self.source = source |
|
|
|
self.prepared_source = None # source to be sent to C++ |
|
|
|
|
|
|
|
self.python_multiprocessing = python_multiprocessing |
|
|
|
|
|
|
|
@@ -3463,9 +3464,9 @@ class GeneratorDataset(MappableDataset): |
|
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): |
|
|
|
if new_op.num_parallel_workers > 1: |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) |
|
|
|
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) |
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) |
|
|
|
else: |
|
|
|
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) |
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) |
|
|
|
new_op.sample_fn = sample_fn |
|
|
|
else: |
|
|
|
try: |
|
|
|
@@ -3476,11 +3477,11 @@ class GeneratorDataset(MappableDataset): |
|
|
|
iter(self.source) |
|
|
|
except TypeError: |
|
|
|
# Use generator function if input callable |
|
|
|
new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples)) |
|
|
|
new_op.prepared_source = (lambda: _generator_fn(self.source, new_op.num_samples)) |
|
|
|
else: |
|
|
|
# Use iterator function if input is iterable |
|
|
|
# Random accessible input is also iterable |
|
|
|
new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples)) |
|
|
|
new_op.prepared_source = (lambda: _iter_fn(self.source, new_op.num_samples)) |
|
|
|
|
|
|
|
return new_op |
|
|
|
|
|
|
|
@@ -3492,12 +3493,12 @@ class GeneratorDataset(MappableDataset): |
|
|
|
|
|
|
|
def parse(self, children=None): |
|
|
|
if self.schema is None: |
|
|
|
return cde.GeneratorNode(self.source, self.column_names, self.column_types, self.source_len, |
|
|
|
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, |
|
|
|
self.sampler) |
|
|
|
schema = self.schema |
|
|
|
if isinstance(schema, Schema): |
|
|
|
schema = self.schema.cpp_schema |
|
|
|
return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler) |
|
|
|
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler) |
|
|
|
|
|
|
|
|
|
|
|
class TFRecordDataset(SourceDataset): |
|
|
|
|