Merge pull request !3016 from jiangzhiwen/dataset/csvtags/v0.6.0-beta
| @@ -31,6 +31,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | #include "minddata/dataset/engine/datasetops/source/clue_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | #include "minddata/dataset/engine/datasetops/source/coco_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | ||||
| @@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { | |||||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | ||||
| {kClue, &DEPipeline::ParseClueOp}, | {kClue, &DEPipeline::ParseClueOp}, | ||||
| {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}, | {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}, | ||||
| {kCsv, &DEPipeline::ParseCsvOp}, | |||||
| {kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}}; | {kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}}; | ||||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | DEPipeline::DEPipeline() : iterator_(nullptr) { | ||||
| @@ -1848,6 +1850,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||||
| std::shared_ptr<DatasetOp> *bottom) { | |||||
| std::vector<std::string> files_list; | |||||
| std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | |||||
| if (!args["dataset_files"].is_none()) { | |||||
| files_list = ToStringVector(args["dataset_files"]); | |||||
| (void)builder->SetCsvFilesList(files_list); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); | |||||
| } | |||||
| // Optional arguments | |||||
| bool shuffle_required = false; | |||||
| int64_t num_devices = 0; | |||||
| std::vector<std::string> col_names; | |||||
| for (auto arg : args) { | |||||
| std::string key = py::str(arg.first); | |||||
| py::handle value = arg.second; | |||||
| if (!value.is_none()) { | |||||
| if (key == "num_parallel_workers") { | |||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| } else if (key == "shuffle_files") { | |||||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||||
| } else if (key == "shuffle_global") { | |||||
| shuffle_required = ToBool(value); | |||||
| } else if (key == "num_samples") { | |||||
| (void)builder->SetNumSamples(ToInt(value)); | |||||
| } else if (key == "num_shards") { | |||||
| num_devices = ToInt(value); | |||||
| (void)builder->SetNumDevices(num_devices); | |||||
| } else if (key == "shard_id") { | |||||
| (void)builder->SetDeviceId(ToInt(value)); | |||||
| } else if (key == "field_delim") { | |||||
| (void)builder->SetFieldDelim(ToString(value)[0]); | |||||
| } else if (key == "column_defaults") { | |||||
| py::list py_object_list = py::reinterpret_borrow<py::list>(value); | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list; | |||||
| for (auto l : py_object_list) { | |||||
| std::string type_s = (std::string)py::str(l.get_type().attr("__name__")); | |||||
| if (type_s == "int") { | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, ToInt(l))); | |||||
| } else if (type_s == "float") { | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, ToFloat(l))); | |||||
| } else if (type_s == "str") { | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ToString(l))); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Record type is not allowed"); | |||||
| } | |||||
| } | |||||
| (void)builder->SetColumDefault(column_default_list); | |||||
| } else if (key == "column_names") { | |||||
| col_names = ToStringVector(value); | |||||
| (void)builder->SetColumName(col_names); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::shared_ptr<CsvOp> csv_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&csv_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); | |||||
| *top = csv_op; | |||||
| if (shuffle_required) { | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||||
| int64_t shuffle_size = 0; | |||||
| int64_t num_rows = 0; | |||||
| // First, get the number of rows in the dataset and then compute the shuffle size | |||||
| RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(files_list, col_names.empty(), &num_rows)); | |||||
| RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); | |||||
| // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller | |||||
| RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, csv_op, &shuffle_op)); | |||||
| *top = shuffle_op; | |||||
| *bottom = csv_op; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to inject a shuffle operator over top of the current operation being built. | // Helper function to inject a shuffle operator over top of the current operation being built. | ||||
| Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, | Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, | ||||
| std::shared_ptr<DatasetOp> *shuffle_op) { | std::shared_ptr<DatasetOp> *shuffle_op) { | ||||
| @@ -73,6 +73,7 @@ enum OpName { | |||||
| kClue, | kClue, | ||||
| kEpochCtrl, | kEpochCtrl, | ||||
| kSentencePieceVocab, | kSentencePieceVocab, | ||||
| kCsv | |||||
| }; | }; | ||||
| // The C++ binder class that we expose to the python script. | // The C++ binder class that we expose to the python script. | ||||
| @@ -201,6 +202,8 @@ class DEPipeline { | |||||
| Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | ||||
| Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| private: | private: | ||||
| // Execution tree that links the dataset operators. | // Execution tree that links the dataset operators. | ||||
| std::shared_ptr<ExecutionTree> tree_; | std::shared_ptr<ExecutionTree> tree_; | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | #include "minddata/dataset/engine/datasetops/source/clue_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | #include "minddata/dataset/engine/datasetops/source/coco_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | #include "minddata/dataset/engine/datasetops/source/io_block.h" | ||||
| @@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) { | |||||
| return count; | return count; | ||||
| }); | }); | ||||
| (void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp") | |||||
| .def_static("get_num_rows", [](const py::list &files, bool csv_header) { | |||||
| int64_t count = 0; | |||||
| std::vector<std::string> filenames; | |||||
| for (auto file : files) { | |||||
| file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); | |||||
| } | |||||
| THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count)); | |||||
| return count; | |||||
| }); | |||||
| (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | ||||
| .def_static("get_num_rows", | .def_static("get_num_rows", | ||||
| [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | ||||
| @@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||||
| .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) | .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) | ||||
| .value("CELEBA", OpName::kCelebA) | .value("CELEBA", OpName::kCelebA) | ||||
| .value("TEXTFILE", OpName::kTextFile) | .value("TEXTFILE", OpName::kTextFile) | ||||
| .value("CLUE", OpName::kClue) | |||||
| .value("EPOCHCTRL", OpName::kEpochCtrl); | |||||
| .value("EPOCHCTRL", OpName::kEpochCtrl) | |||||
| .value("CSV", OpName::kCsv) | |||||
| .value("CLUE", OpName::kClue); | |||||
| (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) | (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) | ||||
| .value("DE_JIEBA_MIX", JiebaMode::kMix) | .value("DE_JIEBA_MIX", JiebaMode::kMix) | ||||
| @@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||||
| celeba_op.cc | celeba_op.cc | ||||
| text_file_op.cc | text_file_op.cc | ||||
| clue_op.cc | clue_op.cc | ||||
| csv_op.cc | |||||
| ) | ) | ||||
| set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | ||||
| @@ -29,4 +30,4 @@ if (ENABLE_PYTHON) | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) | |||||
| add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) | |||||
| @@ -0,0 +1,757 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include <fstream> | |||||
| #include <iomanip> | |||||
| #include <stdexcept> | |||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/engine/jagged_connector.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| CsvOp::Builder::Builder() | |||||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||||
| builder_rows_per_buffer_ = config_manager->rows_per_buffer(); | |||||
| builder_worker_connector_size_ = config_manager->worker_connector_size(); | |||||
| } | |||||
| Status CsvOp::Builder::ValidateInputs() const { | |||||
| std::string err; | |||||
| err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; | |||||
| err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; | |||||
| return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); | |||||
| } | |||||
| Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) { | |||||
| RETURN_IF_NOT_OK(ValidateInputs()); | |||||
| // Throttle the number of workers if we have more workers than files! | |||||
| if (static_cast<size_t>(builder_num_workers_) > builder_csv_files_list_.size()) { | |||||
| builder_num_workers_ = builder_csv_files_list_.size(); | |||||
| MS_LOG(WARNING) << "CsvOp operator parallelism reduced to " << builder_num_workers_ << " workers."; | |||||
| } | |||||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||||
| builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, | |||||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | |||||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); | |||||
| RETURN_IF_NOT_OK(csv_op->Init()); | |||||
| *op = std::move(csv_op); | |||||
| return Status::OK(); | |||||
| } | |||||
| CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | |||||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | |||||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | |||||
| int32_t num_device, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| csv_files_list_(std::move(csv_files_list)), | |||||
| field_delim_(field_delim), | |||||
| column_default_list_(column_default), | |||||
| column_name_list_(column_name), | |||||
| rows_per_buffer_(rows_per_buffer), | |||||
| num_rows_per_shard_(0), | |||||
| all_num_rows_(0), | |||||
| num_samples_(num_samples), | |||||
| filename_index_(std::make_unique<StringIndex>()), | |||||
| load_jagged_connector_(true), | |||||
| shuffle_files_(shuffle_files), | |||||
| finished_reading_dataset_(false), | |||||
| num_devices_(num_device), | |||||
| device_id_(device_id), | |||||
| load_io_block_queue_(true) { | |||||
| worker_connector_size_ = worker_connector_size; | |||||
| } | |||||
| Status CsvOp::Init() { | |||||
| RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_)); | |||||
| int32_t safe_queue_size = static_cast<int32_t>(std::ceil(csv_files_list_.size() / num_workers_) + 1); | |||||
| io_block_queues_.Init(num_workers_, safe_queue_size); | |||||
| RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); | |||||
| jagged_buffer_connector_ = std::make_shared<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||||
| return Status::OK(); | |||||
| } | |||||
| int CsvOp::CsvParser::put_record(char c) { | |||||
| std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_); | |||||
| std::shared_ptr<Tensor> t; | |||||
| switch (column_default_[cur_col_]->type) { | |||||
| case CsvOp::INT: | |||||
| Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)); | |||||
| t->SetItemAt<int32_t>({0}, std::stoi(s)); | |||||
| break; | |||||
| case CsvOp::FLOAT: | |||||
| Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)); | |||||
| t->SetItemAt<float>({0}, std::stof(s)); | |||||
| break; | |||||
| case CsvOp::STRING: | |||||
| Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); | |||||
| break; | |||||
| default: | |||||
| Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); | |||||
| break; | |||||
| } | |||||
| (*tensor_table_)[cur_row_][cur_col_] = std::move(t); | |||||
| pos_ = 0; | |||||
| cur_col_++; | |||||
| return 0; | |||||
| } | |||||
| int CsvOp::CsvParser::put_row(char c) { | |||||
| if (total_rows_ < start_offset_) { | |||||
| total_rows_++; | |||||
| cur_col_ = 0; | |||||
| return 0; | |||||
| } | |||||
| if (total_rows_ >= end_offset_) { | |||||
| return 0; | |||||
| } | |||||
| put_record(c); | |||||
| total_rows_++; | |||||
| cur_row_++; | |||||
| cur_col_ = 0; | |||||
| if (cur_row_ == csv_rows_per_buffer_) { | |||||
| cur_buffer_->set_tensor_table(std::move(tensor_table_)); | |||||
| buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); | |||||
| cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||||
| tensor_table_ = std::make_unique<TensorQTable>(); | |||||
| cur_row_ = 0; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int CsvOp::CsvParser::end_file(char c) { | |||||
| if (cur_col_ > 0) { | |||||
| put_row(c); | |||||
| } | |||||
| if (cur_row_ > 0) { | |||||
| cur_buffer_->set_tensor_table(std::move(tensor_table_)); | |||||
| buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int CsvOp::CsvParser::countRows(char c) { | |||||
| Message m; | |||||
| if (c == '"') { | |||||
| m = Message::MS_QUOTE; | |||||
| } else if (c == '\r' || c == '\n' || c == std::char_traits<char>::eof()) { | |||||
| m = Message::MS_END_OF_LINE; | |||||
| } else { | |||||
| m = Message::MS_NORMAL; | |||||
| } | |||||
| StateDiagram::iterator it = sdl.find({cur_state_, m}); | |||||
| if (it == sd.end()) { | |||||
| return -1; | |||||
| } | |||||
| cur_state_ = it->second.first; | |||||
| return it->second.second(*this, c); | |||||
| } | |||||
| Status CsvOp::CsvParser::initCsvParser() { | |||||
| str_buf_.resize(CSV_BUFFER_SIZE); | |||||
| // State diagram for counting rows | |||||
| sdl = {// START_OF_FILE | |||||
| // ┌───────────┬───────────┬─────────────┐ | |||||
| // │ abc │ " │ \n │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // | null_func │ null_func │ null_func │ | |||||
| // └───────────┴───────────┴─────────────┘ | |||||
| {{State::START_OF_FILE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||||
| {{State::START_OF_FILE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||||
| {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, | |||||
| // UNQUOTE | |||||
| // ┌───────────┬───────────┬─────────────┐ | |||||
| // │ abc │ " │ \n │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // | null_func │ null_func │ add_row │ | |||||
| // └───────────┴───────────┴─────────────┘ | |||||
| {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||||
| {{State::UNQUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||||
| {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, | |||||
| // QUOTE | |||||
| // ┌───────────┬──────────────┬───────────┐ | |||||
| // │ abc │ " │ \n │ | |||||
| // ├───────────┼──────────────┼───────────┤ | |||||
| // │ QUOTE │ SECOND_QUOTE │ QUOTE │ | |||||
| // ├───────────┼──────────────┼───────────┤ | |||||
| // | null_func │ null_func │ null_func │ | |||||
| // └───────────┴──────────────┴───────────┘ | |||||
| {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::null_func}}, | |||||
| {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, | |||||
| {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::null_func}}, | |||||
| // SECOND_QUOTE | |||||
| // ┌───────────┬───────────┬─────────────┐ | |||||
| // │ abc │ " │ \n │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // | null_func │ null_func │ add_row │ | |||||
| // └───────────┴───────────┴─────────────┘ | |||||
| {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||||
| {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||||
| {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, | |||||
| // END_OF_LINE | |||||
| // ┌───────────┬───────────┬─────────────┐ | |||||
| // │ abc │ " │ \n │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // │ UNQUOTE │ QUOTE │ END_OF_LINE │ | |||||
| // ├───────────┼───────────┼─────────────┤ | |||||
| // | null_func │ null_func │ null_func │ | |||||
| // └───────────┴───────────┴─────────────┘ | |||||
| {{State::END_OF_LINE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, | |||||
| {{State::END_OF_LINE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, | |||||
| {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}}; | |||||
| // State diagram for CSV parser | |||||
| sd = {// START_OF_FILE | |||||
| // ┌───────────┬──────────┬──────────┬────────────────┬────────────────┐ | |||||
| // │ abc │ , │ " │ \n │ EOF │ | |||||
| // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ | |||||
| // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||||
| // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ | |||||
| // | lambda │ lambda │ lambda │ null_func │ null_func │ | |||||
| // └───────────┴──────────┴──────────┴────────────────┴────────────────┘ | |||||
| {{State::START_OF_FILE, Message::MS_NORMAL}, | |||||
| {State::UNQUOTE, | |||||
| [this](CsvParser &, char c) -> int { | |||||
| this->tensor_table_ = std::make_unique<TensorQTable>(); | |||||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||||
| this->str_buf_[0] = c; | |||||
| this->pos_ = 1; | |||||
| return 0; | |||||
| }}}, | |||||
| {{State::START_OF_FILE, Message::MS_DELIM}, | |||||
| {State::DELIM, | |||||
| [this](CsvParser &, char c) -> int { | |||||
| this->tensor_table_ = std::make_unique<TensorQTable>(); | |||||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||||
| this->put_record(c); | |||||
| return 0; | |||||
| }}}, | |||||
| {{State::START_OF_FILE, Message::MS_QUOTE}, | |||||
| {State::QUOTE, | |||||
| [this](CsvParser &, char c) -> int { | |||||
| this->tensor_table_ = std::make_unique<TensorQTable>(); | |||||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||||
| this->pos_ = 0; | |||||
| return 0; | |||||
| }}}, | |||||
| {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, | |||||
| {{State::START_OF_FILE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::null_func}}, | |||||
| // UNQUOTE | |||||
| // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ | |||||
| // │ abc │ , │ " │ \n │ EOF │ | |||||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||||
| // │ UNQUOTE │ DELIM │ EXCEPTION │ END_OF_LINE │ END_OF_FILE │ | |||||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||||
| // | put_char │ put_record │ exception │ put_row │ end_file │ | |||||
| // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ | |||||
| {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, | |||||
| {{State::UNQUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, | |||||
| {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, | |||||
| {{State::UNQUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, | |||||
| // UNQUOTE-Exception | |||||
| {{State::UNQUOTE, Message::MS_QUOTE}, {State::EXCEPTION, &CsvParser::catch_exception}}, | |||||
| // DELIM | |||||
| // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ | |||||
| // │ abc │ , │ " │ \n │ EOF │ | |||||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||||
| // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||||
| // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ | |||||
| // | put_char │ put_record │ lambda │ put_row │ end_file │ | |||||
| // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ | |||||
| {{State::DELIM, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, | |||||
| {{State::DELIM, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, | |||||
| {{State::DELIM, Message::MS_QUOTE}, | |||||
| {State::QUOTE, | |||||
| [this](CsvParser &, char c) -> int { | |||||
| this->pos_ = 0; | |||||
| return 0; | |||||
| }}}, | |||||
| {{State::DELIM, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, | |||||
| {{State::DELIM, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, | |||||
| // QUOTE | |||||
| // ┌───────────┬──────────┬──────────────┬──────────┬────────────────┐ | |||||
| // │ abc │ , │ " │ \n │ EOF │ | |||||
| // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ | |||||
| // │ QUOTE │ QUOTE │ SECOND_QUOTE │ QUOTE │ EXCEPTION │ | |||||
| // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ | |||||
| // | put_char │ put_char │ null_func │ put_char │ exception │ | |||||
| // └───────────┴──────────┴──────────────┴──────────┴────────────────┘ | |||||
| {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::put_char}}, | |||||
| {{State::QUOTE, Message::MS_DELIM}, {State::QUOTE, &CsvParser::put_char}}, | |||||
| {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, | |||||
| {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::put_char}}, | |||||
| // QUOTE-Exception | |||||
| {{State::QUOTE, Message::MS_END_OF_FILE}, {State::EXCEPTION, &CsvParser::catch_exception}}, | |||||
| // SECOND_QUOTE | |||||
| // ┌───────────┬────────────┬──────────┬─────────────┬────────────────┐ | |||||
| // │ abc │ , │ " │ \n │ EOF │ | |||||
| // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ | |||||
| // │ EXCEPTION │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||||
| // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ | |||||
| // | exception │ put_record │ put_char │ put_row │ end_file │ | |||||
| // └───────────┴────────────┴──────────┴─────────────┴────────────────┘ | |||||
| {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::put_char}}, | |||||
| {{State::SECOND_QUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, | |||||
| {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, | |||||
| {{State::SECOND_QUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, | |||||
| // SECOND_QUOTE-Exception | |||||
| {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::EXCEPTION, &CsvParser::catch_exception}}, | |||||
| // END_OF_LINE | |||||
| // ┌─────────┬────────┬────────┬─────────────┬─────────────┐ | |||||
| // │ abc │ , │ " │ \n │ EOF │ | |||||
| // ├─────────┼────────┼────────┼─────────────┼─────────────┤ | |||||
| // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ | |||||
| // ├─────────┼────────┼────────┼─────────────┼─────────────┤ | |||||
| // | lambda │ lambda │ lambda │ null_func │ end_file │ | |||||
| // └─────────┴────────┴────────┴─────────────┴─────────────┘ | |||||
| {{State::END_OF_LINE, Message::MS_NORMAL}, | |||||
| {State::UNQUOTE, | |||||
| [this](CsvParser &, char c) -> int { | |||||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||||
| this->str_buf_[0] = c; | |||||
| this->pos_ = 1; | |||||
| return 0; | |||||
| }}}, | |||||
| {{State::END_OF_LINE, Message::MS_DELIM}, | |||||
| {State::DELIM, | |||||
| [this](CsvParser &, char c) -> int { | |||||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||||
| this->put_record(c); | |||||
| return 0; | |||||
| }}}, | |||||
| {{State::END_OF_LINE, Message::MS_QUOTE}, | |||||
| {State::QUOTE, | |||||
| [this](CsvParser &, char c) -> int { | |||||
| this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); | |||||
| return 0; | |||||
| }}}, | |||||
| {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, | |||||
| {{State::END_OF_LINE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}}; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::Reset() { | |||||
| load_jagged_connector_ = true; | |||||
| load_io_block_queue_ = true; | |||||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id) { | |||||
| CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); | |||||
| csv_parser.setStartOffset(start_offset); | |||||
| csv_parser.setEndOffset(end_offset); | |||||
| std::ifstream ifs; | |||||
| ifs.open(file, std::ifstream::in); | |||||
| if (column_name_list_.empty()) { | |||||
| std::string tmp; | |||||
| getline(ifs, tmp); | |||||
| } | |||||
| csv_parser.Reset(); | |||||
| try { | |||||
| while (ifs.good()) { | |||||
| char chr = ifs.get(); | |||||
| if (csv_parser.processMessage(chr) != 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Failed to parse CSV file " + file + ":" + std::to_string(csv_parser.total_rows_)); | |||||
| } | |||||
| } | |||||
| } catch (std::invalid_argument &ia) { | |||||
| std::string err_row = std::to_string(csv_parser.total_rows_); | |||||
| RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", invalid argument of " + std::string(ia.what())); | |||||
| } catch (std::out_of_range &oor) { | |||||
| std::string err_row = std::to_string(csv_parser.total_rows_); | |||||
| RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of Range error: " + std::string(oor.what())); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::operator()() { | |||||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||||
| // launch one thread, responsible for filling IoBlockQueue | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this))); | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1))); | |||||
| // must be called after launching workers. | |||||
| TaskManager::FindMe()->Post(); | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| while (!finished_reading_dataset_) { | |||||
| int64_t buffer_id = 0; | |||||
| int32_t workers_done = 0; | |||||
| int64_t rows_read = 0; | |||||
| load_io_block_queue_ = true; | |||||
| while (workers_done < num_workers_) { | |||||
| std::unique_ptr<DataBuffer> buffer; | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||||
| if (buffer->eoe()) { | |||||
| workers_done++; | |||||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||||
| } | |||||
| rows_read += buffer->NumRows(); | |||||
| buffer->set_id(buffer_id++); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||||
| } else { | |||||
| // end of epoch | |||||
| load_jagged_connector_ = false; | |||||
| load_io_block_queue_ = false; | |||||
| } | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||||
| finished_reading_dataset_ = true; | |||||
| NotifyToFillIOBlockQueue(); | |||||
| } else { | |||||
| jagged_buffer_connector_->DoReset(); | |||||
| buffer_id = 0; | |||||
| } | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::WorkerEntry(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<FilenameBlock> io_block; | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| while (!io_block->eof()) { | |||||
| if (!io_block->eoe()) { | |||||
| if (load_jagged_connector_) { | |||||
| std::string filename; | |||||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||||
| int64_t start_offset = io_block->GetStartOffset(); | |||||
| int64_t end_offset = io_block->GetEndOffset(); | |||||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||||
| } | |||||
| } else { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // A print method typically used for debugging | |||||
| void CsvOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Always show the id and name as first line regardless if this summary or detailed print | |||||
| out << "(" << std::setw(2) << operator_id_ << ") <CsvOp>:"; | |||||
| if (!show_all) { | |||||
| // Call the super class for displaying any common 1-liner info | |||||
| ParallelOp::Print(out, show_all); | |||||
| // Then show any custom derived-internal 1-liner info for this op | |||||
| out << "\n"; | |||||
| } else { | |||||
| // Call the super class for displaying any common detailed info | |||||
| ParallelOp::Print(out, show_all); | |||||
| // Then show any custom derived-internal stuff | |||||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ | |||||
| << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | |||||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n"; | |||||
| for (int i = 0; i < csv_files_list_.size(); ++i) { | |||||
| out << " " << csv_files_list_[i]; | |||||
| } | |||||
| out << "\n\n"; | |||||
| } | |||||
| } | |||||
| // Pops an element from a queue in io_block_queues | |||||
| Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes an element to a queue in io_block_queues | |||||
| Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||||
| return Status::OK(); | |||||
| } | |||||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||||
| std::mt19937 rng(seed); | |||||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||||
| } | |||||
| Status CsvOp::WaitToFillIOBlockQueue() { | |||||
| // must be called first if called by worker spanwed by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::vector<int64_t> i_keys; | |||||
| if (shuffle_files_) { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| i_keys.push_back(it.key()); | |||||
| } | |||||
| } | |||||
| uint32_t seed = 0; | |||||
| while (true) { | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||||
| io_block_queue_wait_post_.Clear(); | |||||
| if (finished_reading_dataset_) { | |||||
| break; | |||||
| } | |||||
| if (shuffle_files_) { | |||||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||||
| } | |||||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||||
| int32_t queue_index = 0; | |||||
| int64_t pre_count = 0; | |||||
| int64_t start_offset = 0; | |||||
| int64_t end_offset = 0; | |||||
| bool finish = false; | |||||
| while (!finish) { | |||||
| std::vector<std::pair<std::string, int64_t>> file_index; | |||||
| if (!i_keys.empty()) { | |||||
| for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | |||||
| { | |||||
| if (!load_io_block_queue_) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it)); | |||||
| } | |||||
| } else { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| { | |||||
| if (!load_io_block_queue_) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key())); | |||||
| } | |||||
| } | |||||
| for (auto file_info : file_index) { | |||||
| if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { | |||||
| auto ioBlock = | |||||
| std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); | |||||
| queue_index = (queue_index + 1) % num_workers_; | |||||
| } | |||||
| pre_count += filename_numrows_[file_info.first]; | |||||
| } | |||||
| if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||||
| finish = false; | |||||
| } else { | |||||
| finish = true; | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); | |||||
| return Status::OK(); | |||||
| } | |||||
| void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||||
| bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count) { | |||||
| *start_offset = 0; | |||||
| *end_offset = 0; | |||||
| bool push = false; | |||||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||||
| if (device_id_ + 1 < 0) { | |||||
| MS_LOG(ERROR) << "Device id is invalid"; | |||||
| return false; | |||||
| } | |||||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||||
| *start_offset = start_index - pre_count; | |||||
| push = true; | |||||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| if (pre_count >= start_index && pre_count < end_index) { | |||||
| *start_offset = 0; | |||||
| push = true; | |||||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| return push; | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| Status CsvOp::PostEndOfEpoch(int32_t queue_index) { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::CalculateNumRowsPerShard() { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| int64_t count = CountTotalRows(it.value()); | |||||
| filename_numrows_[it.value()] = count; | |||||
| all_num_rows_ += count; | |||||
| } | |||||
| if (all_num_rows_ == 0) { | |||||
| RETURN_STATUS_UNEXPECTED( | |||||
| "There is no valid data matching the dataset API CsvDataset. Please check file path or dataset API " | |||||
| "validation first."); | |||||
| } | |||||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | |||||
| return Status::OK(); | |||||
| } | |||||
| int64_t CsvOp::CountTotalRows(const std::string &file) { | |||||
| CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); | |||||
| std::ifstream ifs; | |||||
| ifs.open(file, std::ifstream::in); | |||||
| if (column_name_list_.empty()) { | |||||
| std::string tmp; | |||||
| getline(ifs, tmp); | |||||
| } | |||||
| csv_parser.Reset(); | |||||
| while (ifs.good()) { | |||||
| char chr = ifs.get(); | |||||
| if (csv_parser.countRows(chr) != 0) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| return csv_parser.total_rows_; | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| Status CsvOp::PostEndOfData() { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) { | |||||
| std::shared_ptr<CsvOp> op; | |||||
| *count = 0; | |||||
| if (csv_header) { | |||||
| RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).Build(&op)); | |||||
| } else { | |||||
| RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).SetColumName({""}).Build(&op)); | |||||
| } | |||||
| for (auto file : files) { | |||||
| *count += op->CountTotalRows(file); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::string> CsvOp::split(const std::string &s, char delim) { | |||||
| std::vector<std::string> res; | |||||
| std::stringstream ss(s); | |||||
| std::string item; | |||||
| while (getline(ss, item, delim)) { | |||||
| res.push_back(item); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| Status CsvOp::ComputeColMap() { | |||||
| // Set the column name mapping (base class field) | |||||
| if (column_name_id_map_.empty()) { | |||||
| if (column_name_list_.empty()) { | |||||
| std::string line; | |||||
| std::ifstream handle(csv_files_list_[0]); | |||||
| getline(handle, line); | |||||
| std::vector<std::string> col_names = split(line, field_delim_); | |||||
| for (int32_t i = 0; i < col_names.size(); i++) { | |||||
| column_name_id_map_[col_names[i]] = i; | |||||
| } | |||||
| } else { | |||||
| for (int32_t i = 0; i < column_name_list_.size(); i++) { | |||||
| column_name_id_map_[column_name_list_[i]] = i; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| MS_LOG(WARNING) << "Column name map is already set!"; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,451 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ | |||||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <utility> | |||||
| #include <limits> | |||||
| #include "minddata/dataset/util/auto_index.h" | |||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| const size_t CSV_BUFFER_SIZE = 4096; | |||||
| using StringIndex = AutoIndexObj<std::string>; | |||||
| class JaggedConnector; | |||||
| class CsvOp : public ParallelOp { | |||||
| public: | |||||
| enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; | |||||
| struct BaseRecord { | |||||
| public: | |||||
| BaseRecord() = default; | |||||
| explicit BaseRecord(RecordType t) : type(t) {} | |||||
| virtual ~BaseRecord() {} | |||||
| RecordType type; | |||||
| }; | |||||
| template <typename T> | |||||
| class Record : public BaseRecord { | |||||
| public: | |||||
| Record() = default; | |||||
| Record(RecordType t, T v) : BaseRecord(t), value(v) {} | |||||
| ~Record() {} | |||||
| T value; | |||||
| }; | |||||
| // CsvParser is a class that parsing CSV file. | |||||
| // We design a state machine to implement CSV syntactic analysis. It contains two state diagram,'sd' and 'sdl'. | |||||
| // The 'sd' is used for parsing CSV syntactic, it's complete and complicate. | |||||
| // The 'sdl' is used for counting the record rows, it's concise and it runs fast. | |||||
| struct CsvParser { | |||||
| public: | |||||
| CsvParser() = delete; | |||||
| CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim, | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default) | |||||
| : worker_id_(worker_id), | |||||
| buffer_connector_(connector), | |||||
| csv_rows_per_buffer_(rows_per_buffer), | |||||
| csv_field_delim_(field_delim), | |||||
| column_default_(column_default), | |||||
| cur_state_(START_OF_FILE), | |||||
| pos_(0), | |||||
| cur_row_(0), | |||||
| cur_col_(0), | |||||
| total_rows_(0), | |||||
| start_offset_(0), | |||||
| end_offset_(std::numeric_limits<int64_t>::max()) { | |||||
| cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||||
| initCsvParser(); | |||||
| } | |||||
| ~CsvParser() = default; | |||||
| void Reset() { | |||||
| cur_state_ = START_OF_FILE; | |||||
| pos_ = 0; | |||||
| cur_row_ = 0; | |||||
| cur_col_ = 0; | |||||
| } | |||||
| void setStartOffset(int64_t start_offset) { start_offset_ = start_offset; } | |||||
| void setEndOffset(int64_t end_offset) { end_offset_ = end_offset; } | |||||
| int processMessage(char c) { | |||||
| Message m = getMessage(c); | |||||
| StateDiagram::iterator it = sd.find({cur_state_, m}); | |||||
| if (it == sd.end()) { | |||||
| return -1; | |||||
| } | |||||
| cur_state_ = it->second.first; | |||||
| return it->second.second(*this, c); | |||||
| } | |||||
| int countRows(char c); | |||||
| Status initCsvParser(); | |||||
| enum State : uint8_t { | |||||
| START_OF_FILE = 0, | |||||
| UNQUOTE, | |||||
| DELIM, | |||||
| QUOTE, | |||||
| SECOND_QUOTE, | |||||
| END_OF_LINE, | |||||
| END_OF_FILE, | |||||
| EXCEPTION | |||||
| }; | |||||
| enum Message : uint8_t { | |||||
| MS_NORMAL = 0, | |||||
| MS_DELIM, | |||||
| MS_QUOTE, | |||||
| MS_END_OF_LINE, | |||||
| MS_END_OF_FILE, | |||||
| }; | |||||
| typedef std::pair<State, Message> StateMessagePair; | |||||
| typedef std::pair<State, std::function<int(CsvParser &, char)>> StateActionPair; | |||||
| typedef std::map<StateMessagePair, StateActionPair> StateDiagram; | |||||
| Message getMessage(char c) { | |||||
| if (c == csv_field_delim_) { | |||||
| return Message::MS_DELIM; | |||||
| } else if (c == '"') { | |||||
| return Message::MS_QUOTE; | |||||
| } else if (c == '\r' || c == '\n') { | |||||
| return Message::MS_END_OF_LINE; | |||||
| } else if (c == std::char_traits<char>::eof()) { | |||||
| return Message::MS_END_OF_FILE; | |||||
| } else { | |||||
| return Message::MS_NORMAL; | |||||
| } | |||||
| } | |||||
| int null_func(char c) { return 0; } | |||||
| int put_char(char c) { | |||||
| if (pos_ >= str_buf_.size()) { | |||||
| str_buf_.resize(str_buf_.size() * 2); | |||||
| } | |||||
| str_buf_[pos_] = c; | |||||
| pos_++; | |||||
| return 0; | |||||
| } | |||||
| int put_record(char c); | |||||
| int put_row(char c); | |||||
| int end_file(char c); | |||||
| int add_row(char c) { | |||||
| total_rows_++; | |||||
| return 0; | |||||
| } | |||||
| int catch_exception(char c) { | |||||
| MS_LOG(ERROR) << "Invalid syntax!"; | |||||
| return -1; | |||||
| } | |||||
| int32_t worker_id_; | |||||
| std::shared_ptr<JaggedConnector> buffer_connector_; | |||||
| int64_t csv_rows_per_buffer_; | |||||
| const char csv_field_delim_; | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_; | |||||
| State cur_state_; | |||||
| size_t pos_; | |||||
| int cur_row_; | |||||
| int cur_col_; | |||||
| int64_t total_rows_; | |||||
| int64_t start_offset_; | |||||
| int64_t end_offset_; | |||||
| StateDiagram sd; | |||||
| StateDiagram sdl; | |||||
| std::vector<char> str_buf_; | |||||
| std::unique_ptr<TensorQTable> tensor_table_; | |||||
| std::unique_ptr<DataBuffer> cur_buffer_; | |||||
| }; | |||||
| class Builder { | |||||
| public: | |||||
| // Builder constructor. Creates the builder object. | |||||
| // @note No default args | |||||
| // @return This is a constructor. | |||||
| Builder(); | |||||
| // Default destructor | |||||
| ~Builder() = default; | |||||
| // Checks if the inputs of the builder is valid. | |||||
| // @return Status - the error code returned. | |||||
| Status ValidateInputs() const; | |||||
| // Create the final object. | |||||
| // @param op - dataset op. | |||||
| // @return - the error code return. | |||||
| Status Build(std::shared_ptr<CsvOp> *op); | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetNumWorkers(int32_t num_workers) { | |||||
| builder_num_workers_ = num_workers; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||||
| builder_op_connector_size_ = op_connector_size; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { | |||||
| builder_rows_per_buffer_ = rows_per_buffer; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetNumDevices(int64_t num_dev) { | |||||
| builder_num_devices_ = num_dev; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetDeviceId(int64_t dev_id) { | |||||
| builder_device_id_ = dev_id; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetCsvFilesList(const std::vector<std::string> &files_list) { | |||||
| builder_csv_files_list_ = files_list; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetShuffleFiles(bool shuffle_files) { | |||||
| builder_shuffle_files_ = shuffle_files; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetNumSamples(int64_t num_samples) { | |||||
| builder_num_samples_ = num_samples; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetFieldDelim(char field_delim) { | |||||
| builder_field_delim_ = field_delim; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetColumDefault(std::vector<std::shared_ptr<CsvOp::BaseRecord>> record_list) { | |||||
| builder_column_default_list_ = record_list; | |||||
| return *this; | |||||
| } | |||||
| // Setter method. | |||||
| // @return Builder - setter method returns reference to the builder. | |||||
| Builder &SetColumName(std::vector<std::string> col_name_list) { | |||||
| builder_column_name_list_ = col_name_list; | |||||
| return *this; | |||||
| } | |||||
| private: | |||||
| int32_t builder_device_id_; | |||||
| int32_t builder_num_devices_; | |||||
| int32_t builder_num_workers_; | |||||
| int32_t builder_op_connector_size_; | |||||
| int64_t builder_rows_per_buffer_; | |||||
| int64_t builder_num_samples_; | |||||
| int32_t builder_worker_connector_size_; | |||||
| std::vector<std::string> builder_csv_files_list_; | |||||
| bool builder_shuffle_files_; | |||||
| char builder_field_delim_; | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | |||||
| std::vector<std::string> builder_column_name_list_; | |||||
| }; | |||||
| // Constructor of CsvOp | |||||
| CsvOp() = delete; | |||||
| CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | |||||
| int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); | |||||
| // Default destructor | |||||
| ~CsvOp() = default; | |||||
| // A print method typically used for debugging | |||||
| // @param out - The output stream to write output to | |||||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| // Instantiates the internal queues and connectors | |||||
| // @return Status - the error code returned | |||||
| Status Init(); | |||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - the error code returned. | |||||
| Status operator()() override; | |||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| // @return Status - the error code returned. | |||||
| Status Reset() override; | |||||
| // Get total rows in files. | |||||
| // @param files - all csv files. | |||||
| // @param csv_header - a bool that indicates csv file include header line | |||||
| // @param count - number of rows. | |||||
| // @return Status - the error coed returned. | |||||
| static Status CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count); | |||||
| // File names getter | |||||
| // @return Vector of the input file names | |||||
| std::vector<std::string> FileNames() { return csv_files_list_; } | |||||
| private: | |||||
| // The entry point for when workers are launched. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| Status WorkerEntry(int32_t worker_id) override; | |||||
| // Parses a single row and puts the data into a tensor table. | |||||
| // @param line - the content of the row. | |||||
| // @param tensor_table - the tensor table to put the parsed data in. | |||||
| // @param row - the id of the row filled in the tensor table. | |||||
| // @return Status - the error code returned. | |||||
| Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row); | |||||
| // Reads a csv file and loads the data into multiple buffers. | |||||
| // @param file - the file to read. | |||||
| // @param start_offset - the start offset of file. | |||||
| // @param end_offset - the end offset of file. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id); | |||||
| // Pops an element from a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to pop from. | |||||
| // @param out_block - the popped element. | |||||
| // @return Status - the error code returned. | |||||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||||
| // Pushes an element to a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to push to. | |||||
| // @param io_block - the element to push onto the queue. | |||||
| // @return Status - the error code returned. | |||||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||||
| // @return Status - the error code returned. | |||||
| Status WaitToFillIOBlockQueue(); | |||||
| // Fill the IOBlockQueue. | |||||
| // @para i_keys - keys of file to fill to the IOBlockQueue | |||||
| // @return Status - the error code returned. | |||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||||
| void NotifyToFillIOBlockQueue(); | |||||
| // Select file and push it to the block queue. | |||||
| // @param file_name - File name. | |||||
| // @param start_file - If file contains the first sample of data. | |||||
| // @param end_file - If file contains the end sample of data. | |||||
| // @param pre_count - Total rows of previous files. | |||||
| // @return Status - the error code returned. | |||||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfEpoch(int32_t queue_index); | |||||
| // Calculate number of rows in each shard. | |||||
| // @return Status - the error code returned. | |||||
| Status CalculateNumRowsPerShard(); | |||||
| // Count number of rows in each file. | |||||
| // @param filename - csv file name. | |||||
| // @return int64_t - the total number of rows in file. | |||||
| int64_t CountTotalRows(const std::string &file); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfData(); | |||||
| // Private function for computing the assignment of the column name map. | |||||
| // @return - Status | |||||
| Status ComputeColMap() override; | |||||
| // Split string based on a character delimiter | |||||
| // @return - the a string vector | |||||
| std::vector<std::string> split(const std::string &s, char delim); | |||||
| int32_t device_id_; | |||||
| bool shuffle_files_; | |||||
| bool finished_reading_dataset_; | |||||
| int32_t num_devices_; | |||||
| int64_t rows_per_buffer_; | |||||
| bool load_io_block_queue_; | |||||
| int64_t num_rows_per_shard_; | |||||
| int64_t all_num_rows_; | |||||
| int64_t num_samples_; | |||||
| std::map<std::string, int64_t> filename_numrows_; | |||||
| std::unique_ptr<StringIndex> filename_index_; | |||||
| std::vector<std::string> csv_files_list_; | |||||
| WaitPost io_block_queue_wait_post_; | |||||
| std::shared_ptr<JaggedConnector> jagged_buffer_connector_; | |||||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||||
| bool load_jagged_connector_; | |||||
| char field_delim_; | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_; | |||||
| std::vector<std::string> column_name_list_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ | |||||
| @@ -21,7 +21,7 @@ can also create samplers with this module to sample data. | |||||
| from .core import config | from .core import config | ||||
| from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ | from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ | ||||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ | GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ | ||||
| TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset | |||||
| TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset | |||||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | ||||
| WeightedRandomSampler, Sampler | WeightedRandomSampler, Sampler | ||||
| from .engine.cache_client import DatasetCache | from .engine.cache_client import DatasetCache | ||||
| @@ -31,5 +31,5 @@ from .engine.graphdata import GraphData | |||||
| __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", | __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", | ||||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | "MindDataset", "GeneratorDataset", "TFRecordDataset", | ||||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", | "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", | ||||
| "CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", | |||||
| "CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler", | |||||
| "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] | "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] | ||||
| @@ -29,7 +29,7 @@ from .samplers import * | |||||
| from ..core import config | from ..core import config | ||||
| __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", | __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", | ||||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", | |||||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "CSVDataset", | |||||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | ||||
| "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", | "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", | ||||
| "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] | "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] | ||||
| @@ -33,7 +33,7 @@ import copy | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | ||||
| MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo | |||||
| MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo | |||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | ||||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | ||||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save | |||||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset | |||||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | ||||
| from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE | from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE | ||||
| @@ -1012,7 +1012,7 @@ class Dataset: | |||||
| if isinstance(sampler, samplers.DistributedSampler): | if isinstance(sampler, samplers.DistributedSampler): | ||||
| dev_id = sampler.shard_id | dev_id = sampler.shard_id | ||||
| return "", dev_id | return "", dev_id | ||||
| if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)): | |||||
| if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset, CSVDataset)): | |||||
| if output_dataset.shard_id is not None: | if output_dataset.shard_id is not None: | ||||
| dev_id = output_dataset.shard_id | dev_id = output_dataset.shard_id | ||||
| return "", dev_id | return "", dev_id | ||||
| @@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset): | |||||
| } | } | ||||
| Args: | Args: | ||||
| dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of | |||||
| files. The list will be sorted in a lexicographical order. | |||||
| dataset_files (str or a list of strings): String or list of files to be read or glob strings to search for | |||||
| a pattern of files. The list will be sorted in a lexicographical order. | |||||
| task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. | task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. | ||||
| (default=AFQMC). | (default=AFQMC). | ||||
| usage (str, optional): Need train, test or eval data (default="train"). | usage (str, optional): Need train, test or eval data (default="train"). | ||||
| @@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset): | |||||
| return False | return False | ||||
| class CSVDataset(SourceDataset): | |||||
| """ | |||||
| A source dataset that reads and parses CSV datasets. | |||||
| Args: | |||||
| dataset_files (str or a list of strings): String or list of files to be read or glob strings to search | |||||
| for a pattern of files. The list will be sorted in a lexicographical order. | |||||
| field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=','). | |||||
| column_defaults (list, optional): List of default values for the CSV field (default=None). Each item | |||||
| in the list is either a valid type (float, int, or string). If this is not provided, treats all | |||||
| columns as string type. | |||||
| column_names (list of string, optional): List of column names of the dataset (default=None). If this | |||||
| is not provided, infers the column_names from the first row of CSV file. | |||||
| num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). | |||||
| num_parallel_workers (int, optional): number of workers to read the data | |||||
| (default=None, number set in the config). | |||||
| shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). | |||||
| If shuffle is False, no shuffling will be performed; | |||||
| If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL | |||||
| Otherwise, there are two levels of shuffling: | |||||
| - Shuffle.GLOBAL: Shuffle both the files and samples. | |||||
| - Shuffle.FILES: Shuffle files only. | |||||
| num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||||
| argument should be specified only when num_shards is also specified. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files | |||||
| >>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4']) | |||||
| """ | |||||
| @check_csvdataset | |||||
| def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, | |||||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||||
| super().__init__(num_parallel_workers) | |||||
| self.dataset_files = self._find_files(dataset_files) | |||||
| self.dataset_files.sort() | |||||
| self.field_delim = field_delim | |||||
| self.column_defaults = column_defaults | |||||
| self.column_names = column_names | |||||
| self.num_samples = num_samples | |||||
| if not isinstance(shuffle, (bool, Shuffle)): | |||||
| raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") | |||||
| if not isinstance(shuffle, Shuffle): | |||||
| if shuffle: | |||||
| self.shuffle_level = Shuffle.GLOBAL | |||||
| self.shuffle_files = True | |||||
| else: | |||||
| self.shuffle_level = None | |||||
| self.shuffle_files = False | |||||
| else: | |||||
| self.shuffle_level = shuffle | |||||
| self.shuffle_files = True | |||||
| self.num_shards = num_shards | |||||
| self.shard_id = shard_id | |||||
| def get_args(self): | |||||
| args = super().get_args() | |||||
| args["dataset_files"] = self.dataset_files | |||||
| args['field_delim'] = self.field_delim | |||||
| args['column_defaults'] = self.column_defaults | |||||
| args['column_names'] = self.column_names | |||||
| args["num_samples"] = self.num_samples | |||||
| if self.shuffle_files is not None: | |||||
| args["shuffle_files"] = self.shuffle_files | |||||
| args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) | |||||
| args["shuffle"] = self.shuffle_level | |||||
| args["num_shards"] = self.num_shards | |||||
| args["shard_id"] = self.shard_id | |||||
| return args | |||||
| def get_dataset_size(self): | |||||
| """ | |||||
| Get the number of batches in an epoch. | |||||
| Return: | |||||
| Number, number of batches. | |||||
| """ | |||||
| if self._dataset_size is None: | |||||
| num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) | |||||
| num_rows = get_num_rows(num_rows, self.num_shards) | |||||
| if self.num_samples is None: | |||||
| return num_rows | |||||
| return min(self.num_samples, num_rows) | |||||
| return self._dataset_size | |||||
| def is_shuffled(self): | |||||
| return self.shuffle_files | |||||
| def is_sharded(self): | |||||
| if self.num_shards is not None: | |||||
| return self.num_shards > 1 | |||||
| return False | |||||
| class TextFileDataset(SourceDataset): | class TextFileDataset(SourceDataset): | ||||
| """ | """ | ||||
| A source dataset that reads and parses datasets stored on disk in text format. | A source dataset that reads and parses datasets stored on disk in text format. | ||||
| @@ -185,6 +185,8 @@ class Iterator: | |||||
| op_type = OpName.SENTENCEPIECEVOCAB | op_type = OpName.SENTENCEPIECEVOCAB | ||||
| elif isinstance(dataset, de.CLUEDataset): | elif isinstance(dataset, de.CLUEDataset): | ||||
| op_type = OpName.CLUE | op_type = OpName.CLUE | ||||
| elif isinstance(dataset, de.CSVDataset): | |||||
| op_type = OpName.CSV | |||||
| else: | else: | ||||
| raise ValueError("Unsupported DatasetOp") | raise ValueError("Unsupported DatasetOp") | ||||
| @@ -787,6 +787,49 @@ def check_cluedataset(method): | |||||
| return new_method | return new_method | ||||
| def check_csvdataset(method): | |||||
| """A wrapper that wrap a parameter checker to the original Dataset(CSVDataset).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||||
| # check dataset_files; required argument | |||||
| dataset_files = param_dict.get('dataset_files') | |||||
| type_check(dataset_files, (str, list), "dataset files") | |||||
| # check field_delim | |||||
| field_delim = param_dict.get('field_delim') | |||||
| type_check(field_delim, (str,), 'field delim') | |||||
| if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: | |||||
| raise ValueError("field_delim is not legal.") | |||||
| # check column_defaults | |||||
| column_defaults = param_dict.get('column_defaults') | |||||
| if column_defaults is not None: | |||||
| if not isinstance(column_defaults, list): | |||||
| raise TypeError("column_defaults should be type of list.") | |||||
| for item in column_defaults: | |||||
| if not isinstance(item, (str, int, float)): | |||||
| raise TypeError("column type is not legal.") | |||||
| # check column_names: must be list of string. | |||||
| column_names = param_dict.get("column_names") | |||||
| if column_names is not None: | |||||
| all_string = all(isinstance(item, str) for item in column_names) | |||||
| if not all_string: | |||||
| raise TypeError("column_names should be a list of str.") | |||||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||||
| check_sampler_shuffle_shard_options(param_dict) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_textfiledataset(method): | def check_textfiledataset(method): | ||||
| """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" | """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" | ||||
| @@ -77,6 +77,7 @@ SET(DE_UT_SRCS | |||||
| celeba_op_test.cc | celeba_op_test.cc | ||||
| take_op_test.cc | take_op_test.cc | ||||
| clue_op_test.cc | clue_op_test.cc | ||||
| csv_op_test.cc | |||||
| text_file_op_test.cc | text_file_op_test.cc | ||||
| filter_op_test.cc | filter_op_test.cc | ||||
| concat_op_test.cc | concat_op_test.cc | ||||
| @@ -0,0 +1,122 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "common/common.h" | |||||
| #include "common/utils.h" | |||||
| #include "gtest/gtest.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace common = mindspore::common; | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestCSVOp : public UT::DatasetOpTesting { | |||||
| }; | |||||
| TEST_F(MindDataTestCSVOp, TestCSVBasic) { | |||||
| // Start with an empty execution tree | |||||
| auto tree = std::make_shared<ExecutionTree>(); | |||||
| std::string dataset_path; | |||||
| dataset_path = datasets_root_path_ + "/testCSV/1.csv"; | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list; | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0)); | |||||
| std::shared_ptr<CsvOp> op; | |||||
| CsvOp::Builder builder; | |||||
| builder.SetCsvFilesList({dataset_path}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetNumWorkers(16) | |||||
| .SetShuffleFiles(false) | |||||
| .SetOpConnectorSize(2) | |||||
| .SetFieldDelim(',') | |||||
| .SetColumDefault(column_default_list) | |||||
| .SetColumName({"col1", "col2", "col3", "col4"}); | |||||
| Status rc = builder.Build(&op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = tree->AssociateNode(op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = tree->AssignRoot(op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| MS_LOG(INFO) << "Launching tree and begin iteration."; | |||||
| rc = tree->Prepare(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = tree->Launch(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Start the loop of reading tensors from our pipeline | |||||
| DatasetIterator di(tree); | |||||
| TensorRow tensor_list; | |||||
| rc = di.FetchNextTensorRow(&tensor_list); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| int row_count = 0; | |||||
| while (!tensor_list.empty()) { | |||||
| // Display the tensor by calling the printer on it | |||||
| for (int i = 0; i < tensor_list.size(); i++) { | |||||
| std::ostringstream ss; | |||||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||||
| MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; | |||||
| } | |||||
| rc = di.FetchNextTensorRow(&tensor_list); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| row_count++; | |||||
| } | |||||
| ASSERT_EQ(row_count, 3); | |||||
| } | |||||
| TEST_F(MindDataTestCSVOp, TestTotalRows) { | |||||
| std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv"; | |||||
| std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv"; | |||||
| std::vector<std::string> files; | |||||
| files.push_back(csv_file1); | |||||
| int64_t total_rows = 0; | |||||
| CsvOp::CountAllFileRows(files, false, &total_rows); | |||||
| ASSERT_EQ(total_rows, 3); | |||||
| files.clear(); | |||||
| files.push_back(csv_file2); | |||||
| CsvOp::CountAllFileRows(files, false, &total_rows); | |||||
| ASSERT_EQ(total_rows, 5); | |||||
| files.clear(); | |||||
| files.push_back(csv_file1); | |||||
| files.push_back(csv_file2); | |||||
| CsvOp::CountAllFileRows(files, false, &total_rows); | |||||
| ASSERT_EQ(total_rows, 8); | |||||
| files.clear(); | |||||
| } | |||||
| @@ -0,0 +1,3 @@ | |||||
| 1,2,3,4 | |||||
| 5,6,7,8 | |||||
| 9,10,11,12 | |||||
| @@ -0,0 +1,8 @@ | |||||
| ,"222",3,"4""" | |||||
| "5",6,,"8" | |||||
| 9,10,"1""1",12 | |||||
| ,,"", | |||||
| ,,, | |||||
| a,b,c,"" | |||||
| a,b,c,d | |||||
| @@ -0,0 +1 @@ | |||||
| 大家,早上好,中午好,下午好,晚上好 | |||||
| @@ -0,0 +1,2 @@ | |||||
| "a,b","c""d","e | |||||
| f"," g " | |||||
| @@ -0,0 +1,3 @@ | |||||
| 1,2,3,4 | |||||
| 5,6,7,8 | |||||
| a,"c",d,"e | |||||
| @@ -0,0 +1,2 @@ | |||||
| col1,col2,col3,col4 | |||||
| a,b,c,d | |||||
| @@ -0,0 +1 @@ | |||||
| 3,0.3,4,55.5 | |||||
| @@ -0,0 +1 @@ | |||||
| "a","b","c","d" | |||||
| @@ -0,0 +1 @@ | |||||
| a|b|c|d | |||||
| @@ -0,0 +1,10 @@ | |||||
| 1,2,3,4 | |||||
| "a","b","c | |||||
| ","d | |||||
| e" | |||||
| 5,6,7,8 | |||||
| 9,10,11,12 | |||||
| a,"b | |||||
| ",c,"d | |||||
| e" | |||||
| @@ -0,0 +1,238 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import mindspore.dataset as ds | |||||
| import numpy as np | |||||
| import pytest | |||||
| DATA_FILE = '../data/dataset/testCSV/1.csv' | |||||
| def test_csv_dataset_basic(): | |||||
| """ | |||||
| Test CSV with repeat, skip and so on | |||||
| """ | |||||
| TRAIN_FILE = '../data/dataset/testCSV/1.csv' | |||||
| buffer = [] | |||||
| data = ds.CSVDataset( | |||||
| TRAIN_FILE, | |||||
| column_defaults=["0", 0, 0.0, "0"], | |||||
| column_names=['1', '2', '3', '4'], | |||||
| shuffle=False) | |||||
| data = data.repeat(2) | |||||
| data = data.skip(2) | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.append(d) | |||||
| assert len(buffer) == 4 | |||||
| def test_csv_dataset_one_file(): | |||||
| data = ds.CSVDataset( | |||||
| DATA_FILE, | |||||
| column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.append(d) | |||||
| assert len(buffer) == 3 | |||||
| def test_csv_dataset_all_file(): | |||||
| APPEND_FILE = '../data/dataset/testCSV/2.csv' | |||||
| data = ds.CSVDataset( | |||||
| [DATA_FILE, APPEND_FILE], | |||||
| column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.append(d) | |||||
| assert len(buffer) == 10 | |||||
| def test_csv_dataset_num_samples(): | |||||
| data = ds.CSVDataset( | |||||
| DATA_FILE, | |||||
| column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False, num_samples=2) | |||||
| count = 0 | |||||
| for _ in data.create_dict_iterator(): | |||||
| count += 1 | |||||
| assert count == 2 | |||||
| def test_csv_dataset_distribution(): | |||||
| TEST_FILE = '../data/dataset/testCSV/1.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False, num_shards=2, shard_id=0) | |||||
| count = 0 | |||||
| for _ in data.create_dict_iterator(): | |||||
| count += 1 | |||||
| assert count == 2 | |||||
| def test_csv_dataset_quoted(): | |||||
| TEST_FILE = '../data/dataset/testCSV/quoted.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=["", "", "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.extend([d['col1'].item().decode("utf8"), | |||||
| d['col2'].item().decode("utf8"), | |||||
| d['col3'].item().decode("utf8"), | |||||
| d['col4'].item().decode("utf8")]) | |||||
| assert buffer == ['a', 'b', 'c', 'd'] | |||||
| def test_csv_dataset_separated(): | |||||
| TEST_FILE = '../data/dataset/testCSV/separated.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| field_delim='|', | |||||
| column_defaults=["", "", "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.extend([d['col1'].item().decode("utf8"), | |||||
| d['col2'].item().decode("utf8"), | |||||
| d['col3'].item().decode("utf8"), | |||||
| d['col4'].item().decode("utf8")]) | |||||
| assert buffer == ['a', 'b', 'c', 'd'] | |||||
| def test_csv_dataset_embedded(): | |||||
| TEST_FILE = '../data/dataset/testCSV/embedded.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=["", "", "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.extend([d['col1'].item().decode("utf8"), | |||||
| d['col2'].item().decode("utf8"), | |||||
| d['col3'].item().decode("utf8"), | |||||
| d['col4'].item().decode("utf8")]) | |||||
| assert buffer == ['a,b', 'c"d', 'e\nf', ' g '] | |||||
| def test_csv_dataset_chinese(): | |||||
| TEST_FILE = '../data/dataset/testCSV/chinese.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=["", "", "", "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4', 'col5'], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.extend([d['col1'].item().decode("utf8"), | |||||
| d['col2'].item().decode("utf8"), | |||||
| d['col3'].item().decode("utf8"), | |||||
| d['col4'].item().decode("utf8"), | |||||
| d['col5'].item().decode("utf8")]) | |||||
| assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好'] | |||||
| def test_csv_dataset_header(): | |||||
| TEST_FILE = '../data/dataset/testCSV/header.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=["", "", "", ""], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.extend([d['col1'].item().decode("utf8"), | |||||
| d['col2'].item().decode("utf8"), | |||||
| d['col3'].item().decode("utf8"), | |||||
| d['col4'].item().decode("utf8")]) | |||||
| assert buffer == ['a', 'b', 'c', 'd'] | |||||
| def test_csv_dataset_number(): | |||||
| TEST_FILE = '../data/dataset/testCSV/number.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=[0.0, 0.0, 0, 0.0], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| buffer = [] | |||||
| for d in data.create_dict_iterator(): | |||||
| buffer.extend([d['col1'].item(), | |||||
| d['col2'].item(), | |||||
| d['col3'].item(), | |||||
| d['col4'].item()]) | |||||
| assert np.allclose(buffer, [3.0, 0.3, 4, 55.5]) | |||||
| def test_csv_dataset_size(): | |||||
| TEST_FILE = '../data/dataset/testCSV/size.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=[0.0, 0.0, 0, 0.0], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| assert data.get_dataset_size() == 5 | |||||
| def test_csv_dataset_exception(): | |||||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=["", "", "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| with pytest.raises(Exception) as err: | |||||
| for _ in data.create_dict_iterator(): | |||||
| pass | |||||
| assert "Failed to parse CSV file" in str(err.value) | |||||
| def test_csv_dataset_type_error(): | |||||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | |||||
| data = ds.CSVDataset( | |||||
| TEST_FILE, | |||||
| column_defaults=["", 0, "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| with pytest.raises(Exception) as err: | |||||
| for _ in data.create_dict_iterator(): | |||||
| pass | |||||
| assert "invalid argument of stoi" in str(err.value) | |||||
| if __name__ == "__main__": | |||||
| test_csv_dataset_basic() | |||||
| test_csv_dataset_one_file() | |||||
| test_csv_dataset_all_file() | |||||
| test_csv_dataset_num_samples() | |||||
| test_csv_dataset_distribution() | |||||
| test_csv_dataset_quoted() | |||||
| test_csv_dataset_separated() | |||||
| test_csv_dataset_embedded() | |||||
| test_csv_dataset_chinese() | |||||
| test_csv_dataset_header() | |||||
| test_csv_dataset_number() | |||||
| test_csv_dataset_size() | |||||
| test_csv_dataset_exception() | |||||
| test_csv_dataset_type_error() | |||||