|
|
|
@@ -362,8 +362,8 @@ Status MnistOp::ParseMnistData() { |
|
|
|
} |
|
|
|
|
|
|
|
Status MnistOp::WalkAllFiles() { |
|
|
|
const std::string kImageExtension = "idx3-ubyte"; |
|
|
|
const std::string kLabelExtension = "idx1-ubyte"; |
|
|
|
const std::string img_ext = "idx3-ubyte"; |
|
|
|
const std::string lbl_ext = "idx1-ubyte"; |
|
|
|
const std::string train_prefix = "train"; |
|
|
|
const std::string test_prefix = "t10k"; |
|
|
|
|
|
|
|
@@ -374,13 +374,13 @@ Status MnistOp::WalkAllFiles() { |
|
|
|
if (dir_it != nullptr) { |
|
|
|
while (dir_it->hasNext()) { |
|
|
|
Path file = dir_it->next(); |
|
|
|
std::string filename = file.Basename(); |
|
|
|
if (filename.find(prefix + "-images-" + kImageExtension) != std::string::npos) { |
|
|
|
std::string fname = file.Basename(); // name of the mnist file |
|
|
|
if ((fname.find(prefix + "-images") != std::string::npos) && (fname.find(img_ext) != std::string::npos)) { |
|
|
|
image_names_.push_back(file.toString()); |
|
|
|
MS_LOG(INFO) << "Mnist operator found image file at " << filename << "."; |
|
|
|
} else if (filename.find(prefix + "-labels-" + kLabelExtension) != std::string::npos) { |
|
|
|
MS_LOG(INFO) << "Mnist operator found image file at " << fname << "."; |
|
|
|
} else if ((fname.find(prefix + "-labels") != std::string::npos) && (fname.find(lbl_ext) != std::string::npos)) { |
|
|
|
label_names_.push_back(file.toString()); |
|
|
|
MS_LOG(INFO) << "Mnist Operator found label file at " << filename << "."; |
|
|
|
MS_LOG(INFO) << "Mnist Operator found label file at " << fname << "."; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
|