| @@ -52,6 +52,7 @@ add_subdirectory(core) | |||||
| add_subdirectory(kernels) | add_subdirectory(kernels) | ||||
| add_subdirectory(engine) | add_subdirectory(engine) | ||||
| add_subdirectory(api) | add_subdirectory(api) | ||||
| add_subdirectory(nlp) | |||||
| ###################################################################### | ###################################################################### | ||||
| ################### Create _c_dataengine Library ###################### | ################### Create _c_dataengine Library ###################### | ||||
| @@ -68,6 +69,8 @@ set(submodules | |||||
| $<TARGET_OBJECTS:engine-datasetops> | $<TARGET_OBJECTS:engine-datasetops> | ||||
| $<TARGET_OBJECTS:engine-opt> | $<TARGET_OBJECTS:engine-opt> | ||||
| $<TARGET_OBJECTS:engine> | $<TARGET_OBJECTS:engine> | ||||
| $<TARGET_OBJECTS:nlp> | |||||
| $<TARGET_OBJECTS:nlp-kernels> | |||||
| ) | ) | ||||
| if (ENABLE_TDTQUE) | if (ENABLE_TDTQUE) | ||||
| @@ -40,6 +40,8 @@ | |||||
| #include "dataset/kernels/data/type_cast_op.h" | #include "dataset/kernels/data/type_cast_op.h" | ||||
| #include "dataset/kernels/text/jieba_tokenizer_op.h" | #include "dataset/kernels/text/jieba_tokenizer_op.h" | ||||
| #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | #include "dataset/kernels/text/unicode_char_tokenizer_op.h" | ||||
| #include "dataset/nlp/vocab.h" | |||||
| #include "dataset/nlp/kernels/lookup_op.h" | |||||
| #include "dataset/engine/datasetops/source/cifar_op.h" | #include "dataset/engine/datasetops/source/cifar_op.h" | ||||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | #include "dataset/engine/datasetops/source/image_folder_op.h" | ||||
| #include "dataset/engine/datasetops/source/io_block.h" | #include "dataset/engine/datasetops/source/io_block.h" | ||||
| @@ -414,10 +416,13 @@ void bindTensorOps5(py::module *m) { | |||||
| py::arg("mode") = JiebaMode::kMix) | py::arg("mode") = JiebaMode::kMix) | ||||
| .def("add_word", | .def("add_word", | ||||
| [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); | [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); | ||||
| (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>( | (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>( | ||||
| *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") | *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") | ||||
| .def(py::init<>()); | .def(py::init<>()); | ||||
| (void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp", | |||||
| "Tensor operation to LookUp each word") | |||||
| .def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown")) | |||||
| .def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab")); | |||||
| } | } | ||||
| void bindSamplerOps(py::module *m) { | void bindSamplerOps(py::module *m) { | ||||
| @@ -479,6 +484,27 @@ void bindInfoObjects(py::module *m) { | |||||
| .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); | .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); | ||||
| } | } | ||||
| void bindVocabObjects(py::module *m) { | |||||
| (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab") | |||||
| .def_static("from_list", | |||||
| [](const py::list &words) { | |||||
| std::shared_ptr<Vocab> v; | |||||
| THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v)); | |||||
| return v; | |||||
| }) | |||||
| .def_static("from_file", | |||||
| [](const std::string &path, const std::string &dlm, int32_t vocab_size) { | |||||
| std::shared_ptr<Vocab> v; | |||||
| THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v)); | |||||
| return v; | |||||
| }) | |||||
| .def_static("from_dict", [](const py::dict &words) { | |||||
| std::shared_ptr<Vocab> v; | |||||
| THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); | |||||
| return v; | |||||
| }); | |||||
| } | |||||
| // This is where we externalize the C logic as python modules | // This is where we externalize the C logic as python modules | ||||
| PYBIND11_MODULE(_c_dataengine, m) { | PYBIND11_MODULE(_c_dataengine, m) { | ||||
| m.doc() = "pybind11 for _c_dataengine"; | m.doc() = "pybind11 for _c_dataengine"; | ||||
| @@ -543,6 +569,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||||
| bindSamplerOps(&m); | bindSamplerOps(&m); | ||||
| bindDatasetOps(&m); | bindDatasetOps(&m); | ||||
| bindInfoObjects(&m); | bindInfoObjects(&m); | ||||
| bindVocabObjects(&m); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,7 @@ | |||||
| add_subdirectory(kernels) | |||||
| add_library(nlp OBJECT | |||||
| vocab.cc | |||||
| ) | |||||
| add_dependencies(nlp nlp-kernels) | |||||
| @@ -0,0 +1,3 @@ | |||||
| add_library(nlp-kernels OBJECT | |||||
| lookup_op.cc | |||||
| ) | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/nlp/kernels/lookup_op.h" | |||||
| #include <string> | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id) | |||||
| : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} | |||||
| Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| RETURN_UNEXPECTED_IF_NULL(vocab_); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor"); | |||||
| std::vector<WordIdType> word_ids; | |||||
| word_ids.reserve(input->Size()); | |||||
| for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) { | |||||
| word_ids.push_back(vocab_->Lookup(std::string(*itr), default_id_)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, | |||||
| reinterpret_cast<unsigned char *>(word_ids.data()))); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); | |||||
| outputs[0] = type_; | |||||
| return Status::OK(); | |||||
| } | |||||
| void LookupOp::Print(std::ostream &out) const { | |||||
| out << "LookupOp: " | |||||
| << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_NLP_KERNELS_LOOKUP_OP_H_ | |||||
| #define DATASET_NLP_KERNELS_LOOKUP_OP_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| #include "dataset/util/status.h" | |||||
| #include "dataset/nlp/vocab.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class LookupOp : public TensorOp { | |||||
| public: | |||||
| // constructor for lookup, takes in a vocab object | |||||
| // @param std::shared_ptr<Vocab> vocab - | |||||
| // @param WordIdType default_id, id to lookup if a word is not in vocab | |||||
| explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = Vocab::kSpecialTokens::unk); | |||||
| // perform actual lookup on each tensor | |||||
| // @param const std::shared_ptr<Tensor> &input | |||||
| // @param std::shared_ptr<Tensor> *output | |||||
| // @return error code | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| // print method | |||||
| // @param std::ostream out | |||||
| void Print(std::ostream &out) const override; | |||||
| // @param std::vector<DataType> &inputs - | |||||
| // @param std::vector<DataType> &outputs - | |||||
| // @return error code | |||||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||||
| private: | |||||
| std::shared_ptr<Vocab> vocab_; | |||||
| WordIdType default_id_; | |||||
| DataType type_; // type of tensor after lookup | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_NLP_KERNELS_LOOKUP_OP_H_ | |||||
| @@ -0,0 +1,101 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <fstream> | |||||
| #include <map> | |||||
| #include <utility> | |||||
| #include "dataset/nlp/vocab.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { | |||||
| word2id_ = std::move(word2id); | |||||
| id2word_.resize(word2id_.size()); | |||||
| for (auto p : word2id_) { | |||||
| id2word_[p.second - kSpecialTokens::num_tokens] = p.first; | |||||
| } | |||||
| } | |||||
| WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { | |||||
| auto itr = word2id_.find(word); | |||||
| return itr == word2id_.end() ? default_id : itr->second; | |||||
| } | |||||
| WordType Vocab::Lookup(WordIdType id) const { | |||||
| if (id < kSpecialTokens::num_tokens) { | |||||
| return reserved_token_str_[id]; | |||||
| } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { | |||||
| return reserved_token_str_[kSpecialTokens::unk]; | |||||
| } else { | |||||
| return id2word_[id - kSpecialTokens::num_tokens]; | |||||
| } | |||||
| } | |||||
| Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab) { | |||||
| std::unordered_map<WordType, WordIdType> word2id; | |||||
| WordIdType word_id = kSpecialTokens::num_tokens; | |||||
| for (auto word : words) { | |||||
| const std::string s = py::str(word); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s); | |||||
| word2id[s] = word_id++; | |||||
| } | |||||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | |||||
| std::shared_ptr<Vocab> *vocab) { | |||||
| std::unordered_map<WordType, WordIdType> word2id; | |||||
| WordIdType word_id = kSpecialTokens::num_tokens; | |||||
| std::fstream handle(path, std::ios::in); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); | |||||
| std::string word; | |||||
| while (std::getline(handle, word)) { | |||||
| if (!delimiter.empty()) { | |||||
| // if delimiter is not found, find_first_of would return std::string::npos which is -1 | |||||
| word = word.substr(0, word.find_first_of(delimiter)); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word); | |||||
| word2id[word] = word_id++; | |||||
| // break if enough row is read, if vocab_size is smaller than 0 | |||||
| if (word_id == vocab_size + kSpecialTokens::num_tokens) break; | |||||
| } | |||||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) { | |||||
| std::unordered_map<WordType, WordIdType> word2id; | |||||
| std::map<WordIdType, WordType> id2word; | |||||
| for (auto p : words) { | |||||
| WordIdType word_id = py::reinterpret_borrow<py::int_>(p.second); | |||||
| if (word_id < kSpecialTokens::num_tokens) continue; // skip id that are reserved | |||||
| std::string word = py::str(p.first); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word); | |||||
| id2word[word_id] = word; | |||||
| } | |||||
| WordIdType cnt = kSpecialTokens::num_tokens; | |||||
| for (auto p : id2word) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2"); | |||||
| word2id[p.second] = p.first; | |||||
| } | |||||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||||
| return Status::OK(); | |||||
| } | |||||
| const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"}; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_NLP_VOCAB_H_ | |||||
| #define DATASET_NLP_VOCAB_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "dataset/util/status.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/stl.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace py = pybind11; | |||||
| using WordIdType = int32_t; | |||||
| using WordType = std::string; | |||||
| class Vocab { | |||||
| public: | |||||
| // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous | |||||
| // @param const py::dict &words - a dictionary containing word, word id pair. | |||||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||||
| // @return error code | |||||
| static Status BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab); | |||||
| // Build a vocab from a python list, id will be assigned automatically, start from 2 | |||||
| // @param const py::list &words - a list of string, used to build vocab, id starts from 2 | |||||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||||
| // @return error code | |||||
| static Status BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab); | |||||
| // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 | |||||
| // @param std::string &path - path to vocab file , each line is assumed to contain 1 word | |||||
| // @param std::string &delimiter - delimiter to break each line with | |||||
| // @param int32_t vocab_size - number of words to read from file | |||||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||||
| // @return error code | |||||
| static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | |||||
| std::shared_ptr<Vocab> *vocab); | |||||
| // Lookup the id of a word, if word doesn't exist in vocab, return default_id | |||||
| // @param const WordType word - word to look up | |||||
| // @param WordIdType default_id - word id to return to user when its not in the vocab | |||||
| // @return WordIdType, word_id | |||||
| WordIdType Lookup(const WordType &word, WordIdType default_id) const; | |||||
| // reverse lookup, lookup the word based on its id | |||||
| // @param WordIdType id - word id to lookup to | |||||
| // @return WordType the word | |||||
| WordType Lookup(WordIdType id) const; | |||||
| // constructor, shouldn't be called directly, can't be private due to std::make_unique() | |||||
| // @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map | |||||
| explicit Vocab(std::unordered_map<WordType, WordIdType> map); | |||||
| // enum type that holds all special tokens, add more if needed | |||||
| enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 }; | |||||
| // reversed lookup table for the reserved tokens | |||||
| static const std::vector<WordType> reserved_token_str_; | |||||
| private: | |||||
| std::unordered_map<WordType, WordIdType> word2id_; | |||||
| std::vector<WordType> id2word_; // reverse lookup | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_NLP_VOCAB_H_ | |||||
| @@ -0,0 +1,19 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| mindspore.dataset.text | |||||
| """ | |||||
| from .c_transforms import * | |||||
| @@ -0,0 +1,77 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| c transforms for all text related operators | |||||
| """ | |||||
| import mindspore._c_dataengine as cde | |||||
| from .validators import check_lookup, check_from_list, check_from_dict, check_from_file | |||||
| class Vocab(cde.Vocab): | |||||
| """ | |||||
| Vocab object that is used for lookup word | |||||
| Args: | |||||
| """ | |||||
| def __init__(self): | |||||
| pass | |||||
| @classmethod | |||||
| @check_from_list | |||||
| def from_list(cls, word_list): | |||||
| """ | |||||
| build a vocab object from a list of word | |||||
| Args: | |||||
| word_list(list): a list of string where each element is a word | |||||
| """ | |||||
| return super().from_list(word_list) | |||||
| @classmethod | |||||
| @check_from_file | |||||
| def from_file(cls, file_path, delimiter=None, vocab_size=None): | |||||
| """ | |||||
| build a vocab object from a list of word | |||||
| Args: | |||||
| file_path(str): path to the file which contains the vocab list | |||||
| delimiter(None, str): a delimiter to break up each line in file, the first element is taken to be the word | |||||
| vocab_size(None, int): number of words to read from file_path | |||||
| """ | |||||
| return super().from_file(file_path, delimiter, vocab_size) | |||||
| @classmethod | |||||
| @check_from_dict | |||||
| def from_dict(cls, word_dict): | |||||
| """ | |||||
| build a vocab object from a dict. | |||||
| Args: | |||||
| word_dict(dict): dict contains word, id pairs. id should start from 2 and continuous | |||||
| """ | |||||
| return super().from_dict(word_dict) | |||||
| class Lookup(cde.LookupOp): | |||||
| """ | |||||
| Lookup operator that looks up a word to an id | |||||
| Args: | |||||
| vocab(Vocab): a Vocab object | |||||
| unknown(None,int): default id to lookup a word that is out of vocab | |||||
| """ | |||||
| @check_lookup | |||||
| def __init__(self, vocab, unknown=None): | |||||
| if unknown is None: | |||||
| super().__init__(vocab) | |||||
| else: | |||||
| super().__init__(vocab, unknown) | |||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| validators for text ops | |||||
| """ | |||||
| from functools import wraps | |||||
| import mindspore._c_dataengine as cde | |||||
| def check_lookup(method): | |||||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| vocab, unknown = (list(args) + 2 * [None])[:2] | |||||
| if "vocab" in kwargs: | |||||
| vocab = kwargs.get("vocab") | |||||
| if "unknown" in kwargs: | |||||
| unknown = kwargs.get("unknown") | |||||
| if unknown is not None: | |||||
| assert isinstance(unknown, int) and unknown >= 0, "unknown needs to be a non-negative integer" | |||||
| assert isinstance(vocab, cde.Vocab), "vocab is not an instance of cde.Vocab" | |||||
| kwargs["vocab"] = vocab | |||||
| kwargs["unknown"] = unknown | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| def check_from_file(method): | |||||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3] | |||||
| if "file_path" in kwargs: | |||||
| file_path = kwargs.get("file_path") | |||||
| if "delimiter" in kwargs: | |||||
| delimiter = kwargs.get("delimiter") | |||||
| if "vocab_size" in kwargs: | |||||
| vocab_size = kwargs.get("vocab_size") | |||||
| assert isinstance(file_path, str), "file_path needs to be str" | |||||
| if delimiter is not None: | |||||
| assert isinstance(delimiter, str), "delimiter needs to be str" | |||||
| else: | |||||
| delimiter = "" | |||||
| if vocab_size is not None: | |||||
| assert isinstance(vocab_size, int) and vocab_size > 0, "vocab size needs to be a positive integer" | |||||
| else: | |||||
| vocab_size = -1 | |||||
| kwargs["file_path"] = file_path | |||||
| kwargs["delimiter"] = delimiter | |||||
| kwargs["vocab_size"] = vocab_size | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| def check_from_list(method): | |||||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| word_list, = (list(args) + [None])[:1] | |||||
| if "word_list" in kwargs: | |||||
| word_list = kwargs.get("word_list") | |||||
| assert isinstance(word_list, list), "word_list needs to be a list of words" | |||||
| for word in word_list: | |||||
| assert isinstance(word, str), "each word in word list needs to be type str" | |||||
| kwargs["word_list"] = word_list | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| def check_from_dict(method): | |||||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| word_dict, = (list(args) + [None])[:1] | |||||
| if "word_dict" in kwargs: | |||||
| word_dict = kwargs.get("word_dict") | |||||
| assert isinstance(word_dict, dict), "word_dict needs to be a list of word,id pairs" | |||||
| for word, word_id in word_dict.items(): | |||||
| assert isinstance(word, str), "each word in word_dict needs to be type str" | |||||
| assert isinstance(word_id, int) and word_id >= 0, "each word id needs to be positive integer" | |||||
| kwargs["word_dict"] = word_dict | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| @@ -0,0 +1,14 @@ | |||||
| not,1 | |||||
| all,2 | |||||
| those,3 | |||||
| who,4 | |||||
| wonder,5 | |||||
| are,6 | |||||
| lost,7 | |||||
| Tolkein,8 | |||||
| home,9 | |||||
| is,10 | |||||
| behind,11 | |||||
| world,12 | |||||
| ahead,13 | |||||
| the,14 | |||||
| @@ -0,0 +1,6 @@ | |||||
| home | |||||
| is | |||||
| behind | |||||
| the | |||||
| world | |||||
| ahead | |||||
| @@ -0,0 +1,47 @@ | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.text as text | |||||
| # this file contains "home is behind the world head" each word is 1 line | |||||
| DATA_FILE = "../data/dataset/testVocab/words.txt" | |||||
| VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt" | |||||
| def test_from_list(): | |||||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" ")) | |||||
| lookup = text.Lookup(vocab) | |||||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||||
| data = data.map(input_columns=["text"], operations=lookup) | |||||
| ind = 0 | |||||
| res = [2, 1, 4, 5, 6, 7] | |||||
| for d in data.create_dict_iterator(): | |||||
| assert d["text"] == res[ind], ind | |||||
| ind += 1 | |||||
| def test_from_file(): | |||||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",") | |||||
| lookup = text.Lookup(vocab) | |||||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||||
| data = data.map(input_columns=["text"], operations=lookup) | |||||
| ind = 0 | |||||
| res = [10, 11, 12, 15, 13, 14] | |||||
| for d in data.create_dict_iterator(): | |||||
| assert d["text"] == res[ind], ind | |||||
| ind += 1 | |||||
| def test_from_dict(): | |||||
| vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6}) | |||||
| lookup = text.Lookup(vocab, 6) # default value is -1 | |||||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||||
| data = data.map(input_columns=["text"], operations=lookup) | |||||
| res = [3, 6, 2, 4, 5, 6] | |||||
| ind = 0 | |||||
| for d in data.create_dict_iterator(): | |||||
| assert d["text"] == res[ind], ind | |||||
| ind += 1 | |||||
| if __name__ == '__main__': | |||||
| test_from_list() | |||||
| test_from_file() | |||||
| test_from_dict() | |||||