|
|
|
@@ -213,6 +213,23 @@ def test_raise_error(): |
|
|
|
ds3.use_sampler(testsampler) |
|
|
|
assert excinfo.type == 'ValueError' |
|
|
|
|
|
|
|
def test_imagefolder_error(): |
|
|
|
DATA_DIR = "../data/dataset/testPK/data" |
|
|
|
data = ds.ImageFolderDataset(DATA_DIR, num_samples=14) |
|
|
|
|
|
|
|
data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)}, |
|
|
|
{'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)}, |
|
|
|
{'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)}, |
|
|
|
{'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)}, |
|
|
|
{'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)}, |
|
|
|
{'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}] |
|
|
|
|
|
|
|
data2 = ds.PaddedDataset(data1) |
|
|
|
data3 = data + data2 |
|
|
|
with pytest.raises(ValueError) as excinfo: |
|
|
|
testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None) |
|
|
|
data3.use_sampler(testsampler) |
|
|
|
assert excinfo.type == 'ValueError' |
|
|
|
|
|
|
|
def test_imagefolder_padded(): |
|
|
|
DATA_DIR = "../data/dataset/testPK/data" |
|
|
|
|