|
|
|
@@ -3157,6 +3157,9 @@ class GeneratorDataset(MappableDataset): |
|
|
|
self.column_names.append(col["name"]) |
|
|
|
self.column_types.append(DataType(col["type"])) |
|
|
|
|
|
|
|
if source is not None and hasattr(source, "__len__"): |
|
|
|
self._dataset_size = len(source) |
|
|
|
|
|
|
|
def get_args(self): |
|
|
|
args = super().get_args() |
|
|
|
args["source"] = self.source |
|
|
|
@@ -3177,6 +3180,7 @@ class GeneratorDataset(MappableDataset): |
|
|
|
return self._dataset_size |
|
|
|
if self._dataset_size is None: |
|
|
|
return None |
|
|
|
|
|
|
|
return min(rows_from_sampler, self._dataset_size) |
|
|
|
|
|
|
|
# manually set dataset_size as a temporary solution. |
|
|
|
|