Merge pull request !3016 from jiangzhiwen/dataset/csvtags/v0.6.0-beta
| @@ -31,6 +31,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||
| @@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { | |||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | |||
| {kClue, &DEPipeline::ParseClueOp}, | |||
| {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}, | |||
| {kCsv, &DEPipeline::ParseCsvOp}, | |||
| {kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}}; | |||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | |||
| @@ -1848,6 +1850,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| files_list = ToStringVector(args["dataset_files"]); | |||
| (void)builder->SetCsvFilesList(files_list); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); | |||
| } | |||
| // Optional arguments | |||
| bool shuffle_required = false; | |||
| int64_t num_devices = 0; | |||
| std::vector<std::string> col_names; | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "shuffle_global") { | |||
| shuffle_required = ToBool(value); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_shards") { | |||
| num_devices = ToInt(value); | |||
| (void)builder->SetNumDevices(num_devices); | |||
| } else if (key == "shard_id") { | |||
| (void)builder->SetDeviceId(ToInt(value)); | |||
| } else if (key == "field_delim") { | |||
| (void)builder->SetFieldDelim(ToString(value)[0]); | |||
| } else if (key == "column_defaults") { | |||
| py::list py_object_list = py::reinterpret_borrow<py::list>(value); | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list; | |||
| for (auto l : py_object_list) { | |||
| std::string type_s = (std::string)py::str(l.get_type().attr("__name__")); | |||
| if (type_s == "int") { | |||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, ToInt(l))); | |||
| } else if (type_s == "float") { | |||
| column_default_list.push_back(std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, ToFloat(l))); | |||
| } else if (type_s == "str") { | |||
| column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ToString(l))); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Record type is not allowed"); | |||
| } | |||
| } | |||
| (void)builder->SetColumDefault(column_default_list); | |||
| } else if (key == "column_names") { | |||
| col_names = ToStringVector(value); | |||
| (void)builder->SetColumName(col_names); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<CsvOp> csv_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&csv_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); | |||
| *top = csv_op; | |||
| if (shuffle_required) { | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t shuffle_size = 0; | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset and then compute the shuffle size | |||
| RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(files_list, col_names.empty(), &num_rows)); | |||
| RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); | |||
| // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller | |||
| RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, csv_op, &shuffle_op)); | |||
| *top = shuffle_op; | |||
| *bottom = csv_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to inject a shuffle operator over top of the current operation being built. | |||
| Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, | |||
| std::shared_ptr<DatasetOp> *shuffle_op) { | |||
| @@ -73,6 +73,7 @@ enum OpName { | |||
| kClue, | |||
| kEpochCtrl, | |||
| kSentencePieceVocab, | |||
| kCsv | |||
| }; | |||
| // The C++ binder class that we expose to the python script. | |||
| @@ -201,6 +202,8 @@ class DEPipeline { | |||
| Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| private: | |||
| // Execution tree that links the dataset operators. | |||
| std::shared_ptr<ExecutionTree> tree_; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| @@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) { | |||
| return count; | |||
| }); | |||
| (void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp") | |||
| .def_static("get_num_rows", [](const py::list &files, bool csv_header) { | |||
| int64_t count = 0; | |||
| std::vector<std::string> filenames; | |||
| for (auto file : files) { | |||
| file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); | |||
| } | |||
| THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count)); | |||
| return count; | |||
| }); | |||
| (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | |||
| .def_static("get_num_rows", | |||
| [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| @@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) | |||
| .value("CELEBA", OpName::kCelebA) | |||
| .value("TEXTFILE", OpName::kTextFile) | |||
| .value("CLUE", OpName::kClue) | |||
| .value("EPOCHCTRL", OpName::kEpochCtrl); | |||
| .value("EPOCHCTRL", OpName::kEpochCtrl) | |||
| .value("CSV", OpName::kCsv) | |||
| .value("CLUE", OpName::kClue); | |||
| (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) | |||
| .value("DE_JIEBA_MIX", JiebaMode::kMix) | |||
| @@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| celeba_op.cc | |||
| text_file_op.cc | |||
| clue_op.cc | |||
| csv_op.cc | |||
| ) | |||
| set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| @@ -29,4 +30,4 @@ if (ENABLE_PYTHON) | |||
| ) | |||
| endif() | |||
| add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) | |||
| add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) | |||
| @@ -0,0 +1,757 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include <fstream> | |||
| #include <iomanip> | |||
| #include <stdexcept> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/jagged_connector.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CsvOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| builder_rows_per_buffer_ = config_manager->rows_per_buffer(); | |||
| builder_worker_connector_size_ = config_manager->worker_connector_size(); | |||
| } | |||
| Status CsvOp::Builder::ValidateInputs() const { | |||
| std::string err; | |||
| err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; | |||
| err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; | |||
| return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); | |||
| } | |||
| Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) { | |||
| RETURN_IF_NOT_OK(ValidateInputs()); | |||
| // Throttle the number of workers if we have more workers than files! | |||
| if (static_cast<size_t>(builder_num_workers_) > builder_csv_files_list_.size()) { | |||
| builder_num_workers_ = builder_csv_files_list_.size(); | |||
| MS_LOG(WARNING) << "CsvOp operator parallelism reduced to " << builder_num_workers_ << " workers."; | |||
| } | |||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||
| builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | |||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(csv_op->Init()); | |||
| *op = std::move(csv_op); | |||
| return Status::OK(); | |||
| } | |||
| CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | |||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | |||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | |||
| int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| csv_files_list_(std::move(csv_files_list)), | |||
| field_delim_(field_delim), | |||
| column_default_list_(column_default), | |||
| column_name_list_(column_name), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_rows_per_shard_(0), | |||
| all_num_rows_(0), | |||
| num_samples_(num_samples), | |||
| filename_index_(std::make_unique<StringIndex>()), | |||
| load_jagged_connector_(true), | |||
| shuffle_files_(shuffle_files), | |||
| finished_reading_dataset_(false), | |||
| num_devices_(num_device), | |||
| device_id_(device_id), | |||
| load_io_block_queue_(true) { | |||
| worker_connector_size_ = worker_connector_size; | |||
| } | |||
| Status CsvOp::Init() { | |||
| RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_)); | |||
| int32_t safe_queue_size = static_cast<int32_t>(std::ceil(csv_files_list_.size() / num_workers_) + 1); | |||
| io_block_queues_.Init(num_workers_, safe_queue_size); | |||
| RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); | |||
| jagged_buffer_connector_ = std::make_shared<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| int CsvOp::CsvParser::put_record(char c) { | |||
| std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_); | |||
| std::shared_ptr<Tensor> t; | |||
| switch (column_default_[cur_col_]->type) { | |||
| case CsvOp::INT: | |||
| Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)); | |||
| t->SetItemAt<int32_t>({0}, std::stoi(s)); | |||
| break; | |||
| case CsvOp::FLOAT: | |||
| Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)); | |||
| t->SetItemAt<float>({0}, std::stof(s)); | |||
| break; | |||
| case CsvOp::STRING: | |||
| Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); | |||
| break; | |||
| default: | |||
| Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); | |||
| break; | |||
| } | |||
| (*tensor_table_)[cur_row_][cur_col_] = std::move(t); | |||
| pos_ = 0; | |||
| cur_col_++; | |||
| return 0; | |||
| } | |||
| int CsvOp::CsvParser::put_row(char c) { | |||
| if (total_rows_ < start_offset_) { | |||
| total_rows_++; | |||
| cur_col_ = 0; | |||
| return 0; | |||
| } | |||
| if (total_rows_ >= end_offset_) { | |||
| return 0; | |||
| } | |||
| put_record(c); | |||
| total_rows_++; | |||
| cur_row_++; | |||
| cur_col_ = 0; | |||
| if (cur_row_ == csv_rows_per_buffer_) { | |||
| cur_buffer_->set_tensor_table(std::move(tensor_table_)); | |||
| buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); | |||
| cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| tensor_table_ = std::make_unique<TensorQTable>(); | |||
| cur_row_ = 0; | |||
| } | |||
| return 0; | |||
| } | |||
| int CsvOp::CsvParser::end_file(char c) { | |||
| if (cur_col_ > 0) { | |||
| put_row(c); | |||
| } | |||
| if (cur_row_ > 0) { | |||
| cur_buffer_->set_tensor_table(std::move(tensor_table_)); | |||
| buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); | |||
| } | |||
| return 0; | |||
| } | |||
| int CsvOp::CsvParser::countRows(char c) { | |||
| Message m; | |||
| if (c == '"') { | |||
| m = Message::MS_QUOTE; | |||
| } else if (c == '\r' || c == '\n' || c == std::char_traits<char>::eof()) { | |||
| m = Message::MS_END_OF_LINE; | |||
| } else { | |||
| m = Message::MS_NORMAL; | |||
| } | |||
| StateDiagram::iterator it = sdl.find({cur_state_, m}); | |||
| if (it == sd.end()) { | |||
| return -1; | |||
| } | |||
| cur_state_ = it->second.first; | |||
| return it->second.second(*this, c); | |||
| } | |||
| Status CsvOp::CsvParser::initCsvParser() { | |||
| str_buf_.resize(CSV_BUFFER_SIZE); | |||
| // State diagram for counting rows | |||
| sdl = {// START_OF_FILE | |||
| // ┌───────────┬───────────┬─────────────┐ | |||
| // │ abc │ " │ \n │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // | null_func │ null_func │ null_func │ | |||
| // └───────────┴───────────┴─────────────┘ | |||
| {{State::START_OF_FILE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||
| {{State::START_OF_FILE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||
| {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, | |||
| // UNQUOTE | |||
| // ┌───────────┬───────────┬─────────────┐ | |||
| // │ abc │ " │ \n │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // | null_func │ null_func │ add_row │ | |||
| // └───────────┴───────────┴─────────────┘ | |||
| {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||
| {{State::UNQUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||
| {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, | |||
| // QUOTE | |||
| // ┌───────────┬──────────────┬───────────┐ | |||
| // │ abc │ " │ \n │ | |||
| // ├───────────┼──────────────┼───────────┤ | |||
| // │ QUOTE │ SECOND_QUOTE │ QUOTE │ | |||
| // ├───────────┼──────────────┼───────────┤ | |||
| // | null_func │ null_func │ null_func │ | |||
| // └───────────┴──────────────┴───────────┘ | |||
| {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::null_func}}, | |||
| {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, | |||
| {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::null_func}}, | |||
| // SECOND_QUOTE | |||
| // ┌───────────┬───────────┬─────────────┐ | |||
| // │ abc │ " │ \n │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // | null_func │ null_func │ add_row │ | |||
| // └───────────┴───────────┴─────────────┘ | |||
| {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||
| {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||
| {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, | |||
| // END_OF_LINE | |||
| // ┌───────────┬───────────┬─────────────┐ | |||
| // │ abc │ " │ \n │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||
| // ├───────────┼───────────┼─────────────┤ | |||
| // | null_func │ null_func │ null_func │ | |||
| // └───────────┴───────────┴─────────────┘ | |||
| {{State::END_OF_LINE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||
| {{State::END_OF_LINE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||
| {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}}; | |||
| // State diagram for CSV parser | |||
| sd = {// START_OF_FILE | |||
| // ┌───────────┬──────────┬──────────┬────────────────┬────────────────┐ | |||
| // │ abc │ , │ " │ \n │ EOF │ | |||
| // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ | |||
| // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||
| // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ | |||
| // | lambda │ lambda │ lambda │ null_func │ null_func │ | |||
| // └───────────┴──────────┴──────────┴────────────────┴────────────────┘ | |||
| {{State::START_OF_FILE, Message::MS_NORMAL}, | |||
| {State::UNQUOTE, | |||
| [this](CsvParser &, char c) -> int { | |||
| this->tensor_table_ = std::make_unique<TensorQTable>(); | |||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||
| this->str_buf_[0] = c; | |||
| this->pos_ = 1; | |||
| return 0; | |||
| }}}, | |||
| {{State::START_OF_FILE, Message::MS_DELIM}, | |||
| {State::DELIM, | |||
| [this](CsvParser &, char c) -> int { | |||
| this->tensor_table_ = std::make_unique<TensorQTable>(); | |||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||
| this->put_record(c); | |||
| return 0; | |||
| }}}, | |||
| {{State::START_OF_FILE, Message::MS_QUOTE}, | |||
| {State::QUOTE, | |||
| [this](CsvParser &, char c) -> int { | |||
| this->tensor_table_ = std::make_unique<TensorQTable>(); | |||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||
| this->pos_ = 0; | |||
| return 0; | |||
| }}}, | |||
| {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, | |||
| {{State::START_OF_FILE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::null_func}}, | |||
| // UNQUOTE | |||
| // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ | |||
| // │ abc │ , │ " │ \n │ EOF │ | |||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||
| // │ UNQUOTE │ DELIM │ EXCEPTION │ END_OF_LINE │ END_OF_FILE │ | |||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||
| // | put_char │ put_record │ exception │ put_row │ end_file │ | |||
| // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ | |||
| {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, | |||
| {{State::UNQUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, | |||
| {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, | |||
| {{State::UNQUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, | |||
| // UNQUOTE-Exception | |||
| {{State::UNQUOTE, Message::MS_QUOTE}, {State::EXCEPTION, &CsvParser::catch_exception}}, | |||
| // DELIM | |||
| // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ | |||
| // │ abc │ , │ " │ \n │ EOF │ | |||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||
| // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||
| // | put_char │ put_record │ lambda │ put_row │ end_file │ | |||
| // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ | |||
| {{State::DELIM, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, | |||
| {{State::DELIM, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, | |||
| {{State::DELIM, Message::MS_QUOTE}, | |||
| {State::QUOTE, | |||
| [this](CsvParser &, char c) -> int { | |||
| this->pos_ = 0; | |||
| return 0; | |||
| }}}, | |||
| {{State::DELIM, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, | |||
| {{State::DELIM, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, | |||
| // QUOTE | |||
| // ┌───────────┬──────────┬──────────────┬──────────┬────────────────┐ | |||
| // │ abc │ , │ " │ \n │ EOF │ | |||
| // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ | |||
| // │ QUOTE │ QUOTE │ SECOND_QUOTE │ QUOTE │ EXCEPTION │ | |||
| // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ | |||
| // | put_char │ put_char │ null_func │ put_char │ exception │ | |||
| // └───────────┴──────────┴──────────────┴──────────┴────────────────┘ | |||
| {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::put_char}}, | |||
| {{State::QUOTE, Message::MS_DELIM}, {State::QUOTE, &CsvParser::put_char}}, | |||
| {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, | |||
| {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::put_char}}, | |||
| // QUOTE-Exception | |||
| {{State::QUOTE, Message::MS_END_OF_FILE}, {State::EXCEPTION, &CsvParser::catch_exception}}, | |||
| // SECOND_QUOTE | |||
| // ┌───────────┬────────────┬──────────┬─────────────┬────────────────┐ | |||
| // │ abc │ , │ " │ \n │ EOF │ | |||
| // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ | |||
| // │ EXCEPTION │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||
| // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ | |||
| // | exception │ put_record │ put_char │ put_row │ end_file │ | |||
| // └───────────┴────────────┴──────────┴─────────────┴────────────────┘ | |||
| {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::put_char}}, | |||
| {{State::SECOND_QUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, | |||
| {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, | |||
| {{State::SECOND_QUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, | |||
| // SECOND_QUOTE-Exception | |||
| {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::EXCEPTION, &CsvParser::catch_exception}}, | |||
| // END_OF_LINE | |||
| // ┌─────────┬────────┬────────┬─────────────┬─────────────┐ | |||
| // │ abc │ , │ " │ \n │ EOF │ | |||
| // ├─────────┼────────┼────────┼─────────────┼─────────────┤ | |||
| // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||
| // ├─────────┼────────┼────────┼─────────────┼─────────────┤ | |||
| // | lambda │ lambda │ lambda │ null_func │ end_file │ | |||
| // └─────────┴────────┴────────┴─────────────┴─────────────┘ | |||
| {{State::END_OF_LINE, Message::MS_NORMAL}, | |||
| {State::UNQUOTE, | |||
| [this](CsvParser &, char c) -> int { | |||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||
| this->str_buf_[0] = c; | |||
| this->pos_ = 1; | |||
| return 0; | |||
| }}}, | |||
| {{State::END_OF_LINE, Message::MS_DELIM}, | |||
| {State::DELIM, | |||
| [this](CsvParser &, char c) -> int { | |||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||
| this->put_record(c); | |||
| return 0; | |||
| }}}, | |||
| {{State::END_OF_LINE, Message::MS_QUOTE}, | |||
| {State::QUOTE, | |||
| [this](CsvParser &, char c) -> int { | |||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||
| return 0; | |||
| }}}, | |||
| {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, | |||
| {{State::END_OF_LINE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}}; | |||
| return Status::OK(); | |||
| } | |||
| Status CsvOp::Reset() { | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||
| NotifyToFillIOBlockQueue(); | |||
| return Status::OK(); | |||
| } | |||
| Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id) { | |||
| CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); | |||
| csv_parser.setStartOffset(start_offset); | |||
| csv_parser.setEndOffset(end_offset); | |||
| std::ifstream ifs; | |||
| ifs.open(file, std::ifstream::in); | |||
| if (column_name_list_.empty()) { | |||
| std::string tmp; | |||
| getline(ifs, tmp); | |||
| } | |||
| csv_parser.Reset(); | |||
| try { | |||
| while (ifs.good()) { | |||
| char chr = ifs.get(); | |||
| if (csv_parser.processMessage(chr) != 0) { | |||
| RETURN_STATUS_UNEXPECTED("Failed to parse CSV file " + file + ":" + std::to_string(csv_parser.total_rows_)); | |||
| } | |||
| } | |||
| } catch (std::invalid_argument &ia) { | |||
| std::string err_row = std::to_string(csv_parser.total_rows_); | |||
| RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", invalid argument of " + std::string(ia.what())); | |||
| } catch (std::out_of_range &oor) { | |||
| std::string err_row = std::to_string(csv_parser.total_rows_); | |||
| RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of Range error: " + std::string(oor.what())); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CsvOp::operator()() { | |||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| // launch one thread, responsible for filling IoBlockQueue | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1))); | |||
| // must be called after launching workers. | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||
| NotifyToFillIOBlockQueue(); | |||
| while (!finished_reading_dataset_) { | |||
| int64_t buffer_id = 0; | |||
| int32_t workers_done = 0; | |||
| int64_t rows_read = 0; | |||
| load_io_block_queue_ = true; | |||
| while (workers_done < num_workers_) { | |||
| std::unique_ptr<DataBuffer> buffer; | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||
| if (buffer->eoe()) { | |||
| workers_done++; | |||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||
| } | |||
| rows_read += buffer->NumRows(); | |||
| buffer->set_id(buffer_id++); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||
| } else { | |||
| // end of epoch | |||
| load_jagged_connector_ = false; | |||
| load_io_block_queue_ = false; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||
| finished_reading_dataset_ = true; | |||
| NotifyToFillIOBlockQueue(); | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||
| return Status::OK(); | |||
| } | |||
| Status CsvOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<FilenameBlock> io_block; | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| while (!io_block->eof()) { | |||
| if (!io_block->eoe()) { | |||
| if (load_jagged_connector_) { | |||
| std::string filename; | |||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||
| int64_t start_offset = io_block->GetStartOffset(); | |||
| int64_t end_offset = io_block->GetEndOffset(); | |||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||
| } | |||
| } else { | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||
| } | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // A print method typically used for debugging | |||
| void CsvOp::Print(std::ostream &out, bool show_all) const { | |||
| // Always show the id and name as first line regardless if this summary or detailed print | |||
| out << "(" << std::setw(2) << operator_id_ << ") <CsvOp>:"; | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ | |||
| << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | |||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n"; | |||
| for (int i = 0; i < csv_files_list_.size(); ++i) { | |||
| out << " " << csv_files_list_[i]; | |||
| } | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| // Pops an element from a queue in io_block_queues | |||
| Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||
| return Status::OK(); | |||
| } | |||
| // Pushes an element to a queue in io_block_queues | |||
| Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||
| return Status::OK(); | |||
| } | |||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||
| std::mt19937 rng(seed); | |||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||
| } | |||
| Status CsvOp::WaitToFillIOBlockQueue() { | |||
| // must be called first if called by worker spanwed by taskgroup | |||
| TaskManager::FindMe()->Post(); | |||
| std::vector<int64_t> i_keys; | |||
| if (shuffle_files_) { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| i_keys.push_back(it.key()); | |||
| } | |||
| } | |||
| uint32_t seed = 0; | |||
| while (true) { | |||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||
| io_block_queue_wait_post_.Clear(); | |||
| if (finished_reading_dataset_) { | |||
| break; | |||
| } | |||
| if (shuffle_files_) { | |||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||
| } | |||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||
| int32_t queue_index = 0; | |||
| int64_t pre_count = 0; | |||
| int64_t start_offset = 0; | |||
| int64_t end_offset = 0; | |||
| bool finish = false; | |||
| while (!finish) { | |||
| std::vector<std::pair<std::string, int64_t>> file_index; | |||
| if (!i_keys.empty()) { | |||
| for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it)); | |||
| } | |||
| } else { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key())); | |||
| } | |||
| } | |||
| for (auto file_info : file_index) { | |||
| if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { | |||
| auto ioBlock = | |||
| std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); | |||
| queue_index = (queue_index + 1) % num_workers_; | |||
| } | |||
| pre_count += filename_numrows_[file_info.first]; | |||
| } | |||
| if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||
| finish = false; | |||
| } else { | |||
| finish = true; | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); | |||
| return Status::OK(); | |||
| } | |||
| void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||
| bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count) { | |||
| *start_offset = 0; | |||
| *end_offset = 0; | |||
| bool push = false; | |||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||
| if (device_id_ + 1 < 0) { | |||
| MS_LOG(ERROR) << "Device id is invalid"; | |||
| return false; | |||
| } | |||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||
| *start_offset = start_index - pre_count; | |||
| push = true; | |||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| if (pre_count >= start_index && pre_count < end_index) { | |||
| *start_offset = 0; | |||
| push = true; | |||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| return push; | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| Status CsvOp::PostEndOfEpoch(int32_t queue_index) { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CsvOp::CalculateNumRowsPerShard() { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| int64_t count = CountTotalRows(it.value()); | |||
| filename_numrows_[it.value()] = count; | |||
| all_num_rows_ += count; | |||
| } | |||
| if (all_num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API CsvDataset. Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t CsvOp::CountTotalRows(const std::string &file) { | |||
| CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); | |||
| std::ifstream ifs; | |||
| ifs.open(file, std::ifstream::in); | |||
| if (column_name_list_.empty()) { | |||
| std::string tmp; | |||
| getline(ifs, tmp); | |||
| } | |||
| csv_parser.Reset(); | |||
| while (ifs.good()) { | |||
| char chr = ifs.get(); | |||
| if (csv_parser.countRows(chr) != 0) { | |||
| break; | |||
| } | |||
| } | |||
| return csv_parser.total_rows_; | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| Status CsvOp::PostEndOfData() { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) { | |||
| std::shared_ptr<CsvOp> op; | |||
| *count = 0; | |||
| if (csv_header) { | |||
| RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).Build(&op)); | |||
| } else { | |||
| RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).SetColumName({""}).Build(&op)); | |||
| } | |||
| for (auto file : files) { | |||
| *count += op->CountTotalRows(file); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::vector<std::string> CsvOp::split(const std::string &s, char delim) { | |||
| std::vector<std::string> res; | |||
| std::stringstream ss(s); | |||
| std::string item; | |||
| while (getline(ss, item, delim)) { | |||
| res.push_back(item); | |||
| } | |||
| return res; | |||
| } | |||
| Status CsvOp::ComputeColMap() { | |||
| // Set the column name mapping (base class field) | |||
| if (column_name_id_map_.empty()) { | |||
| if (column_name_list_.empty()) { | |||
| std::string line; | |||
| std::ifstream handle(csv_files_list_[0]); | |||
| getline(handle, line); | |||
| std::vector<std::string> col_names = split(line, field_delim_); | |||
| for (int32_t i = 0; i < col_names.size(); i++) { | |||
| column_name_id_map_[col_names[i]] = i; | |||
| } | |||
| } else { | |||
| for (int32_t i = 0; i < column_name_list_.size(); i++) { | |||
| column_name_id_map_[column_name_list_[i]] = i; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Column name map is already set!"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,451 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <utility> | |||
| #include <limits> | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const size_t CSV_BUFFER_SIZE = 4096; | |||
| using StringIndex = AutoIndexObj<std::string>; | |||
| class JaggedConnector; | |||
| class CsvOp : public ParallelOp { | |||
| public: | |||
| enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; | |||
| struct BaseRecord { | |||
| public: | |||
| BaseRecord() = default; | |||
| explicit BaseRecord(RecordType t) : type(t) {} | |||
| virtual ~BaseRecord() {} | |||
| RecordType type; | |||
| }; | |||
| template <typename T> | |||
| class Record : public BaseRecord { | |||
| public: | |||
| Record() = default; | |||
| Record(RecordType t, T v) : BaseRecord(t), value(v) {} | |||
| ~Record() {} | |||
| T value; | |||
| }; | |||
| // CsvParser is a class that parsing CSV file. | |||
| // We design a state machine to implement CSV syntactic analysis. It contains two state diagram,'sd' and 'sdl'. | |||
| // The 'sd' is used for parsing CSV syntactic, it's complete and complicate. | |||
| // The 'sdl' is used for counting the record rows, it's concise and it runs fast. | |||
| struct CsvParser { | |||
| public: | |||
| CsvParser() = delete; | |||
| CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim, | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default) | |||
| : worker_id_(worker_id), | |||
| buffer_connector_(connector), | |||
| csv_rows_per_buffer_(rows_per_buffer), | |||
| csv_field_delim_(field_delim), | |||
| column_default_(column_default), | |||
| cur_state_(START_OF_FILE), | |||
| pos_(0), | |||
| cur_row_(0), | |||
| cur_col_(0), | |||
| total_rows_(0), | |||
| start_offset_(0), | |||
| end_offset_(std::numeric_limits<int64_t>::max()) { | |||
| cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| initCsvParser(); | |||
| } | |||
| ~CsvParser() = default; | |||
| void Reset() { | |||
| cur_state_ = START_OF_FILE; | |||
| pos_ = 0; | |||
| cur_row_ = 0; | |||
| cur_col_ = 0; | |||
| } | |||
| void setStartOffset(int64_t start_offset) { start_offset_ = start_offset; } | |||
| void setEndOffset(int64_t end_offset) { end_offset_ = end_offset; } | |||
| int processMessage(char c) { | |||
| Message m = getMessage(c); | |||
| StateDiagram::iterator it = sd.find({cur_state_, m}); | |||
| if (it == sd.end()) { | |||
| return -1; | |||
| } | |||
| cur_state_ = it->second.first; | |||
| return it->second.second(*this, c); | |||
| } | |||
| int countRows(char c); | |||
| Status initCsvParser(); | |||
| enum State : uint8_t { | |||
| START_OF_FILE = 0, | |||
| UNQUOTE, | |||
| DELIM, | |||
| QUOTE, | |||
| SECOND_QUOTE, | |||
| END_OF_LINE, | |||
| END_OF_FILE, | |||
| EXCEPTION | |||
| }; | |||
| enum Message : uint8_t { | |||
| MS_NORMAL = 0, | |||
| MS_DELIM, | |||
| MS_QUOTE, | |||
| MS_END_OF_LINE, | |||
| MS_END_OF_FILE, | |||
| }; | |||
| typedef std::pair<State, Message> StateMessagePair; | |||
| typedef std::pair<State, std::function<int(CsvParser &, char)>> StateActionPair; | |||
| typedef std::map<StateMessagePair, StateActionPair> StateDiagram; | |||
| Message getMessage(char c) { | |||
| if (c == csv_field_delim_) { | |||
| return Message::MS_DELIM; | |||
| } else if (c == '"') { | |||
| return Message::MS_QUOTE; | |||
| } else if (c == '\r' || c == '\n') { | |||
| return Message::MS_END_OF_LINE; | |||
| } else if (c == std::char_traits<char>::eof()) { | |||
| return Message::MS_END_OF_FILE; | |||
| } else { | |||
| return Message::MS_NORMAL; | |||
| } | |||
| } | |||
| int null_func(char c) { return 0; } | |||
| int put_char(char c) { | |||
| if (pos_ >= str_buf_.size()) { | |||
| str_buf_.resize(str_buf_.size() * 2); | |||
| } | |||
| str_buf_[pos_] = c; | |||
| pos_++; | |||
| return 0; | |||
| } | |||
| int put_record(char c); | |||
| int put_row(char c); | |||
| int end_file(char c); | |||
| int add_row(char c) { | |||
| total_rows_++; | |||
| return 0; | |||
| } | |||
| int catch_exception(char c) { | |||
| MS_LOG(ERROR) << "Invalid syntax!"; | |||
| return -1; | |||
| } | |||
| int32_t worker_id_; | |||
| std::shared_ptr<JaggedConnector> buffer_connector_; | |||
| int64_t csv_rows_per_buffer_; | |||
| const char csv_field_delim_; | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_; | |||
| State cur_state_; | |||
| size_t pos_; | |||
| int cur_row_; | |||
| int cur_col_; | |||
| int64_t total_rows_; | |||
| int64_t start_offset_; | |||
| int64_t end_offset_; | |||
| StateDiagram sd; | |||
| StateDiagram sdl; | |||
| std::vector<char> str_buf_; | |||
| std::unique_ptr<TensorQTable> tensor_table_; | |||
| std::unique_ptr<DataBuffer> cur_buffer_; | |||
| }; | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Checks if the inputs of the builder is valid. | |||
| // @return Status - the error code returned. | |||
| Status ValidateInputs() const; | |||
| // Create the final object. | |||
| // @param op - dataset op. | |||
| // @return - the error code return. | |||
| Status Build(std::shared_ptr<CsvOp> *op); | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| builder_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||
| builder_op_connector_size_ = op_connector_size; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { | |||
| builder_rows_per_buffer_ = rows_per_buffer; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumDevices(int64_t num_dev) { | |||
| builder_num_devices_ = num_dev; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetDeviceId(int64_t dev_id) { | |||
| builder_device_id_ = dev_id; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetCsvFilesList(const std::vector<std::string> &files_list) { | |||
| builder_csv_files_list_ = files_list; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShuffleFiles(bool shuffle_files) { | |||
| builder_shuffle_files_ = shuffle_files; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetFieldDelim(char field_delim) { | |||
| builder_field_delim_ = field_delim; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetColumDefault(std::vector<std::shared_ptr<CsvOp::BaseRecord>> record_list) { | |||
| builder_column_default_list_ = record_list; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetColumName(std::vector<std::string> col_name_list) { | |||
| builder_column_name_list_ = col_name_list; | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int64_t builder_rows_per_buffer_; | |||
| int64_t builder_num_samples_; | |||
| int32_t builder_worker_connector_size_; | |||
| std::vector<std::string> builder_csv_files_list_; | |||
| bool builder_shuffle_files_; | |||
| char builder_field_delim_; | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | |||
| std::vector<std::string> builder_column_name_list_; | |||
| }; | |||
| // Constructor of CsvOp | |||
| CsvOp() = delete; | |||
| CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | |||
| int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~CsvOp() = default; | |||
| // A print method typically used for debugging | |||
| // @param out - The output stream to write output to | |||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Instantiates the internal queues and connectors | |||
| // @return Status - the error code returned | |||
| Status Init(); | |||
| // Class functor operator () override. | |||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - the error code returned. | |||
| Status operator()() override; | |||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||
| // @return Status - the error code returned. | |||
| Status Reset() override; | |||
| // Get total rows in files. | |||
| // @param files - all csv files. | |||
| // @param csv_header - a bool that indicates csv file include header line | |||
| // @param count - number of rows. | |||
| // @return Status - the error coed returned. | |||
| static Status CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count); | |||
| // File names getter | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return csv_files_list_; } | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Parses a single row and puts the data into a tensor table. | |||
| // @param line - the content of the row. | |||
| // @param tensor_table - the tensor table to put the parsed data in. | |||
| // @param row - the id of the row filled in the tensor table. | |||
| // @return Status - the error code returned. | |||
| Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row); | |||
| // Reads a csv file and loads the data into multiple buffers. | |||
| // @param file - the file to read. | |||
| // @param start_offset - the start offset of file. | |||
| // @param end_offset - the end offset of file. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id); | |||
| // Pops an element from a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to pop from. | |||
| // @param out_block - the popped element. | |||
| // @return Status - the error code returned. | |||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||
| // Pushes an element to a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to push to. | |||
| // @param io_block - the element to push onto the queue. | |||
| // @return Status - the error code returned. | |||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||
| // @return Status - the error code returned. | |||
| Status WaitToFillIOBlockQueue(); | |||
| // Fill the IOBlockQueue. | |||
| // @para i_keys - keys of file to fill to the IOBlockQueue | |||
| // @return Status - the error code returned. | |||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||
| void NotifyToFillIOBlockQueue(); | |||
| // Select file and push it to the block queue. | |||
| // @param file_name - File name. | |||
| // @param start_file - If file contains the first sample of data. | |||
| // @param end_file - If file contains the end sample of data. | |||
| // @param pre_count - Total rows of previous files. | |||
| // @return Status - the error code returned. | |||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfEpoch(int32_t queue_index); | |||
| // Calculate number of rows in each shard. | |||
| // @return Status - the error code returned. | |||
| Status CalculateNumRowsPerShard(); | |||
| // Count number of rows in each file. | |||
| // @param filename - csv file name. | |||
| // @return int64_t - the total number of rows in file. | |||
| int64_t CountTotalRows(const std::string &file); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfData(); | |||
| // Private function for computing the assignment of the column name map. | |||
| // @return - Status | |||
| Status ComputeColMap() override; | |||
| // Split string based on a character delimiter | |||
| // @return - the a string vector | |||
| std::vector<std::string> split(const std::string &s, char delim); | |||
| int32_t device_id_; | |||
| bool shuffle_files_; | |||
| bool finished_reading_dataset_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| bool load_io_block_queue_; | |||
| int64_t num_rows_per_shard_; | |||
| int64_t all_num_rows_; | |||
| int64_t num_samples_; | |||
| std::map<std::string, int64_t> filename_numrows_; | |||
| std::unique_ptr<StringIndex> filename_index_; | |||
| std::vector<std::string> csv_files_list_; | |||
| WaitPost io_block_queue_wait_post_; | |||
| std::shared_ptr<JaggedConnector> jagged_buffer_connector_; | |||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||
| bool load_jagged_connector_; | |||
| char field_delim_; | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_; | |||
| std::vector<std::string> column_name_list_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ | |||
| @@ -21,7 +21,7 @@ can also create samplers with this module to sample data. | |||
| from .core import config | |||
| from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ | |||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ | |||
| TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset | |||
| TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset | |||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | |||
| WeightedRandomSampler, Sampler | |||
| from .engine.cache_client import DatasetCache | |||
| @@ -31,5 +31,5 @@ from .engine.graphdata import GraphData | |||
| __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | |||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", | |||
| "CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", | |||
| "CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler", | |||
| "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] | |||
| @@ -29,7 +29,7 @@ from .samplers import * | |||
| from ..core import config | |||
| __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "CSVDataset", | |||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | |||
| "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", | |||
| "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] | |||
| @@ -33,7 +33,7 @@ import copy | |||
| import numpy as np | |||
| from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | |||
| MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo | |||
| MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo | |||
| from mindspore._c_expression import typing | |||
| from mindspore import log as logger | |||
| @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE | |||
| @@ -1012,7 +1012,7 @@ class Dataset: | |||
| if isinstance(sampler, samplers.DistributedSampler): | |||
| dev_id = sampler.shard_id | |||
| return "", dev_id | |||
| if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)): | |||
| if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset, CSVDataset)): | |||
| if output_dataset.shard_id is not None: | |||
| dev_id = output_dataset.shard_id | |||
| return "", dev_id | |||
| @@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset): | |||
| } | |||
| Args: | |||
| dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of | |||
| files. The list will be sorted in a lexicographical order. | |||
| dataset_files (str or a list of strings): String or list of files to be read or glob strings to search for | |||
| a pattern of files. The list will be sorted in a lexicographical order. | |||
| task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. | |||
| (default=AFQMC). | |||
| usage (str, optional): Need train, test or eval data (default="train"). | |||
| @@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset): | |||
| return False | |||
| class CSVDataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses CSV datasets. | |||
| Args: | |||
| dataset_files (str or a list of strings): String or list of files to be read or glob strings to search | |||
| for a pattern of files. The list will be sorted in a lexicographical order. | |||
| field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=','). | |||
| column_defaults (list, optional): List of default values for the CSV field (default=None). Each item | |||
| in the list is either a valid type (float, int, or string). If this is not provided, treats all | |||
| columns as string type. | |||
| column_names (list of string, optional): List of column names of the dataset (default=None). If this | |||
| is not provided, infers the column_names from the first row of CSV file. | |||
| num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). | |||
| num_parallel_workers (int, optional): number of workers to read the data | |||
| (default=None, number set in the config). | |||
| shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). | |||
| If shuffle is False, no shuffling will be performed; | |||
| If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL | |||
| Otherwise, there are two levels of shuffling: | |||
| - Shuffle.GLOBAL: Shuffle both the files and samples. | |||
| - Shuffle.FILES: Shuffle files only. | |||
| num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument should be specified only when num_shards is also specified. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files | |||
| >>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4']) | |||
| """ | |||
| @check_csvdataset | |||
| def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, | |||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_files = self._find_files(dataset_files) | |||
| self.dataset_files.sort() | |||
| self.field_delim = field_delim | |||
| self.column_defaults = column_defaults | |||
| self.column_names = column_names | |||
| self.num_samples = num_samples | |||
| if not isinstance(shuffle, (bool, Shuffle)): | |||
| raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") | |||
| if not isinstance(shuffle, Shuffle): | |||
| if shuffle: | |||
| self.shuffle_level = Shuffle.GLOBAL | |||
| self.shuffle_files = True | |||
| else: | |||
| self.shuffle_level = None | |||
| self.shuffle_files = False | |||
| else: | |||
| self.shuffle_level = shuffle | |||
| self.shuffle_files = True | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["dataset_files"] = self.dataset_files | |||
| args['field_delim'] = self.field_delim | |||
| args['column_defaults'] = self.column_defaults | |||
| args['column_names'] = self.column_names | |||
| args["num_samples"] = self.num_samples | |||
| if self.shuffle_files is not None: | |||
| args["shuffle_files"] = self.shuffle_files | |||
| args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| return args | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self._dataset_size is None: | |||
| num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) | |||
| num_rows = get_num_rows(num_rows, self.num_shards) | |||
| if self.num_samples is None: | |||
| return num_rows | |||
| return min(self.num_samples, num_rows) | |||
| return self._dataset_size | |||
| def is_shuffled(self): | |||
| return self.shuffle_files | |||
| def is_sharded(self): | |||
| if self.num_shards is not None: | |||
| return self.num_shards > 1 | |||
| return False | |||
| class TextFileDataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses datasets stored on disk in text format. | |||
| @@ -185,6 +185,8 @@ class Iterator: | |||
| op_type = OpName.SENTENCEPIECEVOCAB | |||
| elif isinstance(dataset, de.CLUEDataset): | |||
| op_type = OpName.CLUE | |||
| elif isinstance(dataset, de.CSVDataset): | |||
| op_type = OpName.CSV | |||
| else: | |||
| raise ValueError("Unsupported DatasetOp") | |||
| @@ -787,6 +787,49 @@ def check_cluedataset(method): | |||
| return new_method | |||
| def check_csvdataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(CSVDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| # check dataset_files; required argument | |||
| dataset_files = param_dict.get('dataset_files') | |||
| type_check(dataset_files, (str, list), "dataset files") | |||
| # check field_delim | |||
| field_delim = param_dict.get('field_delim') | |||
| type_check(field_delim, (str,), 'field delim') | |||
| if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: | |||
| raise ValueError("field_delim is not legal.") | |||
| # check column_defaults | |||
| column_defaults = param_dict.get('column_defaults') | |||
| if column_defaults is not None: | |||
| if not isinstance(column_defaults, list): | |||
| raise TypeError("column_defaults should be type of list.") | |||
| for item in column_defaults: | |||
| if not isinstance(item, (str, int, float)): | |||
| raise TypeError("column type is not legal.") | |||
| # check column_names: must be list of string. | |||
| column_names = param_dict.get("column_names") | |||
| if column_names is not None: | |||
| all_string = all(isinstance(item, str) for item in column_names) | |||
| if not all_string: | |||
| raise TypeError("column_names should be a list of str.") | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_textfiledataset(method): | |||
| """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" | |||
| @@ -77,6 +77,7 @@ SET(DE_UT_SRCS | |||
| celeba_op_test.cc | |||
| take_op_test.cc | |||
| clue_op_test.cc | |||
| csv_op_test.cc | |||
| text_file_op_test.cc | |||
| filter_op_test.cc | |||
| concat_op_test.cc | |||
| @@ -0,0 +1,122 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "common/utils.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace common = mindspore::common; | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestCSVOp : public UT::DatasetOpTesting { | |||
| }; | |||
| TEST_F(MindDataTestCSVOp, TestCSVBasic) { | |||
| // Start with an empty execution tree | |||
| auto tree = std::make_shared<ExecutionTree>(); | |||
| std::string dataset_path; | |||
| dataset_path = datasets_root_path_ + "/testCSV/1.csv"; | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list; | |||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||
| std::shared_ptr<CsvOp> op; | |||
| CsvOp::Builder builder; | |||
| builder.SetCsvFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetNumWorkers(16) | |||
| .SetShuffleFiles(false) | |||
| .SetOpConnectorSize(2) | |||
| .SetFieldDelim(',') | |||
| .SetColumDefault(column_default_list) | |||
| .SetColumName({"col1", "col2", "col3", "col4"}); | |||
| Status rc = builder.Build(&op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->AssociateNode(op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->AssignRoot(op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration."; | |||
| rc = tree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(tree); | |||
| TensorRow tensor_list; | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int row_count = 0; | |||
| while (!tensor_list.empty()) { | |||
| // Display the tensor by calling the printer on it | |||
| for (int i = 0; i < tensor_list.size(); i++) { | |||
| std::ostringstream ss; | |||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||
| MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; | |||
| } | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| row_count++; | |||
| } | |||
| ASSERT_EQ(row_count, 3); | |||
| } | |||
| TEST_F(MindDataTestCSVOp, TestTotalRows) { | |||
| std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv"; | |||
| std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv"; | |||
| std::vector<std::string> files; | |||
| files.push_back(csv_file1); | |||
| int64_t total_rows = 0; | |||
| CsvOp::CountAllFileRows(files, false, &total_rows); | |||
| ASSERT_EQ(total_rows, 3); | |||
| files.clear(); | |||
| files.push_back(csv_file2); | |||
| CsvOp::CountAllFileRows(files, false, &total_rows); | |||
| ASSERT_EQ(total_rows, 5); | |||
| files.clear(); | |||
| files.push_back(csv_file1); | |||
| files.push_back(csv_file2); | |||
| CsvOp::CountAllFileRows(files, false, &total_rows); | |||
| ASSERT_EQ(total_rows, 8); | |||
| files.clear(); | |||
| } | |||
| @@ -0,0 +1,3 @@ | |||
| 1,2,3,4 | |||
| 5,6,7,8 | |||
| 9,10,11,12 | |||
| @@ -0,0 +1,8 @@ | |||
| ,"222",3,"4""" | |||
| "5",6,,"8" | |||
| 9,10,"1""1",12 | |||
| ,,"", | |||
| ,,, | |||
| a,b,c,"" | |||
| a,b,c,d | |||
| @@ -0,0 +1 @@ | |||
| 大家,早上好,中午好,下午好,晚上好 | |||
| @@ -0,0 +1,2 @@ | |||
| "a,b","c""d","e | |||
| f"," g " | |||
| @@ -0,0 +1,3 @@ | |||
| 1,2,3,4 | |||
| 5,6,7,8 | |||
| a,"c",d,"e | |||
| @@ -0,0 +1,2 @@ | |||
| col1,col2,col3,col4 | |||
| a,b,c,d | |||
| @@ -0,0 +1 @@ | |||
| 3,0.3,4,55.5 | |||
| @@ -0,0 +1 @@ | |||
| "a","b","c","d" | |||
| @@ -0,0 +1 @@ | |||
| a|b|c|d | |||
| @@ -0,0 +1,10 @@ | |||
| 1,2,3,4 | |||
| "a","b","c | |||
| ","d | |||
| e" | |||
| 5,6,7,8 | |||
| 9,10,11,12 | |||
| a,"b | |||
| ",c,"d | |||
| e" | |||
| @@ -0,0 +1,238 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| import numpy as np | |||
| import pytest | |||
| DATA_FILE = '../data/dataset/testCSV/1.csv' | |||
| def test_csv_dataset_basic(): | |||
| """ | |||
| Test CSV with repeat, skip and so on | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCSV/1.csv' | |||
| buffer = [] | |||
| data = ds.CSVDataset( | |||
| TRAIN_FILE, | |||
| column_defaults=["0", 0, 0.0, "0"], | |||
| column_names=['1', '2', '3', '4'], | |||
| shuffle=False) | |||
| data = data.repeat(2) | |||
| data = data.skip(2) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append(d) | |||
| assert len(buffer) == 4 | |||
| def test_csv_dataset_one_file(): | |||
| data = ds.CSVDataset( | |||
| DATA_FILE, | |||
| column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append(d) | |||
| assert len(buffer) == 3 | |||
| def test_csv_dataset_all_file(): | |||
| APPEND_FILE = '../data/dataset/testCSV/2.csv' | |||
| data = ds.CSVDataset( | |||
| [DATA_FILE, APPEND_FILE], | |||
| column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append(d) | |||
| assert len(buffer) == 10 | |||
| def test_csv_dataset_num_samples(): | |||
| data = ds.CSVDataset( | |||
| DATA_FILE, | |||
| column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False, num_samples=2) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| count += 1 | |||
| assert count == 2 | |||
| def test_csv_dataset_distribution(): | |||
| TEST_FILE = '../data/dataset/testCSV/1.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False, num_shards=2, shard_id=0) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| count += 1 | |||
| assert count == 2 | |||
| def test_csv_dataset_quoted(): | |||
| TEST_FILE = '../data/dataset/testCSV/quoted.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", "", "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.extend([d['col1'].item().decode("utf8"), | |||
| d['col2'].item().decode("utf8"), | |||
| d['col3'].item().decode("utf8"), | |||
| d['col4'].item().decode("utf8")]) | |||
| assert buffer == ['a', 'b', 'c', 'd'] | |||
| def test_csv_dataset_separated(): | |||
| TEST_FILE = '../data/dataset/testCSV/separated.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| field_delim='|', | |||
| column_defaults=["", "", "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.extend([d['col1'].item().decode("utf8"), | |||
| d['col2'].item().decode("utf8"), | |||
| d['col3'].item().decode("utf8"), | |||
| d['col4'].item().decode("utf8")]) | |||
| assert buffer == ['a', 'b', 'c', 'd'] | |||
| def test_csv_dataset_embedded(): | |||
| TEST_FILE = '../data/dataset/testCSV/embedded.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", "", "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.extend([d['col1'].item().decode("utf8"), | |||
| d['col2'].item().decode("utf8"), | |||
| d['col3'].item().decode("utf8"), | |||
| d['col4'].item().decode("utf8")]) | |||
| assert buffer == ['a,b', 'c"d', 'e\nf', ' g '] | |||
| def test_csv_dataset_chinese(): | |||
| TEST_FILE = '../data/dataset/testCSV/chinese.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", "", "", "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4', 'col5'], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.extend([d['col1'].item().decode("utf8"), | |||
| d['col2'].item().decode("utf8"), | |||
| d['col3'].item().decode("utf8"), | |||
| d['col4'].item().decode("utf8"), | |||
| d['col5'].item().decode("utf8")]) | |||
| assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好'] | |||
| def test_csv_dataset_header(): | |||
| TEST_FILE = '../data/dataset/testCSV/header.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", "", "", ""], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.extend([d['col1'].item().decode("utf8"), | |||
| d['col2'].item().decode("utf8"), | |||
| d['col3'].item().decode("utf8"), | |||
| d['col4'].item().decode("utf8")]) | |||
| assert buffer == ['a', 'b', 'c', 'd'] | |||
| def test_csv_dataset_number(): | |||
| TEST_FILE = '../data/dataset/testCSV/number.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=[0.0, 0.0, 0, 0.0], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(): | |||
| buffer.extend([d['col1'].item(), | |||
| d['col2'].item(), | |||
| d['col3'].item(), | |||
| d['col4'].item()]) | |||
| assert np.allclose(buffer, [3.0, 0.3, 4, 55.5]) | |||
| def test_csv_dataset_size(): | |||
| TEST_FILE = '../data/dataset/testCSV/size.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=[0.0, 0.0, 0, 0.0], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| assert data.get_dataset_size() == 5 | |||
| def test_csv_dataset_exception(): | |||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", "", "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| with pytest.raises(Exception) as err: | |||
| for _ in data.create_dict_iterator(): | |||
| pass | |||
| assert "Failed to parse CSV file" in str(err.value) | |||
| def test_csv_dataset_type_error(): | |||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", 0, "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| with pytest.raises(Exception) as err: | |||
| for _ in data.create_dict_iterator(): | |||
| pass | |||
| assert "invalid argument of stoi" in str(err.value) | |||
| if __name__ == "__main__": | |||
| test_csv_dataset_basic() | |||
| test_csv_dataset_one_file() | |||
| test_csv_dataset_all_file() | |||
| test_csv_dataset_num_samples() | |||
| test_csv_dataset_distribution() | |||
| test_csv_dataset_quoted() | |||
| test_csv_dataset_separated() | |||
| test_csv_dataset_embedded() | |||
| test_csv_dataset_chinese() | |||
| test_csv_dataset_header() | |||
| test_csv_dataset_number() | |||
| test_csv_dataset_size() | |||
| test_csv_dataset_exception() | |||
| test_csv_dataset_type_error() | |||