|
|
|
@@ -4940,6 +4940,12 @@ class CelebADataset(MappableDataset): |
|
|
|
self.shard_id = shard_id |
|
|
|
self.shuffle_level = shuffle |
|
|
|
|
|
|
|
if usage != "all": |
|
|
|
dir = os.path.realpath(self.dataset_dir) |
|
|
|
partition_file = os.path.join(dir, "list_eval_partition.txt") |
|
|
|
if os.path.exists(partition_file) is False: |
|
|
|
raise RuntimeError("Partition file can not be found when usage is not 'all'.") |
|
|
|
|
|
|
|
def get_args(self): |
|
|
|
args = super().get_args() |
|
|
|
args["dataset_dir"] = self.dataset_dir |
|
|
|
|