addressed code review comments. added check in python layer to exclude directories and to raise an error if a pattern does not match any file fixed clang format fixed cppcheck fixed cppcheck (used std::accumulate and std::copy_if). regenerated tfrecord file to contain correct header, it was a dummy header before fixed cppcheck: added const reference for string parameter for lambdas, fixed clang format: whitespace adjustments more clang whitespace fixes... changed print to logger.infotags/v0.2.0-alpha
| @@ -42,6 +42,7 @@ | |||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| #include "dataset/util/task_manager.h" | #include "dataset/util/task_manager.h" | ||||
| #include "dataset/util/wait_post.h" | #include "dataset/util/wait_post.h" | ||||
| #include "utils/system/crc32c.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder() | |||||
| builder_data_schema_ = std::make_unique<DataSchema>(); | builder_data_schema_ = std::make_unique<DataSchema>(); | ||||
| } | } | ||||
| bool ValidateFirstRowCrc(const std::string &filename) { | |||||
| std::ifstream reader; | |||||
| reader.open(filename); | |||||
| if (!reader) { | |||||
| return false; | |||||
| } | |||||
| // read data | |||||
| int64_t record_length = 0; | |||||
| (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t))); | |||||
| // read crc from file | |||||
| uint32_t masked_crc = 0; | |||||
| (void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t))); | |||||
| // generate crc from data | |||||
| uint32_t generated_crc = | |||||
| system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t)); | |||||
| return masked_crc == generated_crc; | |||||
| } | |||||
| Status TFReaderOp::Builder::ValidateInputs() const { | Status TFReaderOp::Builder::ValidateInputs() const { | ||||
| std::string err_msg; | std::string err_msg; | ||||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : ""; | |||||
| if (!builder_equal_rows_per_shard_) { | |||||
| err_msg += builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_) | |||||
| ? "No enough tf_file files provided\n" | |||||
| : ""; | |||||
| if (builder_num_workers_ <= 0) { | |||||
| err_msg += "Number of parallel workers is smaller or equal to 0\n"; | |||||
| } | |||||
| if (!builder_equal_rows_per_shard_ && | |||||
| builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) { | |||||
| err_msg += "Not enough tfrecord files provided\n"; | |||||
| } | |||||
| if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { | |||||
| err_msg += "Wrong sharding configs\n"; | |||||
| } | } | ||||
| err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; | |||||
| std::vector<std::string> invalid_files(builder_dataset_files_list_.size()); | |||||
| auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), | |||||
| [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); | |||||
| invalid_files.resize(std::distance(invalid_files.begin(), it)); | |||||
| if (!invalid_files.empty()) { | |||||
| err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; | |||||
| std::string accumulated_filenames = std::accumulate( | |||||
| invalid_files.begin(), invalid_files.end(), std::string(""), | |||||
| [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); | |||||
| err_msg += accumulated_filenames; | |||||
| } | |||||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | ||||
| } | } | ||||
| @@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off | |||||
| RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); | RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); | ||||
| rows_read++; | rows_read++; | ||||
| } | } | ||||
| // ignore crc footer | // ignore crc footer | ||||
| (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t))); | (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t))); | ||||
| rows_total++; | rows_total++; | ||||
| @@ -900,13 +900,22 @@ class SourceDataset(Dataset): | |||||
| List, files. | List, files. | ||||
| """ | """ | ||||
| def flat(lists): | |||||
| return list(np.array(lists).flatten()) | |||||
| if not isinstance(patterns, list): | if not isinstance(patterns, list): | ||||
| patterns = [patterns] | patterns = [patterns] | ||||
| file_list = flat([glob.glob(file, recursive=True) for file in patterns]) | |||||
| file_list = [] | |||||
| unmatched_patterns = [] | |||||
| for pattern in patterns: | |||||
| matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)] | |||||
| if matches: | |||||
| file_list.extend(matches) | |||||
| else: | |||||
| unmatched_patterns.append(pattern) | |||||
| if unmatched_patterns: | |||||
| raise ValueError("The following patterns did not match any files: ", unmatched_patterns) | |||||
| if file_list: # not empty | if file_list: # not empty | ||||
| return file_list | return file_list | ||||
| raise ValueError("The list of path names matching the patterns is empty.") | raise ValueError("The list of path names matching the patterns is empty.") | ||||
| @@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) { | |||||
| TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true); | TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true); | ||||
| ASSERT_EQ(total_rows, 60); | ASSERT_EQ(total_rows, 60); | ||||
| } | } | ||||
| TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) { | |||||
| // Start with an empty execution tree | |||||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||||
| std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||||
| std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; | |||||
| std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt"; | |||||
| std::string nonexistent_file = "this/file/doesnt/exist"; | |||||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||||
| TFReaderOp::Builder builder; | |||||
| builder.SetDatasetFilesList({invalid_file, valid_file, schema_file}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetNumWorkers(16); | |||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||||
| schema->LoadSchemaFile(schema_file, {}); | |||||
| builder.SetDataSchema(std::move(schema)); | |||||
| Status rc = builder.Build(&my_tfreader_op); | |||||
| ASSERT_TRUE(!rc.IsOk()); | |||||
| builder.SetDatasetFilesList({invalid_file, valid_file, schema_file, nonexistent_file}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetNumWorkers(16); | |||||
| schema = std::make_unique<DataSchema>(); | |||||
| schema->LoadSchemaFile(schema_file, {}); | |||||
| builder.SetDataSchema(std::move(schema)); | |||||
| rc = builder.Build(&my_tfreader_op); | |||||
| ASSERT_TRUE(!rc.IsOk()); | |||||
| } | |||||
| @@ -0,0 +1 @@ | |||||
| this is just a text file, not a valid tfrecord file. | |||||
| @@ -32,7 +32,7 @@ def test_case_tf_shape(): | |||||
| ds1 = ds.TFRecordDataset(FILES, schema_file) | ds1 = ds.TFRecordDataset(FILES, schema_file) | ||||
| ds1 = ds1.batch(2) | ds1 = ds1.batch(2) | ||||
| for data in ds1.create_dict_iterator(): | for data in ds1.create_dict_iterator(): | ||||
| print(data) | |||||
| logger.info(data) | |||||
| output_shape = ds1.output_shapes() | output_shape = ds1.output_shapes() | ||||
| assert (len(output_shape[-1]) == 1) | assert (len(output_shape[-1]) == 1) | ||||
| @@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list(): | |||||
| a = row["col_sint32"] | a = row["col_sint32"] | ||||
| assert "col_sint32" in str(info.value) | assert "col_sint32" in str(info.value) | ||||
| def test_case_invalid_files(): | |||||
| valid_file = "../data/dataset/testTFTestAllTypes/test.data" | |||||
| invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" | |||||
| files = [invalid_file, valid_file, SCHEMA_FILE] | |||||
| data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| row = data.create_dict_iterator().get_next() | |||||
| assert "cannot be opened" in str(info.value) | |||||
| assert "not valid tfrecord files" in str(info.value) | |||||
| assert valid_file not in str(info.value) | |||||
| assert invalid_file in str(info.value) | |||||
| assert SCHEMA_FILE in str(info.value) | |||||
| nonexistent_file = "this/file/does/not/exist" | |||||
| files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file] | |||||
| with pytest.raises(ValueError) as info: | |||||
| data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||||
| assert "did not match any files" in str(info.value) | |||||
| assert valid_file not in str(info.value) | |||||
| assert invalid_file not in str(info.value) | |||||
| assert SCHEMA_FILE not in str(info.value) | |||||
| assert nonexistent_file in str(info.value) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_case_tf_shape() | test_case_tf_shape() | ||||
| test_case_tf_file() | test_case_tf_file() | ||||
| @@ -212,3 +238,4 @@ if __name__ == '__main__': | |||||
| test_tf_record_schema() | test_tf_record_schema() | ||||
| test_tf_record_shuffle() | test_tf_record_shuffle() | ||||
| test_tf_shard_equal_rows() | test_tf_shard_equal_rows() | ||||
| test_case_invalid_files() | |||||