| @@ -22,15 +22,54 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/jagged_connector.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | ||||
| #include "minddata/dataset/engine/jagged_connector.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "utils/system/crc32c.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| bool ValidateFirstRowCrc(const std::string &filename) { | |||||
| std::ifstream reader; | |||||
| reader.open(filename); | |||||
| if (!reader) { | |||||
| return false; | |||||
| } | |||||
| // read data | |||||
| int64_t record_length = 0; | |||||
| (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t))); | |||||
| // read crc from file | |||||
| uint32_t masked_crc = 0; | |||||
| (void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t))); | |||||
| // generate crc from data | |||||
| uint32_t generated_crc = | |||||
| system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t)); | |||||
| return masked_crc == generated_crc; | |||||
| } | |||||
| // Validator for TFRecordNode | // Validator for TFRecordNode | ||||
| Status TFRecordNode::ValidateParams() { return Status::OK(); } | |||||
| Status TFRecordNode::ValidateParams() { | |||||
| std::vector<std::string> invalid_files(dataset_files_.size()); | |||||
| auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), | |||||
| [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); | |||||
| invalid_files.resize(std::distance(invalid_files.begin(), it)); | |||||
| std::string err_msg; | |||||
| if (!invalid_files.empty()) { | |||||
| err_msg += "Invalid file, the following files either cannot be opened, or are not valid tfrecord files:\n"; | |||||
| std::string accumulated_filenames = std::accumulate( | |||||
| invalid_files.begin(), invalid_files.end(), std::string(""), | |||||
| [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); | |||||
| err_msg += accumulated_filenames; | |||||
| } | |||||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| // Function to build TFRecordNode | // Function to build TFRecordNode | ||||
| std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | ||||
| @@ -398,11 +398,9 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetShard) { | |||||
| // Create a TFRecord Dataset | // Create a TFRecord Dataset | ||||
| // Each file has two columns("image", "label") and 3 rows | // Each file has two columns("image", "label") and 3 rows | ||||
| std::vector<std::string> files = { | |||||
| datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", | |||||
| datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", | |||||
| datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data" | |||||
| }; | |||||
| std::vector<std::string> files = {datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data", | |||||
| datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data", | |||||
| datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0003.data"}; | |||||
| std::shared_ptr<Dataset> ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true); | std::shared_ptr<Dataset> ds1 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, true); | ||||
| EXPECT_NE(ds1, nullptr); | EXPECT_NE(ds1, nullptr); | ||||
| std::shared_ptr<Dataset> ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false); | std::shared_ptr<Dataset> ds2 = TFRecord(files, "", {}, 0, ShuffleMode::kFalse, 2, 1, false); | ||||
| @@ -505,3 +503,12 @@ TEST_F(MindDataTestPipeline, TestIncorrectTFSchemaObject) { | |||||
| // this will fail due to the incorrect schema used | // this will fail due to the incorrect schema used | ||||
| EXPECT_FALSE(itr->GetNextRow(&mp)); | EXPECT_FALSE(itr->GetNextRow(&mp)); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestIncorrectTFrecordFile) { | |||||
| std::string path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json"; | |||||
| std::shared_ptr<api::Dataset> ds = api::TFRecord({path}); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // the tf record file is incorrect, hence validate param will fail | |||||
| auto itr = ds->CreateIterator(); | |||||
| EXPECT_EQ(itr, nullptr); | |||||
| } | |||||