diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 319b67f678..be6ef76aa7 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -4874,6 +4874,12 @@ class CelebADataset(MappableDataset): self.shard_id = shard_id self.shuffle_level = shuffle + if usage != "all": + dir = os.path.realpath(self.dataset_dir) + partition_file = os.path.join(dir, "list_eval_partition.txt") + if os.path.exists(partition_file) is False: + raise RuntimeError("Partition file can not be found when usage is not 'all'.") + def get_args(self): args = super().get_args() args["dataset_dir"] = self.dataset_dir