|
|
|
@@ -24,6 +24,7 @@ from mindspore import log as logger |
|
|
|
|
|
|
|
DATA_DIR_10 = "../data/dataset/testCifar10Data" |
|
|
|
DATA_DIR_100 = "../data/dataset/testCifar100Data" |
|
|
|
NO_BIN_DIR = "../data/dataset/testMnistData" |
|
|
|
|
|
|
|
|
|
|
|
def load_cifar(path, kind="cifar10"): |
|
|
|
@@ -208,6 +209,12 @@ def test_cifar10_exception(): |
|
|
|
with pytest.raises(ValueError, match=error_msg_6): |
|
|
|
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88) |
|
|
|
|
|
|
|
error_msg_7 = "No .bin files found" |
|
|
|
with pytest.raises(RuntimeError, match=error_msg_7): |
|
|
|
ds1 = ds.Cifar10Dataset(NO_BIN_DIR) |
|
|
|
for _ in ds1.__iter__(): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def test_cifar10_visualize(plot=False): |
|
|
|
""" |
|
|
|
@@ -352,6 +359,12 @@ def test_cifar100_exception(): |
|
|
|
with pytest.raises(ValueError, match=error_msg_6): |
|
|
|
ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88) |
|
|
|
|
|
|
|
error_msg_7 = "No .bin files found" |
|
|
|
with pytest.raises(RuntimeError, match=error_msg_7): |
|
|
|
ds1 = ds.Cifar100Dataset(NO_BIN_DIR) |
|
|
|
for _ in ds1.__iter__(): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def test_cifar100_visualize(plot=False): |
|
|
|
""" |
|
|
|
|