diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 0723abbd9d..d0fbedfd35 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -874,45 +874,118 @@ std::vector CsvOp::split(const std::string &s, char delim) { Status CsvOp::ComputeColMap() { // Set the column name mapping (base class field) if (column_name_id_map_.empty()) { - if (column_name_list_.empty()) { + if (!ColumnNameValidate()) { + RETURN_STATUS_UNEXPECTED("Fail to validate column name for input CSV file list"); + } + + for (auto &csv_file : csv_files_list_) { + Status rc = ColMapAnalyse(csv_file); + + /* Process exception if ERROR in column name solving*/ + if (!rc.IsOk()) { + MS_LOG(ERROR) << "Fail to analyse column name map, invalid file: " + csv_file; + RETURN_STATUS_UNEXPECTED("Fail to analyse column name map, invalid file: " + csv_file); + } + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + + if (column_default_list_.size() < column_name_id_map_.size()) { + for (int32_t i = column_default_list_.size(); i < column_name_id_map_.size(); i++) { + column_default_list_.push_back(std::make_shared>(CsvOp::STRING, "")); + } + } + + if (column_default_list_.size() != column_name_id_map_.size()) { + RETURN_STATUS_UNEXPECTED( + "Invalid parameter, the number of column names does not match the column defaults, column_default_list: " + + std::to_string(column_default_list_.size()) + + ", column_name_id_map: " + std::to_string(column_name_id_map_.size())); + } + + return Status::OK(); +} + +Status CsvOp::ColMapAnalyse(const std::string &csv_file_name) { + if (column_name_list_.empty()) { + // Actually we only deal with the first file, because the column name set in other files must remain the same + if (!check_flag_) { std::string line; - std::ifstream handle(csv_files_list_[0]); + std::ifstream handle(csv_file_name); + getline(handle, line); std::vector col_names = split(line, field_delim_); + for (int32_t i = 0; i < col_names.size(); i++) { - // consider the case of CRLF + // consider the case of CRLF on windows col_names[i].erase(col_names[i].find_last_not_of('\r') + 1); + if (column_name_id_map_.find(col_names[i]) == column_name_id_map_.end()) { column_name_id_map_[col_names[i]] = i; } else { - RETURN_STATUS_UNEXPECTED("Invalid parameter, duplicate column names are not allowed: " + col_names[i]); + MS_LOG(ERROR) << "Invalid parameter, duplicate column names are not allowed: " + col_names[i] + + ", The corresponding data files: " + csv_file_name; + + RETURN_STATUS_UNEXPECTED("Invalid parameter, duplicate column names are not allowed: " + col_names[i] + + ", The corresponding data files: " + csv_file_name); } } - } else { - for (int32_t i = 0; i < column_name_list_.size(); i++) { + check_flag_ = true; + } + } else { + if (!check_flag_) { // Case the first CSV file, validate the column names + for (int32_t i = 0; i < column_name_list_.size(); ++i) { if (column_name_id_map_.find(column_name_list_[i]) == column_name_id_map_.end()) { column_name_id_map_[column_name_list_[i]] = i; } else { + MS_LOG(ERROR) << "Invalid parameter, duplicate column names are not allowed: " + column_name_list_[i] + + ", The corresponding data files: " + csv_file_name; + RETURN_STATUS_UNEXPECTED("Invalid parameter, duplicate column names are not allowed: " + - column_name_list_[i]); + column_name_list_[i] + ", The corresponding data files: " + csv_file_name); } } + check_flag_ = true; } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; } - if (column_default_list_.size() < column_name_id_map_.size()) { - for (int32_t i = column_default_list_.size(); i < column_name_id_map_.size(); i++) { - column_default_list_.push_back(std::make_shared>(CsvOp::STRING, "")); + return Status::OK(); +} + +bool CsvOp::ColumnNameValidate() { + /* Case 1: Users specify the column_names */ + if (!column_name_list_.empty()) { + return true; + } + + /* Case 2: Inferring the column_names from the first row of CSV files + \\ record: the column name set in first CSV file. + \\ match_file: First file same */ + std::vector record; + std::string match_file; + + for (auto &csv_file : csv_files_list_) { + std::string line; + std::ifstream handle(csv_file); + + // Parse the csv_file into column name set + getline(handle, line); + std::vector col_names = split(line, field_delim_); + + /* Analyse the column name and draw a conclusion*/ + if (record.empty()) { // Case the first file + record = col_names; + match_file = csv_file; + } else { // Case the other files + if (col_names != record) { + MS_LOG(ERROR) + << "Every corresponding column name must be identical, either element or permutation. Invalid files are: " + + match_file + " and " + csv_file; + return false; + } } } - if (column_default_list_.size() != column_name_id_map_.size()) { - RETURN_STATUS_UNEXPECTED( - "Invalid parameter, the number of column names does not match the column defaults, column_default_list: " + - std::to_string(column_default_list_.size()) + - ", column_name_id_map: " + std::to_string(column_name_id_map_.size())); - } - return Status::OK(); + return true; } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index d101c56545..4027b15fd4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -386,6 +386,14 @@ class CsvOp : public ParallelOp { // @return - the a string vector std::vector split(const std::string &s, char delim); + // Private function for analysing the column name in every CSV file + // @return - Status + Status ColMapAnalyse(const std::string &csv_file_name); + + // Private function for validating whether the column name set in every CSV file remain the same + // @return bool - whether column name identical in all CSV files + bool ColumnNameValidate(); + int32_t device_id_; bool shuffle_files_; bool finished_reading_dataset_; @@ -405,6 +413,7 @@ class CsvOp : public ParallelOp { char field_delim_; std::vector> column_default_list_; std::vector column_name_list_; + bool check_flag_ = false; }; } // namespace dataset } // namespace mindspore