Merge pull request !425 from yanghaitao/yht_textfiledatasetv2tags/v0.2.0-alpha
| @@ -28,10 +28,10 @@ | |||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "mindrecord/include/shard_category.h" | |||
| #include "mindrecord/include/shard_sample.h" | |||
| #include "mindrecord/include/shard_shuffle.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -61,7 +61,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kVoc, &DEPipeline::ParseVOCOp}, | |||
| {kCifar10, &DEPipeline::ParseCifar10Op}, | |||
| {kCifar100, &DEPipeline::ParseCifar100Op}, | |||
| {kCelebA, &DEPipeline::ParseCelebAOp}}; | |||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | |||
| {kTextFile, &DEPipeline::ParseTextFileOp}}; | |||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | |||
| try { | |||
| @@ -985,5 +986,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| // Required arguments | |||
| std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| (void)builder->SetTextFilesList(ToStringVector(args["dataset_files"])); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); | |||
| } | |||
| // Optional arguments | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_shards") { | |||
| (void)builder->SetNumDevices(ToInt(value)); | |||
| } else if (key == "shard_id") { | |||
| (void)builder->SetDeviceId(ToInt(value)); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<TextFileOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -58,7 +58,8 @@ enum OpName { | |||
| kVoc, | |||
| kCifar10, | |||
| kCifar100, | |||
| kCelebA | |||
| kCelebA, | |||
| kTextFile | |||
| }; | |||
| // The C++ binder class that we expose to the python script. | |||
| @@ -148,6 +149,8 @@ class DEPipeline { | |||
| Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| private: | |||
| // Execution tree that links the dataset operators. | |||
| std::shared_ptr<ExecutionTree> tree_; | |||
| @@ -55,6 +55,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/kernels/data/to_float16_op.h" | |||
| #include "dataset/util/random.h" | |||
| #include "mindrecord/include/shard_operator.h" | |||
| @@ -176,6 +177,17 @@ void bindDatasetOps(py::module *m) { | |||
| THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count)); | |||
| return count; | |||
| }); | |||
| (void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp") | |||
| .def_static("get_num_rows", [](const py::list &files) { | |||
| int64_t count = 0; | |||
| std::vector<std::string> filenames; | |||
| for (auto file : files) { | |||
| !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); | |||
| } | |||
| THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); | |||
| return count; | |||
| }); | |||
| } | |||
| void bindTensor(py::module *m) { | |||
| (void)py::class_<GlobalContext>(*m, "GlobalContext") | |||
| @@ -463,7 +475,8 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("VOC", OpName::kVoc) | |||
| .value("CIFAR10", OpName::kCifar10) | |||
| .value("CIFAR100", OpName::kCifar100) | |||
| .value("CELEBA", OpName::kCelebA); | |||
| .value("CELEBA", OpName::kCelebA) | |||
| .value("TEXTFILE", OpName::kTextFile); | |||
| (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic()) | |||
| .value("DE_INTER_LINEAR", InterpolationMode::kLinear) | |||
| @@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT | |||
| manifest_op.cc | |||
| cifar_op.cc | |||
| celeba_op.cc | |||
| text_file_op.cc | |||
| ) | |||
| add_dependencies(engine-datasetops-source mindspore::protobuf) | |||
| @@ -0,0 +1,459 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "common/utils.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/util/wait_post.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TextFileOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| builder_rows_per_buffer_ = config_manager->rows_per_buffer(); | |||
| builder_worker_connector_size_ = config_manager->worker_connector_size(); | |||
| } | |||
| Status TextFileOp::Builder::ValidateInputs() const { | |||
| std::string err_msg; | |||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : ""; | |||
| err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; | |||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| RETURN_IF_NOT_OK(ValidateInputs()); | |||
| // Throttle the number of workers if we have more workers than files! | |||
| if (static_cast<size_t>(builder_num_workers_) > builder_text_files_list_.size()) { | |||
| builder_num_workers_ = builder_text_files_list_.size(); | |||
| MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK( | |||
| builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | |||
| std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, | |||
| builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(text_file_op->Init()); | |||
| *op = std::move(text_file_op); | |||
| return Status::OK(); | |||
| } | |||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_samples_(num_samples), | |||
| text_files_list_(std::move(text_files_list)), | |||
| shuffle_files_(shuffle_files), | |||
| data_schema_(std::move(schema)), | |||
| all_num_rows_(0), | |||
| num_rows_per_shard_(0), | |||
| filename_index_(std::make_unique<StringIndex>()), | |||
| finished_reading_dataset_(false), | |||
| load_io_block_queue_(true), | |||
| load_jagged_connector_(true) { | |||
| worker_connector_size_ = worker_connector_size; | |||
| } | |||
| Status TextFileOp::Init() { | |||
| RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); | |||
| int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1); | |||
| io_block_queues_.Init(num_workers_, safe_queue_size); | |||
| for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { | |||
| col_name_map_[data_schema_->column(i).name()] = i; | |||
| } | |||
| RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); | |||
| jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::Reset() { | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||
| NotifyToFillIOBlockQueue(); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) { | |||
| TensorRow tRow(1, nullptr); | |||
| (*tensor_table)->push_back(std::move(tRow)); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(), | |||
| TensorShape(std::vector<dsize_t>(1, line.size())), data_schema_->column(0).type(), | |||
| const_cast<unsigned char *>(reinterpret_cast<const unsigned char *>(common::SafeCStr(line))))); | |||
| (**tensor_table)[row][0] = std::move(tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id) { | |||
| std::ifstream handle(file); | |||
| if (!handle.is_open()) { | |||
| RETURN_STATUS_UNEXPECTED("Failed to open file " + file); | |||
| } | |||
| int64_t rows_each_buffer = 0; | |||
| int64_t rows_total = 0; | |||
| std::string line; | |||
| std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| cur_buffer->set_column_name_map(col_name_map_); | |||
| std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>(); | |||
| while (getline(handle, line)) { | |||
| // If read to the end offset of this file, break. | |||
| if (rows_total >= end_offset) { | |||
| break; | |||
| } | |||
| // Skip line before start offset. | |||
| if (rows_total < start_offset) { | |||
| rows_total++; | |||
| continue; | |||
| } | |||
| RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); | |||
| rows_each_buffer++; | |||
| rows_total++; | |||
| if (rows_each_buffer == rows_per_buffer_) { | |||
| cur_buffer->set_tensor_table(std::move(tensor_table)); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); | |||
| cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| cur_buffer->set_column_name_map(col_name_map_); | |||
| tensor_table = std::make_unique<TensorQTable>(); | |||
| rows_each_buffer = 0; | |||
| } | |||
| } | |||
| if (rows_each_buffer > 0) { | |||
| cur_buffer->set_tensor_table(std::move(tensor_table)); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<FilenameBlock> io_block; | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| while (!io_block->eof()) { | |||
| if (!io_block->eoe()) { | |||
| if (load_jagged_connector_) { | |||
| std::string filename; | |||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||
| int64_t start_offset = io_block->GetStartOffset(); | |||
| int64_t end_offset = io_block->GetEndOffset(); | |||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||
| } | |||
| } else { | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||
| } | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Pops an element from a queue in io_block_queues | |||
| Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||
| return Status::OK(); | |||
| } | |||
| // Pushes an element to a queue in io_block_queues | |||
| Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||
| return Status::OK(); | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| Status TextFileOp::PostEndOfData() { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||
| std::mt19937 rng(seed); | |||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||
| } | |||
| bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count) { | |||
| *start_offset = 0; | |||
| *end_offset = 0; | |||
| bool push = false; | |||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||
| if (device_id_ + 1 < 0) { | |||
| MS_LOG(ERROR) << "Device id is invalid"; | |||
| return false; | |||
| } | |||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||
| *start_offset = start_index - pre_count; | |||
| push = true; | |||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| if (pre_count >= start_index && pre_count < end_index) { | |||
| *start_offset = 0; | |||
| push = true; | |||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| return push; | |||
| } | |||
| Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||
| int32_t queue_index = 0; | |||
| int64_t pre_count = 0; | |||
| int64_t start_offset = 0; | |||
| int64_t end_offset = 0; | |||
| bool finish = false; | |||
| while (!finish) { | |||
| std::vector<std::pair<std::string, int64_t>> file_index; | |||
| if (!i_keys.empty()) { | |||
| for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| auto file_it = filename_index_->Search(*it); | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it)); | |||
| } | |||
| } else { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key())); | |||
| } | |||
| } | |||
| for (auto file_info : file_index) { | |||
| if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { | |||
| auto ioBlock = | |||
| std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); | |||
| queue_index = (queue_index + 1) % num_workers_; | |||
| } | |||
| pre_count += filename_numrows_[file_info.first]; | |||
| } | |||
| if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||
| finish = false; | |||
| } else { | |||
| finish = true; | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::WaitToFillIOBlockQueue() { | |||
| // must be called first if called by worker spanwed by taskgroup | |||
| TaskManager::FindMe()->Post(); | |||
| std::vector<int64_t> i_keys; | |||
| if (shuffle_files_) { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| i_keys.push_back(it.key()); | |||
| } | |||
| } | |||
| uint32_t seed = 0; | |||
| while (true) { | |||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||
| io_block_queue_wait_post_.Clear(); | |||
| if (finished_reading_dataset_) { | |||
| break; | |||
| } | |||
| if (shuffle_files_) { | |||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||
| } | |||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||
| Status TextFileOp::operator()() { | |||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| // launch one thread, responsible for filling IoBlockQueue | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); | |||
| // Read data from disk into buffers | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); | |||
| // must be called after launching workers. | |||
| TaskManager::FindMe()->Post(); | |||
| io_block_queue_wait_post_.Register(tree_->AllTasks()); | |||
| NotifyToFillIOBlockQueue(); | |||
| while (!finished_reading_dataset_) { | |||
| int64_t buffer_id = 0; | |||
| int32_t workers_done = 0; | |||
| int64_t rows_read = 0; | |||
| load_io_block_queue_ = true; | |||
| while (workers_done < num_workers_) { | |||
| std::unique_ptr<DataBuffer> buffer; | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||
| if (buffer->eoe()) { | |||
| workers_done++; | |||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||
| } | |||
| rows_read += buffer->NumRows(); | |||
| buffer->set_id(buffer_id++); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||
| } else { | |||
| // end of epoch | |||
| load_jagged_connector_ = false; | |||
| load_io_block_queue_ = false; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||
| finished_reading_dataset_ = true; | |||
| NotifyToFillIOBlockQueue(); | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||
| return Status::OK(); | |||
| } | |||
| int64_t TextFileOp::CountTotalRows(const std::string &file) { | |||
| std::ifstream handle(file); | |||
| if (!handle.is_open()) { | |||
| MS_LOG(ERROR) << "Failed to open file: " << file; | |||
| return 0; | |||
| } | |||
| std::string line; | |||
| int64_t count = 0; | |||
| while (getline(handle, line)) { | |||
| count++; | |||
| } | |||
| return count; | |||
| } | |||
| Status TextFileOp::CalculateNumRowsPerShard() { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| int64_t count = CountTotalRows(it.value()); | |||
| filename_numrows_[it.value()] = count; | |||
| all_num_rows_ += count; | |||
| } | |||
| if (all_num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Number of rows can not be zero"); | |||
| } | |||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | |||
| return Status::OK(); | |||
| } | |||
| Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) { | |||
| std::shared_ptr<TextFileOp> op; | |||
| *count = 0; | |||
| RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); | |||
| for (auto file : files) { | |||
| *count += op->CountTotalRows(file); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,263 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ | |||
| #include <memory> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/util/auto_index.h" | |||
| #include "dataset/engine/data_schema.h" | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| #include "dataset/util/queue.h" | |||
| #include "dataset/util/wait_post.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using StringIndex = AutoIndexObj<std::string>; | |||
| class TextFileOp : public ParallelOp { | |||
| public: | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Checks if the inputs of the builder is valid. | |||
| // @return Status - the error code returned. | |||
| Status ValidateInputs() const; | |||
| // Create the final object. | |||
| // @param op - dataset op. | |||
| // @return - the error code return. | |||
| Status Build(std::shared_ptr<TextFileOp> *op); | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| builder_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||
| builder_op_connector_size_ = op_connector_size; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { | |||
| builder_rows_per_buffer_ = rows_per_buffer; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumDevices(int64_t num_dev) { | |||
| builder_num_devices_ = num_dev; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetDeviceId(int64_t dev_id) { | |||
| builder_device_id_ = dev_id; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetTextFilesList(const std::vector<std::string> &files_list) { | |||
| builder_text_files_list_ = files_list; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShuffleFiles(bool shuffle_files) { | |||
| builder_shuffle_files_ = shuffle_files; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int64_t builder_rows_per_buffer_; | |||
| int64_t builder_num_samples_; | |||
| int32_t builder_worker_connector_size_; | |||
| std::vector<std::string> builder_text_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| }; | |||
| // Constructor of TextFileOp | |||
| // @note The builder class should be used to call this constructor. | |||
| // @param num_workers - number of worker threads reading data from tf_file files. | |||
| // @param rows_per_buffer - number of rows that a full buffer will contain. | |||
| // @param total_num_rows - number of rows to read | |||
| // @param dataset_files_list - list of filepaths for the dataset files. | |||
| // @param data_schema - the data schema object. | |||
| // @param op_connector_size - size of each queue in the connector that the child operator pulls from. | |||
| // @param columns_to_load - the names of the columns to load data from. | |||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~TextFileOp() = default; | |||
| // Instantiates the internal queues and connectors | |||
| // @return Status - the error code returned | |||
| Status Init(); | |||
| // Class functor operator () override. | |||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - the error code returned. | |||
| Status operator()() override; | |||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||
| // @return Status - the error code returned. | |||
| Status Reset() override; | |||
| // Get total rows in files. | |||
| // @param files - all text files. | |||
| // @param count - number of rows. | |||
| // @return Status - the error coed returned. | |||
| static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count); | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Parses a single row and puts the data into a tensor table. | |||
| // @param line - the content of the row. | |||
| // @param tensor_table - the tensor table to put the parsed data in. | |||
| // @param row - the id of the row filled in the tensor table. | |||
| // @return Status - the error code returned. | |||
| Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row); | |||
| // Reads a text file and loads the data into multiple buffers. | |||
| // @param file - the file to read. | |||
| // @param start_offset - the start offset of file. | |||
| // @param end_offset - the end offset of file. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id); | |||
| // Calculate number of rows in each shard. | |||
| // @return Status - the error code returned. | |||
| Status CalculateNumRowsPerShard(); | |||
| // Count number of rows in each file. | |||
| // @param filename - text file name. | |||
| // @return int64_t - the total number of rows in file. | |||
| int64_t CountTotalRows(const std::string &file); | |||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||
| void NotifyToFillIOBlockQueue(); | |||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||
| // @return Status - the error code returned. | |||
| Status WaitToFillIOBlockQueue(); | |||
| // Fill the IOBlockQueue. | |||
| // @para i_keys - keys of file to fill to the IOBlockQueue | |||
| // @return Status - the error code returned. | |||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||
| // Select file and push it to the block queue. | |||
| // @param file_name - File name. | |||
| // @param start_file - If file contains the first sample of data. | |||
| // @param end_file - If file contains the end sample of data. | |||
| // @param pre_count - Total rows of previous files. | |||
| // @return Status - the error code returned. | |||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count); | |||
| // Pops an element from a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to pop from. | |||
| // @param out_block - the popped element. | |||
| // @return Status - the error code returned. | |||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||
| // Pushes an element to a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to push to. | |||
| // @param io_block - the element to push onto the queue. | |||
| // @return Status - the error code returned. | |||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfData(); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfEpoch(int32_t queue_index); | |||
| int32_t device_id_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| int64_t num_samples_; | |||
| std::vector<std::string> text_files_list_; | |||
| bool shuffle_files_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| int64_t all_num_rows_; | |||
| int64_t num_rows_per_shard_; | |||
| std::map<std::string, int64_t> filename_numrows_; | |||
| std::unique_ptr<StringIndex> filename_index_; | |||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||
| WaitPost io_block_queue_wait_post_; | |||
| bool finished_reading_dataset_; | |||
| bool load_io_block_queue_; | |||
| bool load_jagged_connector_; | |||
| std::unordered_map<std::string, int32_t> col_name_map_; | |||
| std::unique_ptr<JaggedConnector> jagged_buffer_connector_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ | |||
| @@ -20,8 +20,8 @@ can also create samplers with this module to sample data. | |||
| from .core.configuration import config | |||
| from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \ | |||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \ | |||
| Shuffle, zip | |||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ | |||
| Schema, Shuffle, zip | |||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | |||
| WeightedRandomSampler | |||
| from .engine.serializer_deserializer import serialize, deserialize, show | |||
| @@ -29,5 +29,5 @@ from .engine.serializer_deserializer import serialize, deserialize, show | |||
| __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | |||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | |||
| "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", | |||
| "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", | |||
| "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"] | |||
| @@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset", | |||
| "ImageFolderDatasetV2", "MnistDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | |||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | |||
| "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", | |||
| "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] | |||
| "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", | |||
| "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] | |||
| @@ -29,7 +29,7 @@ from importlib import import_module | |||
| import numpy as np | |||
| from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | |||
| MindRecordOp, CBatchInfo | |||
| MindRecordOp, TextFileOp, CBatchInfo | |||
| from mindspore._c_expression import typing | |||
| from mindspore import log as logger | |||
| @@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator | |||
| from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | |||
| check_zip_dataset, check_add_column | |||
| check_zip_dataset, check_add_column, check_textfiledataset | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -888,6 +888,29 @@ class SourceDataset(Dataset): | |||
| # No need for __init__ since it is the same as the super's init | |||
| @staticmethod | |||
| def _find_files(patterns): | |||
| """ | |||
| Utility function to search for files with the given glob patterns. | |||
| Args: | |||
| patterns (str or list[str]): string or list of patterns to be searched. | |||
| Returns: | |||
| List, files. | |||
| """ | |||
| def flat(lists): | |||
| return list(np.array(lists).flatten()) | |||
| if not isinstance(patterns, list): | |||
| patterns = [patterns] | |||
| file_list = flat([glob.glob(file, recursive=True) for file in patterns]) | |||
| if file_list: # not empty | |||
| return file_list | |||
| raise ValueError("The list of path names matching the patterns is empty.") | |||
| class DatasetOp(Dataset): | |||
| """ | |||
| @@ -2126,30 +2149,6 @@ class TFRecordDataset(SourceDataset): | |||
| >>> # 3) get all rows from dataset_files with schema file "./schema.json": | |||
| >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json") | |||
| """ | |||
| @staticmethod | |||
| def _find_files(patterns): | |||
| """ | |||
| Utility function to search for files with the given glob patterns. | |||
| Args: | |||
| patterns (str or list[str]): string or list of patterns to be searched. | |||
| Returns: | |||
| List, files. | |||
| """ | |||
| def flat(lists): | |||
| return list(np.array(lists).flatten()) | |||
| if not isinstance(patterns, list): | |||
| patterns = [patterns] | |||
| file_list = flat([glob.glob(file, recursive=True) for file in patterns]) | |||
| if file_list: # not empty | |||
| return file_list | |||
| raise ValueError("The list of path names matching the patterns is empty.") | |||
| @check_tfrecorddataset | |||
| def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, | |||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): | |||
| @@ -2952,3 +2951,82 @@ class CelebADataset(SourceDataset): | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| return args | |||
| class TextFileDataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses datasets stored on disk in text format. | |||
| The generated dataset has one columns ['text']. | |||
| Args: | |||
| dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of | |||
| files. The list will be sorted in a lexicographical order. | |||
| num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). | |||
| num_parallel_workers (int, optional): number of workers to read the data | |||
| (default=None, number set in the config). | |||
| shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). | |||
| If shuffle is False, no shuffling will be performed; | |||
| If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL | |||
| Otherwise, there are two levels of shuffling: | |||
| - Shuffle.GLOBAL: Shuffle both the files and samples. | |||
| - Shuffle.FILES: Shuffle files only. | |||
| num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument should be specified only when num_shards is also specified. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files | |||
| >>> dataset = ds.TextFileDataset(dataset_files=dataset_files) | |||
| """ | |||
| @check_textfiledataset | |||
| def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, | |||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_files = self._find_files(dataset_files) | |||
| self.dataset_files.sort() | |||
| self.num_samples = num_samples | |||
| if not isinstance(shuffle, (bool, Shuffle)): | |||
| raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") | |||
| if not isinstance(shuffle, Shuffle): | |||
| if shuffle: | |||
| self.shuffle_level = Shuffle.GLOBAL | |||
| self.shuffle_files = True | |||
| else: | |||
| self.shuffle_level = None | |||
| self.shuffle_files = False | |||
| else: | |||
| self.shuffle_level = shuffle | |||
| self.shuffle_files = True | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["dataset_files"] = self.dataset_files | |||
| args["num_samples"] = self.num_samples | |||
| if self.shuffle_files is not None: | |||
| args["shuffle_files"] = self.shuffle_files | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| return args | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self._dataset_size is None: | |||
| num_rows = TextFileOp.get_num_rows(self.dataset_files) | |||
| num_rows = get_num_rows(num_rows, self.num_shards) | |||
| if self.num_samples is None: | |||
| return num_rows | |||
| return min(self.num_samples, num_rows) | |||
| return self._dataset_size | |||
| @@ -48,12 +48,16 @@ def alter_tree(node): | |||
| def _alter_node(node): | |||
| """Performing some alteration to a dataset node. A common alteration is to insert a node.""" | |||
| if isinstance(node, de.TFRecordDataset) and node.shuffle_level == de.Shuffle.GLOBAL: | |||
| if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL: | |||
| # Remove the connection between the parent's node to the current node because we are inserting a node. | |||
| if node.output: | |||
| node.output.pop() | |||
| # Perform a fast scan for average rows per file | |||
| avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) | |||
| if isinstance(node, de.TFRecordDataset): | |||
| avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) | |||
| else: | |||
| avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files) | |||
| # Shuffle between 4 files with a minimum size of 10000 rows | |||
| new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000)) | |||
| return new_shuffle | |||
| @@ -157,6 +161,8 @@ class Iterator: | |||
| op_type = OpName.CIFAR100 | |||
| elif isinstance(dataset, de.CelebADataset): | |||
| op_type = OpName.CELEBA | |||
| elif isinstance(dataset, de.TextFileDataset): | |||
| op_type = OpName.TEXTFILE | |||
| else: | |||
| raise ValueError("Unsupported DatasetOp") | |||
| @@ -849,3 +849,25 @@ def check_add_column(method): | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_textfiledataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| # check dataset_files; required argument | |||
| dataset_files = param_dict.get('dataset_files') | |||
| if dataset_files is None: | |||
| raise ValueError("dataset_files is not provided.") | |||
| if not isinstance(dataset_files, (str, list)): | |||
| raise TypeError("dataset_files should be of type str or a list of strings.") | |||
| check_param_type(nreq_param_int, param_dict, int) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| This module is to support nlp augmentations. It includes two parts: | |||
| c_transforms and py_transforms. C_transforms is a high performance | |||
| image augmentation module which is developed with c++ opencv. Py_transforms | |||
| provide more kinds of image augmentations which is developed with python PIL. | |||
| """ | |||
| from .utils import as_text | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Some basic function for nlp | |||
| """ | |||
| import numpy as np | |||
| def as_text(array, encoding='utf8'): | |||
| """ | |||
| Convert data of array to unicode. | |||
| Args: | |||
| array (numpy array): Data of array should be ASCII values of each character after converted. | |||
| encoding (string): Indicating the charset for decoding. | |||
| Returns: | |||
| A 'str' object. | |||
| """ | |||
| if not isinstance(array, np.ndarray): | |||
| raise ValueError('input should be a numpy array') | |||
| byte_array = bytearray(list(array)) | |||
| return byte_array.decode(encoding) | |||
| @@ -65,7 +65,7 @@ SET(DE_UT_SRCS | |||
| cifar_op_test.cc | |||
| celeba_op_test.cc | |||
| take_op_test.cc | |||
| ) | |||
| text_file_op_test.cc) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -0,0 +1,112 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "common/utils.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace common = mindspore::common; | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestTextFileOp : public UT::DatasetOpTesting { | |||
| }; | |||
| TEST_F(MindDataTestTextFileOp, TestTextFileBasic) { | |||
| // Start with an empty execution tree | |||
| auto tree = std::make_shared<ExecutionTree>(); | |||
| std::string dataset_path; | |||
| dataset_path = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||
| std::shared_ptr<TextFileOp> op; | |||
| TextFileOp::Builder builder; | |||
| builder.SetTextFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetNumWorkers(16) | |||
| .SetOpConnectorSize(2); | |||
| Status rc = builder.Build(&op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->AssociateNode(op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->AssignRoot(op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration."; | |||
| rc = tree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(tree); | |||
| TensorRow tensor_list; | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int row_count = 0; | |||
| while (!tensor_list.empty()) { | |||
| // Display the tensor by calling the printer on it | |||
| for (int i = 0; i < tensor_list.size(); i++) { | |||
| std::ostringstream ss; | |||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||
| MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; | |||
| } | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| row_count++; | |||
| } | |||
| ASSERT_EQ(row_count, 3); | |||
| } | |||
| TEST_F(MindDataTestTextFileOp, TestTotalRows) { | |||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||
| std::vector<std::string> files; | |||
| files.push_back(tf_file1); | |||
| int64_t total_rows = 0; | |||
| TextFileOp::CountAllFileRows(files, &total_rows); | |||
| ASSERT_EQ(total_rows, 3); | |||
| files.clear(); | |||
| files.push_back(tf_file2); | |||
| TextFileOp::CountAllFileRows(files, &total_rows); | |||
| ASSERT_EQ(total_rows, 2); | |||
| files.clear(); | |||
| files.push_back(tf_file1); | |||
| files.push_back(tf_file2); | |||
| TextFileOp::CountAllFileRows(files, &total_rows); | |||
| ASSERT_EQ(total_rows, 5); | |||
| files.clear(); | |||
| } | |||
| @@ -0,0 +1,3 @@ | |||
| This is a text file. | |||
| Be happy every day. | |||
| Good luck to everyone. | |||
| @@ -0,0 +1,2 @@ | |||
| Another file. | |||
| End of file. | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| import mindspore.dataset.transforms.nlp.utils as nlp | |||
| DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" | |||
| DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" | |||
| def test_textline_dataset_one_file(): | |||
| data = ds.TextFileDataset(DATA_FILE) | |||
| count = 0 | |||
| for i in data.create_dict_iterator(): | |||
| logger.info("{}".format(i["text"])) | |||
| count += 1 | |||
| assert(count == 3) | |||
| def test_textline_dataset_all_file(): | |||
| data = ds.TextFileDataset(DATA_ALL_FILE) | |||
| count = 0 | |||
| for i in data.create_dict_iterator(): | |||
| logger.info("{}".format(i["text"])) | |||
| count += 1 | |||
| assert(count == 5) | |||
| def test_textline_dataset_totext(): | |||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) | |||
| count = 0 | |||
| line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] | |||
| for i in data.create_dict_iterator(): | |||
| str = nlp.as_text(i["text"]) | |||
| assert(str == line[count]) | |||
| count += 1 | |||
| assert(count == 5) | |||
| def test_textline_dataset_num_samples(): | |||
| data = ds.TextFileDataset(DATA_FILE, num_samples=2) | |||
| count = 0 | |||
| for i in data.create_dict_iterator(): | |||
| count += 1 | |||
| assert(count == 2) | |||
| def test_textline_dataset_distribution(): | |||
| data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1) | |||
| count = 0 | |||
| for i in data.create_dict_iterator(): | |||
| count += 1 | |||
| assert(count == 3) | |||
| def test_textline_dataset_repeat(): | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.repeat(3) | |||
| count = 0 | |||
| line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", | |||
| "This is a text file.", "Be happy every day.", "Good luck to everyone.", | |||
| "This is a text file.", "Be happy every day.", "Good luck to everyone."] | |||
| for i in data.create_dict_iterator(): | |||
| str = nlp.as_text(i["text"]) | |||
| assert(str == line[count]) | |||
| count += 1 | |||
| assert(count == 9) | |||
| def test_textline_dataset_get_datasetsize(): | |||
| data = ds.TextFileDataset(DATA_FILE) | |||
| size = data.get_dataset_size() | |||
| assert(size == 3) | |||
| if __name__ == "__main__": | |||
| test_textline_dataset_one_file() | |||
| test_textline_dataset_all_file() | |||
| test_textline_dataset_totext() | |||
| test_textline_dataset_num_samples() | |||
| test_textline_dataset_distribution() | |||
| test_textline_dataset_repeat() | |||
| test_textline_dataset_get_datasetsize() | |||