Merge pull request !1932 from jiangzhiwen/dataset/cluetags/v0.5.0-beta
| @@ -31,6 +31,7 @@ | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "mindrecord/include/shard_category.h" | |||
| #include "mindrecord/include/shard_distributed_sample.h" | |||
| @@ -72,7 +73,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | |||
| {kRandomData, &DEPipeline::ParseRandomDataOp}, | |||
| {kTextFile, &DEPipeline::ParseTextFileOp}, | |||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}}; | |||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | |||
| {kClue, &DEPipeline::ParseClueOp}}; | |||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | |||
| try { | |||
| @@ -1210,6 +1212,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { | |||
| for (auto p : py::reinterpret_borrow<py::dict>(value)) { | |||
| if (!p.second.is_none()) { | |||
| @@ -1236,6 +1239,7 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<BuildVocabOp::Builder> builder = std::make_shared<BuildVocabOp::Builder>(); | |||
| for (auto arg : args) { | |||
| @@ -1267,5 +1271,45 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| (void)builder->SetClueFilesList(ToStringVector(args["dataset_files"])); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); | |||
| } | |||
| // Optional arguments | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_shards") { | |||
| (void)builder->SetNumDevices(ToInt(value)); | |||
| } else if (key == "shard_id") { | |||
| (void)builder->SetDeviceId(ToInt(value)); | |||
| } else if (key == "cols_to_keyword") { | |||
| std::map<std::string, std::string> map_dict; | |||
| for (auto p : py::reinterpret_borrow<py::dict>(value)) { | |||
| if (!p.second.is_none()) { | |||
| map_dict.insert({ToString(p.first), ToString(p.second)}); | |||
| } else { | |||
| map_dict.insert({ToString(p.first), ToString(p.first)}); | |||
| } | |||
| } | |||
| (void)builder->SetColsKeyMap(map_dict); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<ClueOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -64,7 +64,8 @@ enum OpName { | |||
| kCelebA, | |||
| kRandomData, | |||
| kTextFile, | |||
| kBuildVocab | |||
| kBuildVocab, | |||
| kClue | |||
| }; | |||
| // The C++ binder class that we expose to the python script. | |||
| @@ -166,6 +167,8 @@ class DEPipeline { | |||
| Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| private: | |||
| // Execution tree that links the dataset operators. | |||
| std::shared_ptr<ExecutionTree> tree_; | |||
| @@ -55,6 +55,7 @@ | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #include "dataset/engine/datasetops/source/coco_op.h" | |||
| #include "dataset/engine/gnn/graph.h" | |||
| @@ -201,6 +202,18 @@ void bindDatasetOps(py::module *m) { | |||
| THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); | |||
| return count; | |||
| }); | |||
| (void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp") | |||
| .def_static("get_num_rows", [](const py::list &files) { | |||
| int64_t count = 0; | |||
| std::vector<std::string> filenames; | |||
| for (auto file : files) { | |||
| file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); | |||
| } | |||
| THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); | |||
| return count; | |||
| }); | |||
| (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | |||
| .def_static("get_num_rows", | |||
| [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| @@ -629,7 +642,8 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("RANDOMDATA", OpName::kRandomData) | |||
| .value("BUILDVOCAB", OpName::kBuildVocab) | |||
| .value("CELEBA", OpName::kCelebA) | |||
| .value("TEXTFILE", OpName::kTextFile); | |||
| .value("TEXTFILE", OpName::kTextFile) | |||
| .value("CLUE", OpName::kClue); | |||
| (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) | |||
| .value("DE_JIEBA_MIX", JiebaMode::kMix) | |||
| @@ -19,4 +19,5 @@ add_library(engine-datasetops-source OBJECT | |||
| random_data_op.cc | |||
| celeba_op.cc | |||
| text_file_op.cc | |||
| clue_op.cc | |||
| ) | |||
| @@ -0,0 +1,551 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <fstream> | |||
| #include <iomanip> | |||
| #include <utility> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| #include "dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ClueOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| builder_rows_per_buffer_ = config_manager->rows_per_buffer(); | |||
| builder_worker_connector_size_ = config_manager->worker_connector_size(); | |||
| } | |||
| Status ClueOp::Builder::ValidateInputs() const { | |||
| std::string err; | |||
| err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; | |||
| err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; | |||
| return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); | |||
| } | |||
| Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) { | |||
| RETURN_IF_NOT_OK(ValidateInputs()); | |||
| // Throttle the number of workers if we have more workers than files! | |||
| if (static_cast<size_t>(builder_num_workers_) > builder_clue_files_list_.size()) { | |||
| builder_num_workers_ = builder_clue_files_list_.size(); | |||
| MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers."; | |||
| } | |||
| ColKeyMap ck_map; | |||
| for (auto &p : builder_cols_to_keyword_) { | |||
| ck_map.insert({p.first, split(p.second, '/')}); | |||
| } | |||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, | |||
| builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, | |||
| builder_device_id_); | |||
| RETURN_IF_NOT_OK(clue_op->Init()); | |||
| *op = std::move(clue_op); | |||
| return Status::OK(); | |||
| } | |||
| std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim) { | |||
| std::vector<std::string> res; | |||
| std::stringstream ss(s); | |||
| std::string item; | |||
| while (getline(ss, item, delim)) { | |||
| res.push_back(item); | |||
| } | |||
| return res; | |||
| } | |||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_rows_per_shard_(0), | |||
| all_num_rows_(0), | |||
| num_samples_(num_samples), | |||
| filename_index_(std::make_unique<StringIndex>()), | |||
| clue_files_list_(std::move(clue_files_list)), | |||
| load_jagged_connector_(true), | |||
| cols_to_keyword_(cols_to_keyword), | |||
| shuffle_files_(shuffle_files), | |||
| finished_reading_dataset_(false), | |||
| num_devices_(num_device), | |||
| device_id_(device_id), | |||
| load_io_block_queue_(true) { | |||
| worker_connector_size_ = worker_connector_size; | |||
| } | |||
| Status ClueOp::Init() { | |||
| RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); | |||
| int32_t safe_queue_size = static_cast<int32_t>(std::ceil(clue_files_list_.size() / num_workers_) + 1); | |||
| io_block_queues_.Init(num_workers_, safe_queue_size); | |||
| // Set the column name mapping (base class field) | |||
| int count = 0; | |||
| for (auto &p : cols_to_keyword_) { | |||
| column_name_id_map_[p.first] = count; | |||
| count++; | |||
| } | |||
| RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); | |||
| jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::Reset() { | |||
| load_jagged_connector_ = true; | |||
| load_io_block_queue_ = true; | |||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||
| NotifyToFillIOBlockQueue(); | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) { | |||
| TensorRow tRow(1, nullptr); | |||
| (*tensor_table)->push_back(std::move(tRow)); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); | |||
| (**tensor_table)[row][0] = std::move(tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) { | |||
| nlohmann::json cursor = js; | |||
| for (int i = 0; i < key_chain.size(); i++) { | |||
| if (cursor.find(key_chain[i]) != cursor.end()) { | |||
| cursor = cursor[key_chain[i]]; | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]); | |||
| } | |||
| } | |||
| std::string final_str = key_chain.back(); | |||
| switch (cursor.type()) { | |||
| case nlohmann::detail::value_t::string: | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get<std::string>()}, TensorShape::CreateScalar())); | |||
| break; | |||
| case nlohmann::detail::value_t::number_integer: | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); | |||
| (*t)->SetItemAt<int32_t>({0}, cursor.get<int32_t>()); | |||
| break; | |||
| case nlohmann::detail::value_t::number_unsigned: | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); | |||
| (*t)->SetItemAt<int32_t>({0}, cursor.get<uint32_t>()); | |||
| break; | |||
| case nlohmann::detail::value_t::number_float: | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); | |||
| (*t)->SetItemAt<int32_t>({0}, cursor.get<float>()); | |||
| break; | |||
| case nlohmann::detail::value_t::array: | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get<std::vector<std::string>>()}, TensorShape::CreateScalar())); | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id) { | |||
| std::ifstream handle(file); | |||
| if (!handle.is_open()) { | |||
| RETURN_STATUS_UNEXPECTED("Failed to open file " + file); | |||
| } | |||
| int64_t rows_each_buffer = 0; | |||
| int64_t rows_total = 0; | |||
| std::string line; | |||
| std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>(); | |||
| while (getline(handle, line)) { | |||
| if (line.empty()) { | |||
| continue; | |||
| } | |||
| // If read to the end offset of this file, break. | |||
| if (rows_total >= end_offset) { | |||
| break; | |||
| } | |||
| // Skip line before start offset. | |||
| if (rows_total < start_offset) { | |||
| rows_total++; | |||
| continue; | |||
| } | |||
| try { | |||
| nlohmann::json js = nlohmann::json::parse(line); | |||
| int cols_count = cols_to_keyword_.size(); | |||
| TensorRow tRow(cols_count, nullptr); | |||
| tensor_table->push_back(std::move(tRow)); | |||
| int cout = 0; | |||
| for (auto &p : cols_to_keyword_) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor)); | |||
| (*tensor_table)[rows_each_buffer][cout] = std::move(tensor); | |||
| cout++; | |||
| } | |||
| } catch (const std::exception &err) { | |||
| // Catch any exception and convert to Status return code | |||
| RETURN_STATUS_UNEXPECTED("Failed to load json file"); | |||
| } | |||
| // RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); | |||
| rows_each_buffer++; | |||
| rows_total++; | |||
| if (rows_each_buffer == rows_per_buffer_) { | |||
| cur_buffer->set_tensor_table(std::move(tensor_table)); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); | |||
| cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | |||
| tensor_table = std::make_unique<TensorQTable>(); | |||
| rows_each_buffer = 0; | |||
| } | |||
| } | |||
| if (rows_each_buffer > 0) { | |||
| cur_buffer->set_tensor_table(std::move(tensor_table)); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::operator()() { | |||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| // launch one thread, responsible for filling IoBlockQueue | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1))); | |||
| // must be called after launching workers. | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||
| NotifyToFillIOBlockQueue(); | |||
| while (!finished_reading_dataset_) { | |||
| int64_t buffer_id = 0; | |||
| int32_t workers_done = 0; | |||
| int64_t rows_read = 0; | |||
| load_io_block_queue_ = true; | |||
| while (workers_done < num_workers_) { | |||
| std::unique_ptr<DataBuffer> buffer; | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||
| if (buffer->eoe()) { | |||
| workers_done++; | |||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||
| } | |||
| rows_read += buffer->NumRows(); | |||
| buffer->set_id(buffer_id++); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||
| } else { | |||
| // end of epoch | |||
| load_jagged_connector_ = false; | |||
| load_io_block_queue_ = false; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||
| finished_reading_dataset_ = true; | |||
| NotifyToFillIOBlockQueue(); | |||
| } else { | |||
| jagged_buffer_connector_->DoReset(); | |||
| buffer_id = 0; | |||
| } | |||
| } | |||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<FilenameBlock> io_block; | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| while (!io_block->eof()) { | |||
| if (!io_block->eoe()) { | |||
| if (load_jagged_connector_) { | |||
| std::string filename; | |||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||
| int64_t start_offset = io_block->GetStartOffset(); | |||
| int64_t end_offset = io_block->GetEndOffset(); | |||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||
| } | |||
| } else { | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||
| } | |||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // A print method typically used for debugging | |||
| void ClueOp::Print(std::ostream &out, bool show_all) const { | |||
| // Always show the id and name as first line regardless if this summary or detailed print | |||
| out << "(" << std::setw(2) << operator_id_ << ") <ClueOp>:"; | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ | |||
| << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | |||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; | |||
| for (int i = 0; i < clue_files_list_.size(); ++i) { | |||
| out << " " << clue_files_list_[i]; | |||
| } | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| // Pops an element from a queue in io_block_queues | |||
| Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||
| return Status::OK(); | |||
| } | |||
| // Pushes an element to a queue in io_block_queues | |||
| Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||
| return Status::OK(); | |||
| } | |||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||
| std::mt19937 rng(seed); | |||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||
| } | |||
| Status ClueOp::WaitToFillIOBlockQueue() { | |||
| // must be called first if called by worker spanwed by taskgroup | |||
| TaskManager::FindMe()->Post(); | |||
| std::vector<int64_t> i_keys; | |||
| if (shuffle_files_) { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| i_keys.push_back(it.key()); | |||
| } | |||
| } | |||
| uint32_t seed = 0; | |||
| while (true) { | |||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||
| io_block_queue_wait_post_.Clear(); | |||
| if (finished_reading_dataset_) { | |||
| break; | |||
| } | |||
| if (shuffle_files_) { | |||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||
| } | |||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||
| int32_t queue_index = 0; | |||
| int64_t pre_count = 0; | |||
| int64_t start_offset = 0; | |||
| int64_t end_offset = 0; | |||
| bool finish = false; | |||
| while (!finish) { | |||
| std::vector<std::pair<std::string, int64_t>> file_index; | |||
| if (!i_keys.empty()) { | |||
| for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| auto file_it = filename_index_->Search(*it); | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it)); | |||
| } | |||
| } else { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key())); | |||
| } | |||
| } | |||
| for (auto file_info : file_index) { | |||
| if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { | |||
| auto ioBlock = | |||
| std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); | |||
| queue_index = (queue_index + 1) % num_workers_; | |||
| } | |||
| pre_count += filename_numrows_[file_info.first]; | |||
| } | |||
| if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||
| finish = false; | |||
| } else { | |||
| finish = true; | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); | |||
| return Status::OK(); | |||
| } | |||
| void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||
| bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count) { | |||
| *start_offset = 0; | |||
| *end_offset = 0; | |||
| bool push = false; | |||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||
| if (device_id_ + 1 < 0) { | |||
| MS_LOG(ERROR) << "Device id is invalid"; | |||
| return false; | |||
| } | |||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||
| *start_offset = start_index - pre_count; | |||
| push = true; | |||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| if (pre_count >= start_index && pre_count < end_index) { | |||
| *start_offset = 0; | |||
| push = true; | |||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||
| *end_offset = end_index - pre_count; | |||
| } else { | |||
| *end_offset = filename_numrows_[file_name]; | |||
| } | |||
| } | |||
| return push; | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| Status ClueOp::PostEndOfEpoch(int32_t queue_index) { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::CalculateNumRowsPerShard() { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| int64_t count = CountTotalRows(it.value()); | |||
| filename_numrows_[it.value()] = count; | |||
| all_num_rows_ += count; | |||
| } | |||
| if (all_num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t ClueOp::CountTotalRows(const std::string &file) { | |||
| std::ifstream handle(file); | |||
| if (!handle.is_open()) { | |||
| MS_LOG(ERROR) << "Failed to open file: " << file; | |||
| return 0; | |||
| } | |||
| std::string line; | |||
| int64_t count = 0; | |||
| while (getline(handle, line)) { | |||
| if (!line.empty()) { | |||
| count++; | |||
| } | |||
| } | |||
| return count; | |||
| } | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| Status ClueOp::PostEndOfData() { | |||
| for (int i = 0; i < num_workers_; ++i) { | |||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) { | |||
| std::shared_ptr<ClueOp> op; | |||
| *count = 0; | |||
| RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op)); | |||
| for (auto file : files) { | |||
| *count += op->CountTotalRows(file); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,270 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ | |||
| #include <memory> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <nlohmann/json.hpp> | |||
| #include "dataset/util/auto_index.h" | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using StringIndex = AutoIndexObj<std::string>; | |||
| using ColKeyMap = std::map<std::string, std::vector<std::string>>; | |||
| class JaggedConnector; | |||
| class ClueOp : public ParallelOp { | |||
| public: | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Checks if the inputs of the builder is valid. | |||
| // @return Status - the error code returned. | |||
| Status ValidateInputs() const; | |||
| // Create the final object. | |||
| // @param op - dataset op. | |||
| // @return - the error code return. | |||
| Status Build(std::shared_ptr<ClueOp> *op); | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| builder_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||
| builder_op_connector_size_ = op_connector_size; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { | |||
| builder_rows_per_buffer_ = rows_per_buffer; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumDevices(int64_t num_dev) { | |||
| builder_num_devices_ = num_dev; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetDeviceId(int64_t dev_id) { | |||
| builder_device_id_ = dev_id; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetClueFilesList(const std::vector<std::string> &files_list) { | |||
| builder_clue_files_list_ = files_list; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShuffleFiles(bool shuffle_files) { | |||
| builder_shuffle_files_ = shuffle_files; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetColsKeyMap(const std::map<std::string, std::string> &cols_to_key) { | |||
| builder_cols_to_keyword_ = cols_to_key; | |||
| return *this; | |||
| } | |||
| // Split string based on a character delimiter | |||
| // @return - the a string vector | |||
| std::vector<std::string> split(const std::string &s, char delim); | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int64_t builder_rows_per_buffer_; | |||
| int64_t builder_num_samples_; | |||
| int32_t builder_worker_connector_size_; | |||
| std::vector<std::string> builder_clue_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::map<std::string, std::string> builder_cols_to_keyword_; | |||
| }; | |||
| // Constructor of ClueOp | |||
| ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~ClueOp() = default; | |||
| // A print method typically used for debugging | |||
| // @param out - The output stream to write output to | |||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Instantiates the internal queues and connectors | |||
| // @return Status - the error code returned | |||
| Status Init(); | |||
| // Class functor operator () override. | |||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - the error code returned. | |||
| Status operator()() override; | |||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||
| // @return Status - the error code returned. | |||
| Status Reset() override; | |||
| // Get total rows in files. | |||
| // @param files - all clue files. | |||
| // @param count - number of rows. | |||
| // @return Status - the error coed returned. | |||
| static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count); | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Parses a single row and puts the data into a tensor table. | |||
| // @param line - the content of the row. | |||
| // @param tensor_table - the tensor table to put the parsed data in. | |||
| // @param row - the id of the row filled in the tensor table. | |||
| // @return Status - the error code returned. | |||
| Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row); | |||
| // Reads a clue file and loads the data into multiple buffers. | |||
| // @param file - the file to read. | |||
| // @param start_offset - the start offset of file. | |||
| // @param end_offset - the end offset of file. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| // @return Status - the error code returned. | |||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||
| const int32_t worker_id); | |||
| // Pops an element from a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to pop from. | |||
| // @param out_block - the popped element. | |||
| // @return Status - the error code returned. | |||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||
| // Pushes an element to a queue in IOBlockQueue. | |||
| // @param index - the index of the queue to push to. | |||
| // @param io_block - the element to push onto the queue. | |||
| // @return Status - the error code returned. | |||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||
| // @return Status - the error code returned. | |||
| Status WaitToFillIOBlockQueue(); | |||
| // Fill the IOBlockQueue. | |||
| // @para i_keys - keys of file to fill to the IOBlockQueue | |||
| // @return Status - the error code returned. | |||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||
| void NotifyToFillIOBlockQueue(); | |||
| // Select file and push it to the block queue. | |||
| // @param file_name - File name. | |||
| // @param start_file - If file contains the first sample of data. | |||
| // @param end_file - If file contains the end sample of data. | |||
| // @param pre_count - Total rows of previous files. | |||
| // @return Status - the error code returned. | |||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfEpoch(int32_t queue_index); | |||
| // Calculate number of rows in each shard. | |||
| // @return Status - the error code returned. | |||
| Status CalculateNumRowsPerShard(); | |||
| // Count number of rows in each file. | |||
| // @param filename - clue file name. | |||
| // @return int64_t - the total number of rows in file. | |||
| int64_t CountTotalRows(const std::string &file); | |||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||
| // @return Status - the error code returned. | |||
| Status PostEndOfData(); | |||
| // @return Status - the error code returned. | |||
| Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t); | |||
| int32_t device_id_; | |||
| bool shuffle_files_; | |||
| bool finished_reading_dataset_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| bool load_io_block_queue_; | |||
| int64_t num_rows_per_shard_; | |||
| int64_t all_num_rows_; | |||
| int64_t num_samples_; | |||
| std::map<std::string, int64_t> filename_numrows_; | |||
| std::unique_ptr<StringIndex> filename_index_; | |||
| std::vector<std::string> clue_files_list_; | |||
| WaitPost io_block_queue_wait_post_; | |||
| std::unique_ptr<JaggedConnector> jagged_buffer_connector_; | |||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||
| bool load_jagged_connector_; | |||
| ColKeyMap cols_to_keyword_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ | |||
| @@ -43,7 +43,7 @@ TextFileOp::Builder::Builder() | |||
| Status TextFileOp::Builder::ValidateInputs() const { | |||
| std::string err_msg; | |||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : ""; | |||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; | |||
| err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; | |||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| @@ -21,7 +21,7 @@ can also create samplers with this module to sample data. | |||
| from .core.configuration import config | |||
| from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ | |||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ | |||
| TextFileDataset, Schema, Shuffle, zip, RandomDataset | |||
| TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset | |||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | |||
| WeightedRandomSampler, Sampler | |||
| from .engine.serializer_deserializer import serialize, deserialize, show | |||
| @@ -29,6 +29,6 @@ from .engine.graphdata import GraphData | |||
| __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | |||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", | |||
| "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", | |||
| "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] | |||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", | |||
| "CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", | |||
| "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] | |||
| @@ -30,7 +30,7 @@ from ..core.configuration import config, ConfigurationManager | |||
| __all__ = ["config", "ConfigurationManager", "zip", | |||
| "ImageFolderDatasetV2", "MnistDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | |||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", | |||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | |||
| "VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema", | |||
| "DistributedSampler", "PKSampler", | |||
| @@ -33,7 +33,7 @@ import copy | |||
| import numpy as np | |||
| from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ | |||
| MindRecordOp, TextFileOp, VOCOp, CocoOp, CBatchInfo | |||
| MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo | |||
| from mindspore._c_expression import typing | |||
| from mindspore import log as logger | |||
| @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||
| check_split | |||
| check_split, check_cluedataset | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -4317,6 +4317,222 @@ class CelebADataset(MappableDataset): | |||
| return self.sampler.is_sharded() | |||
| class CLUEDataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses CLUE datasets. | |||
| CLUE, the Chinese Language Understanding Evaluation Benchmark, a collection of datasets, baselines, pre-trained | |||
| models, corpus and leaderboard. Here we bring in classification task of CLUE, which are AFQMC, TNEWS, IFLYTEK, | |||
| CMNLI, WSC and CSL. | |||
| Args: | |||
| dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of | |||
| files. The list will be sorted in a lexicographical order. | |||
| task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. | |||
| (default=AFQMC). | |||
| usage (str, optional): Need train, test or eval data (default="train"). | |||
| num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). | |||
| num_parallel_workers (int, optional): number of workers to read the data | |||
| (default=None, number set in the config). | |||
| shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). | |||
| If shuffle is False, no shuffling will be performed; | |||
| If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL | |||
| Otherwise, there are two levels of shuffling: | |||
| - Shuffle.GLOBAL: Shuffle both the files and samples. | |||
| - Shuffle.FILES: Shuffle files only. | |||
| num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument should be specified only when num_shards is also specified. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files | |||
| >>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train') | |||
| """ | |||
| @check_cluedataset | |||
| def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, | |||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_files = self._find_files(dataset_files) | |||
| self.dataset_files.sort() | |||
| self.num_samples = num_samples | |||
| self.task_dict = { | |||
| 'AFQMC': { | |||
| 'train': { | |||
| 'sentence1': 'sentence1', | |||
| 'sentence2': 'sentence2', | |||
| 'label': 'label' | |||
| }, | |||
| 'test': { | |||
| 'id': 'id', | |||
| 'sentence1': 'sentence1', | |||
| 'sentence2': 'sentence2' | |||
| }, | |||
| 'eval': { | |||
| 'sentence1': 'sentence1', | |||
| 'sentence2': 'sentence2', | |||
| 'label': 'label' | |||
| } | |||
| }, | |||
| 'CMNLI': { | |||
| 'train': { | |||
| 'sentence1': 'sentence1', | |||
| 'sentence2': 'sentence2', | |||
| 'label': 'label' | |||
| }, | |||
| 'test': { | |||
| 'id': 'id', | |||
| 'sentence1': 'sentence1', | |||
| 'sentence2': 'sentence2' | |||
| }, | |||
| 'eval': { | |||
| 'sentence1': 'sentence1', | |||
| 'sentence2': 'sentence2', | |||
| 'label': 'label' | |||
| } | |||
| }, | |||
| 'CSL': { | |||
| 'train': { | |||
| 'id': 'id', | |||
| 'abst': 'abst', | |||
| 'keyword': 'keyword', | |||
| 'label': 'label' | |||
| }, | |||
| 'test': { | |||
| 'id': 'id', | |||
| 'abst': 'abst', | |||
| 'keyword': 'keyword' | |||
| }, | |||
| 'eval': { | |||
| 'id': 'id', | |||
| 'abst': 'abst', | |||
| 'keyword': 'keyword', | |||
| 'label': 'label' | |||
| } | |||
| }, | |||
| 'IFLYTEK': { | |||
| 'train': { | |||
| 'label': 'label', | |||
| 'label_des': 'label_des', | |||
| 'sentence': 'sentence' | |||
| }, | |||
| 'test': { | |||
| 'id': 'id', | |||
| 'sentence': 'sentence', | |||
| }, | |||
| 'eval': { | |||
| 'label': 'label', | |||
| 'label_des': 'label_des', | |||
| 'sentence': 'sentence' | |||
| } | |||
| }, | |||
| 'TNEWS': { | |||
| 'train': { | |||
| 'label': 'label', | |||
| 'label_desc': 'label_desc', | |||
| 'sentence': 'sentence', | |||
| 'keywords': 'keywords' | |||
| }, | |||
| 'test': { | |||
| 'id': 'id', | |||
| 'sentence': 'sentence', | |||
| 'keywords': 'keywords' | |||
| }, | |||
| 'eval': { | |||
| 'label': 'label', | |||
| 'label_desc': 'label_desc', | |||
| 'sentence': 'sentence', | |||
| 'keywords': 'keywords' | |||
| } | |||
| }, | |||
| 'WSC': { | |||
| 'train': { | |||
| 'span1_index': 'target/span1_index', | |||
| 'span2_index': 'target/span2_index', | |||
| 'span1_text': 'target/span1_text', | |||
| 'span2_text': 'target/span2_text', | |||
| 'idx': 'idx', | |||
| 'label': 'label', | |||
| 'text': 'text' | |||
| }, | |||
| 'test': { | |||
| 'span1_index': 'target/span1_index', | |||
| 'span2_index': 'target/span2_index', | |||
| 'span1_text': 'target/span1_text', | |||
| 'span2_text': 'target/span2_text', | |||
| 'idx': 'idx', | |||
| 'text': 'text' | |||
| }, | |||
| 'eval': { | |||
| 'span1_index': 'target/span1_index', | |||
| 'span2_index': 'target/span2_index', | |||
| 'span1_text': 'target/span1_text', | |||
| 'span2_text': 'target/span2_text', | |||
| 'idx': 'idx', | |||
| 'label': 'label', | |||
| 'text': 'text' | |||
| } | |||
| } | |||
| } | |||
| self.cols_to_keyword = self.task_dict[task][usage] | |||
| if not isinstance(shuffle, (bool, Shuffle)): | |||
| raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") | |||
| if not isinstance(shuffle, Shuffle): | |||
| if shuffle: | |||
| self.shuffle_level = Shuffle.GLOBAL | |||
| self.shuffle_files = True | |||
| else: | |||
| self.shuffle_level = None | |||
| self.shuffle_files = False | |||
| else: | |||
| self.shuffle_level = shuffle | |||
| self.shuffle_files = True | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["dataset_files"] = self.dataset_files | |||
| args["num_samples"] = self.num_samples | |||
| if self.shuffle_files is not None: | |||
| args["shuffle_files"] = self.shuffle_files | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cols_to_keyword"] = self.cols_to_keyword | |||
| return args | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self._dataset_size is None: | |||
| num_rows = ClueOp.get_num_rows(self.dataset_files) | |||
| num_rows = get_num_rows(num_rows, self.num_shards) | |||
| if self.num_samples is None: | |||
| return num_rows | |||
| return min(self.num_samples, num_rows) | |||
| return self._dataset_size | |||
| def is_shuffled(self): | |||
| return self.shuffle_files | |||
| def is_sharded(self): | |||
| if self.num_shards is not None: | |||
| return self.num_shards > 1 | |||
| return False | |||
| class TextFileDataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses datasets stored on disk in text format. | |||
| @@ -50,7 +50,8 @@ def alter_tree(node): | |||
| def _alter_node(node): | |||
| """Performing some alteration to a dataset node. A common alteration is to insert a node.""" | |||
| if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL: | |||
| if isinstance(node, (de.TFRecordDataset, de.TextFileDataset, de.CLUEDataset)) \ | |||
| and node.shuffle_level == de.Shuffle.GLOBAL: | |||
| # Remove the connection between the parent's node to the current node because we are inserting a node. | |||
| if node.output: | |||
| node.output.pop() | |||
| @@ -179,6 +180,8 @@ class Iterator: | |||
| op_type = OpName.TEXTFILE | |||
| elif isinstance(dataset, de.BuildVocabDataset): | |||
| op_type = OpName.BUILDVOCAB | |||
| elif isinstance(dataset, de.CLUEDataset): | |||
| op_type = OpName.CLUE | |||
| else: | |||
| raise ValueError("Unsupported DatasetOp") | |||
| @@ -1075,6 +1075,41 @@ def check_add_column(method): | |||
| return new_method | |||
| def check_cluedataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset).""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| # check dataset_files; required argument | |||
| dataset_files = param_dict.get('dataset_files') | |||
| if dataset_files is None: | |||
| raise ValueError("dataset_files is not provided.") | |||
| if not isinstance(dataset_files, (str, list)): | |||
| raise TypeError("dataset_files should be of type str or a list of strings.") | |||
| # check task | |||
| task_param = param_dict.get('task') | |||
| if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']: | |||
| raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL") | |||
| # check usage | |||
| usage_param = param_dict.get('usage') | |||
| if usage_param not in ['train', 'test', 'eval']: | |||
| raise ValueError("usage should be train, test or eval") | |||
| check_param_type(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_textfiledataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" | |||
| @@ -65,6 +65,7 @@ SET(DE_UT_SRCS | |||
| cifar_op_test.cc | |||
| celeba_op_test.cc | |||
| take_op_test.cc | |||
| clue_op_test.cc | |||
| text_file_op_test.cc | |||
| filter_op_test.cc | |||
| concat_op_test.cc | |||
| @@ -0,0 +1,117 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "common/utils.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace common = mindspore::common; | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestCLUEOp : public UT::DatasetOpTesting { | |||
| }; | |||
| TEST_F(MindDataTestCLUEOp, TestCLUEBasic) { | |||
| // Start with an empty execution tree | |||
| auto tree = std::make_shared<ExecutionTree>(); | |||
| std::string dataset_path; | |||
| dataset_path = datasets_root_path_ + "/testCLUE/afqmc/train.json"; | |||
| std::map<std::string, std::string> key_map; | |||
| key_map["sentence1"] = "sentence1"; | |||
| key_map["sentence2"] = "sentence2"; | |||
| key_map["label"] = "label"; | |||
| std::shared_ptr<ClueOp> op; | |||
| ClueOp::Builder builder; | |||
| builder.SetClueFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetNumWorkers(16) | |||
| .SetOpConnectorSize(2) | |||
| .SetColsKeyMap(key_map); | |||
| Status rc = builder.Build(&op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->AssociateNode(op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->AssignRoot(op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration."; | |||
| rc = tree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(tree); | |||
| TensorRow tensor_list; | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int row_count = 0; | |||
| while (!tensor_list.empty()) { | |||
| // Display the tensor by calling the printer on it | |||
| for (int i = 0; i < tensor_list.size(); i++) { | |||
| std::ostringstream ss; | |||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||
| MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; | |||
| } | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| row_count++; | |||
| } | |||
| ASSERT_EQ(row_count, 3); | |||
| } | |||
| TEST_F(MindDataTestCLUEOp, TestTotalRows) { | |||
| std::string tf_file1 = datasets_root_path_ + "/testCLUE/afqmc/train.json"; | |||
| std::string tf_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json"; | |||
| std::vector<std::string> files; | |||
| files.push_back(tf_file1); | |||
| int64_t total_rows = 0; | |||
| ClueOp::CountAllFileRows(files, &total_rows); | |||
| ASSERT_EQ(total_rows, 3); | |||
| files.clear(); | |||
| files.push_back(tf_file2); | |||
| ClueOp::CountAllFileRows(files, &total_rows); | |||
| ASSERT_EQ(total_rows, 3); | |||
| files.clear(); | |||
| files.push_back(tf_file1); | |||
| files.push_back(tf_file2); | |||
| ClueOp::CountAllFileRows(files, &total_rows); | |||
| ASSERT_EQ(total_rows, 6); | |||
| files.clear(); | |||
| } | |||
| @@ -0,0 +1,3 @@ | |||
| {"sentence1": "你有花呗吗", "sentence2": "我的花呗没额度了", "label": "0"} | |||
| {"sentence1": "吃饭能用花呗吗", "sentence2": "花呗太方便了", "label": "0"} | |||
| {"sentence1": "蚂蚁花呗支付金额有什么限制", "sentence2": "我到实体店消费用花呗支付受金额限制", "label": "1"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"id": 0, "sentence1": "借呗取消的时间", "sentence2": "蚂蚁借呗恢复的月数"} | |||
| {"id": 1, "sentence1": "网商贷用什么方法转变成借呗", "sentence2": "什么手段能将网商贷切换为借呗"} | |||
| {"id": 2, "sentence1": "我的借呗为什么开通不了", "sentence2": "我为啥没法开通借呗"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"sentence1": "蚂蚁借呗等额还款能否换成先息后本", "sentence2": "借呗可以先息到期还本吗", "label": "0"} | |||
| {"sentence1": "蚂蚁花呗说我违约了", "sentence2": "蚂蚁花呗违约行为是啥", "label": "0"} | |||
| {"sentence1": "帮我看看本月花呗账单结清了没", "sentence2": "上月的花呗账单", "label": "0"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"sentence1": "每个人都有权利", "sentence2": "每个人都有福利", "label": "neutral"} | |||
| {"sentence1": "有时候我喜欢他,但我也喜欢看到有人打他", "sentence2": "说实话,我有点喜欢他,但还是喜欢看到有人打他。", "label": "entailment"} | |||
| {"sentence1": "我最喜欢的餐馆是离你最近的一家", "sentence2": "我最喜欢的餐馆离你家至少一百英里远。", "label": "contradiction"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"id": 0, "sentence1": "今天,全球都在看着最新航天飞机的处女航。", "sentence2": "全世界都在看最新的航天飞机发射。"} | |||
| {"id": 1, "sentence1": "而我们把竹篮放在一个地方,把玻璃瓶放在另一处,把书放在另一处,满了要把它放到车里", "sentence2": "我们没有分开任何东西,都把它全扔进一个箱子里。"} | |||
| {"id": 2, "sentence1": "她占用了我的很多时间,她给我读了很多关于灵异的故事,我觉得很无聊。", "sentence2": "我喜欢和她一起读鬼故事。"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"sentence1": "你应该给这件衣服定一个价格。", "sentence2": "不同的衣服有不同的价格。", "label": "neutral"} | |||
| {"sentence1": "我怎么知道他要说什么", "sentence2": "他说什么我并不知道。", "label": "entailment"} | |||
| {"sentence1": "向左。", "sentence2": "向右。", "label": "contradiction"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"id": 1, "abst": "这是第一段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"} | |||
| {"id": 2, "abst": "这是第二段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"} | |||
| {"id": 3, "abst": "这是第三段很长的文本", "keyword": ["1", "2", "3"], "label": "0"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"id": 2415, "abst": "长文本1", "keyword": ["关键词1", "关键词2"]} | |||
| {"id": 2565, "abst": "长文本2", "keyword": ["关键词1", "关键词2", "关键词3"]} | |||
| {"id": 2625, "abst": "长文本3", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"]} | |||
| @@ -0,0 +1,3 @@ | |||
| {"id": 1, "abst": "这是一段长文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "0"} | |||
| {"id": 2, "abst": "这是一段长文本", "keyword": ["关键词5", "关键词6", "关键词7", "关键词8"], "label": "0"} | |||
| {"id": 3, "abst": "这是一段长文本", "keyword": ["关键词9", "关键词10", "关键词11", "关键词12"], "label": "0"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"label": "110", "label_des": "社区超市", "sentence": "这是第一段文本"} | |||
| {"label": "70", "label_des": "工具", "sentence": "这是第二段文本"} | |||
| {"label": "10", "label_des": "社区服务", "sentence": "这是第三段文本"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"id": 0, "sentence": "文本1"} | |||
| {"id": 1, "sentence": "文本2"} | |||
| {"id": 2, "sentence": "文本3"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"label": "11", "label_des": "薅羊毛", "sentence": "第一个文本"} | |||
| {"label": "95", "label_des": "借贷", "sentence": "第二个文本"} | |||
| {"label": "74", "label_des": "违章", "sentence": "第三个文本"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"label": "102", "label_desc": "news_entertainment", "sentence": "新闻1", "keywords": "关键词一,关键词二,关键词三,关键词四"} | |||
| {"label": "110", "label_desc": "news_military", "sentence": "新闻2", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"} | |||
| {"label": "104", "label_desc": "news_finance", "sentence": "新闻3", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"id": 0, "sentence": "新闻1", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5"} | |||
| {"id": 1, "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4"} | |||
| {"id": 2, "sentence": "新闻3", "keywords": ""} | |||
| @@ -0,0 +1,3 @@ | |||
| {"label": "108", "label_desc": "news_edu", "sentence": "新闻1", "keywords": ""} | |||
| {"label": "104", "label_desc": "news_finance", "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5,关键词6"} | |||
| {"label": "106", "label_desc": "news_house", "sentence": "新闻3", "keywords": ""} | |||
| @@ -0,0 +1,3 @@ | |||
| {"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"} | |||
| {"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"} | |||
| {"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?"} | |||
| {"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场"} | |||
| {"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业"} | |||
| @@ -0,0 +1,3 @@ | |||
| {"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"} | |||
| {"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"} | |||
| {"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"} | |||
| @@ -0,0 +1,355 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| def test_clue(): | |||
| """ | |||
| Test CLUE with repeat, skip and so on | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) | |||
| data = data.repeat(2) | |||
| data = data.skip(3) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'sentence1': d['sentence1'].item().decode("utf8"), | |||
| 'sentence2': d['sentence2'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| def test_clue_num_shards(): | |||
| """ | |||
| Test num_shards param of CLUE dataset | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'sentence1': d['sentence1'].item().decode("utf8"), | |||
| 'sentence2': d['sentence2'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 1 | |||
| def test_clue_num_samples(): | |||
| """ | |||
| Test num_samples param of CLUE dataset | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| count += 1 | |||
| assert count == 2 | |||
| def test_textline_dataset_get_datasetsize(): | |||
| """ | |||
| Test get_dataset_size of CLUE dataset | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| data = ds.TextFileDataset(TRAIN_FILE) | |||
| size = data.get_dataset_size() | |||
| assert size == 3 | |||
| def test_clue_afqmc(): | |||
| """ | |||
| Test AFQMC for train, test and evaluation | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json' | |||
| EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json' | |||
| # train | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'sentence1': d['sentence1'].item().decode("utf8"), | |||
| 'sentence2': d['sentence2'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # test | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'id': d['id'], | |||
| 'sentence1': d['sentence1'].item().decode("utf8"), | |||
| 'sentence2': d['sentence2'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # evaluation | |||
| buffer = [] | |||
| data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'sentence1': d['sentence1'].item().decode("utf8"), | |||
| 'sentence2': d['sentence2'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| def test_clue_cmnli(): | |||
| """ | |||
| Test CMNLI for train, test and evaluation | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json' | |||
| TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json' | |||
| EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json' | |||
| # train | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'sentence1': d['sentence1'].item().decode("utf8"), | |||
| 'sentence2': d['sentence2'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # test | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'id': d['id'], | |||
| 'sentence1': d['sentence1'], | |||
| 'sentence2': d['sentence2'] | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # eval | |||
| buffer = [] | |||
| data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'], | |||
| 'sentence1': d['sentence1'], | |||
| 'sentence2': d['sentence2'] | |||
| }) | |||
| assert len(buffer) == 3 | |||
| def test_clue_csl(): | |||
| """ | |||
| Test CSL for train, test and evaluation | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json' | |||
| TEST_FILE = '../data/dataset/testCLUE/csl/test.json' | |||
| EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json' | |||
| # train | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'id': d['id'], | |||
| 'abst': d['abst'].item().decode("utf8"), | |||
| 'keyword': [i.item().decode("utf8") for i in d['keyword']], | |||
| 'label': d['label'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # test | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'id': d['id'], | |||
| 'abst': d['abst'].item().decode("utf8"), | |||
| 'keyword': [i.item().decode("utf8") for i in d['keyword']], | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # eval | |||
| buffer = [] | |||
| data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'id': d['id'], | |||
| 'abst': d['abst'].item().decode("utf8"), | |||
| 'keyword': [i.item().decode("utf8") for i in d['keyword']], | |||
| 'label': d['label'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| def test_clue_iflytek(): | |||
| """ | |||
| Test IFLYTEK for train, test and evaluation | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json' | |||
| TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json' | |||
| EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json' | |||
| # train | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'label_des': d['label_des'].item().decode("utf8"), | |||
| 'sentence': d['sentence'].item().decode("utf8"), | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # test | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'id': d['id'], | |||
| 'sentence': d['sentence'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # eval | |||
| buffer = [] | |||
| data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'label_des': d['label_des'].item().decode("utf8"), | |||
| 'sentence': d['sentence'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| def test_clue_tnews(): | |||
| """ | |||
| Test TNEWS for train, test and evaluation | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json' | |||
| TEST_FILE = '../data/dataset/testCLUE/tnews/test.json' | |||
| EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json' | |||
| # train | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'label_desc': d['label_desc'].item().decode("utf8"), | |||
| 'sentence': d['sentence'].item().decode("utf8"), | |||
| 'keywords': | |||
| d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # test | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'id': d['id'], | |||
| 'sentence': d['sentence'].item().decode("utf8"), | |||
| 'keywords': | |||
| d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # eval | |||
| buffer = [] | |||
| data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False) | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'label_desc': d['label_desc'].item().decode("utf8"), | |||
| 'sentence': d['sentence'].item().decode("utf8"), | |||
| 'keywords': | |||
| d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] | |||
| }) | |||
| assert len(buffer) == 3 | |||
| def test_clue_wsc(): | |||
| """ | |||
| Test WSC for train, test and evaluation | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json' | |||
| TEST_FILE = '../data/dataset/testCLUE/wsc/test.json' | |||
| EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json' | |||
| # train | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train') | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'span1_index': d['span1_index'], | |||
| 'span2_index': d['span2_index'], | |||
| 'span1_text': d['span1_text'].item().decode("utf8"), | |||
| 'span2_text': d['span2_text'].item().decode("utf8"), | |||
| 'idx': d['idx'], | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'text': d['text'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # test | |||
| buffer = [] | |||
| data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test') | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'span1_index': d['span1_index'], | |||
| 'span2_index': d['span2_index'], | |||
| 'span1_text': d['span1_text'].item().decode("utf8"), | |||
| 'span2_text': d['span2_text'].item().decode("utf8"), | |||
| 'idx': d['idx'], | |||
| 'text': d['text'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| # eval | |||
| buffer = [] | |||
| data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval') | |||
| for d in data.create_dict_iterator(): | |||
| buffer.append({ | |||
| 'span1_index': d['span1_index'], | |||
| 'span2_index': d['span2_index'], | |||
| 'span1_text': d['span1_text'].item().decode("utf8"), | |||
| 'span2_text': d['span2_text'].item().decode("utf8"), | |||
| 'idx': d['idx'], | |||
| 'label': d['label'].item().decode("utf8"), | |||
| 'text': d['text'].item().decode("utf8") | |||
| }) | |||
| assert len(buffer) == 3 | |||
| if __name__ == "__main__": | |||
| test_clue() | |||
| test_clue_afqmc() | |||
| test_clue_cmnli() | |||
| test_clue_csl() | |||
| test_clue_iflytek() | |||
| test_clue_tnews() | |||
| test_clue_wsc() | |||