Browse Source

fix number of columns not match

tags/v0.7.0-beta
jiangzhiwen 5 years ago
parent
commit
2cc6b5cb52
3 changed files with 29 additions and 13 deletions
  1. +24
    -10
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
  2. +3
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
  3. +2
    -2
      tests/ut/python/dataset/test_datasets_csv.py

+ 24
- 10
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc View File

@@ -100,6 +100,10 @@ Status CsvOp::Init() {
int CsvOp::CsvParser::put_record(char c) { int CsvOp::CsvParser::put_record(char c) {
std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_); std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_);
std::shared_ptr<Tensor> t; std::shared_ptr<Tensor> t;
if (cur_col_ >= column_default_.size()) {
err_message_ = "Number of file columns does not match the default records";
return -1;
}
switch (column_default_[cur_col_]->type) { switch (column_default_[cur_col_]->type) {
case CsvOp::INT: case CsvOp::INT:
Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)); Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32));
@@ -116,6 +120,10 @@ int CsvOp::CsvParser::put_record(char c) {
Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar());
break; break;
} }
if (cur_col_ >= (*tensor_table_)[cur_row_].size()) {
err_message_ = "Number of file columns does not match the tensor table";
return -1;
}
(*tensor_table_)[cur_row_][cur_col_] = std::move(t); (*tensor_table_)[cur_row_][cur_col_] = std::move(t);
pos_ = 0; pos_ = 0;
cur_col_++; cur_col_++;
@@ -134,7 +142,11 @@ int CsvOp::CsvParser::put_row(char c) {
return 0; return 0;
} }


put_record(c);
int ret = put_record(c);
if (ret < 0) {
return ret;
}

total_rows_++; total_rows_++;
cur_row_++; cur_row_++;
cur_col_ = 0; cur_col_ = 0;
@@ -265,8 +277,7 @@ Status CsvOp::CsvParser::initCsvParser() {
[this](CsvParser &, char c) -> int { [this](CsvParser &, char c) -> int {
this->tensor_table_ = std::make_unique<TensorQTable>(); this->tensor_table_ = std::make_unique<TensorQTable>();
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
this->put_record(c);
return 0;
return this->put_record(c);
}}}, }}},
{{State::START_OF_FILE, Message::MS_QUOTE}, {{State::START_OF_FILE, Message::MS_QUOTE},
{State::QUOTE, {State::QUOTE,
@@ -367,8 +378,7 @@ Status CsvOp::CsvParser::initCsvParser() {
if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) { if (this->total_rows_ > this->start_offset_ && this->total_rows_ <= this->end_offset_) {
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
} }
this->put_record(c);
return 0;
return this->put_record(c);
}}}, }}},
{{State::END_OF_LINE, Message::MS_QUOTE}, {{State::END_OF_LINE, Message::MS_QUOTE},
{State::QUOTE, {State::QUOTE,
@@ -408,15 +418,16 @@ Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, cons
while (ifs.good()) { while (ifs.good()) {
char chr = ifs.get(); char chr = ifs.get();
if (csv_parser.processMessage(chr) != 0) { if (csv_parser.processMessage(chr) != 0) {
RETURN_STATUS_UNEXPECTED("Failed to parse CSV file " + file + ":" + std::to_string(csv_parser.total_rows_));
RETURN_STATUS_UNEXPECTED("Failed to parse file " + file + ":" + std::to_string(csv_parser.total_rows_ + 1) +
". error message: " + csv_parser.err_message_);
} }
} }
} catch (std::invalid_argument &ia) { } catch (std::invalid_argument &ia) {
std::string err_row = std::to_string(csv_parser.total_rows_);
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", invalid argument of " + std::string(ia.what()));
std::string err_row = std::to_string(csv_parser.total_rows_ + 1);
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", type does not match");
} catch (std::out_of_range &oor) { } catch (std::out_of_range &oor) {
std::string err_row = std::to_string(csv_parser.total_rows_);
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of Range error: " + std::string(oor.what()));
std::string err_row = std::to_string(csv_parser.total_rows_ + 1);
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of range");
} }
return Status::OK(); return Status::OK();
} }
@@ -763,6 +774,9 @@ Status CsvOp::ComputeColMap() {
column_default_list_.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, "")); column_default_list_.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ""));
} }
} }
if (column_default_list_.size() != column_name_id_map_.size()) {
RETURN_STATUS_UNEXPECTED("The number of column names does not match the column defaults");
}
return Status::OK(); return Status::OK();
} }
} // namespace dataset } // namespace dataset


+ 3
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h View File

@@ -76,7 +76,8 @@ class CsvOp : public ParallelOp {
cur_col_(0), cur_col_(0),
total_rows_(0), total_rows_(0),
start_offset_(0), start_offset_(0),
end_offset_(std::numeric_limits<int64_t>::max()) {
end_offset_(std::numeric_limits<int64_t>::max()),
err_message_("unkonw") {
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
initCsvParser(); initCsvParser();
} }
@@ -189,6 +190,7 @@ class CsvOp : public ParallelOp {
std::vector<char> str_buf_; std::vector<char> str_buf_;
std::unique_ptr<TensorQTable> tensor_table_; std::unique_ptr<TensorQTable> tensor_table_;
std::unique_ptr<DataBuffer> cur_buffer_; std::unique_ptr<DataBuffer> cur_buffer_;
std::string err_message_;
}; };


class Builder { class Builder {


+ 2
- 2
tests/ut/python/dataset/test_datasets_csv.py View File

@@ -205,7 +205,7 @@ def test_csv_dataset_exception():
with pytest.raises(Exception) as err: with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator(): for _ in data.create_dict_iterator():
pass pass
assert "Failed to parse CSV file" in str(err.value)
assert "Failed to parse file" in str(err.value)




def test_csv_dataset_type_error(): def test_csv_dataset_type_error():
@@ -218,7 +218,7 @@ def test_csv_dataset_type_error():
with pytest.raises(Exception) as err: with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator(): for _ in data.create_dict_iterator():
pass pass
assert "invalid argument of stoi" in str(err.value)
assert "type does not match" in str(err.value)




if __name__ == "__main__": if __name__ == "__main__":


Loading…
Cancel
Save