| @@ -52,6 +52,7 @@ add_subdirectory(core) | |||
| add_subdirectory(kernels) | |||
| add_subdirectory(engine) | |||
| add_subdirectory(api) | |||
| add_subdirectory(nlp) | |||
| ###################################################################### | |||
| ################### Create _c_dataengine Library ###################### | |||
| @@ -68,6 +69,8 @@ set(submodules | |||
| $<TARGET_OBJECTS:engine-datasetops> | |||
| $<TARGET_OBJECTS:engine-opt> | |||
| $<TARGET_OBJECTS:engine> | |||
| $<TARGET_OBJECTS:nlp> | |||
| $<TARGET_OBJECTS:nlp-kernels> | |||
| ) | |||
| if (ENABLE_TDTQUE) | |||
| @@ -40,6 +40,8 @@ | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/kernels/text/jieba_tokenizer_op.h" | |||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | |||
| #include "dataset/nlp/vocab.h" | |||
| #include "dataset/nlp/kernels/lookup_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| @@ -414,10 +416,13 @@ void bindTensorOps5(py::module *m) { | |||
| py::arg("mode") = JiebaMode::kMix) | |||
| .def("add_word", | |||
| [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); | |||
| (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>( | |||
| *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") | |||
| .def(py::init<>()); | |||
| (void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp", | |||
| "Tensor operation to LookUp each word") | |||
| .def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown")) | |||
| .def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab")); | |||
| } | |||
| void bindSamplerOps(py::module *m) { | |||
| @@ -479,6 +484,27 @@ void bindInfoObjects(py::module *m) { | |||
| .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); | |||
| } | |||
| void bindVocabObjects(py::module *m) { | |||
| (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab") | |||
| .def_static("from_list", | |||
| [](const py::list &words) { | |||
| std::shared_ptr<Vocab> v; | |||
| THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v)); | |||
| return v; | |||
| }) | |||
| .def_static("from_file", | |||
| [](const std::string &path, const std::string &dlm, int32_t vocab_size) { | |||
| std::shared_ptr<Vocab> v; | |||
| THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v)); | |||
| return v; | |||
| }) | |||
| .def_static("from_dict", [](const py::dict &words) { | |||
| std::shared_ptr<Vocab> v; | |||
| THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); | |||
| return v; | |||
| }); | |||
| } | |||
| // This is where we externalize the C logic as python modules | |||
| PYBIND11_MODULE(_c_dataengine, m) { | |||
| m.doc() = "pybind11 for _c_dataengine"; | |||
| @@ -543,6 +569,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| bindSamplerOps(&m); | |||
| bindDatasetOps(&m); | |||
| bindInfoObjects(&m); | |||
| bindVocabObjects(&m); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,7 @@ | |||
| add_subdirectory(kernels) | |||
| add_library(nlp OBJECT | |||
| vocab.cc | |||
| ) | |||
| add_dependencies(nlp nlp-kernels) | |||
| @@ -0,0 +1,3 @@ | |||
| add_library(nlp-kernels OBJECT | |||
| lookup_op.cc | |||
| ) | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/nlp/kernels/lookup_op.h" | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id) | |||
| : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} | |||
| Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| RETURN_UNEXPECTED_IF_NULL(vocab_); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor"); | |||
| std::vector<WordIdType> word_ids; | |||
| word_ids.reserve(input->Size()); | |||
| for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) { | |||
| word_ids.push_back(vocab_->Lookup(std::string(*itr), default_id_)); | |||
| } | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, | |||
| reinterpret_cast<unsigned char *>(word_ids.data()))); | |||
| return Status::OK(); | |||
| } | |||
| Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); | |||
| outputs[0] = type_; | |||
| return Status::OK(); | |||
| } | |||
| void LookupOp::Print(std::ostream &out) const { | |||
| out << "LookupOp: " | |||
| << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_NLP_KERNELS_LOOKUP_OP_H_ | |||
| #define DATASET_NLP_KERNELS_LOOKUP_OP_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/nlp/vocab.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class LookupOp : public TensorOp { | |||
| public: | |||
| // constructor for lookup, takes in a vocab object | |||
| // @param std::shared_ptr<Vocab> vocab - | |||
| // @param WordIdType default_id, id to lookup if a word is not in vocab | |||
| explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = Vocab::kSpecialTokens::unk); | |||
| // perform actual lookup on each tensor | |||
| // @param const std::shared_ptr<Tensor> &input | |||
| // @param std::shared_ptr<Tensor> *output | |||
| // @return error code | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| // print method | |||
| // @param std::ostream out | |||
| void Print(std::ostream &out) const override; | |||
| // @param std::vector<DataType> &inputs - | |||
| // @param std::vector<DataType> &outputs - | |||
| // @return error code | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| private: | |||
| std::shared_ptr<Vocab> vocab_; | |||
| WordIdType default_id_; | |||
| DataType type_; // type of tensor after lookup | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_NLP_KERNELS_LOOKUP_OP_H_ | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <fstream> | |||
| #include <map> | |||
| #include <utility> | |||
| #include "dataset/nlp/vocab.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { | |||
| word2id_ = std::move(word2id); | |||
| id2word_.resize(word2id_.size()); | |||
| for (auto p : word2id_) { | |||
| id2word_[p.second - kSpecialTokens::num_tokens] = p.first; | |||
| } | |||
| } | |||
| WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { | |||
| auto itr = word2id_.find(word); | |||
| return itr == word2id_.end() ? default_id : itr->second; | |||
| } | |||
| WordType Vocab::Lookup(WordIdType id) const { | |||
| if (id < kSpecialTokens::num_tokens) { | |||
| return reserved_token_str_[id]; | |||
| } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { | |||
| return reserved_token_str_[kSpecialTokens::unk]; | |||
| } else { | |||
| return id2word_[id - kSpecialTokens::num_tokens]; | |||
| } | |||
| } | |||
| Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab) { | |||
| std::unordered_map<WordType, WordIdType> word2id; | |||
| WordIdType word_id = kSpecialTokens::num_tokens; | |||
| for (auto word : words) { | |||
| const std::string s = py::str(word); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s); | |||
| word2id[s] = word_id++; | |||
| } | |||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||
| return Status::OK(); | |||
| } | |||
| Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | |||
| std::shared_ptr<Vocab> *vocab) { | |||
| std::unordered_map<WordType, WordIdType> word2id; | |||
| WordIdType word_id = kSpecialTokens::num_tokens; | |||
| std::fstream handle(path, std::ios::in); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); | |||
| std::string word; | |||
| while (std::getline(handle, word)) { | |||
| if (!delimiter.empty()) { | |||
| // if delimiter is not found, find_first_of would return std::string::npos which is -1 | |||
| word = word.substr(0, word.find_first_of(delimiter)); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word); | |||
| word2id[word] = word_id++; | |||
| // break if enough row is read, if vocab_size is smaller than 0 | |||
| if (word_id == vocab_size + kSpecialTokens::num_tokens) break; | |||
| } | |||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||
| return Status::OK(); | |||
| } | |||
| Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) { | |||
| std::unordered_map<WordType, WordIdType> word2id; | |||
| std::map<WordIdType, WordType> id2word; | |||
| for (auto p : words) { | |||
| WordIdType word_id = py::reinterpret_borrow<py::int_>(p.second); | |||
| if (word_id < kSpecialTokens::num_tokens) continue; // skip id that are reserved | |||
| std::string word = py::str(p.first); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word); | |||
| id2word[word_id] = word; | |||
| } | |||
| WordIdType cnt = kSpecialTokens::num_tokens; | |||
| for (auto p : id2word) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2"); | |||
| word2id[p.second] = p.first; | |||
| } | |||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||
| return Status::OK(); | |||
| } | |||
| const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"}; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_NLP_VOCAB_H_ | |||
| #define DATASET_NLP_VOCAB_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "dataset/util/status.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace py = pybind11; | |||
| using WordIdType = int32_t; | |||
| using WordType = std::string; | |||
| class Vocab { | |||
| public: | |||
| // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous | |||
| // @param const py::dict &words - a dictionary containing word, word id pair. | |||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||
| // @return error code | |||
| static Status BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab); | |||
| // Build a vocab from a python list, id will be assigned automatically, start from 2 | |||
| // @param const py::list &words - a list of string, used to build vocab, id starts from 2 | |||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||
| // @return error code | |||
| static Status BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab); | |||
| // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 | |||
| // @param std::string &path - path to vocab file , each line is assumed to contain 1 word | |||
| // @param std::string &delimiter - delimiter to break each line with | |||
| // @param int32_t vocab_size - number of words to read from file | |||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||
| // @return error code | |||
| static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | |||
| std::shared_ptr<Vocab> *vocab); | |||
| // Lookup the id of a word, if word doesn't exist in vocab, return default_id | |||
| // @param const WordType word - word to look up | |||
| // @param WordIdType default_id - word id to return to user when its not in the vocab | |||
| // @return WordIdType, word_id | |||
| WordIdType Lookup(const WordType &word, WordIdType default_id) const; | |||
| // reverse lookup, lookup the word based on its id | |||
| // @param WordIdType id - word id to lookup to | |||
| // @return WordType the word | |||
| WordType Lookup(WordIdType id) const; | |||
| // constructor, shouldn't be called directly, can't be private due to std::make_unique() | |||
| // @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map | |||
| explicit Vocab(std::unordered_map<WordType, WordIdType> map); | |||
| // enum type that holds all special tokens, add more if needed | |||
| enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 }; | |||
| // reversed lookup table for the reserved tokens | |||
| static const std::vector<WordType> reserved_token_str_; | |||
| private: | |||
| std::unordered_map<WordType, WordIdType> word2id_; | |||
| std::vector<WordType> id2word_; // reverse lookup | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_NLP_VOCAB_H_ | |||
| @@ -0,0 +1,19 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| mindspore.dataset.text | |||
| """ | |||
| from .c_transforms import * | |||
| @@ -0,0 +1,77 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| c transforms for all text related operators | |||
| """ | |||
| import mindspore._c_dataengine as cde | |||
| from .validators import check_lookup, check_from_list, check_from_dict, check_from_file | |||
| class Vocab(cde.Vocab): | |||
| """ | |||
| Vocab object that is used for lookup word | |||
| Args: | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| @classmethod | |||
| @check_from_list | |||
| def from_list(cls, word_list): | |||
| """ | |||
| build a vocab object from a list of word | |||
| Args: | |||
| word_list(list): a list of string where each element is a word | |||
| """ | |||
| return super().from_list(word_list) | |||
| @classmethod | |||
| @check_from_file | |||
| def from_file(cls, file_path, delimiter=None, vocab_size=None): | |||
| """ | |||
| build a vocab object from a list of word | |||
| Args: | |||
| file_path(str): path to the file which contains the vocab list | |||
| delimiter(None, str): a delimiter to break up each line in file, the first element is taken to be the word | |||
| vocab_size(None, int): number of words to read from file_path | |||
| """ | |||
| return super().from_file(file_path, delimiter, vocab_size) | |||
| @classmethod | |||
| @check_from_dict | |||
| def from_dict(cls, word_dict): | |||
| """ | |||
| build a vocab object from a dict. | |||
| Args: | |||
| word_dict(dict): dict contains word, id pairs. id should start from 2 and continuous | |||
| """ | |||
| return super().from_dict(word_dict) | |||
| class Lookup(cde.LookupOp): | |||
| """ | |||
| Lookup operator that looks up a word to an id | |||
| Args: | |||
| vocab(Vocab): a Vocab object | |||
| unknown(None,int): default id to lookup a word that is out of vocab | |||
| """ | |||
| @check_lookup | |||
| def __init__(self, vocab, unknown=None): | |||
| if unknown is None: | |||
| super().__init__(vocab) | |||
| else: | |||
| super().__init__(vocab, unknown) | |||
| @@ -0,0 +1,108 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| validators for text ops | |||
| """ | |||
| from functools import wraps | |||
| import mindspore._c_dataengine as cde | |||
| def check_lookup(method): | |||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| vocab, unknown = (list(args) + 2 * [None])[:2] | |||
| if "vocab" in kwargs: | |||
| vocab = kwargs.get("vocab") | |||
| if "unknown" in kwargs: | |||
| unknown = kwargs.get("unknown") | |||
| if unknown is not None: | |||
| assert isinstance(unknown, int) and unknown >= 0, "unknown needs to be a non-negative integer" | |||
| assert isinstance(vocab, cde.Vocab), "vocab is not an instance of cde.Vocab" | |||
| kwargs["vocab"] = vocab | |||
| kwargs["unknown"] = unknown | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_from_file(method): | |||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3] | |||
| if "file_path" in kwargs: | |||
| file_path = kwargs.get("file_path") | |||
| if "delimiter" in kwargs: | |||
| delimiter = kwargs.get("delimiter") | |||
| if "vocab_size" in kwargs: | |||
| vocab_size = kwargs.get("vocab_size") | |||
| assert isinstance(file_path, str), "file_path needs to be str" | |||
| if delimiter is not None: | |||
| assert isinstance(delimiter, str), "delimiter needs to be str" | |||
| else: | |||
| delimiter = "" | |||
| if vocab_size is not None: | |||
| assert isinstance(vocab_size, int) and vocab_size > 0, "vocab size needs to be a positive integer" | |||
| else: | |||
| vocab_size = -1 | |||
| kwargs["file_path"] = file_path | |||
| kwargs["delimiter"] = delimiter | |||
| kwargs["vocab_size"] = vocab_size | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_from_list(method): | |||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word_list, = (list(args) + [None])[:1] | |||
| if "word_list" in kwargs: | |||
| word_list = kwargs.get("word_list") | |||
| assert isinstance(word_list, list), "word_list needs to be a list of words" | |||
| for word in word_list: | |||
| assert isinstance(word, str), "each word in word list needs to be type str" | |||
| kwargs["word_list"] = word_list | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_from_dict(method): | |||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word_dict, = (list(args) + [None])[:1] | |||
| if "word_dict" in kwargs: | |||
| word_dict = kwargs.get("word_dict") | |||
| assert isinstance(word_dict, dict), "word_dict needs to be a list of word,id pairs" | |||
| for word, word_id in word_dict.items(): | |||
| assert isinstance(word, str), "each word in word_dict needs to be type str" | |||
| assert isinstance(word_id, int) and word_id >= 0, "each word id needs to be positive integer" | |||
| kwargs["word_dict"] = word_dict | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -0,0 +1,14 @@ | |||
| not,1 | |||
| all,2 | |||
| those,3 | |||
| who,4 | |||
| wonder,5 | |||
| are,6 | |||
| lost,7 | |||
| Tolkein,8 | |||
| home,9 | |||
| is,10 | |||
| behind,11 | |||
| world,12 | |||
| ahead,13 | |||
| the,14 | |||
| @@ -0,0 +1,6 @@ | |||
| home | |||
| is | |||
| behind | |||
| the | |||
| world | |||
| ahead | |||
| @@ -0,0 +1,47 @@ | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| # this file contains "home is behind the world head" each word is 1 line | |||
| DATA_FILE = "../data/dataset/testVocab/words.txt" | |||
| VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt" | |||
| def test_from_list(): | |||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" ")) | |||
| lookup = text.Lookup(vocab) | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.map(input_columns=["text"], operations=lookup) | |||
| ind = 0 | |||
| res = [2, 1, 4, 5, 6, 7] | |||
| for d in data.create_dict_iterator(): | |||
| assert d["text"] == res[ind], ind | |||
| ind += 1 | |||
| def test_from_file(): | |||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",") | |||
| lookup = text.Lookup(vocab) | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.map(input_columns=["text"], operations=lookup) | |||
| ind = 0 | |||
| res = [10, 11, 12, 15, 13, 14] | |||
| for d in data.create_dict_iterator(): | |||
| assert d["text"] == res[ind], ind | |||
| ind += 1 | |||
| def test_from_dict(): | |||
| vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6}) | |||
| lookup = text.Lookup(vocab, 6) # default value is -1 | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.map(input_columns=["text"], operations=lookup) | |||
| res = [3, 6, 2, 4, 5, 6] | |||
| ind = 0 | |||
| for d in data.create_dict_iterator(): | |||
| assert d["text"] == res[ind], ind | |||
| ind += 1 | |||
| if __name__ == '__main__': | |||
| test_from_list() | |||
| test_from_file() | |||
| test_from_dict() | |||