Browse Source

add get_dataset_size to celebadataset

tags/v0.5.0-beta
yanghaitao 5 years ago
parent
commit
cc6c7a3f60
2 changed files with 30 additions and 0 deletions
  1. +25
    -0
      mindspore/dataset/engine/datasets.py
  2. +5
    -0
      tests/ut/python/dataset/test_datasets_celeba.py

+ 25
- 0
mindspore/dataset/engine/datasets.py View File

@@ -4009,6 +4009,31 @@ class CelebADataset(MappableDataset):
args["shard_id"] = self.shard_id
return args

def get_dataset_size(self):
"""
Get the number of batches in an epoch.

Return:
Number, number of batches.
"""
if self._dataset_size is None:
dir = os.path.realpath(self.dataset_dir)
attr_file = os.path.join(dir, "list_attr_celeba.txt")
num_rows = ''
try:
with open(attr_file, 'r') as f:
num_rows = int(f.readline())
except Exception:
raise RuntimeError("Get dataset size failed from attribution file.")
rows_per_shard = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None:
rows_per_shard = min(self.num_samples, rows_per_shard)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
return self._dataset_size

def is_shuffled(self):
if self.shuffle_level is None:
return True


+ 5
- 0
tests/ut/python/dataset/test_datasets_celeba.py View File

@@ -85,9 +85,14 @@ def test_celeba_dataset_distribute():
count = count + 1
assert count == 1

def test_celeba_get_dataset_size():
data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
size = data.get_dataset_size()
assert size == 2

if __name__ == '__main__':
test_celeba_dataset_label()
test_celeba_dataset_op()
test_celeba_dataset_ext()
test_celeba_dataset_distribute()
test_celeba_get_dataset_size()

Loading…
Cancel
Save