diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 1648734704..dff443dc1e 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2569,6 +2569,8 @@ class GeneratorDataset(SourceDataset): # Random accessible input is also iterable self.source = (lambda: _iter_fn(source, num_samples)) + if column_names is not None and not isinstance(column_names, list): + column_names = [column_names] self.column_names = column_names if column_types is not None: