|
|
|
@@ -3780,6 +3780,18 @@ class GeneratorDataset(MappableDataset): |
|
|
|
self.schema = schema |
|
|
|
if not isinstance(schema, Schema): |
|
|
|
self.schema = Schema(schema) |
|
|
|
# Move get dataset_size by len from parse to here, because self.source will |
|
|
|
# lose attribution of '__len__' after deepcopy. |
|
|
|
self.dataset_size = None |
|
|
|
if hasattr(self.source, "__len__"): |
|
|
|
if not self.num_shards: |
|
|
|
self.dataset_size = len(self.source) |
|
|
|
else: |
|
|
|
self.dataset_size = math.ceil(len(self.source) / self.num_shards) |
|
|
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
|
if id(self) in memodict: |
|
|
|
@@ -3838,24 +3850,16 @@ class GeneratorDataset(MappableDataset): |
|
|
|
return self.sampler.is_sharded() |
|
|
|
|
|
|
|
def parse(self, children=None): |
|
|
|
dataset_size = -1 |
|
|
|
if hasattr(self.source, "__len__"): |
|
|
|
if not self.num_shards: |
|
|
|
dataset_size = len(self.source) |
|
|
|
else: |
|
|
|
dataset_size = math.ceil(len(self.source) / self.num_shards) |
|
|
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
if rows_from_sampler is not None and rows_from_sampler < dataset_size: |
|
|
|
dataset_size = rows_from_sampler |
|
|
|
if self.dataset_size is None: |
|
|
|
self.dataset_size = -1 |
|
|
|
if self.schema is None: |
|
|
|
return cde.GeneratorNode(self.source, self.column_names, self.column_types).SetGeneratorDatasetSize( |
|
|
|
dataset_size) \ |
|
|
|
self.dataset_size) \ |
|
|
|
.SetNumWorkers(self.num_parallel_workers) |
|
|
|
schema = self.schema |
|
|
|
if isinstance(schema, Schema): |
|
|
|
schema = self.schema.cpp_schema |
|
|
|
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(dataset_size).SetNumWorkers( |
|
|
|
return cde.GeneratorNode(self.source, schema).SetGeneratorDatasetSize(self.dataset_size).SetNumWorkers( |
|
|
|
self.num_parallel_workers) |
|
|
|
|
|
|
|
|
|
|
|
|