| @@ -103,6 +103,8 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||
| @@ -1217,6 +1219,27 @@ ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, boo | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| IWSLT2016Dataset::IWSLT2016Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::vector<std::vector<char>> &language_pair, | |||
| const std::vector<char> &valid_set, const std::vector<char> &test_set, | |||
| int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| auto ds = std::make_shared<IWSLT2016Node>(CharToString(dataset_dir), CharToString(usage), | |||
| VectorCharToString(language_pair), CharToString(valid_set), | |||
| CharToString(test_set), num_samples, shuffle, num_shards, shard_id, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| IWSLT2017Dataset::IWSLT2017Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::vector<std::vector<char>> &language_pair, int64_t num_samples, | |||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| auto ds = | |||
| std::make_shared<IWSLT2017Node>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(language_pair), | |||
| num_samples, shuffle, num_shards, shard_id, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||
| auto sampler_obj = sampler ? sampler->Parse() : nullptr; | |||
| @@ -41,6 +41,8 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h" | |||
| @@ -272,6 +274,33 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(IWSLT2016Node, 2, ([](const py::module *m) { | |||
| (void)py::class_<IWSLT2016Node, DatasetNode, std::shared_ptr<IWSLT2016Node>>( | |||
| *m, "IWSLT2016Node", "to create an IWSLT2016Node") | |||
| .def(py::init([](std::string dataset_dir, std::string usage, std::vector<std::string> language_pair, | |||
| std::string valid_set, std::string test_set, int64_t num_samples, int32_t shuffle, | |||
| int32_t num_shards, int32_t shard_id) { | |||
| std::shared_ptr<IWSLT2016Node> iwslt2016 = std::make_shared<IWSLT2016Node>( | |||
| dataset_dir, usage, language_pair, valid_set, test_set, num_samples, toShuffleMode(shuffle), | |||
| num_shards, shard_id, nullptr); | |||
| THROW_IF_ERROR(iwslt2016->ValidateParams()); | |||
| return iwslt2016; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) { | |||
| (void)py::class_<IWSLT2017Node, DatasetNode, std::shared_ptr<IWSLT2017Node>>( | |||
| *m, "IWSLT2017Node", "to create an IWSLT2017Node") | |||
| .def(py::init([](std::string dataset_dir, std::string usage, std::vector<std::string> language_pair, | |||
| int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id) { | |||
| std::shared_ptr<IWSLT2017Node> iwslt2017 = | |||
| std::make_shared<IWSLT2017Node>(dataset_dir, usage, language_pair, num_samples, | |||
| toShuffleMode(shuffle), num_shards, shard_id, nullptr); | |||
| THROW_IF_ERROR(iwslt2017->ValidateParams()); | |||
| return iwslt2017; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<KMnistNode, DatasetNode, std::shared_ptr<KMnistNode>>(*m, "KMnistNode", | |||
| "to create a KMnistNode") | |||
| @@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| fashion_mnist_op.cc | |||
| flickr_op.cc | |||
| image_folder_op.cc | |||
| iwslt_op.cc | |||
| io_block.cc | |||
| kmnist_op.cc | |||
| lj_speech_op.cc | |||
| @@ -0,0 +1,497 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/datasetops/source/iwslt_op.h" | |||
| #include <fstream> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "utils/file_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| IWSLTOp::IWSLTOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::unique_ptr<DataSchema> data_schema, | |||
| IWSLTType type, const std::string &dataset_dir, const std::string &usage, | |||
| const std::vector<std::string> &language_pair, const std::string &valid_set, | |||
| const std::string &test_set) | |||
| : NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices, | |||
| device_id), | |||
| iwslt_type_(type), | |||
| data_schema_(std::move(data_schema)), | |||
| dataset_dir_(dataset_dir), | |||
| usage_(usage), | |||
| language_pair_(language_pair), | |||
| valid_set_(valid_set), | |||
| test_set_(test_set) {} | |||
| Status IWSLTOp::Init() { | |||
| RETURN_IF_NOT_OK(this->GetFiles()); | |||
| RETURN_IF_NOT_OK(filename_index_->insert(src_target_file_list_)); | |||
| int32_t safe_queue_size = static_cast<int32_t>(std::ceil(src_target_file_list_.size() / num_workers_) + 1); | |||
| io_block_queues_.Init(num_workers_, safe_queue_size); | |||
| jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| std::vector<std::string> IWSLTOp::Split(const std::string &s, const std::string &delim) { | |||
| std::vector<std::string> res; | |||
| std::string::size_type pos1 = 0; | |||
| std::string::size_type pos2 = s.find(delim); | |||
| while (std::string::npos != pos2) { | |||
| res.push_back(s.substr(pos1, pos2 - pos1)); | |||
| pos1 = pos2 + delim.size(); | |||
| pos2 = s.find(delim, pos1); | |||
| } | |||
| if (pos1 != s.length()) { | |||
| res.push_back(s.substr(pos1)); | |||
| } | |||
| return res; | |||
| } | |||
| Status IWSLTOp::Trim(std::string *text, const std::string &character) { | |||
| RETURN_UNEXPECTED_IF_NULL(text); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!text->empty(), "Invalid file, read an empty line."); | |||
| (void)text->erase(0, text->find_first_not_of(character)); | |||
| (void)text->erase(text->find_last_not_of(character) + 1); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLTOp::LoadTensor(const std::string &line, TensorRow *out_row, size_t index) { | |||
| RETURN_UNEXPECTED_IF_NULL(out_row); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); | |||
| (*out_row)[index] = std::move(tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLTOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) { | |||
| std::ifstream handle(file); | |||
| std::string line; | |||
| if (!handle.is_open()) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + " file: " + file); | |||
| } | |||
| int64_t rows_total = 0; | |||
| while (getline(handle, line)) { | |||
| if (line.empty()) { | |||
| continue; | |||
| } | |||
| // If read to the end offset of this file, break. | |||
| if (rows_total >= end_offset) { | |||
| break; | |||
| } | |||
| // Skip line before start offset. | |||
| if (rows_total < start_offset) { | |||
| rows_total++; | |||
| continue; | |||
| } | |||
| const int kColumnSize = 2; | |||
| TensorRow tRow(kColumnSize, nullptr); | |||
| tRow.setPath({file, file}); | |||
| // Remove the newline character. | |||
| RETURN_IF_NOT_OK(Trim(&line, "\n")); | |||
| RETURN_IF_NOT_OK(Trim(&line, "\r")); | |||
| std::vector<std::string> sentence_list = Split(line, "#*$"); | |||
| if (!sentence_list.empty() && sentence_list.size() == kColumnSize) { | |||
| RETURN_IF_NOT_OK(LoadTensor(sentence_list[0], &tRow, 0)); | |||
| RETURN_IF_NOT_OK(LoadTensor(sentence_list[1], &tRow, 1)); | |||
| RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow))); | |||
| rows_total++; | |||
| } | |||
| } | |||
| handle.close(); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLTOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||
| int32_t queue_index = 0; | |||
| int64_t pre_count = 0; | |||
| int64_t start_offset = 0; | |||
| int64_t end_offset = 0; | |||
| bool finish = false; | |||
| while (!finish) { | |||
| std::vector<std::pair<std::string, int64_t>> file_index; | |||
| if (!i_keys.empty()) { | |||
| for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it)); | |||
| } | |||
| } else { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| { | |||
| if (!load_io_block_queue_) { | |||
| break; | |||
| } | |||
| } | |||
| file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key())); | |||
| } | |||
| } | |||
| for (auto file_info : file_index) { | |||
| if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { | |||
| auto ioBlock = | |||
| std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); | |||
| RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); | |||
| queue_index = (queue_index + 1) % num_workers_; | |||
| } | |||
| pre_count += filename_numrows_[file_info.first]; | |||
| } | |||
| if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||
| finish = false; | |||
| } else { | |||
| finish = true; | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); | |||
| return Status::OK(); | |||
| } | |||
| void IWSLTOp::Print(std::ostream &out, bool show_all) const { | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op. | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff. | |||
| out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | |||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nIWSLT files list:\n"; | |||
| for (int i = 0; i < src_target_file_list_.size(); ++i) { | |||
| out << " " << src_target_file_list_[i]; | |||
| } | |||
| out << "\nData Schema:\n"; | |||
| out << *data_schema_ << "\n\n"; | |||
| } | |||
| } | |||
| int64_t IWSLTOp::CountFileRows(const std::string &file) { | |||
| std::ifstream handle(file); | |||
| if (!handle.is_open()) { | |||
| MS_LOG(ERROR) << "Invalid file, failed to open file: " << file; | |||
| return 0; | |||
| } | |||
| std::string line; | |||
| int64_t count = 0; | |||
| while (getline(handle, line)) { | |||
| if (!line.empty()) { | |||
| count++; | |||
| } | |||
| } | |||
| handle.close(); | |||
| return count; | |||
| } | |||
| Status IWSLTOp::CalculateNumRowsPerShard() { | |||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||
| int64_t count = CountFileRows(it.value()); | |||
| filename_numrows_[it.value()] = count; | |||
| num_rows_ += count; | |||
| } | |||
| if (num_rows_ == 0) { | |||
| std::stringstream ss; | |||
| for (int i = 0; i < src_target_file_list_.size(); ++i) { | |||
| ss << " " << src_target_file_list_[i]; | |||
| } | |||
| std::string file_list = ss.str(); | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) + | |||
| "Dataset API can't read the data file (interface mismatch or no data found). Check " + | |||
| DatasetName() + ": " + file_list); | |||
| } | |||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_)); | |||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLTOp::ComputeColMap() { | |||
| // Set the column name mapping (base class field). | |||
| if (column_name_id_map_.empty()) { | |||
| for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { | |||
| column_name_id_map_[data_schema_->Column(i).Name()] = i; | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Column name map is already set!"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLTOp::CountTotalRows(IWSLTType type, const std::string &dataset_dir, const std::string &usage, | |||
| const std::vector<std::string> &language_pair, const std::string &valid_set, | |||
| const std::string &test_set, int64_t *count) { | |||
| RETURN_UNEXPECTED_IF_NULL(count); | |||
| int32_t num_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| int32_t connector_que_size = GlobalContext::config_manager()->op_connector_size(); | |||
| int32_t worker_connector_size = GlobalContext::config_manager()->worker_connector_size(); | |||
| const int32_t shard_id = 0; | |||
| const int32_t num_shards = 1; | |||
| const int64_t num_samples = 0; | |||
| bool shuffle_files = false; | |||
| // Do internal Schema generation. | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| // Create and initialize. | |||
| std::shared_ptr<IWSLTOp> op = std::make_shared<IWSLTOp>( | |||
| num_workers, num_samples, worker_connector_size, connector_que_size, shuffle_files, num_shards, shard_id, | |||
| std::move(schema), type, dataset_dir, usage, language_pair, valid_set, test_set); | |||
| RETURN_IF_NOT_OK(op->Init()); | |||
| *count = 0; | |||
| std::vector<std::string> file_list = op->FileNames(); | |||
| for (auto file : file_list) { | |||
| *count += op->CountFileRows(file); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status LoadXmlDocument(XMLDocument *xml_document, const std::string &file_path, XMLElement **doc) { | |||
| RETURN_UNEXPECTED_IF_NULL(xml_document); | |||
| XMLError e = xml_document->LoadFile(common::SafeCStr(file_path)); | |||
| if (e != XMLError::XML_SUCCESS) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to load xml file: " + file_path); | |||
| } | |||
| XMLElement *root = xml_document->RootElement(); | |||
| if (root == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, failed to load root element for xml file."); | |||
| } | |||
| XMLElement *firstChild = root->FirstChildElement(); | |||
| if (firstChild == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, no first child found in " + file_path); | |||
| } | |||
| *doc = firstChild->FirstChildElement("doc"); | |||
| if (*doc == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, no doc found in " + file_path); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLTOp::CleanXmlFile(const std::string &src_file_path, const std::string &target_file_path, | |||
| const std::string &new_file_path) { | |||
| XMLDocument xml_document1, xml_document2; | |||
| XMLElement *src_doc, *target_doc; | |||
| RETURN_IF_NOT_OK(LoadXmlDocument(&xml_document1, src_file_path, &src_doc)); | |||
| RETURN_IF_NOT_OK(LoadXmlDocument(&xml_document2, target_file_path, &target_doc)); | |||
| std::string src_content, target_content; | |||
| std::ofstream new_file(new_file_path); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(new_file.is_open(), "Invalid file, failed to open file: " + new_file_path); | |||
| while (src_doc && target_doc) { | |||
| XMLElement *src_seg = src_doc->FirstChildElement("seg"); | |||
| XMLElement *target_seg = target_doc->FirstChildElement("seg"); | |||
| while (src_seg && target_seg) { | |||
| src_content = src_seg->GetText(); | |||
| target_content = target_seg->GetText(); | |||
| RETURN_IF_NOT_OK(Trim(&src_content, " ")); | |||
| RETURN_IF_NOT_OK(Trim(&target_content, " ")); | |||
| src_seg = src_seg->NextSiblingElement(); | |||
| target_seg = target_seg->NextSiblingElement(); | |||
| new_file << (src_content + "#*$" + target_content + "\n"); | |||
| } | |||
| src_doc = src_doc->NextSiblingElement(); | |||
| target_doc = target_doc->NextSiblingElement(); | |||
| } | |||
| new_file.close(); | |||
| return Status::OK(); | |||
| } | |||
| bool IWSLTOp::IsContainTags(const std::string &content) { | |||
| std::vector<std::string> xml_tags = {"<url", "<keywords", "<talkid", "<description", "<reviewer", | |||
| "<translator", "<title", "<speaker", "<doc", "</doc"}; | |||
| int i = 0; | |||
| int size = xml_tags.size(); | |||
| while (i < size) { | |||
| if (content.find(xml_tags[i]) != std::string::npos) { | |||
| return true; | |||
| } | |||
| i++; | |||
| } | |||
| return false; | |||
| } | |||
| Status IWSLTOp::CleanTagFile(const std::string &src_file_path, const std::string &target_file_path, | |||
| const std::string &new_file_path) { | |||
| std::ifstream src_handle(src_file_path); | |||
| std::ifstream target_handle(target_file_path); | |||
| std::ofstream new_file(new_file_path, std::ios::trunc); | |||
| std::string src_content, target_content; | |||
| while (getline(src_handle, src_content)) { | |||
| while (getline(target_handle, target_content)) { | |||
| if (!IsContainTags(src_content) && !IsContainTags(target_content)) { | |||
| RETURN_IF_NOT_OK(Trim(&src_content, " ")); | |||
| RETURN_IF_NOT_OK(Trim(&target_content, " ")); | |||
| new_file << (src_content + "#*$" + target_content + "\n"); | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| new_file.close(); | |||
| src_handle.close(); | |||
| target_handle.close(); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLTOp::GenerateNewFile(const std::vector<std::string> &src_file_list, | |||
| const std::vector<std::string> &target_file_list, | |||
| std::vector<std::string> *src_target_file_list) { | |||
| RETURN_UNEXPECTED_IF_NULL(src_target_file_list); | |||
| std::string::size_type position; | |||
| std::string new_path; | |||
| std::string src_path, target_path; | |||
| for (int i = 0; i < src_file_list.size(); i++) { | |||
| src_path = src_file_list[i]; | |||
| target_path = target_file_list[i]; | |||
| // Add new train file name. | |||
| position = src_path.find(".tags"); | |||
| if (position != std::string::npos) { | |||
| new_path = src_path; | |||
| const int kTagSize = 5; | |||
| const int kSuffixSize = 3; | |||
| new_path = new_path.replace(new_path.find(".tags"), kTagSize, ""); | |||
| new_path = new_path.substr(0, new_path.length() - kSuffixSize); | |||
| // Write data to the new file path. | |||
| RETURN_IF_NOT_OK(CleanTagFile(src_path, target_path, new_path)); | |||
| src_target_file_list->push_back(new_path); | |||
| } else { | |||
| // Add new valid or test file name. | |||
| // Delete suffix. | |||
| const int kSuffixXMLSize = 7; | |||
| new_path = src_path; | |||
| new_path = new_path.substr(0, new_path.length() - kSuffixXMLSize); | |||
| // Write data to the new file path. | |||
| RETURN_IF_NOT_OK(CleanXmlFile(src_path, target_path, new_path)); | |||
| src_target_file_list->push_back(new_path); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::string IWSLTOp::GenerateIWSLT2016TagsFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &suffix) { | |||
| Path src_language_path(src_language); | |||
| Path target_language_path(target_language); | |||
| Path sub_dir(src_language + "-" + target_language); | |||
| Path file_name("train.tags." + src_language + "-" + target_language + "." + suffix); | |||
| Path file_path = dir / "texts" / src_language_path / target_language_path / sub_dir / file_name; | |||
| return file_path.ToString(); | |||
| } | |||
| std::string IWSLTOp::GenerateIWSLT2016XMLFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &set_type, | |||
| const std::string &suffix) { | |||
| Path src_language_path(src_language); | |||
| Path target_language_path(target_language); | |||
| Path sub_dir(src_language + "-" + target_language); | |||
| Path file_name("IWSLT16.TED." + set_type + "." + src_language + "-" + target_language + "." + suffix + ".xml"); | |||
| Path file_path = dir / "texts" / src_language_path / target_language_path / sub_dir / file_name; | |||
| return file_path.ToString(); | |||
| } | |||
| std::string IWSLTOp::GenerateIWSLT2017TagsFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &suffix) { | |||
| Path sub_const_dir("texts"); | |||
| Path sub_src_language_dir("DeEnItNlRo"); | |||
| Path sub_tgt_language_dir("DeEnItNlRo"); | |||
| Path sub_src_tgt_dir("DeEnItNlRo-DeEnItNlRo"); | |||
| Path file_name("train.tags." + src_language + "-" + target_language + "." + suffix); | |||
| Path file_path = dir / sub_const_dir / sub_src_language_dir / sub_tgt_language_dir / sub_src_tgt_dir / file_name; | |||
| return file_path.ToString(); | |||
| } | |||
| std::string IWSLTOp::GenerateIWSLT2017XMLFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &set_type, | |||
| const std::string &suffix) { | |||
| Path sub_const_dir("texts"); | |||
| Path sub_src_language_dir("DeEnItNlRo"); | |||
| Path sub_tgt_language_dir("DeEnItNlRo"); | |||
| Path sub_src_tgt_dir("DeEnItNlRo-DeEnItNlRo"); | |||
| Path file_name("IWSLT17.TED." + set_type + "." + src_language + "-" + target_language + "." + suffix + ".xml"); | |||
| Path file_path = dir / sub_const_dir / sub_src_language_dir / sub_tgt_language_dir / sub_src_tgt_dir / file_name; | |||
| return file_path.ToString(); | |||
| } | |||
| Status IWSLTOp::GetFiles() { | |||
| std::vector<std::string> src_path_list; | |||
| std::vector<std::string> target_path_list; | |||
| auto real_dataset_dir = FileUtils::GetRealPath(dataset_dir_.data()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(real_dataset_dir.has_value(), "Get real path failed: " + dataset_dir_); | |||
| Path root_dir(real_dataset_dir.value()); | |||
| if (iwslt_type_ == kIWSLT2016) { | |||
| if (usage_ == "train" || usage_ == "all") { | |||
| src_path_list.push_back( | |||
| GenerateIWSLT2016TagsFileName(root_dir, language_pair_[0], language_pair_[1], language_pair_[0])); | |||
| target_path_list.push_back( | |||
| GenerateIWSLT2016TagsFileName(root_dir, language_pair_[0], language_pair_[1], language_pair_[1])); | |||
| } | |||
| if (usage_ == "valid" || usage_ == "all") { | |||
| src_path_list.push_back( | |||
| GenerateIWSLT2016XMLFileName(root_dir, language_pair_[0], language_pair_[1], valid_set_, language_pair_[0])); | |||
| target_path_list.push_back( | |||
| GenerateIWSLT2016XMLFileName(root_dir, language_pair_[0], language_pair_[1], valid_set_, language_pair_[1])); | |||
| } | |||
| if (usage_ == "test" || usage_ == "all") { | |||
| src_path_list.push_back( | |||
| GenerateIWSLT2016XMLFileName(root_dir, language_pair_[0], language_pair_[1], test_set_, language_pair_[0])); | |||
| target_path_list.push_back( | |||
| GenerateIWSLT2016XMLFileName(root_dir, language_pair_[0], language_pair_[1], test_set_, language_pair_[1])); | |||
| } | |||
| } else { | |||
| if (usage_ == "train" || usage_ == "all") { | |||
| src_path_list.push_back( | |||
| GenerateIWSLT2017TagsFileName(root_dir, language_pair_[0], language_pair_[1], language_pair_[0])); | |||
| target_path_list.push_back( | |||
| GenerateIWSLT2017TagsFileName(root_dir, language_pair_[0], language_pair_[1], language_pair_[1])); | |||
| } | |||
| if (usage_ == "valid" || usage_ == "all") { | |||
| src_path_list.push_back( | |||
| GenerateIWSLT2017XMLFileName(root_dir, language_pair_[0], language_pair_[1], valid_set_, language_pair_[0])); | |||
| target_path_list.push_back( | |||
| GenerateIWSLT2017XMLFileName(root_dir, language_pair_[0], language_pair_[1], valid_set_, language_pair_[1])); | |||
| } | |||
| if (usage_ == "test" || usage_ == "all") { | |||
| src_path_list.push_back( | |||
| GenerateIWSLT2017XMLFileName(root_dir, language_pair_[0], language_pair_[1], test_set_, language_pair_[0])); | |||
| target_path_list.push_back( | |||
| GenerateIWSLT2017XMLFileName(root_dir, language_pair_[0], language_pair_[1], test_set_, language_pair_[1])); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(GenerateNewFile(src_path_list, target_path_list, &src_target_file_list_)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,236 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IWSLT_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IWSLT_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "./tinyxml2.h" | |||
| #include "debug/common.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" | |||
| #include "minddata/dataset/engine/jagged_connector.h" | |||
| using tinyxml2::XMLDocument; | |||
| using tinyxml2::XMLElement; | |||
| using tinyxml2::XMLError; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class JaggedConnector; | |||
| /// \class IWSLTOp. | |||
| /// \brief A Op derived class to represent IWSLT Op. | |||
| class IWSLTOp : public NonMappableLeafOp { | |||
| public: | |||
| enum IWSLTType { kIWSLT2016, kIWSLT2017 }; | |||
| /// \brief Constructor of IWSLTOp. | |||
| /// \param[in] num_workers Number of worker threads reading data from yelp_review files. | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// \param[in] worker_connector_size Size of each internal queue. | |||
| /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. | |||
| /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. | |||
| /// \param[in] num_devices Number of devices that the dataset should be divided into. | |||
| /// \param[in] device_id The device ID within num_devices. | |||
| /// \param[in] data_schema Schema of dataset. | |||
| /// \param[in] type Type of data set read. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Usage of this dataset, can be "train", "test", "valid" or "all" data. | |||
| /// \param[in] language_pair List containing src and tgt language. | |||
| /// \param[in] valid_set A string to identify validation set. | |||
| /// \param[in] test_set A string to identify test set. | |||
| IWSLTOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::unique_ptr<DataSchema>, IWSLTType type, | |||
| const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &language_pair, | |||
| const std::string &valid_set, const std::string &test_set); | |||
| /// \brief Destructor. | |||
| ~IWSLTOp() = default; | |||
| /// \brief A print method typically used for debugging. | |||
| /// \param[out] out The output stream to write output to. | |||
| /// \param[in] show_all A bool to control if you want to show all info or just a summary. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// \brief Instantiates the internal queues and connectors. | |||
| /// \return Status The error code returned. | |||
| Status Init() override; | |||
| /// \brief Function to count the number of samples in the IWSLT dataset. | |||
| /// \param[in] type IWSLT data set version, which can be kIWSLT2016 and kIWSLT2017. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of IWSLT2017, can be "train", "valid", "test" or "all". | |||
| /// \param[in] language_pair List containing src and tgt language. | |||
| /// \param[in] valid_set A string to identify validation set. | |||
| /// \param[in] test_set A string to identify test set. | |||
| /// \param[out] count The total number of rows in file. | |||
| /// \return Status The status code returned. | |||
| static Status CountTotalRows(IWSLTType type, const std::string &dataset_dir, const std::string &usage, | |||
| const std::vector<std::string> &language_pair, const std::string &valid_set, | |||
| const std::string &test_set, int64_t *count); | |||
| /// \brief Op name getter. | |||
| /// \return std::string Name of the current Op. | |||
| std::string Name() const override { return "IWSLTOp"; } | |||
| /// \brief DatasetName name getter. | |||
| /// \param[in] upper If true, the return value is uppercase, otherwise, it is lowercase. | |||
| /// \return std::string DatasetName of the current Op. | |||
| std::string DatasetName(bool upper = false) const { return upper ? "IWSLT" : "iwslt"; } | |||
| // \brief File names getter. | |||
| // \return std::vector<std::string> Vector of the input file names. | |||
| std::vector<std::string> FileNames() { return src_target_file_list_; } | |||
| private: | |||
| /// \brief Split string based on a character delimiter. | |||
| /// \param[in] s The input string. | |||
| /// \param[in] delim The delimiter. | |||
| /// \return std::vector<std::string> The result after segmentation. | |||
| std::vector<std::string> Split(const std::string &s, const std::string &delim); | |||
| /// \brief Remove the characters specified at the beginning and end. | |||
| /// \param[in] text The input string. | |||
| /// \param[in] character The removed character. | |||
| /// \return Status The status code returned. | |||
| Status Trim(std::string *text, const std::string &character); | |||
| /// \brief Function to count the number of samples in one data file. | |||
| /// \param[in] file Path to the data file. | |||
| /// \return int64_t The total number of rows in file. | |||
| int64_t CountFileRows(const std::string &file); | |||
| /// \brief Parses a single row and puts the data into a tensor table. | |||
| /// \param[in] line The content of the row. | |||
| /// \param[out] out_row Output tensor. | |||
| /// \param[in] index The id of the row filled in the tensor table. | |||
| /// \return Status The status code returned. | |||
| Status LoadTensor(const std::string &line, TensorRow *out_row, size_t index); | |||
| /// \brief Reads a IWSLT file and loads the data into multiple TensorRows. | |||
| /// \param[in] file The file to read. | |||
| /// \param[in] start_offset The start offset of file. | |||
| /// \param[in] end_offset The end offset of file. | |||
| /// \param[in] worker_id The id of the worker that is executing this function. | |||
| /// \return Status The status code returned. | |||
| Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; | |||
| /// \brief Fill the IOBlockQueue. | |||
| /// \param[in] i_keys Keys of file to fill to the IOBlockQueue. | |||
| /// \return Status The status code returned. | |||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override; | |||
| /// \brief Calculate number of rows in each shard. | |||
| /// \return Status The status code returned. | |||
| Status CalculateNumRowsPerShard() override; | |||
| /// \brief Private function for computing the assignment of the column name map. | |||
| /// \return Status The status code returned. | |||
| Status ComputeColMap() override; | |||
| /// \brief Write the data of the source file and the target file to a new file after cleaning. | |||
| /// \param[in] src_file_path Source file path. | |||
| /// \param[in] target_file_path Target file path. | |||
| /// \param[in] new_file_path Write data to new file path. | |||
| /// \return Status The status code returned. | |||
| Status CleanXmlFile(const std::string &src_file_path, const std::string &target_file_path, | |||
| const std::string &new_file_path); | |||
| /// \brief Determine whether the centent contains the specified label. | |||
| /// \param[in] content This content to be determined. | |||
| /// \return bool If it contains, return true, otherwise, return false. | |||
| bool IsContainTags(const std::string &content); | |||
| /// \brief Write the data of the source file and the target file to a new file after cleaning. | |||
| /// \param[in] src_file_path Source file path. | |||
| /// \param[in] target_file_path Target file path. | |||
| /// \param[in] new_file_path Write data to new file path. | |||
| /// \return Status The status code returned. | |||
| Status CleanTagFile(const std::string &file_path, const std::string &target_file_path, | |||
| const std::string &new_file_path); | |||
| // \brief Get all files in the dataset_dir_. | |||
| // \return Status The status code returned. | |||
| Status GetFiles(); | |||
| /// \brief Generate IWSLT2016 training data set file list. | |||
| /// \param[in] dir The directory where the files are stored. | |||
| /// \param[in] src_language The source language type. | |||
| /// \param[in] target_language The target language type. | |||
| /// \param[in] suffix The file suffix. | |||
| /// \return std::string The file path. | |||
| std::string GenerateIWSLT2016TagsFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &suffix); | |||
| /// \brief Generate IWSLT2016 valid data set or test data set file list. | |||
| /// \param[in] dir The directory where the files are stored. | |||
| /// \param[in] src_language The source language type. | |||
| /// \param[in] target_language The target language type. | |||
| /// \param[in] set_type The type of data set read. | |||
| /// \param[in] suffix The file suffix. | |||
| /// \return std::string The file path. | |||
| std::string GenerateIWSLT2016XMLFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &set_type, | |||
| const std::string &suffix); | |||
| /// \brief Generate IWSLT2017 training data set file list. | |||
| /// \param[in] dir The directory where the files are stored. | |||
| /// \param[in] src_language The source language type. | |||
| /// \param[in] target_language The target language type. | |||
| /// \param[in] suffix The file suffix. | |||
| /// \return std::string The file path. | |||
| std::string GenerateIWSLT2017TagsFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &suffix); | |||
| /// \brief Generate IWSLT2016 valid data set or test data set file list. | |||
| /// \param[in] dir The directory where the files are stored. | |||
| /// \param[in] src_language The source language type. | |||
| /// \param[in] target_language The target language type. | |||
| /// \param[in] set_type The type of data set read. | |||
| /// \param[in] suffix The file suffix. | |||
| /// \return std::string The file path. | |||
| std::string GenerateIWSLT2017XMLFileName(Path dir, const std::string &src_language, | |||
| const std::string &target_language, const std::string &set_type, | |||
| const std::string &suffix); | |||
| /// \brief Generate new file path and write data. | |||
| /// \param[in] src_path_list The source file path. | |||
| /// \param[in] target_path_list The target file path. | |||
| /// \param[out] src_target_file_list The newly generated file path list. | |||
| /// \return Status The status code returned. | |||
| Status GenerateNewFile(const std::vector<std::string> &src_file_list, | |||
| const std::vector<std::string> &target_file_list, | |||
| std::vector<std::string> *src_target_file_list); | |||
| IWSLTType iwslt_type_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::vector<std::string> src_target_file_list_; | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| std::vector<std::string> language_pair_; | |||
| std::string valid_set_; | |||
| std::string test_set_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_IWSLT_OP_H_ | |||
| @@ -180,6 +180,30 @@ Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::st | |||
| return Status::OK(); | |||
| } | |||
| Status ValidateMapKey(const std::string &dataset_name, const std::string &key, | |||
| const std::map<std::string, std::vector<std::string>> &map) { | |||
| if (map.find(key) == map.end()) { | |||
| std::string init; | |||
| std::string mode = std::accumulate(map.begin(), map.end(), init, | |||
| [](std::string a, auto b) { return std::move(a) + " " + std::move(b.first); }); | |||
| std::string err_msg = dataset_name + ": " + key + " does not match any key in [" + mode + " ]"; | |||
| LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ValidateMapValue(const std::string &dataset_name, const std::string &str, | |||
| const std::vector<std::string> &valid_strings) { | |||
| if (find(valid_strings.begin(), valid_strings.end(), str) == valid_strings.end()) { | |||
| std::string init; | |||
| std::string mode = std::accumulate(valid_strings.begin(), valid_strings.end(), init, | |||
| [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); | |||
| std::string err_msg = dataset_name + ": " + str + " does not match any string in [" + mode + " ]"; | |||
| LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { | |||
| if (shuffle) { | |||
| if (num_shards > 1) { | |||
| @@ -93,6 +93,8 @@ constexpr char kFashionMnistNode[] = "FashionMnistDataset"; | |||
| constexpr char kFlickrNode[] = "FlickrDataset"; | |||
| constexpr char kGeneratorNode[] = "GeneratorDataset"; | |||
| constexpr char kImageFolderNode[] = "ImageFolderDataset"; | |||
| constexpr char kIWSLT2016Node[] = "IWSLT2016Dataset"; | |||
| constexpr char kIWSLT2017Node[] = "IWSLT2017Dataset"; | |||
| constexpr char kKMnistNode[] = "KMnistDataset"; | |||
| constexpr char kLJSpeechNode[] = "LJSpeechDataset"; | |||
| constexpr char kManifestNode[] = "ManifestDataset"; | |||
| @@ -137,6 +139,12 @@ Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::st | |||
| // Helper function to validate dataset directory parameter | |||
| Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir); | |||
| Status ValidateMapKey(const std::string &dataset_name, const std::string &key, | |||
| const std::map<std::string, std::vector<std::string>> &map); | |||
| Status ValidateMapValue(const std::string &dataset_name, const std::string &str, | |||
| const std::vector<std::string> &valid_strings); | |||
| /// \brief Function to create a sampler for non-mappable dataset (to be used by cache op later). | |||
| /// \notes Non-mappable dataset does not directly support a sampler. It has provided sampling arguments (shuffle, | |||
| /// num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in the pipeline contains | |||
| @@ -19,6 +19,8 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | |||
| fashion_mnist_node.cc | |||
| flickr_node.cc | |||
| image_folder_node.cc | |||
| iwslt2016_node.cc | |||
| iwslt2017_node.cc | |||
| kmnist_node.cc | |||
| lj_speech_node.cc | |||
| manifest_node.cc | |||
| @@ -0,0 +1,192 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h" | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "debug/common.h" | |||
| #include "minddata/dataset/engine/datasetops/source/iwslt_op.h" | |||
| #include "utils/file_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for IWSLT2016Node. | |||
| IWSLT2016Node::IWSLT2016Node(const std::string &dataset_dir, const std::string &usage, | |||
| const std::vector<std::string> &language_pair, const std::string &valid_set, | |||
| const std::string &test_set, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||
| int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| usage_(usage), | |||
| language_pair_(language_pair), | |||
| valid_set_(valid_set), | |||
| test_set_(test_set), | |||
| num_samples_(num_samples), | |||
| shuffle_(shuffle), | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. | |||
| // User discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work | |||
| // if the num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to | |||
| // return num_shards. Once PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| support_language_pair_map_["en"] = {"ar", "de", "fr", "cs"}; | |||
| support_language_pair_map_["ar"] = {"en"}; | |||
| support_language_pair_map_["fr"] = {"en"}; | |||
| support_language_pair_map_["de"] = {"en"}; | |||
| support_language_pair_map_["cs"] = {"en"}; | |||
| } | |||
| std::shared_ptr<DatasetNode> IWSLT2016Node::Copy() { | |||
| auto node = std::make_shared<IWSLT2016Node>(dataset_dir_, usage_, language_pair_, valid_set_, test_set_, num_samples_, | |||
| shuffle_, num_shards_, shard_id_, cache_); | |||
| return node; | |||
| } | |||
| void IWSLT2016Node::Print(std::ostream &out) const { | |||
| out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + | |||
| ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); | |||
| } | |||
| Status IWSLT2016Node::ValidateParams() { | |||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("IWSLT2016Node", dataset_dir_)); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("IWSLT2016Node", usage_, {"train", "valid", "test", "all"})); | |||
| const int kLanguagePairSize = 2; | |||
| if (language_pair_.size() != kLanguagePairSize) { | |||
| std::string err_msg = | |||
| "IWSLT2016Node: language_pair expecting size 2, but got: " + std::to_string(language_pair_.size()); | |||
| LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateMapKey("IWSLT2016Node", language_pair_[0], support_language_pair_map_)); | |||
| RETURN_IF_NOT_OK(ValidateMapValue("IWSLT2016Node", language_pair_[1], support_language_pair_map_[language_pair_[0]])); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("IWSLT2016Node", valid_set_, | |||
| {"dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"})); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("IWSLT2016Node", test_set_, | |||
| {"dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"})); | |||
| if (num_samples_ < 0) { | |||
| std::string err_msg = "IWSLT2016Node: Invalid number of samples: " + std::to_string(num_samples_); | |||
| LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateDatasetShardParams("IWSLT2016Node", num_shards_, shard_id_)); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2016Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // Do internal Schema generation. | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("translation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1))); | |||
| std::shared_ptr<IWSLTOp> iwslt_op = std::make_shared<IWSLTOp>( | |||
| num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, | |||
| std::move(schema), IWSLTOp::IWSLTType::kIWSLT2016, dataset_dir_, usage_, language_pair_, valid_set_, test_set_); | |||
| RETURN_IF_NOT_OK(iwslt_op->Init()); | |||
| // If a global shuffle is used for IWSLT, it will inject a shuffle op over the IWSLT. | |||
| // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be | |||
| // built. This is achieved in the cache transform pass where we call MakeSimpleProducer to reset IWSLT's | |||
| // shuffle option to false. | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp. | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset. | |||
| RETURN_IF_NOT_OK(IWSLTOp::CountTotalRows(IWSLTOp::IWSLTType::kIWSLT2016, dataset_dir_, usage_, language_pair_, | |||
| valid_set_, test_set_, &num_rows)); | |||
| // Add the shuffle op after this op. | |||
| RETURN_IF_NOT_OK( | |||
| AddShuffleOp(iwslt_op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); | |||
| shuffle_op->SetTotalRepeats(GetTotalRepeats()); | |||
| shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| iwslt_op->SetTotalRepeats(GetTotalRepeats()); | |||
| iwslt_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(iwslt_op); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2016Node::GetShardId(int32_t *shard_id) { | |||
| *shard_id = shard_id_; | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2016Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(IWSLTOp::CountTotalRows(IWSLTOp::IWSLTType::kIWSLT2016, dataset_dir_, usage_, language_pair_, | |||
| valid_set_, test_set_, &num_rows)); | |||
| sample_size = num_samples_; | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_))); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2016Node::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["usage"] = usage_; | |||
| args["language_pair"] = language_pair_; | |||
| args["valid_set"] = valid_set_; | |||
| args["test_set"] = test_set_; | |||
| args["num_samples"] = num_samples_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2016Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2016Node::MakeSimpleProducer() { | |||
| shard_id_ = 0; | |||
| num_shards_ = 1; | |||
| shuffle_ = ShuffleMode::kFalse; | |||
| num_samples_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,137 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IWSLT2016_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IWSLT2016_NODE_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class IWSLT2016Node. | |||
| /// \brief A Node derived class to represent IWSLT2016Node. | |||
| class IWSLT2016Node : public NonMappableSourceNode { | |||
| public: | |||
| /// \brief Constructor of IWSLT2016Node. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of IWSLT2016, can be "train", "test", "valid" or "all" data. | |||
| /// \param[in] language_pair List containing src and tgt language. | |||
| /// \param[in] valid_set A string to identify validation set. | |||
| /// \param[in] test_set A string to identify test set. | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// \param[in] shuffle The mode for shuffling data every epoch. | |||
| /// Can be any of: | |||
| /// ShuffleMode::kFalse - No shuffling is performed. | |||
| /// ShuffleMode::kFiles - Shuffle files only. | |||
| /// ShuffleMode::kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified. | |||
| /// \param[in] cache Tensor cache to use. | |||
| IWSLT2016Node(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &language_pair, | |||
| const std::string &valid_set, const std::string &test_set, int64_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor. | |||
| ~IWSLT2016Node() = default; | |||
| /// \brief Node name getter. | |||
| /// \return std::string Name of the current node. | |||
| std::string Name() const override { return kIWSLT2016Node; } | |||
| /// \brief Print the description. | |||
| /// \param[out] out The output stream to write output to. | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object. | |||
| /// \return std::shared_ptr<DatasetNode> A shared pointer to the new copy. | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief A base class override function to create the required runtime dataset op objects for this class. | |||
| /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. | |||
| /// \return Status Status::OK() if build successfully. | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; | |||
| /// \brief Parameters validation. | |||
| /// \return Status Status::OK() if all the parameters are valid. | |||
| Status ValidateParams() override; | |||
| /// \brief Get the shard id of node. | |||
| /// \param[in] shard_id The shard id. | |||
| /// \return Status Status::OK() if get shard id successfully. | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize. | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter. | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size The size of the dataset. | |||
| /// \return Status The status code returned. | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions. | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| const std::vector<std::string> &LanguagePair() const { return language_pair_; } | |||
| const std::string &ValidSet() const { return valid_set_; } | |||
| const std::string &TestSet() const { return test_set_; } | |||
| int64_t NumSamples() const { return num_samples_; } | |||
| ShuffleMode Shuffle() const { return shuffle_; } | |||
| int32_t NumShards() const { return num_shards_; } | |||
| int32_t ShardId() const { return shard_id_; } | |||
| /// \brief Get the arguments of node. | |||
| /// \param[out] out_json JSON string of all attributes. | |||
| /// \return Status The status code returned. | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief IWSLT2016 by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \param[in] sampler The sampler to setup. | |||
| /// \return Status The status code returned. | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so | |||
| /// that this clue node will produce the full set of data into the cache. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \return Status The status code returned. | |||
| Status MakeSimpleProducer() override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| std::vector<std::string> language_pair_; | |||
| std::string valid_set_; | |||
| std::string test_set_; | |||
| int64_t num_samples_; | |||
| ShuffleMode shuffle_; | |||
| int32_t num_shards_; | |||
| int32_t shard_id_; | |||
| std::map<std::string, std::vector<std::string>> support_language_pair_map_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IWSLT2016_NODE_H_ | |||
| @@ -0,0 +1,183 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h" | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "debug/common.h" | |||
| #include "minddata/dataset/engine/datasetops/source/iwslt_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for IWSLT2017Node. | |||
| IWSLT2017Node::IWSLT2017Node(const std::string &dataset_dir, const std::string &usage, | |||
| const std::vector<std::string> &language_pair, int64_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| usage_(usage), | |||
| language_pair_(std::move(language_pair)), | |||
| valid_set_("dev2010"), | |||
| test_set_("tst2010"), | |||
| num_samples_(num_samples), | |||
| shuffle_(shuffle), | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. | |||
| // User discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work | |||
| // if the num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to | |||
| // return num_shards. Once PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| support_language_pair_map_["en"] = {"nl", "de", "it", "ro"}; | |||
| support_language_pair_map_["ro"] = {"de", "en", "nl", "it"}; | |||
| support_language_pair_map_["de"] = {"ro", "en", "nl", "it"}; | |||
| support_language_pair_map_["it"] = {"en", "nl", "de", "ro"}; | |||
| support_language_pair_map_["nl"] = {"de", "en", "it", "ro"}; | |||
| } | |||
| std::shared_ptr<DatasetNode> IWSLT2017Node::Copy() { | |||
| auto node = std::make_shared<IWSLT2017Node>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_, | |||
| shard_id_, cache_); | |||
| return node; | |||
| } | |||
| void IWSLT2017Node::Print(std::ostream &out) const { | |||
| out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + | |||
| ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); | |||
| } | |||
| Status IWSLT2017Node::ValidateParams() { | |||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("IWSLT2017Node", dataset_dir_)); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("IWSLT2017Node", usage_, {"train", "valid", "test", "all"})); | |||
| const int kLanguagePairSize = 2; | |||
| if (language_pair_.size() != kLanguagePairSize) { | |||
| std::string err_msg = | |||
| "IWSLT2017Node: language_pair expecting size 2, but got: " + std::to_string(language_pair_.size()); | |||
| LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateMapKey("IWSLT2017Node", language_pair_[0], support_language_pair_map_)); | |||
| RETURN_IF_NOT_OK(ValidateMapValue("IWSLT2017Node", language_pair_[1], support_language_pair_map_[language_pair_[0]])); | |||
| if (num_samples_ < 0) { | |||
| std::string err_msg = "IWSLT2017Node: Invalid number of samples: " + std::to_string(num_samples_); | |||
| LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| RETURN_IF_NOT_OK(ValidateDatasetShardParams("IWSLT2017Node", num_shards_, shard_id_)); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2017Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // Do internal Schema generation. | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("translation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1))); | |||
| std::shared_ptr<IWSLTOp> iwslt_op = std::make_shared<IWSLTOp>( | |||
| num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, | |||
| std::move(schema), IWSLTOp::IWSLTType::kIWSLT2017, dataset_dir_, usage_, language_pair_, valid_set_, test_set_); | |||
| RETURN_IF_NOT_OK(iwslt_op->Init()); | |||
| // If a global shuffle is used for IWSLT, it will inject a shuffle op over the IWSLT. | |||
| // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be | |||
| // built.This is achieved in the cache transform pass where we call MakeSimpleProducer to reset IWSLT's | |||
| // shuffle option to false. | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp. | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset. | |||
| RETURN_IF_NOT_OK(IWSLTOp::CountTotalRows(IWSLTOp::IWSLTType::kIWSLT2017, dataset_dir_, usage_, language_pair_, | |||
| valid_set_, test_set_, &num_rows)); | |||
| // Add the shuffle op after this op. | |||
| RETURN_IF_NOT_OK( | |||
| AddShuffleOp(iwslt_op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); | |||
| shuffle_op->SetTotalRepeats(GetTotalRepeats()); | |||
| shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| iwslt_op->SetTotalRepeats(GetTotalRepeats()); | |||
| iwslt_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(iwslt_op); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2017Node::GetShardId(int32_t *shard_id) { | |||
| *shard_id = shard_id_; | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2017Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(IWSLTOp::CountTotalRows(IWSLTOp::IWSLTType::kIWSLT2017, dataset_dir_, usage_, language_pair_, | |||
| valid_set_, test_set_, &num_rows)); | |||
| sample_size = num_samples_; | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_))); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2017Node::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["usage"] = usage_; | |||
| args["language_pair"] = language_pair_; | |||
| args["num_samples"] = num_samples_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2017Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| return Status::OK(); | |||
| } | |||
| Status IWSLT2017Node::MakeSimpleProducer() { | |||
| shard_id_ = 0; | |||
| num_shards_ = 1; | |||
| shuffle_ = ShuffleMode::kFalse; | |||
| num_samples_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IWSLT2017_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IWSLT2017_NODE_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class IWSLT2017Node. | |||
| /// \brief A Node derived class to represent IWSLT2017Node. | |||
| class IWSLT2017Node : public NonMappableSourceNode { | |||
| public: | |||
| /// \brief Constructor of IWSLT2017Node. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of IWSLT2017, can be "train", "test", "valid" or "all" data. | |||
| /// \param[in] language_pair List containing src and tgt language. | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// \param[in] shuffle The mode for shuffling data every epoch. | |||
| /// Can be any of: | |||
| /// ShuffleMode::kFalse - No shuffling is performed. | |||
| /// ShuffleMode::kFiles - Shuffle files only. | |||
| /// ShuffleMode::kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified. | |||
| /// \param[in] cache Tensor cache to use. | |||
| IWSLT2017Node(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &language_pair, | |||
| int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||
| std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor. | |||
| ~IWSLT2017Node() = default; | |||
| /// \brief Node name getter. | |||
| /// \return std::string Name of the current node. | |||
| std::string Name() const override { return kIWSLT2017Node; } | |||
| /// \brief Print the description. | |||
| /// \param[out] out The output stream to write output to. | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object. | |||
| /// \return std::shared_ptr<DatasetNode> A shared pointer to the new copy. | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief A base class override function to create the required runtime dataset op objects for this class. | |||
| /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. | |||
| /// \return Status Status::OK() if build successfully. | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; | |||
| /// \brief Parameters validation. | |||
| /// \return Status Status::OK() if all the parameters are valid. | |||
| Status ValidateParams() override; | |||
| /// \brief Get the shard id of node. | |||
| /// \param[in] shard_id The shard id. | |||
| /// \return Status Status::OK() if get shard id successfully. | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize. | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter. | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset. | |||
| /// \return Status The status code returned. | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions. | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| const std::vector<std::string> &LanguagePair() const { return language_pair_; } | |||
| int64_t NumSamples() const { return num_samples_; } | |||
| ShuffleMode Shuffle() const { return shuffle_; } | |||
| int32_t NumShards() const { return num_shards_; } | |||
| int32_t ShardId() const { return shard_id_; } | |||
| /// \brief Get the arguments of node. | |||
| /// \param[out] out_json JSON string of all attributes. | |||
| /// \return Status The status code returned. | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief IWSLT by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \param[in] sampler The sampler to setup. | |||
| /// \return Status The status code returned. | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so | |||
| /// that this clue node will produce the full set of data into the cache. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \return Status The status code returned. | |||
| Status MakeSimpleProducer() override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| std::vector<std::string> language_pair_; | |||
| std::string valid_set_; | |||
| std::string test_set_; | |||
| int64_t num_samples_; | |||
| ShuffleMode shuffle_; | |||
| int32_t num_shards_; | |||
| int32_t shard_id_; | |||
| std::map<std::string, std::vector<std::string>> support_language_pair_map_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IWSLT2017_NODE_H_ | |||
| @@ -2510,6 +2510,122 @@ inline std::shared_ptr<ImageFolderDataset> MS_API ImageFolder(const std::string | |||
| MapStringToChar(class_indexing), cache); | |||
| } | |||
| /// \class IWSLT2016Dataset. | |||
| /// \brief A source dataset for reading and parsing IWSLT2016 dataset. | |||
| class MS_API IWSLT2016Dataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor of IWSLT2016Dataset. | |||
| /// \note The generated dataset has two columns ["text", "translation"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of IWSLT2016, can be "train", "valid", "test" or "all". | |||
| /// \param[in] language_pair List containing src and tgt language. | |||
| /// \param[in] valid_set A string to identify validation set. | |||
| /// \param[in] test_set A string to identify test set. | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// \param[in] shuffle The mode for shuffling data every epoch. | |||
| /// Can be any of: | |||
| /// ShuffleMode::kFalse - No shuffling is performed. | |||
| /// ShuffleMode::kFiles - Shuffle files only. | |||
| /// ShuffleMode::kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified. | |||
| /// \param[in] cache Tensor cache to use. | |||
| IWSLT2016Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::vector<std::vector<char>> &language_pair, const std::vector<char> &valid_set, | |||
| const std::vector<char> &test_set, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||
| int32_t shard_id, const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Destructor of IWSLT2016Dataset. | |||
| ~IWSLT2016Dataset() = default; | |||
| }; | |||
| /// \brief Function to create a IWSLT2016Dataset. | |||
| /// \note The generated dataset has two columns ["text", "translation"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of IWSLT2016, can be "train", "valid", "test" or "all" (default = "all"). | |||
| /// \param[in] language_pair List containing src and tgt language (Default = {"de", "en"}). | |||
| /// \param[in] valid_set A string to identify validation set (Default = "tst2013"). | |||
| /// \param[in] test_set A string to identify test set (Default = "tst2014"). | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// (Default = 0, means all samples). | |||
| /// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode::kGlobal). | |||
| /// Can be any of: | |||
| /// ShuffleMode::kFalse - No shuffling is performed. | |||
| /// ShuffleMode::kFiles - Shuffle files only. | |||
| /// ShuffleMode::kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1). | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified (Default = 0). | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the IWSLT2016Dataset. | |||
| inline std::shared_ptr<IWSLT2016Dataset> MS_API | |||
| IWSLT2016(const std::string &dataset_dir, const std::string &usage = "all", | |||
| const std::vector<std::string> &language_pair = {"de", "en"}, const std::string &valid_set = "tst2013", | |||
| const std::string &test_set = "tst2014", int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, | |||
| int32_t num_shards = 1, int32_t shard_id = 0, const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<IWSLT2016Dataset>(StringToChar(dataset_dir), StringToChar(usage), | |||
| VectorStringToChar(language_pair), StringToChar(valid_set), | |||
| StringToChar(test_set), num_samples, shuffle, num_shards, shard_id, cache); | |||
| } | |||
| /// \class IWSLT2017Dataset. | |||
| /// \brief A source dataset for reading and parsing IWSLT2017 dataset. | |||
| class MS_API IWSLT2017Dataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor of IWSLT2017Dataset. | |||
| /// \note The generated dataset has two columns ["text", "translation"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of IWSLT2017, can be "train", "valid", "test" or "all". | |||
| /// \param[in] language_pair List containing src and tgt language. | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// \param[in] shuffle The mode for shuffling data every epoch. | |||
| /// Can be any of: | |||
| /// ShuffleMode::kFalse - No shuffling is performed. | |||
| /// ShuffleMode::kFiles - Shuffle files only. | |||
| /// ShuffleMode::kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified. | |||
| /// \param[in] cache Tensor cache to use. | |||
| /// \return Shared pointer to the IWSLT2017Dataset. | |||
| IWSLT2017Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::vector<std::vector<char>> &language_pair, int64_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Destructor of IWSLT2017Dataset. | |||
| ~IWSLT2017Dataset() = default; | |||
| }; | |||
| /// \brief Function to create a IWSLT2017Dataset. | |||
| /// \note The generated dataset has two columns ["text", "translation"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of IWSLT2017, can be "train", "valid", "test" or "all" (default = "all"). | |||
| /// \param[in] language_pair List containing src and tgt language (Default = {"de", "en"}). | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// (Default = 0, means all samples). | |||
| /// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode::kGlobal). | |||
| /// Can be any of: | |||
| /// ShuffleMode::kFalse - No shuffling is performed. | |||
| /// ShuffleMode::kFiles - Shuffle files only. | |||
| /// ShuffleMode::kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1). | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified (Default = 0). | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the IWSLT2017Dataset. | |||
| inline std::shared_ptr<IWSLT2017Dataset> MS_API IWSLT2017(const std::string &dataset_dir, | |||
| const std::string &usage = "all", | |||
| const std::vector<std::string> &language_pair = {"de", "en"}, | |||
| int64_t num_samples = 0, | |||
| ShuffleMode shuffle = ShuffleMode::kGlobal, | |||
| int32_t num_shards = 1, int32_t shard_id = 0, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<IWSLT2017Dataset>(StringToChar(dataset_dir), StringToChar(usage), | |||
| VectorStringToChar(language_pair), num_samples, shuffle, num_shards, | |||
| shard_id, cache); | |||
| } | |||
| /// \class KMnistDataset. | |||
| /// \brief A source dataset for reading and parsing KMnist dataset. | |||
| class MS_API KMnistDataset : public Dataset { | |||
| @@ -395,6 +395,28 @@ def check_valid_str(value, valid_strings, arg_name=""): | |||
| raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings))) | |||
| def check_valid_list_tuple(value, valid_list_tuple, data_type, arg_name=""): | |||
| """ | |||
| Validate value in valid_list_tuple. | |||
| Args: | |||
| value (Union[list, tuple]): the value to be validated. | |||
| valid_strings (Union[list, tuple]): name of columns. | |||
| type (tuple): tuple of all valid types for value. | |||
| arg_name (str): the names of value. | |||
| Returns: | |||
| Exception: when the value is not correct, otherwise nothing. | |||
| """ | |||
| valid_length = len(valid_list_tuple[0]) | |||
| type_check(value, (list, tuple), arg_name) | |||
| type_check_list(value, data_type, arg_name) | |||
| if len(value) != valid_length: | |||
| raise ValueError("Input {0} is a list or tuple of length {1}.".format(arg_name, valid_length)) | |||
| if value not in valid_list_tuple: | |||
| raise ValueError( | |||
| "Input {0}{1} is not within the valid set of {2}.".format(arg_name, value, valid_list_tuple)) | |||
| def check_columns(columns, name): | |||
| """ | |||
| Validate strings in column_names. | |||
| @@ -71,7 +71,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \ | |||
| check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \ | |||
| check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset, \ | |||
| check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset | |||
| check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \ | |||
| check_iwslt2017_dataset | |||
| from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ | |||
| get_prefetch_size, get_auto_offload | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| @@ -3712,6 +3713,228 @@ class ImageFolderDataset(MappableDataset): | |||
| return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing) | |||
| class IWSLT2016Dataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses IWSLT2016 datasets. | |||
| The generated dataset has two columns: :py:obj:`[text, translation]`. | |||
| The tensor of column :py:obj: `text` is of the string type. | |||
| The tensor of column :py:obj: `translation` is of the string type. | |||
| Args: | |||
| dataset_dir (str): Path to the root directory that contains the dataset. | |||
| usage (str, optional): Acceptable usages include "train", "valid", "test" and "all" (default=None, all samples). | |||
| language_pair (sequence, optional): Sequence containing source and target language, supported values are | |||
| (`en`, `fr`), ("en", "de"), ("en", "cs"), ("en", "ar"), ("fr", "en"), ("de", "en"), ("cs", "en"), | |||
| ("ar", "en") (default=("de", "en")). | |||
| valid_set (str, optional): A string to identify validation set, when usage is valid or all, the validation set | |||
| of valid_set type will be read, supported values are "dev2010", "tst2010", "tst2011", "tst2012", "tst2013" | |||
| and "tst2014" (default="tst2013"). | |||
| test_set (str, optional): A string to identify test set, when usage is test or all, the test set of test_set | |||
| type will be read, supported values are "dev2010", "tst2010", "tst2011", "tst2012", "tst2013" and "tst2014" | |||
| (default="tst2014"). | |||
| num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset). | |||
| shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch | |||
| (default=Shuffle.GLOBAL). | |||
| If shuffle is False, no shuffling will be performed; | |||
| If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL | |||
| Otherwise, there are two levels of shuffling: | |||
| - Shuffle.GLOBAL: Shuffle both the files and samples. | |||
| - Shuffle.FILES: Shuffle files only. | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| When this argument is specified, `num_samples` reflects the max sample number of per shard. | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| num_parallel_workers (int, optional): Number of workers to read the data | |||
| (default=None, number set in the config). | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If dataset_dir does not contain data files. | |||
| RuntimeError: If num_parallel_workers exceeds the max thread numbers. | |||
| RuntimeError: If num_shards is specified but shard_id is None. | |||
| RuntimeError: If shard_id is specified but num_shards is None. | |||
| Examples: | |||
| >>> iwslt2016_dataset_dir = "/path/to/iwslt2016_dataset_dir" | |||
| >>> dataset = ds.IWSLT2016Dataset(dataset_files=iwslt2016_dataset_dir, usage='all', | |||
| ... language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014') | |||
| About IWSLT2016 dataset: | |||
| IWSLT is an international oral translation conference, a major annual scientific conference dedicated to all aspects | |||
| of oral translation. The MT task of the IWSLT evaluation activity constitutes a data set, which can be publicly | |||
| obtained through the WIT3 website wit3.fbk.eu. The IWSLT2016 data set includes translations from English to Arabic, | |||
| Czech, French, and German, and translations from Arabic, Czech, French, and German to English. | |||
| You can unzip the original IWSLT2016 dataset files into this directory structure and read by MindSpore's API. After | |||
| decompression, you also need to decompress the data set to be read in the specified folder. For example, if you want | |||
| to read the data set of de-en, you need to unzip the tgz file in the de/en directory, the data set is in the | |||
| unzipped folder. | |||
| .. code-block:: | |||
| . | |||
| └── iwslt2016_dataset_directory | |||
| ├── subeval_files | |||
| └── texts | |||
| ├── ar | |||
| │ └── en | |||
| │ └── ar-en | |||
| ├── cs | |||
| │ └── en | |||
| │ └── cs-en | |||
| ├── de | |||
| │ └── en | |||
| │ └── de-en | |||
| │ ├── IWSLT16.TED.dev2010.de-en.de.xml | |||
| │ ├── train.tags.de-en.de | |||
| │ ├── ... | |||
| ├── en | |||
| │ ├── ar | |||
| │ │ └── en-ar | |||
| │ ├── cs | |||
| │ │ └── en-cs | |||
| │ ├── de | |||
| │ │ └── en-de | |||
| │ └── fr | |||
| │ └── en-fr | |||
| └── fr | |||
| └── en | |||
| └── fr-en | |||
| Citation: | |||
| .. code-block:: | |||
| @inproceedings{cettoloEtAl:EAMT2012, | |||
| Address = {Trento, Italy}, | |||
| Author = {Mauro Cettolo and Christian Girardi and Marcello Federico}, | |||
| Booktitle = {Proceedings of the 16$^{th}$ Conference of the European Association for Machine Translation | |||
| (EAMT)}, | |||
| Date = {28-30}, | |||
| Month = {May}, | |||
| Pages = {261--268}, | |||
| Title = {WIT$^3$: Web Inventory of Transcribed and Translated Talks}, | |||
| Year = {2012}} | |||
| """ | |||
| @check_iwslt2016_dataset | |||
| def __init__(self, dataset_dir, usage=None, language_pair=None, valid_set=None, test_set=None, | |||
| num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, | |||
| cache=None): | |||
| super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, | |||
| num_shards=num_shards, shard_id=shard_id, cache=cache) | |||
| self.dataset_dir = dataset_dir | |||
| self.usage = replace_none(usage, 'all') | |||
| self.language_pair = replace_none(language_pair, ["de", "en"]) | |||
| self.valid_set = replace_none(valid_set, 'tst2013') | |||
| self.test_set = replace_none(test_set, 'tst2014') | |||
| def parse(self, children=None): | |||
| return cde.IWSLT2016Node(self.dataset_dir, self.usage, self.language_pair, self.valid_set, self.test_set, | |||
| self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) | |||
| class IWSLT2017Dataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses IWSLT2017 datasets. | |||
| The generated dataset has two columns: :py:obj:`[text, translation]`. | |||
| The tensor of column :py:obj:`text` is of the string type. | |||
| The tensor of column :py:obj:`translation` is of the string type. | |||
| Args: | |||
| dataset_dir (str): Path to the root directory that contains the dataset. | |||
| usage (str, optional): Acceptable usages include "train", "valid", "test" and "all" (default=None, all samples). | |||
| language_pair (list, optional): List containing src and tgt language, supported values are ("en", "nl"), | |||
| ("en", "de"), ("en", "it"), ("en", "ro"), ("nl", "en"), ("nl", "de"), ("nl", "it"), ("nl", "ro"), | |||
| ("de", "en"), ("de", "nl"), ("de", "it"), ("de", "ro"), ("it", "en"), ("it", "nl"), ("it", "de"), | |||
| ("it", "ro"), (`ro`, `en`), (`ro`, `nl`), (`ro`, `de`), (`ro`, `it`) (default=(`de`, `en`)). | |||
| num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset). | |||
| shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch | |||
| (default=Shuffle.GLOBAL). | |||
| If shuffle is False, no shuffling will be performed; | |||
| If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL | |||
| Otherwise, there are two levels of shuffling: | |||
| - Shuffle.GLOBAL: Shuffle both the files and samples. | |||
| - Shuffle.FILES: Shuffle files only. | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| When this argument is specified, `num_samples` reflects the max sample number of per shard. | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| num_parallel_workers (int, optional): Number of workers to read the data | |||
| (default=None, number set in the config). | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If dataset_dir does not contain data files. | |||
| RuntimeError: If num_parallel_workers exceeds the max thread numbers. | |||
| RuntimeError: If num_shards is specified but shard_id is None. | |||
| RuntimeError: If shard_id is specified but num_shards is None. | |||
| Examples: | |||
| >>> iwslt2017_dataset_dir = "/path/to/iwslt207_dataset_dir" | |||
| >>> dataset = ds.IWSLT2017Dataset(dataset_files=iwslt2017_dataset_dir, usage='all', language_pair=('de', 'en')) | |||
| About IWSLT2017 dataset: | |||
| IWSLT is an international oral translation conference, a major annual scientific conference dedicated to all aspects | |||
| of oral translation. The MT task of the IWSLT evaluation activity constitutes a data set, which can be publicly | |||
| obtained through the WIT3 website wit3.fbk.eu. The IWSLT2017 data set involves German, English, Italian, Dutch, and | |||
| Romanian. The data set includes translations in any two different languages. | |||
| You can unzip the original IWSLT2017 dataset files into this directory structure and read by MindSpore's API. You | |||
| need to decompress the dataset package in texts/DeEnItNlRo/DeEnItNlRo directory to get the DeEnItNlRo-DeEnItNlRo | |||
| subdirectory. | |||
| .. code-block:: | |||
| . | |||
| └── iwslt2017_dataset_directory | |||
| └── DeEnItNlRo | |||
| └── DeEnItNlRo | |||
| └── DeEnItNlRo-DeEnItNlRo | |||
| ├── IWSLT17.TED.dev2010.de-en.de.xml | |||
| ├── train.tags.de-en.de | |||
| ├── ... | |||
| Citation: | |||
| .. code-block:: | |||
| @inproceedings{cettoloEtAl:EAMT2012, | |||
| Address = {Trento, Italy}, | |||
| Author = {Mauro Cettolo and Christian Girardi and Marcello Federico}, | |||
| Booktitle = {Proceedings of the 16$^{th}$ Conference of the European Association for Machine Translation | |||
| (EAMT)}, | |||
| Date = {28-30}, | |||
| Month = {May}, | |||
| Pages = {261--268}, | |||
| Title = {WIT$^3$: Web Inventory of Transcribed and Translated Talks}, | |||
| Year = {2012}} | |||
| """ | |||
| @check_iwslt2017_dataset | |||
| def __init__(self, dataset_dir, usage=None, language_pair=None, num_samples=None, shuffle=Shuffle.GLOBAL, | |||
| num_shards=None, shard_id=None, num_parallel_workers=None, cache=None): | |||
| super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, | |||
| num_shards=num_shards, shard_id=shard_id, cache=cache) | |||
| self.dataset_dir = dataset_dir | |||
| self.usage = replace_none(usage, 'all') | |||
| self.language_pair = replace_none(language_pair, ["de", "en"]) | |||
| def parse(self, children=None): | |||
| return cde.IWSLT2017Node(self.dataset_dir, self.usage, self.language_pair, self.num_samples, | |||
| self.shuffle_flag, self.num_shards, self.shard_id) | |||
| class KMnistDataset(MappableDataset): | |||
| """ | |||
| A source dataset for reading and parsing the KMNIST dataset. | |||
| @@ -26,7 +26,8 @@ from mindspore._c_expression import typing | |||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | |||
| INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | |||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \ | |||
| check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id | |||
| check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id, \ | |||
| check_valid_list_tuple | |||
| from . import datasets | |||
| from . import samplers | |||
| @@ -62,6 +63,111 @@ def check_imagefolderdataset(method): | |||
| return new_method | |||
| def check_iwslt2016_dataset(method): | |||
| """A wrapper that wraps a parameter checker around the original Dataset(IWSLT2016dataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| dataset_dir = param_dict.get('dataset_dir') | |||
| check_dir(dataset_dir) | |||
| # check usage | |||
| usage = param_dict.get('usage') | |||
| if usage is not None: | |||
| check_valid_str(usage, ["train", "test", "valid", "all"], "usage") | |||
| support_language_pair = [ | |||
| ['en', 'ar'], ['en', 'ar'], ['en', 'de'], ['en', 'fr'], ['en', 'cs'], ['ar', 'en'], ['fr', 'en'], | |||
| ['de', 'en'], ['cs', 'en'] | |||
| ] | |||
| support_language_pair_tuple = ( | |||
| ('en', 'ar'), ('en', 'ar'), ('en', 'de'), ('en', 'fr'), ('en', 'cs'), ('ar', 'en'), ('fr', 'en'), | |||
| ('de', 'en'), ('cs', 'en') | |||
| ) | |||
| support_set_type = ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"] | |||
| # check language_pair | |||
| language_pair = param_dict.get('language_pair') | |||
| if language_pair is not None: | |||
| if isinstance(language_pair, (list,)): | |||
| check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair") | |||
| elif isinstance(language_pair, (tuple,)): | |||
| check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair") | |||
| else: | |||
| raise TypeError("language_pair should be a type list or tuple of length 2.") | |||
| # check valid_set | |||
| valid_set = param_dict.get('valid_set') | |||
| if valid_set is not None: | |||
| check_valid_str(valid_set, support_set_type, "valid_set") | |||
| # check test_set | |||
| test_set = param_dict.get('test_set') | |||
| if test_set is not None: | |||
| check_valid_str(test_set, support_set_type, "test_set") | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_iwslt2017_dataset(method): | |||
| """A wrapper that wraps a parameter checker around the original Dataset(IWSLT2017dataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| dataset_dir = param_dict.get('dataset_dir') | |||
| check_dir(dataset_dir) | |||
| # check usage | |||
| usage = param_dict.get('usage') | |||
| if usage is not None: | |||
| check_valid_str(usage, ["train", "test", "valid", "all"], "usage") | |||
| support_language_pair = [ | |||
| ['en', 'nl'], ['en', 'de'], ['en', 'it'], ['en', 'ro'], ['ro', 'de'], ['ro', 'en'], ['ro', 'nl'], | |||
| ['ro', 'it'], ['de', 'ro'], ['de', 'en'], ['de', 'nl'], ['de', 'it'], ['it', 'en'], ['it', 'nl'], | |||
| ['it', 'de'], ['it', 'ro'], ['nl', 'de'], ['nl', 'en'], ['nl', 'it'], ['nl', 'ro'] | |||
| ] | |||
| support_language_pair_tuple = ( | |||
| ('en', 'nl'), ('en', 'de'), ('en', 'it'), ('en', 'ro'), ('ro', 'de'), ('ro', 'en'), ('ro', 'nl'), | |||
| ('ro', 'it'), ('de', 'ro'), ('de', 'en'), ('de', 'nl'), ('de', 'it'), ('it', 'en'), ('it', 'nl'), | |||
| ('it', 'de'), ('it', 'ro'), ('nl', 'de'), ('nl', 'en'), ('nl', 'it'), ('nl', 'ro') | |||
| ) | |||
| # check language_pair | |||
| language_pair = param_dict.get('language_pair') | |||
| if language_pair is not None: | |||
| if isinstance(language_pair, (list,)): | |||
| check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair") | |||
| elif isinstance(language_pair, (tuple,)): | |||
| check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair") | |||
| else: | |||
| raise TypeError("language_pair should be a type list or tuple of length 2.") | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_mnist_cifar_dataset(method): | |||
| """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset).""" | |||
| @@ -29,6 +29,7 @@ SET(DE_UT_SRCS | |||
| c_api_dataset_fashion_mnist_test.cc | |||
| c_api_dataset_flickr_test.cc | |||
| c_api_dataset_iterator_test.cc | |||
| c_api_dataset_iwslt_test.cc | |||
| c_api_dataset_kmnist_test.cc | |||
| c_api_dataset_lj_speech_test.cc | |||
| c_api_dataset_manifest_test.cc | |||
| @@ -0,0 +1,985 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/include/dataset/datasets.h" | |||
| using namespace mindspore::dataset; | |||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| protected: | |||
| }; | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: read IWSLT2016Dataset data and get data. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Code schreiben macht Freude.", "Writing code is a joy."}, | |||
| {"Ich hoffe in Zukunft weniger Überstunden machen zu können.", "I hope to work less overtime in the future."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 2 samples. | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: read IWSLT2016Dataset data and get data (usage=valid). | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetUsageValidBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetUsageValidBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "valid", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = {{"heute hat es geregnet.", "it rained today."}, | |||
| {"Leih mir ein Stück Papier.", "Lend me a piece of paper."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 2 samples. | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: read IWSLT2016Dataset data and get data (usage=test). | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetUsageTestBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetUsageTestBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "test", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Ich mag dich.", "I like you."}, {"Ich gebe dir eine Schultasche.", "I will give you a schoolbag."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 2 samples. | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: read IWSLT2016Dataset data and get data (usage=all). | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetUsageAllBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetUsageAllBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "all", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Code schreiben macht Freude.", "Writing code is a joy."}, | |||
| {"heute hat es geregnet.", "it rained today."}, | |||
| {"Ich mag dich.", "I like you."}, | |||
| {"Ich hoffe in Zukunft weniger Überstunden machen zu können.", "I hope to work less overtime in the future."}, | |||
| {"Leih mir ein Stück Papier.", "Lend me a piece of paper."}, | |||
| {"Ich gebe dir eine Schultasche.", "I will give you a schoolbag."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 6 samples. | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: includes tests for shape, type, size. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetGetters) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetGetters."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::vector<DataType> types = ToDETypes(ds->GetOutputTypes()); | |||
| std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes()); | |||
| EXPECT_EQ(types.size(), 2); | |||
| EXPECT_EQ(types[0].ToString(), "string"); | |||
| EXPECT_EQ(types[1].ToString(), "string"); | |||
| EXPECT_EQ(shapes.size(), 2); | |||
| EXPECT_EQ(shapes[0].ToString(), "<>"); | |||
| EXPECT_EQ(shapes[1].ToString(), "<>"); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_EQ(ds->GetColumnNames(), column_names); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 2); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: test whether the interface meets expectations when NumSamples is equal to 2. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetNumSamples) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetNumSamples."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 2, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Code schreiben macht Freude.", "Writing code is a joy."}, | |||
| {"Ich hoffe in Zukunft weniger Überstunden machen zu können.", "I hope to work less overtime in the future."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 2 samples. | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: test interface in a distributed state. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetDistribution) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetDistribution."; | |||
| // Create a IWSLT2016Dataset. | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse, 2); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Code schreiben macht Freude.", "Writing code is a joy."}, | |||
| {"Ich hoffe in Zukunft weniger Überstunden machen zu können.", "I hope to work less overtime in the future."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 1 samples. | |||
| EXPECT_EQ(i, 1); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: test the wrong input. | |||
| /// Expectation: unable to read in data. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetFail."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| // Create a IWSLT2016 Dataset with not exist file. | |||
| std::shared_ptr<Dataset> ds0 = IWSLT2016("invalid_dir", "train", {"de", "en"}, "tst2013", "tst2014"); | |||
| EXPECT_NE(ds0, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter0 = ds0->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter0, nullptr); | |||
| // Create a IWSLT2016 Dataset with invalid usage. | |||
| std::shared_ptr<Dataset> ds1 = IWSLT2016(dataset_dir, "invalid_usage", {"de", "en"}, "tst2013", "tst2014"); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter1 = ds1->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter1, nullptr); | |||
| // Create a IWSLT2016 Dataset with invalid language_pair[0] (src_language). | |||
| std::shared_ptr<Dataset> ds2 = IWSLT2016(dataset_dir, "train", {"invalid", "en"}, "tst2013", "tst2014"); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter2 = ds2->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter2, nullptr); | |||
| // Create a IWSLT2016 Dataset with invalid language_pair[1] (target_language). | |||
| std::shared_ptr<Dataset> ds3 = IWSLT2016(dataset_dir, "train", {"de", "invalid"}, "tst2013", "tst2014"); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter3 = ds3->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input | |||
| EXPECT_EQ(iter3, nullptr); | |||
| // Create a IWSLT2016 Dataset with invalid valid_set. | |||
| std::shared_ptr<Dataset> ds4 = IWSLT2016(dataset_dir, "train", {"de", "en"}, "invalid", "tst2014"); | |||
| EXPECT_NE(ds4, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter4 = ds4->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input | |||
| EXPECT_EQ(iter4, nullptr); | |||
| // Create a IWSLT2016 Dataset with invalid test_set. | |||
| std::shared_ptr<Dataset> ds5 = IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "invalid"); | |||
| EXPECT_NE(ds5, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter5 = ds5->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter5, nullptr); | |||
| // Test invalid num_samples < -1. | |||
| std::shared_ptr<Dataset> ds6 = IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", -1); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter6 = ds6->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter6, nullptr); | |||
| // Test invalid num_shards < 1. | |||
| std::shared_ptr<Dataset> ds7 = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse, 0); | |||
| EXPECT_NE(ds7, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter7 = ds7->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter7, nullptr); | |||
| // Test invalid shard_id >= num_shards. | |||
| std::shared_ptr<Dataset> ds8 = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse, 2, 2); | |||
| EXPECT_NE(ds8, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter8 = ds8->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter8, nullptr); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: test IWSLT2016 Dataset interface in pipeline. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetBasicWithPipeline) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetBasicWithPipeline."; | |||
| // Create two IWSLT2016 Dataset, with single IWSLT2016 file. | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::shared_ptr<Dataset> ds1 = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse); | |||
| std::shared_ptr<Dataset> ds2 = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse); | |||
| EXPECT_NE(ds1, nullptr); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Repeat operation on ds. | |||
| int32_t repeat_num = 2; | |||
| ds1 = ds1->Repeat(repeat_num); | |||
| EXPECT_NE(ds1, nullptr); | |||
| repeat_num = 3; | |||
| ds2 = ds2->Repeat(repeat_num); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Project operation on ds. | |||
| std::vector<std::string> column_project = {"text"}; | |||
| ds1 = ds1->Project(column_project); | |||
| EXPECT_NE(ds1, nullptr); | |||
| ds2 = ds2->Project(column_project); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create a Concat operation on the ds. | |||
| ds1 = ds1->Concat({ds2}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds1->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| // Expect 10 samples. | |||
| EXPECT_EQ(i, 10); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: test IWSLT2016 Dataset interface with different ShuffleMode. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetShuffleFilesA) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetShuffleFilesA."; | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(130); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "all", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFiles); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Ich mag dich.", "I like you."}, | |||
| {"Code schreiben macht Freude.", "Writing code is a joy."}, | |||
| {"heute hat es geregnet.", "it rained today."}, | |||
| {"Ich gebe dir eine Schultasche.", "I will give you a schoolbag."}, | |||
| {"Ich hoffe in Zukunft weniger Überstunden machen zu können.", "I hope to work less overtime in the future."}, | |||
| {"Leih mir ein Stück Papier.", "Lend me a piece of paper."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 6 samples. | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: test IWSLT2016 Dataset interface with different ShuffleMode. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2016DatasetShuffleFilesB) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2016DatasetShuffleFilesB."; | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(130); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "all", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kInfile); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Code schreiben macht Freude.", "Writing code is a joy."}, | |||
| {"heute hat es geregnet.", "it rained today."}, | |||
| {"Ich mag dich.", "I like you."}, | |||
| {"Ich hoffe in Zukunft weniger Überstunden machen zu können.", "I hope to work less overtime in the future."}, | |||
| {"Leih mir ein Stück Papier.", "Lend me a piece of paper."}, | |||
| {"Ich gebe dir eine Schultasche.", "I will give you a schoolbag."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 6 samples. | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: Test IWSLT2016 Dataset. | |||
| /// Description: test IWSLT2016 Dataset interface with different ShuffleMode. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TesIWSLT2016DatasetShuffleFilesGlobal) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TesIWSLT2016DatasetShuffleFilesGlobal."; | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(130); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2016"; | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| std::shared_ptr<Dataset> ds = | |||
| IWSLT2016(dataset_dir, "all", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kGlobal); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Ich mag dich.", "I like you."}, | |||
| {"Code schreiben macht Freude.", "Writing code is a joy."}, | |||
| {"heute hat es geregnet.", "it rained today."}, | |||
| {"Leih mir ein Stück Papier.", "Lend me a piece of paper."}, | |||
| {"Ich gebe dir eine Schultasche.", "I will give you a schoolbag."}, | |||
| {"Ich hoffe in Zukunft weniger Überstunden machen zu können.", "I hope to work less overtime in the future."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 6 samples. | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: Test IWSLT2017 Dataset. | |||
| /// Description: read IWSLT2017Dataset data and get data. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2017DatasetBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2017DatasetBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2017"; | |||
| std::shared_ptr<Dataset> ds = IWSLT2017(dataset_dir, "train", {"de", "en"}, 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Schönes Wetter heute.", "The weather is nice today."}, | |||
| {"Ich bin heute gut gelaunt.", "I am in a good mood today."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 2 samples | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2017 Dataset. | |||
| /// Description: read IWSLT2017Dataset data and get data (usage=valid). | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2017DatasetUsageValidBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2017DatasetUsageValidBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2017"; | |||
| std::shared_ptr<Dataset> ds = IWSLT2017(dataset_dir, "valid", {"de", "en"}, 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Ich kann meinen Code nicht zu Ende schreiben.", "I can't finish writing my code."}, | |||
| {"Vielleicht muss ich Überstunden machen.", "I might have to work overtime."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 2 samples. | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2017 Dataset. | |||
| /// Description: read IWSLT2017Dataset data and get data (usage=test). | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2017DatasetUsageTestBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2017DatasetUsageTestBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2017"; | |||
| std::shared_ptr<Dataset> ds = IWSLT2017(dataset_dir, "test", {"de", "en"}, 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Heute gehe ich ins Labor.", "Today i'm going to the lab."}, | |||
| {"Ich schlafe jetzt wieder ein.", "I am going back to sleep now."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 2 samples. | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2017 Dataset. | |||
| /// Description: read IWSLT2017Dataset data and get data (usage=all). | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2017DatasetUsageAllBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2017DatasetUsageAllBasic."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2017"; | |||
| std::shared_ptr<Dataset> ds = IWSLT2017(dataset_dir, "all", {"de", "en"}, 0, ShuffleMode::kFalse); | |||
| std::vector<std::string> column_names = {"text", "translation"}; | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::vector<std::string>> expected_result = { | |||
| {"Schönes Wetter heute.", "The weather is nice today."}, | |||
| {"Ich kann meinen Code nicht zu Ende schreiben.", "I can't finish writing my code."}, | |||
| {"Heute gehe ich ins Labor.", "Today i'm going to the lab."}, | |||
| {"Ich bin heute gut gelaunt.", "I am in a good mood today."}, | |||
| {"Vielleicht muss ich Überstunden machen.", "I might have to work overtime."}, | |||
| {"Ich schlafe jetzt wieder ein.", "I am going back to sleep now."}}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| for (int j = 0; j < column_names.size(); j++) { | |||
| auto text = row[column_names[j]]; | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); | |||
| } | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| // Expect 6 samples. | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test IWSLT2017 Dataset. | |||
| /// Description: test the wrong input. | |||
| /// Expectation: unable to read in data. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2017DatasetFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2017DatasetFail."; | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2017"; | |||
| // Create a IWSLT2017 Dataset with not exist file. | |||
| std::shared_ptr<Dataset> ds0 = IWSLT2017("invalid_dir", "train", {"de", "en"}); | |||
| EXPECT_NE(ds0, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter0 = ds0->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter0, nullptr); | |||
| // Create a IWSLT2017 Dataset with invalid usage. | |||
| std::shared_ptr<Dataset> ds1 = IWSLT2017(dataset_dir, "invalid_usage", {"de", "en"}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter1 = ds1->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter1, nullptr); | |||
| // Create a IWSLT2017 Dataset with invalid language_pair[0](src_language). | |||
| std::shared_ptr<Dataset> ds2 = IWSLT2017(dataset_dir, "train", {"invalid", "en"}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter2 = ds2->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input | |||
| EXPECT_EQ(iter2, nullptr); | |||
| // Create a IWSLT2017 Dataset with invalid language_pair[1](target_language. | |||
| std::shared_ptr<Dataset> ds3 = IWSLT2016(dataset_dir, "train", {"de", "invalid"}, "tst2013", "tst2014"); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter3 = ds3->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter3, nullptr); | |||
| // Test invalid num_samples < -1. | |||
| std::shared_ptr<Dataset> ds4 = IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", -1); | |||
| EXPECT_NE(ds4, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter4 = ds4->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter4, nullptr); | |||
| // Test invalid num_shards < 1. | |||
| std::shared_ptr<Dataset> ds5 = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse, 0); | |||
| EXPECT_NE(ds5, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter5 = ds5->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter5, nullptr); | |||
| // Test invalid shard_id >= num_shards. | |||
| std::shared_ptr<Dataset> ds6 = | |||
| IWSLT2016(dataset_dir, "train", {"de", "en"}, "tst2013", "tst2014", 0, ShuffleMode::kFalse, 2, 2); | |||
| EXPECT_NE(ds6, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter6 = ds6->CreateIterator(); | |||
| // Expect failure: invalid IWSLT input. | |||
| EXPECT_EQ(iter6, nullptr); | |||
| } | |||
| /// Feature: Test IWSLT2017 Dataset. | |||
| /// Description: test IWSLT2017 Dataset interface in pipeline. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestIWSLT2017DatasetBasicWithPipeline) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIWSLT2017DatasetBasicWithPipeline."; | |||
| // Create two IWSLT2017 Dataset, with single IWSLT2017 file. | |||
| std::string dataset_dir = datasets_root_path_ + "/testIWSLT/IWSLT2017"; | |||
| std::shared_ptr<Dataset> ds1 = IWSLT2017(dataset_dir, "train", {"de", "en"}, 0, ShuffleMode::kFalse); | |||
| std::shared_ptr<Dataset> ds2 = IWSLT2017(dataset_dir, "train", {"de", "en"}, 0, ShuffleMode::kFalse); | |||
| EXPECT_NE(ds1, nullptr); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Repeat operation on ds. | |||
| int32_t repeat_num = 2; | |||
| ds1 = ds1->Repeat(repeat_num); | |||
| EXPECT_NE(ds1, nullptr); | |||
| repeat_num = 3; | |||
| ds2 = ds2->Repeat(repeat_num); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Project operation on ds. | |||
| std::vector<std::string> column_project = {"text"}; | |||
| ds1 = ds1->Project(column_project); | |||
| EXPECT_NE(ds1, nullptr); | |||
| ds2 = ds2->Project(column_project); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create a Concat operation on the ds. | |||
| ds1 = ds1->Concat({ds2}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds1->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| // Expect 10 samples. | |||
| EXPECT_EQ(i, 10); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <srcset setid="iwslt2016-tst2013" srclang="german"> | |||
| <doc docid="test" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description.</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test talkid</talkid> | |||
| <title>test title</title> | |||
| <seg id="1"> heute hat es geregnet. </seg> | |||
| <seg id="2"> Leih mir ein Stück Papier. </seg> | |||
| </doc> | |||
| </srcset> | |||
| </mteval> | |||
| @@ -0,0 +1,14 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <refset setid="iwslt2016-tst2013" srclang="german" trglang="english" refid="ref"> | |||
| <doc docid="test" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description.</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test talkid</talkid> | |||
| <title>test title</title> | |||
| <seg id="1"> it rained today. </seg> | |||
| <seg id="2"> Lend me a piece of paper. </seg> | |||
| </doc> | |||
| </refset> | |||
| </mteval> | |||
| @@ -0,0 +1,14 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <srcset setid="iwslt2016-tst2014" srclang="german"> | |||
| <doc docid="test" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description.</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test talkid</talkid> | |||
| <title>test title</title> | |||
| <seg id="1"> Ich mag dich. </seg> | |||
| <seg id="2"> Ich gebe dir eine Schultasche. </seg> | |||
| </doc> | |||
| </srcset> | |||
| </mteval> | |||
| @@ -0,0 +1,14 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <refset setid="iwslt2016-tst2014" srclang="german" trglang="english" refid="ref"> | |||
| <doc docid="test" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description.</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test talkid</talkid> | |||
| <title>test title</title> | |||
| <seg id="1"> I like you. </seg> | |||
| <seg id="2"> I will give you a schoolbag. </seg> | |||
| </doc> | |||
| </refset> | |||
| </mteval> | |||
| @@ -0,0 +1,10 @@ | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <keywords>test keywords</keywords> | |||
| <speaker>test speaker</speaker> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <description>test description.</description> | |||
| Code schreiben macht Freude. | |||
| Ich hoffe in Zukunft weniger Überstunden machen zu können. | |||
| <translator href="https://gitee.com/mindspore/mindspore">test translator</translator> | |||
| <reviewer href="https://gitee.com/mindspore/mindspore">test reviewer</reviewer> | |||
| @@ -0,0 +1,10 @@ | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <keywords>test keywords</keywords> | |||
| <speaker>test speaker</speaker> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <description>test description.</description> | |||
| Writing code is a joy. | |||
| I hope to work less overtime in the future. | |||
| <reviewer></reviewer> | |||
| <translator></translator> | |||
| @@ -0,0 +1,16 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <srcset setid="iwslt2017-dev2010" srclang="german"> | |||
| <doc docid="test" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <reviewer href="https://gitee.com/mindspore/mindspore">test reviewer</reviewer> | |||
| <translator href="https://gitee.com/mindspore/mindspore">test translator</translator> | |||
| <seg id="1"> Ich kann meinen Code nicht zu Ende schreiben. </seg> | |||
| <seg id="2"> Vielleicht muss ich Überstunden machen. </seg> | |||
| </doc> | |||
| </srcset> | |||
| </mteval> | |||
| @@ -0,0 +1,16 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <refset setid="iwslt2017-dev2010" srclang="german" trglang="english" refid="ref"> | |||
| <doc docid="535" genre="lectures"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <reviewer></reviewer> | |||
| <translator></translator> | |||
| <seg id="1"> I can't finish writing my code. </seg> | |||
| <seg id="2"> I might have to work overtime. </seg> | |||
| </doc> | |||
| </refset> | |||
| </mteval> | |||
| @@ -0,0 +1,16 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <srcset setid="iwslt2017-tst2010" srclang="german"> | |||
| <doc docid="test" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <reviewer href="https://gitee.com/mindspore/mindspore">test reviewer</reviewer> | |||
| <translator href="https://gitee.com/mindspore/mindspore">test translator</translator> | |||
| <seg id="1"> Heute gehe ich ins Labor. </seg> | |||
| <seg id="2"> Ich schlafe jetzt wieder ein. </seg> | |||
| </doc> | |||
| </srcset> | |||
| </mteval> | |||
| @@ -0,0 +1,16 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <mteval> | |||
| <refset setid="iwslt2017-tst2010" srclang="german" trglang="english" refid="ref"> | |||
| <doc docid="test" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <description>test description</description> | |||
| <keywords>test keywords</keywords> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <reviewer></reviewer> | |||
| <translator></translator> | |||
| <seg id="1"> Today i'm going to the lab. </seg> | |||
| <seg id="2"> I am going back to sleep now. </seg> | |||
| </doc> | |||
| </refset> | |||
| </mteval> | |||
| @@ -0,0 +1,12 @@ | |||
| <doc docid="1" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <keywords>test keywords</keywords> | |||
| <speaker>test speaker</speaker> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <description>test description.</description> | |||
| Schönes Wetter heute. | |||
| Ich bin heute gut gelaunt. | |||
| <reviewer href="https://gitee.com/mindspore/mindspore">test name</reviewer> | |||
| <translator href="https://gitee.com/mindspore/mindspore">test name</translator> | |||
| </doc> | |||
| @@ -0,0 +1,12 @@ | |||
| <doc docid="1" genre="test"> | |||
| <url>https://gitee.com/mindspore/mindspore</url> | |||
| <keywords>test keywords</keywords> | |||
| <speaker>test speaker</speaker> | |||
| <talkid>test number</talkid> | |||
| <title>test title</title> | |||
| <description>test description.</description> | |||
| The weather is nice today. | |||
| I am in a good mood today. | |||
| <reviewer href="https://gitee.com/mindspore/mindspore">test reviewer</reviewer> | |||
| <translator href="https://gitee.com/mindspore/mindspore">test translator</translator> | |||
| </doc> | |||
| @@ -0,0 +1,261 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| DATA_IWSLT2016_DIR = '../data/dataset/testIWSLT/IWSLT2016' | |||
| DATA_IWSLT2017_DIR = '../data/dataset/testIWSLT/IWSLT2017' | |||
| def test_iwslt2016_dataset_basic(): | |||
| """ | |||
| Feature: Test IWSLT2016 Dataset. | |||
| Description: read data from a single file. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| buffer = [] | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| data = data.repeat(2) | |||
| data = data.skip(2) | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| buffer.append(d) | |||
| assert len(buffer) == 2 | |||
| def test_iwslt2016_dataset_quoted(): | |||
| """ | |||
| Feature: Test get the IWSLT2016 Dataset. | |||
| Description: read IWSLT2016 data and get data. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| buffer.extend([d['text'].item().decode("utf8"), | |||
| d['translation'].item().decode("utf8")]) | |||
| assert buffer == ["Code schreiben macht Freude.", | |||
| "Writing code is a joy.", | |||
| "Ich hoffe in Zukunft weniger Überstunden machen zu können.", | |||
| "I hope to work less overtime in the future."] | |||
| def test_iwslt2016_dataset_usage_all(): | |||
| """ | |||
| Feature: Test IWSLT2016 Dataset (usage=all). | |||
| Description: read train data and test data. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| buffer = [] | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='all', language_pair=["de", "en"], valid_set='tst2013', | |||
| test_set='tst2014', shuffle=False) | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| buffer.append(d) | |||
| assert len(buffer) == 6 | |||
| def test_iwslt2016_dataset_get_datasetsize(): | |||
| """ | |||
| Feature: Test Getters. | |||
| Description: test get_dataset_size of IWSLT2016 dataset. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| size = data.get_dataset_size() | |||
| assert size == 2 | |||
| def test_iwslt2016_dataset_distribution(): | |||
| """ | |||
| Feature: Test IWSLT2016Dataset in distribution. | |||
| Description: test in a distributed state. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='train', language_pair=["de", "en"], shuffle=False, | |||
| num_shards=2, shard_id=0) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count += 1 | |||
| assert count == 1 | |||
| def test_iwslt2016_dataset_num_samples(): | |||
| """ | |||
| Feature: Test IWSLT2016 Dataset (num_samples=2). | |||
| Description: test get num_samples. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='train', language_pair=["de", "en"], shuffle=False, | |||
| num_samples=2) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count += 1 | |||
| assert count == 2 | |||
| def test_iwslt2016_dataset_exception(): | |||
| """ | |||
| Feature: Error Test. | |||
| Description: test the wrong input. | |||
| Expectation: unable to read in data. | |||
| """ | |||
| def exception_func(item): | |||
| raise Exception("Error occur!") | |||
| try: | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1) | |||
| for _ in data.create_dict_iterator(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | |||
| try: | |||
| data = ds.IWSLT2016Dataset(DATA_IWSLT2016_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| data = data.map(operations=exception_func, input_columns=["translation"], num_parallel_workers=1) | |||
| for _ in data.create_dict_iterator(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | |||
| def test_iwslt2017_dataset_basic(): | |||
| """ | |||
| Feature: Test IWSLT2017 Dataset. | |||
| Description: read data from a single file. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| buffer = [] | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| data = data.repeat(2) | |||
| data = data.skip(2) | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| buffer.append(d) | |||
| assert len(buffer) == 2 | |||
| def test_iwslt2017_dataset_quoted(): | |||
| """ | |||
| Feature: Test get the IWSLT2017 Dataset. | |||
| Description: read IWSLT2017 data and get data. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| buffer = [] | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| buffer.extend([d['text'].item().decode("utf8"), | |||
| d['translation'].item().decode("utf8")]) | |||
| assert buffer == ["Schönes Wetter heute.", | |||
| "The weather is nice today.", | |||
| "Ich bin heute gut gelaunt.", | |||
| "I am in a good mood today."] | |||
| def test_iwslt2017_dataset_usage_all(): | |||
| """ | |||
| Feature: Test IWSLT2017 Dataset(usage=all). | |||
| Description: read train data and test data. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| buffer = [] | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='all', language_pair=["de", "en"], shuffle=False) | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| buffer.append(d) | |||
| assert len(buffer) == 6 | |||
| def test_iwslt2017_dataset_get_datasetsize(): | |||
| """ | |||
| Feature: Test Getters. | |||
| Description: test get_dataset_size of IWSLT2017 dataset. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| size = data.get_dataset_size() | |||
| assert size == 2 | |||
| def test_iwslt2017_dataset_distribution(): | |||
| """ | |||
| Feature: Test IWSLT2017Dataset in distribution. | |||
| Description: test in a distributed state. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='train', language_pair=["de", "en"], shuffle=False, | |||
| num_shards=2, shard_id=0) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count += 1 | |||
| assert count == 1 | |||
| def test_iwslt2017_dataset_num_samples(): | |||
| """ | |||
| Feature: Test IWSLT2017 Dataset (num_samples=2). | |||
| Description: test get num_samples. | |||
| Expectation: the data is processed successfully. | |||
| """ | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='train', language_pair=["de", "en"], shuffle=False, | |||
| num_samples=2) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count += 1 | |||
| assert count == 2 | |||
| def test_iwslt2017_dataset_exception(): | |||
| """ | |||
| Feature: Error Test. | |||
| Description: test the wrong input. | |||
| Expectation: unable to read in data. | |||
| """ | |||
| def exception_func(item): | |||
| raise Exception("Error occur!") | |||
| try: | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1) | |||
| for _ in data.create_dict_iterator(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | |||
| try: | |||
| data = ds.IWSLT2017Dataset(DATA_IWSLT2017_DIR, usage='train', language_pair=["de", "en"], shuffle=False) | |||
| data = data.map(operations=exception_func, input_columns=["translation"], num_parallel_workers=1) | |||
| for _ in data.create_dict_iterator(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | |||
| if __name__ == "__main__": | |||
| test_iwslt2016_dataset_basic() | |||
| test_iwslt2016_dataset_quoted() | |||
| test_iwslt2016_dataset_usage_all() | |||
| test_iwslt2016_dataset_get_datasetsize() | |||
| test_iwslt2016_dataset_distribution() | |||
| test_iwslt2016_dataset_num_samples() | |||
| test_iwslt2016_dataset_exception() | |||
| test_iwslt2017_dataset_basic() | |||
| test_iwslt2017_dataset_quoted() | |||
| test_iwslt2017_dataset_usage_all() | |||
| test_iwslt2017_dataset_get_datasetsize() | |||
| test_iwslt2017_dataset_distribution() | |||
| test_iwslt2017_dataset_num_samples() | |||
| test_iwslt2017_dataset_exception() | |||