Browse Source

fix cifar stuck problem

tags/v0.7.0-beta
xiefangqi 5 years ago
parent
commit
e3e7820413
2 changed files with 16 additions and 0 deletions
  1. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
  2. +13
    -0
      tests/ut/python/dataset/test_datasets_cifarop.py

+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc View File

@@ -336,6 +336,9 @@ Status CifarOp::GetCifarFiles() {
std::string err_msg = "Unable to open directory " + dataset_directory.toString();
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (cifar_files_.size() == 0) {
RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_);
}
std::sort(cifar_files_.begin(), cifar_files_.end());
return Status::OK();
}


+ 13
- 0
tests/ut/python/dataset/test_datasets_cifarop.py View File

@@ -24,6 +24,7 @@ from mindspore import log as logger

DATA_DIR_10 = "../data/dataset/testCifar10Data"
DATA_DIR_100 = "../data/dataset/testCifar100Data"
NO_BIN_DIR = "../data/dataset/testMnistData"


def load_cifar(path, kind="cifar10"):
@@ -208,6 +209,12 @@ def test_cifar10_exception():
with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88)

error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar10Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass


def test_cifar10_visualize(plot=False):
"""
@@ -352,6 +359,12 @@ def test_cifar100_exception():
with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88)

error_msg_7 = "No .bin files found"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar100Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass


def test_cifar100_visualize(plot=False):
"""


Loading…
Cancel
Save