diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index cf09e3a067..14797e568b 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -83,12 +83,12 @@ class DatasetHelper: class _DatasetIter: """Base iter for dataset help""" def __init__(self, dataset): - self.loop_size = 1 + if not hasattr(dataset, '__loop_size__'): + self.loop_size = dataset.get_dataset_size() + else: + self.loop_size = dataset.__loop_size__ + if not hasattr(dataset, '__ME_INITED__'): - if not hasattr(dataset, '__loop_size__'): - self.loop_size = dataset.get_dataset_size() - else: - self.loop_size = dataset.__loop_size__ dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name