|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd. |
|
|
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd. |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@@ -20,6 +20,10 @@ DATA_DIR = "../data/dataset/testCelebAData/" |
|
|
|
|
|
|
|
|
|
|
|
def test_celeba_dataset_label(): |
|
|
|
""" |
|
|
|
Test CelebA dataset with labels |
|
|
|
""" |
|
|
|
logger.info("Test CelebA labels") |
|
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) |
|
|
|
expect_labels = [ |
|
|
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, |
|
|
|
@@ -43,6 +47,10 @@ def test_celeba_dataset_label(): |
|
|
|
|
|
|
|
|
|
|
|
def test_celeba_dataset_op(): |
|
|
|
""" |
|
|
|
Test CelebA dataset with decode |
|
|
|
""" |
|
|
|
logger.info("Test CelebA with decode") |
|
|
|
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0) |
|
|
|
crop_size = (80, 80) |
|
|
|
resize_size = (24, 24) |
|
|
|
@@ -62,6 +70,10 @@ def test_celeba_dataset_op(): |
|
|
|
|
|
|
|
|
|
|
|
def test_celeba_dataset_ext(): |
|
|
|
""" |
|
|
|
Test CelebA dataset with extension |
|
|
|
""" |
|
|
|
logger.info("Test CelebA extension option") |
|
|
|
ext = [".JPEG"] |
|
|
|
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext) |
|
|
|
expect_labels = [ |
|
|
|
@@ -82,6 +94,10 @@ def test_celeba_dataset_ext(): |
|
|
|
|
|
|
|
|
|
|
|
def test_celeba_dataset_distribute(): |
|
|
|
""" |
|
|
|
Test CelebA dataset with distributed options |
|
|
|
""" |
|
|
|
logger.info("Test CelebA with sharding") |
|
|
|
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0) |
|
|
|
count = 0 |
|
|
|
for item in data.create_dict_iterator(num_epochs=1): |
|
|
|
@@ -94,6 +110,10 @@ def test_celeba_dataset_distribute(): |
|
|
|
|
|
|
|
|
|
|
|
def test_celeba_get_dataset_size(): |
|
|
|
""" |
|
|
|
Test CelebA dataset get dataset size |
|
|
|
""" |
|
|
|
logger.info("Test CelebA get dataset size") |
|
|
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) |
|
|
|
size = data.get_dataset_size() |
|
|
|
assert size == 4 |
|
|
|
@@ -112,6 +132,10 @@ def test_celeba_get_dataset_size(): |
|
|
|
|
|
|
|
|
|
|
|
def test_celeba_dataset_exception_file_path(): |
|
|
|
""" |
|
|
|
Test CelebA dataset with bad file path |
|
|
|
""" |
|
|
|
logger.info("Test CelebA with bad file path") |
|
|
|
def exception_func(item): |
|
|
|
raise Exception("Error occur!") |
|
|
|
|
|
|
|
@@ -144,6 +168,20 @@ def test_celeba_dataset_exception_file_path(): |
|
|
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) |
|
|
|
|
|
|
|
|
|
|
|
def test_celeba_sampler_exception(): |
|
|
|
""" |
|
|
|
Test CelebA with bad sampler input |
|
|
|
""" |
|
|
|
logger.info("Test CelebA with bad sampler input") |
|
|
|
try: |
|
|
|
data = ds.CelebADataset(DATA_DIR, sampler="") |
|
|
|
for _ in data.create_dict_iterator(): |
|
|
|
pass |
|
|
|
assert False |
|
|
|
except TypeError as e: |
|
|
|
assert "Argument" in str(e) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_celeba_dataset_label() |
|
|
|
test_celeba_dataset_op() |
|
|
|
@@ -151,3 +189,4 @@ if __name__ == '__main__': |
|
|
|
test_celeba_dataset_distribute() |
|
|
|
test_celeba_get_dataset_size() |
|
|
|
test_celeba_dataset_exception_file_path() |
|
|
|
test_celeba_sampler_exception() |