| @@ -3219,33 +3219,9 @@ class GeneratorDataset(MappableDataset): | |||
| def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, | |||
| num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.source = source | |||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||
| if self.sampler is not None and hasattr(source, "__getitem__"): | |||
| if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler, samplers.Sampler)): | |||
| sampler_instance = self.sampler.create() | |||
| sampler_instance.set_num_rows(len(source)) | |||
| sampler_instance.initialize() | |||
| if num_parallel_workers > 1: | |||
| self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) | |||
| else: | |||
| self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) | |||
| else: | |||
| if num_parallel_workers > 1: | |||
| self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers)) | |||
| else: | |||
| self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) | |||
| else: | |||
| try: | |||
| iter(source) | |||
| except TypeError: | |||
| # Use generator function if input callable | |||
| self.source = (lambda: _generator_fn(source, num_samples)) | |||
| else: | |||
| # Use iterator function if input is iterable | |||
| # Random accessible input is also iterable | |||
| self.source = (lambda: _iter_fn(source, num_samples)) | |||
| self.num_samples = num_samples | |||
| if column_names is not None and not isinstance(column_names, list): | |||
| column_names = [column_names] | |||
| @@ -3310,9 +3286,35 @@ class GeneratorDataset(MappableDataset): | |||
| new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) | |||
| new_op.column_types = copy.deepcopy(self.column_types, memodict) | |||
| new_op.column_names = copy.deepcopy(self.column_names, memodict) | |||
| new_op.source = self.source | |||
| new_op.sampler = self.sampler | |||
| new_op.num_samples = copy.deepcopy(self.num_samples, memodict) | |||
| new_op.sampler = copy.deepcopy(self.sampler) | |||
| if new_op.sampler is not None and hasattr(self.source, "__getitem__"): | |||
| if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler, samplers.Sampler)): | |||
| sampler_instance = new_op.sampler.create() | |||
| sampler_instance.set_num_rows(len(self.source)) | |||
| sampler_instance.initialize() | |||
| if new_op.num_parallel_workers > 1: | |||
| new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers)) | |||
| else: | |||
| new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) | |||
| else: | |||
| if new_op.num_parallel_workers > 1: | |||
| new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers)) | |||
| else: | |||
| new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) | |||
| else: | |||
| try: | |||
| iter(self.source) | |||
| except TypeError: | |||
| # Use generator function if input callable | |||
| new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples)) | |||
| else: | |||
| # Use iterator function if input is iterable | |||
| # Random accessible input is also iterable | |||
| new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples)) | |||
| return new_op | |||