|
|
|
@@ -226,6 +226,27 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): |
|
|
|
logger.info("Got an exception in DE: {}".format(str(e))) |
|
|
|
assert "num_ops" in str(e) |
|
|
|
|
|
|
|
def test_cpp_uniform_augment_random_crop_ut(): |
|
|
|
batch_size=2 |
|
|
|
cifar10_dir = "../data/dataset/testCifar10Data" |
|
|
|
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3] |
|
|
|
|
|
|
|
transforms_ua = [ |
|
|
|
C.RandomCrop(size=[224, 224]), |
|
|
|
C.RandomHorizontalFlip() |
|
|
|
] |
|
|
|
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=1) |
|
|
|
ds1 = ds1.map(input_columns="image", operations=uni_aug) |
|
|
|
|
|
|
|
# apply DatasetOps |
|
|
|
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1) |
|
|
|
num_batches = 0 |
|
|
|
try: |
|
|
|
for data in ds1.create_dict_iterator(): |
|
|
|
num_batches += 1 |
|
|
|
except BaseException as e: |
|
|
|
assert "Crop size" in str(e) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_uniform_augment(num_ops=1) |
|
|
|
@@ -233,3 +254,4 @@ if __name__ == "__main__": |
|
|
|
test_cpp_uniform_augment_exception_pyops(num_ops=1) |
|
|
|
test_cpp_uniform_augment_exception_large_numops(num_ops=6) |
|
|
|
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0) |
|
|
|
test_cpp_uniform_augment_random_crop_ut() |