| @@ -3157,6 +3157,9 @@ class GeneratorDataset(MappableDataset): | |||||
| self.column_names.append(col["name"]) | self.column_names.append(col["name"]) | ||||
| self.column_types.append(DataType(col["type"])) | self.column_types.append(DataType(col["type"])) | ||||
| if source is not None and hasattr(source, "__len__"): | |||||
| self._dataset_size = len(source) | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["source"] = self.source | args["source"] = self.source | ||||
| @@ -3177,6 +3180,7 @@ class GeneratorDataset(MappableDataset): | |||||
| return self._dataset_size | return self._dataset_size | ||||
| if self._dataset_size is None: | if self._dataset_size is None: | ||||
| return None | return None | ||||
| return min(rows_from_sampler, self._dataset_size) | return min(rows_from_sampler, self._dataset_size) | ||||
| # manually set dataset_size as a temporary solution. | # manually set dataset_size as a temporary solution. | ||||