|
|
@@ -4009,6 +4009,31 @@ class CelebADataset(MappableDataset): |
|
|
args["shard_id"] = self.shard_id |
|
|
args["shard_id"] = self.shard_id |
|
|
return args |
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_size(self): |
|
|
|
|
|
""" |
|
|
|
|
|
Get the number of batches in an epoch. |
|
|
|
|
|
|
|
|
|
|
|
Return: |
|
|
|
|
|
Number, number of batches. |
|
|
|
|
|
""" |
|
|
|
|
|
if self._dataset_size is None: |
|
|
|
|
|
dir = os.path.realpath(self.dataset_dir) |
|
|
|
|
|
attr_file = os.path.join(dir, "list_attr_celeba.txt") |
|
|
|
|
|
num_rows = '' |
|
|
|
|
|
try: |
|
|
|
|
|
with open(attr_file, 'r') as f: |
|
|
|
|
|
num_rows = int(f.readline()) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
raise RuntimeError("Get dataset size failed from attribution file.") |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples is not None: |
|
|
|
|
|
rows_per_shard = min(self.num_samples, rows_per_shard) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
if self.shuffle_level is None: |
|
|
if self.shuffle_level is None: |
|
|
return True |
|
|
return True |
|
|
|