Browse Source

added checking of first row crc to find invalid tfrecord files

addressed code review comments. added check in python layer to exclude directories and to raise an error if a pattern does not match any file

fixed clang format

fixed cppcheck

fixed cppcheck (used std::accumulate and std::copy_if). regenerated tfrecord file to contain correct header, it was a dummy header before

fixed cppcheck: added const reference for string parameter for lambdas, fixed clang format: whitespace adjustments

more clang whitespace fixes...

changed print to logger.info
tags/v0.2.0-alpha
Peilin Wang 5 years ago
parent
commit
9bc2134cb7
8 changed files with 127 additions and 11 deletions
  1. +51
    -6
      mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
  2. +13
    -4
      mindspore/dataset/engine/datasets.py
  3. +34
    -0
      tests/ut/cpp/dataset/tfReader_op_test.cc
  4. BIN
      tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data
  5. BIN
      tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data
  6. BIN
      tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data
  7. +1
    -0
      tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt
  8. +28
    -1
      tests/ut/python/dataset/test_tfreader_op.py

+ 51
- 6
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc View File

@@ -42,6 +42,7 @@
#include "dataset/util/status.h"
#include "dataset/util/task_manager.h"
#include "dataset/util/wait_post.h"
#include "utils/system/crc32c.h"

namespace mindspore {
namespace dataset {
@@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder()
builder_data_schema_ = std::make_unique<DataSchema>();
}

bool ValidateFirstRowCrc(const std::string &filename) {
std::ifstream reader;
reader.open(filename);
if (!reader) {
return false;
}

// read data
int64_t record_length = 0;
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));

// read crc from file
uint32_t masked_crc = 0;
(void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));

// generate crc from data
uint32_t generated_crc =
system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));

return masked_crc == generated_crc;
}

Status TFReaderOp::Builder::ValidateInputs() const {
std::string err_msg;
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : "";
if (!builder_equal_rows_per_shard_) {
err_msg += builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)
? "No enough tf_file files provided\n"
: "";

if (builder_num_workers_ <= 0) {
err_msg += "Number of parallel workers is smaller or equal to 0\n";
}

if (!builder_equal_rows_per_shard_ &&
builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) {
err_msg += "Not enough tfrecord files provided\n";
}

if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) {
err_msg += "Wrong sharding configs\n";
}
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";

std::vector<std::string> invalid_files(builder_dataset_files_list_.size());
auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(),
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); });
invalid_files.resize(std::distance(invalid_files.begin(), it));

if (!invalid_files.empty()) {
err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n";

std::string accumulated_filenames = std::accumulate(
invalid_files.begin(), invalid_files.end(), std::string(""),
[](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; });
err_msg += accumulated_filenames;
}

return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

@@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read));
rows_read++;
}

// ignore crc footer
(void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
rows_total++;


+ 13
- 4
mindspore/dataset/engine/datasets.py View File

@@ -900,13 +900,22 @@ class SourceDataset(Dataset):
List, files.
"""

def flat(lists):
return list(np.array(lists).flatten())

if not isinstance(patterns, list):
patterns = [patterns]

file_list = flat([glob.glob(file, recursive=True) for file in patterns])
file_list = []
unmatched_patterns = []
for pattern in patterns:
matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]

if matches:
file_list.extend(matches)
else:
unmatched_patterns.append(pattern)

if unmatched_patterns:
raise ValueError("The following patterns did not match any files: ", unmatched_patterns)

if file_list: # not empty
return file_list
raise ValueError("The list of path names matching the patterns is empty.")


+ 34
- 0
tests/ut/cpp/dataset/tfReader_op_test.cc View File

@@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) {
TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true);
ASSERT_EQ(total_rows, 60);
}

TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();

std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt";
std::string nonexistent_file = "this/file/doesnt/exist";

std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({invalid_file, valid_file, schema_file})
.SetRowsPerBuffer(16)
.SetNumWorkers(16);

std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(schema_file, {});
builder.SetDataSchema(std::move(schema));

Status rc = builder.Build(&my_tfreader_op);
ASSERT_TRUE(!rc.IsOk());

builder.SetDatasetFilesList({invalid_file, valid_file, schema_file, nonexistent_file})
.SetRowsPerBuffer(16)
.SetNumWorkers(16);

schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(schema_file, {});
builder.SetDataSchema(std::move(schema));

rc = builder.Build(&my_tfreader_op);
ASSERT_TRUE(!rc.IsOk());
}

BIN
tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data View File


BIN
tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data View File


BIN
tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data View File


+ 1
- 0
tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt View File

@@ -0,0 +1 @@
this is just a text file, not a valid tfrecord file.

+ 28
- 1
tests/ut/python/dataset/test_tfreader_op.py View File

@@ -32,7 +32,7 @@ def test_case_tf_shape():
ds1 = ds.TFRecordDataset(FILES, schema_file)
ds1 = ds1.batch(2)
for data in ds1.create_dict_iterator():
print(data)
logger.info(data)
output_shape = ds1.output_shapes()
assert (len(output_shape[-1]) == 1)

@@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list():
a = row["col_sint32"]
assert "col_sint32" in str(info.value)

def test_case_invalid_files():
valid_file = "../data/dataset/testTFTestAllTypes/test.data"
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
files = [invalid_file, valid_file, SCHEMA_FILE]

data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)

with pytest.raises(RuntimeError) as info:
row = data.create_dict_iterator().get_next()
assert "cannot be opened" in str(info.value)
assert "not valid tfrecord files" in str(info.value)
assert valid_file not in str(info.value)
assert invalid_file in str(info.value)
assert SCHEMA_FILE in str(info.value)

nonexistent_file = "this/file/does/not/exist"
files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]

with pytest.raises(ValueError) as info:
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert "did not match any files" in str(info.value)
assert valid_file not in str(info.value)
assert invalid_file not in str(info.value)
assert SCHEMA_FILE not in str(info.value)
assert nonexistent_file in str(info.value)

if __name__ == '__main__':
test_case_tf_shape()
test_case_tf_file()
@@ -212,3 +238,4 @@ if __name__ == '__main__':
test_tf_record_schema()
test_tf_record_shuffle()
test_tf_shard_equal_rows()
test_case_invalid_files()

Loading…
Cancel
Save