|
|
|
@@ -113,6 +113,20 @@ def test_manifest_dataset_multi_label_onehot(): |
|
|
|
count = count + 1 |
|
|
|
|
|
|
|
|
|
|
|
def test_manifest_dataset_get_num_class(): |
|
|
|
data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False) |
|
|
|
assert data.num_classes() == 3 |
|
|
|
|
|
|
|
padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}] |
|
|
|
padded_ds = ds.PaddedDataset(padded_samples) |
|
|
|
|
|
|
|
data = data.repeat(2) |
|
|
|
padded_ds = padded_ds.repeat(2) |
|
|
|
|
|
|
|
data1 = data + padded_ds |
|
|
|
assert data1.num_classes() == 3 |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_manifest_dataset_train() |
|
|
|
test_manifest_dataset_eval() |
|
|
|
@@ -120,3 +134,4 @@ if __name__ == '__main__': |
|
|
|
test_manifest_dataset_get_class_index() |
|
|
|
test_manifest_dataset_multi_label() |
|
|
|
test_manifest_dataset_multi_label_onehot() |
|
|
|
test_manifest_dataset_get_num_class() |