|
|
|
@@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples(): |
|
|
|
count += 1 |
|
|
|
assert count == 2 |
|
|
|
|
|
|
|
def test_celeba_padded(): |
|
|
|
data = ds.CelebADataset("../data/dataset/testCelebAData/") |
|
|
|
|
|
|
|
padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}] |
|
|
|
padded_ds = ds.PaddedDataset(padded_samples) |
|
|
|
data = data + padded_ds |
|
|
|
dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None) |
|
|
|
data.use_sampler(dis_sampler) |
|
|
|
data = data.repeat(2) |
|
|
|
|
|
|
|
count = 0 |
|
|
|
for _ in data.create_dict_iterator(): |
|
|
|
count = count + 1 |
|
|
|
assert count == 2 |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_TFRecord_Padded() |
|
|
|
test_GeneratorDataSet_Padded() |
|
|
|
|