|
|
|
@@ -22,15 +22,54 @@ |
|
|
|
#include <utility> |
|
|
|
#include <vector> |
|
|
|
|
|
|
|
#include "minddata/dataset/engine/jagged_connector.h" |
|
|
|
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" |
|
|
|
|
|
|
|
#include "minddata/dataset/engine/jagged_connector.h" |
|
|
|
#include "minddata/dataset/util/status.h" |
|
|
|
#include "utils/system/crc32c.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace dataset { |
|
|
|
namespace api { |
|
|
|
|
|
|
|
bool ValidateFirstRowCrc(const std::string &filename) { |
|
|
|
std::ifstream reader; |
|
|
|
reader.open(filename); |
|
|
|
if (!reader) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// read data |
|
|
|
int64_t record_length = 0; |
|
|
|
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t))); |
|
|
|
|
|
|
|
// read crc from file |
|
|
|
uint32_t masked_crc = 0; |
|
|
|
(void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t))); |
|
|
|
|
|
|
|
// generate crc from data |
|
|
|
uint32_t generated_crc = |
|
|
|
system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t)); |
|
|
|
|
|
|
|
return masked_crc == generated_crc; |
|
|
|
} |
|
|
|
|
|
|
|
// Validator for TFRecordNode |
|
|
|
Status TFRecordNode::ValidateParams() { return Status::OK(); } |
|
|
|
Status TFRecordNode::ValidateParams() { |
|
|
|
std::vector<std::string> invalid_files(dataset_files_.size()); |
|
|
|
auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), |
|
|
|
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); |
|
|
|
invalid_files.resize(std::distance(invalid_files.begin(), it)); |
|
|
|
std::string err_msg; |
|
|
|
if (!invalid_files.empty()) { |
|
|
|
err_msg += "Invalid file, the following files either cannot be opened, or are not valid tfrecord files:\n"; |
|
|
|
|
|
|
|
std::string accumulated_filenames = std::accumulate( |
|
|
|
invalid_files.begin(), invalid_files.end(), std::string(""), |
|
|
|
[](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); |
|
|
|
err_msg += accumulated_filenames; |
|
|
|
} |
|
|
|
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
// Function to build TFRecordNode |
|
|
|
std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { |
|
|
|
|