Merge pull request !2317 from ZiruiWu/vocab_reworktags/v0.5.0-beta
| @@ -1283,18 +1283,18 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas | |||
| py::tuple tp = py::reinterpret_borrow<py::tuple>(value); | |||
| if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0])); | |||
| if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1])); | |||
| } | |||
| if (key == "top_k") { | |||
| } else if (key == "top_k") { | |||
| builder->SetTopK(py::reinterpret_borrow<py::int_>(value)); | |||
| } | |||
| if (key == "columns") { | |||
| } else if (key == "columns") { | |||
| (void)builder->SetColumnNames(ToStringVector(value)); | |||
| } | |||
| if (key == "vocab") { | |||
| } else if (key == "vocab") { | |||
| (void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>()); | |||
| } | |||
| if (key == "num_parallel_workers") { | |||
| } else if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "special_first") { | |||
| (void)builder->SetSpecialFirst(ToBool(value)); | |||
| } else if (key == "special_tokens") { | |||
| (void)builder->SetSpecialTokens(ToStringVector(value)); | |||
| } | |||
| } | |||
| } | |||
| @@ -673,15 +673,16 @@ void bindVocabObjects(py::module *m) { | |||
| (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab") | |||
| .def(py::init<>()) | |||
| .def_static("from_list", | |||
| [](const py::list &words) { | |||
| [](const py::list &words, const py::list &special_tokens, bool special_first) { | |||
| std::shared_ptr<Vocab> v; | |||
| THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v)); | |||
| THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v)); | |||
| return v; | |||
| }) | |||
| .def_static("from_file", | |||
| [](const std::string &path, const std::string &dlm, int32_t vocab_size) { | |||
| [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens, | |||
| bool special_first) { | |||
| std::shared_ptr<Vocab> v; | |||
| THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v)); | |||
| THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v)); | |||
| return v; | |||
| }) | |||
| .def_static("from_dict", [](const py::dict &words) { | |||
| @@ -27,13 +27,16 @@ namespace mindspore { | |||
| namespace dataset { | |||
| BuildVocabOp::BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, | |||
| std::pair<int64_t, int64_t> freq_r, int64_t top_k, int32_t num_workers, int32_t op_conn_size) | |||
| std::pair<int64_t, int64_t> freq_r, int64_t top_k, const std::vector<std::string> &tokens, | |||
| bool prepend, int32_t num_workers, int32_t op_conn_size) | |||
| : ParallelOp(num_workers, op_conn_size), | |||
| interval_(op_conn_size * num_workers), | |||
| vocab_(vocab), | |||
| col_names_(col_names), | |||
| freq_range_(freq_r), | |||
| top_k_(top_k) { | |||
| top_k_(top_k), | |||
| special_tokens_(tokens), | |||
| special_first_(prepend) { | |||
| // init two queues for thread sync | |||
| distributor_queue_ = std::make_unique<Queue<TensorRow>>(num_workers * op_conn_size); | |||
| collector_queue_ = | |||
| @@ -129,7 +132,7 @@ Status BuildVocabOp::CollectorThread() { | |||
| } // all frequencies are obtained | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty"); | |||
| std::vector<std::string> words; | |||
| // make sure enough is reserved | |||
| // make sure enough is reserved, this will become a partially sorted list eventually | |||
| words.reserve(wrkr_map->size()); | |||
| for (auto it = word_cnt_.begin(); it != word_cnt_.end();) { | |||
| @@ -140,6 +143,15 @@ Status BuildVocabOp::CollectorThread() { | |||
| it = word_cnt_.erase(it); | |||
| } | |||
| } | |||
| std::string err_msg; | |||
| for (const std::string &sp_tk : special_tokens_) { | |||
| // if a special word exists in dataset, warn user about this | |||
| err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : ""); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + "."); | |||
| int64_t num_words = std::min(static_cast<int64_t>(words.size()), top_k_); | |||
| if (num_words == 0) { | |||
| MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second | |||
| @@ -152,9 +164,19 @@ Status BuildVocabOp::CollectorThread() { | |||
| int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2]; | |||
| return f1 == f2 ? w1 < w2 : f1 > f2; | |||
| }); | |||
| if (special_first_) { | |||
| for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); | |||
| } | |||
| for (int64_t i = 0; i < num_words; i++) { | |||
| vocab_->append_word(words[i]); | |||
| } | |||
| if (!special_first_) { | |||
| for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | |||
| // then use std::nth_element to partial sort | |||
| @@ -166,16 +188,17 @@ Status BuildVocabOp::Builder::Build(std::shared_ptr<BuildVocabOp> *op) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0, | |||
| "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"); | |||
| (*op) = std::make_shared<BuildVocabOp>(builder_vocab_, builder_col_names_, | |||
| std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, | |||
| builder_num_workers_, builder_connector_size_); | |||
| (*op) = std::make_shared<BuildVocabOp>( | |||
| builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, | |||
| builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| BuildVocabOp::Builder::Builder() | |||
| : builder_top_k_(std::numeric_limits<int64_t>::max()), | |||
| builder_min_freq_(0), | |||
| builder_max_freq_(std::numeric_limits<int64_t>::max()) { | |||
| builder_max_freq_(std::numeric_limits<int64_t>::max()), | |||
| builder_special_first_(true) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_connector_size_ = cfg->op_connector_size(); | |||
| @@ -88,12 +88,26 @@ class BuildVocabOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // set special tokens | |||
| // @param const std::vector<std::string> & col_names - name of columns to get words | |||
| // @return Builder & reference to builder class object | |||
| Builder &SetSpecialTokens(const std::vector<std::string> &tokens) { | |||
| builder_speical_tokens_ = tokens; | |||
| return *this; | |||
| } | |||
| // set vocab object | |||
| Builder &SetVocab(std::shared_ptr<Vocab> vocab) { | |||
| builder_vocab_ = vocab; | |||
| return *this; | |||
| } | |||
| // set special tokens first (or last) | |||
| Builder &SetSpecialFirst(bool prepend) { | |||
| builder_special_first_ = prepend; | |||
| return *this; | |||
| } | |||
| // The builder "build" method creates the final object. | |||
| // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| @@ -104,13 +118,16 @@ class BuildVocabOp : public ParallelOp { | |||
| int32_t builder_connector_size_; | |||
| int64_t builder_min_freq_; | |||
| int64_t builder_max_freq_; | |||
| bool builder_special_first_; | |||
| std::vector<std::string> builder_col_names_; | |||
| std::vector<std::string> builder_speical_tokens_; | |||
| std::shared_ptr<Vocab> builder_vocab_; | |||
| int64_t builder_top_k_; | |||
| }; | |||
| BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, std::pair<int64_t, int64_t> freq_range, | |||
| int64_t top_k, int32_t num_workers, int32_t op_connector_size); | |||
| int64_t top_k, const std::vector<std::string> &tokens, bool prepend, int32_t num_workers, | |||
| int32_t op_connector_size); | |||
| ~BuildVocabOp() = default; | |||
| @@ -137,9 +154,11 @@ class BuildVocabOp : public ParallelOp { | |||
| private: | |||
| const int32_t interval_; | |||
| bool special_first_; | |||
| std::shared_ptr<Vocab> vocab_; | |||
| std::vector<std::string> col_names_; | |||
| std::vector<int32_t> col_ids_; | |||
| std::vector<std::string> special_tokens_; | |||
| // pair = {min_f, max_f} | |||
| // make sure that 0<= min_f < max_f <= int32_max in the builder | |||
| std::pair<int64_t, int64_t> freq_range_; | |||
| @@ -33,7 +33,7 @@ class LookupOp : public TensorOp { | |||
| // constructor for lookup, takes in a vocab object | |||
| // @param std::shared_ptr<Vocab> vocab - | |||
| // @param WordIdType default_id, id to lookup if a word is not in vocab | |||
| explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = Vocab::kSpecialTokens::unk); | |||
| explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1); | |||
| ~LookupOp() = default; | |||
| @@ -14,7 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <fstream> | |||
| #include <map> | |||
| #include <unordered_set> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "dataset/text/vocab.h" | |||
| @@ -28,41 +29,38 @@ WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { | |||
| return itr == word2id_.end() ? default_id : itr->second; | |||
| } | |||
| WordType Vocab::Lookup(WordIdType id) { | |||
| // this operation is most likely only done with since reverse lookup is only needed when training is done | |||
| // hence, the worst case of inserting while keep looking up isn't likely to happen | |||
| if (id2word_.size() != word2id_.size() && (id - kSpecialTokens::num_tokens >= id2word_.size())) { | |||
| id2word_.clear(); | |||
| id2word_.reserve(word2id_.size()); | |||
| for (auto p : word2id_) { | |||
| id2word_[p.second - kSpecialTokens::num_tokens] = p.first; | |||
| } | |||
| } | |||
| if (id < kSpecialTokens::num_tokens) { | |||
| return reserved_token_str_[id]; | |||
| } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { | |||
| return reserved_token_str_[kSpecialTokens::unk]; | |||
| } else { | |||
| return id2word_[id - kSpecialTokens::num_tokens]; | |||
| } | |||
| } | |||
| Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab) { | |||
| Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, | |||
| std::shared_ptr<Vocab> *vocab) { | |||
| // check of duplication on both words and special_tokens will be performed in python | |||
| // special_tokens and words both need to be unique, and shouldn't overlap | |||
| std::unordered_map<WordType, WordIdType> word2id; | |||
| WordIdType word_id = kSpecialTokens::num_tokens; | |||
| // if special is added in front, normal words id will start from number of special tokens | |||
| WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0; | |||
| for (auto word : words) { | |||
| const std::string s = py::str(word); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s); | |||
| word2id[s] = word_id++; | |||
| word2id[py::str(word)] = word_id++; | |||
| } | |||
| word_id = prepend_special ? 0 : word2id.size(); | |||
| for (auto special_token : special_tokens) { | |||
| word2id[py::str(special_token)] = word_id++; | |||
| } | |||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||
| return Status::OK(); | |||
| } | |||
| Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | |||
| std::shared_ptr<Vocab> *vocab) { | |||
| const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab) { | |||
| // python validator checks special_tokens doesn't contain any duplicate words | |||
| std::unordered_set<std::string> specials; | |||
| // used to check that words in file don't contain any special token that already exists | |||
| for (auto word : special_tokens) { | |||
| specials.insert(py::str(word)); | |||
| } | |||
| WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0; | |||
| std::unordered_map<WordType, WordIdType> word2id; | |||
| WordIdType word_id = kSpecialTokens::num_tokens; | |||
| std::fstream handle(path, std::ios::in); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); | |||
| std::string word; | |||
| @@ -71,40 +69,35 @@ Status Vocab::BuildFromFile(const std::string &path, const std::string &delimite | |||
| // if delimiter is not found, find_first_of would return std::string::npos which is -1 | |||
| word = word.substr(0, word.find_first_of(delimiter)); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + "."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens."); | |||
| word2id[word] = word_id++; | |||
| // break if enough row is read, if vocab_size is smaller than 0 | |||
| if (word_id == vocab_size + kSpecialTokens::num_tokens) break; | |||
| if (word2id.size() == vocab_size) break; | |||
| } | |||
| word_id = prepend_special ? 0 : word2id.size(); | |||
| for (auto special_token : special_tokens) { | |||
| word2id[py::str(special_token)] = word_id++; | |||
| } | |||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||
| return Status::OK(); | |||
| } | |||
| Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) { | |||
| std::unordered_map<WordType, WordIdType> word2id; | |||
| std::map<WordIdType, WordType> id2word; | |||
| for (auto p : words) { | |||
| WordIdType word_id = py::reinterpret_borrow<py::int_>(p.second); | |||
| if (word_id < kSpecialTokens::num_tokens) continue; // skip id that are reserved | |||
| std::string word = py::str(p.first); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word); | |||
| id2word[word_id] = word; | |||
| } | |||
| WordIdType cnt = kSpecialTokens::num_tokens; | |||
| for (auto p : id2word) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2"); | |||
| word2id[p.second] = p.first; | |||
| word2id[py::str(p.first)] = py::reinterpret_borrow<py::int_>(p.second); | |||
| } | |||
| *vocab = std::make_shared<Vocab>(std::move(word2id)); | |||
| return Status::OK(); | |||
| } | |||
| const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"}; | |||
| void Vocab::append_word(const std::string &word) { | |||
| if (word2id_.find(word) == word2id_.end()) { | |||
| word2id_[word] = word2id_.size() + kSpecialTokens::num_tokens; | |||
| word2id_[word] = word2id_.size(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| @@ -45,7 +45,8 @@ class Vocab { | |||
| // @param const py::list &words - a list of string, used to build vocab, id starts from 2 | |||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||
| // @return error code | |||
| static Status BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab); | |||
| static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, | |||
| std::shared_ptr<Vocab> *vocab); | |||
| // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 | |||
| // @param std::string &path - path to vocab file , each line is assumed to contain 1 word | |||
| @@ -54,7 +55,7 @@ class Vocab { | |||
| // @param std::shared_ptr<Vocab> *vocab - return value, vocab object | |||
| // @return error code | |||
| static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, | |||
| std::shared_ptr<Vocab> *vocab); | |||
| const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab); | |||
| // Lookup the id of a word, if word doesn't exist in vocab, return default_id | |||
| // @param const WordType word - word to look up | |||
| @@ -80,15 +81,8 @@ class Vocab { | |||
| // destructor | |||
| ~Vocab() = default; | |||
| // enum type that holds all special tokens, add more if needed | |||
| enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 }; | |||
| // reversed lookup table for the reserved tokens | |||
| static const std::vector<WordType> reserved_token_str_; | |||
| private: | |||
| std::unordered_map<WordType, WordIdType> word2id_; | |||
| std::vector<WordType> id2word_; // reverse lookup | |||
| }; | |||
| } // namespace dataset | |||
| @@ -894,9 +894,9 @@ class Dataset: | |||
| return ProjectDataset(self, columns) | |||
| def build_vocab(self, vocab, columns, freq_range, top_k): | |||
| def build_vocab(self, vocab, columns, freq_range, top_k, special_tokens, special_first): | |||
| """ Internal function for building a vocab""" | |||
| return BuildVocabDataset(self, vocab, columns, freq_range, top_k) | |||
| return BuildVocabDataset(self, vocab, columns, freq_range, top_k, special_tokens, special_first) | |||
| def apply(self, apply_func): | |||
| """ | |||
| @@ -4869,9 +4869,15 @@ class BuildVocabDataset(DatasetOp): | |||
| top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are | |||
| taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default=None, | |||
| all words are included). | |||
| special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"] | |||
| (default=None, no special tokens will be added). | |||
| special_first(bool): whether special_tokens will be prepended/appended to vocab, If special_tokens is | |||
| specified and special_first is set to None, special_tokens will be prepended. (default=None). | |||
| prefetch_size (int, optional): prefetch number of records ahead of the user's request (default=None). | |||
| """ | |||
| def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None): | |||
| def __init__(self, input_dataset, vocab, columns, freq_range, top_k, special_tokens, special_first, | |||
| prefetch_size=None): | |||
| super().__init__() | |||
| self.columns = columns | |||
| self.input.append(input_dataset) | |||
| @@ -4879,6 +4885,8 @@ class BuildVocabDataset(DatasetOp): | |||
| self.vocab = vocab | |||
| self.freq_range = freq_range | |||
| self.top_k = top_k | |||
| self.special_tokens = special_tokens | |||
| self.special_first = special_first | |||
| input_dataset.output.append(self) | |||
| def get_args(self): | |||
| @@ -4888,6 +4896,8 @@ class BuildVocabDataset(DatasetOp): | |||
| args["freq_range"] = self.freq_range | |||
| args["prefetch_size"] = self.prefetch_size | |||
| args["top_k"] = self.top_k | |||
| args["special_tokens"] = self.special_tokens | |||
| args["special_first"] = self.special_first | |||
| return args | |||
| def __deepcopy__(self, memodict): | |||
| @@ -4904,4 +4914,7 @@ class BuildVocabDataset(DatasetOp): | |||
| new_op.freq_range = copy.deepcopy(self.freq_range, memodict) | |||
| new_op.top_k = copy.deepcopy(self.top_k, memodict) | |||
| new_op.vocab = self.vocab | |||
| new_op.special_tokens = copy.deepcopy(self.special_tokens) | |||
| new_op.special_first = copy.deepcopy(self.special_first) | |||
| return new_op | |||
| @@ -28,10 +28,12 @@ from .validators import check_lookup, check_jieba_add_dict, \ | |||
| class Lookup(cde.LookupOp): | |||
| """ | |||
| Lookup operator that looks up a word to an id | |||
| Lookup operator that looks up a word to an id. | |||
| Args: | |||
| vocab(Vocab): a Vocab object. | |||
| unknown(int): default id to lookup a word that is out of vocab (default is None). | |||
| unknown(int, optional): default id to lookup a word that is out of vocab. If no argument is passed, 1 will be | |||
| used to be the default id which is the convention for unknown_token <unk>. Otherwise, user is strongly | |||
| encouraged to pass in the id for <unk> (default=None). | |||
| """ | |||
| @check_lookup | |||
| @@ -25,12 +25,13 @@ from .validators import check_from_file, check_from_list, check_from_dict, check | |||
| class Vocab(cde.Vocab): | |||
| """ | |||
| Vocab object that is used for lookup word. | |||
| Vocab object that is used to lookup a word. It contains a map that maps each word(str) to an id (int) | |||
| """ | |||
| @classmethod | |||
| @check_from_dataset | |||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None): | |||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, | |||
| special_first=None): | |||
| """ | |||
| Build a vocab from a dataset. This would collect all unique words in a dataset and return a vocab within | |||
| the frequency range specified by user in freq_range. User would be warned if no words fall into the frequency. | |||
| @@ -49,11 +50,16 @@ class Vocab(cde.Vocab): | |||
| top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are | |||
| taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default=None | |||
| all words are included). | |||
| special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"] | |||
| (default=None, no special tokens will be added). | |||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens | |||
| is specified and special_first is set to None, special_tokens will be prepended. (default=None). | |||
| return: | |||
| text.Vocab: Vocab object built from dataset. | |||
| """ | |||
| vocab = Vocab() | |||
| root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k) | |||
| root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) | |||
| for d in root.create_dict_iterator(): | |||
| if d is not None: | |||
| raise ValueError("from_dataset should receive data other than None.") | |||
| @@ -61,17 +67,21 @@ class Vocab(cde.Vocab): | |||
| @classmethod | |||
| @check_from_list | |||
| def from_list(cls, word_list): | |||
| def from_list(cls, word_list, special_tokens=None, special_first=None): | |||
| """ | |||
| build a vocab object from a list of word. | |||
| Args: | |||
| word_list(list): a list of string where each element is a word. | |||
| word_list(list): a list of string where each element is a word of type string. | |||
| special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"] | |||
| (default=None, no special tokens will be added). | |||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens | |||
| is specified and special_first is set to None, special_tokens will be prepended. (default=None). | |||
| """ | |||
| return super().from_list(word_list) | |||
| return super().from_list(word_list, special_tokens, special_first) | |||
| @classmethod | |||
| @check_from_file | |||
| def from_file(cls, file_path, delimiter=None, vocab_size=None): | |||
| def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None): | |||
| """ | |||
| build a vocab object from a list of word. | |||
| Args: | |||
| @@ -79,8 +89,12 @@ class Vocab(cde.Vocab): | |||
| delimiter(str, optional): a delimiter to break up each line in file, the first element is taken to be | |||
| the word (default=None). | |||
| vocab_size(int, optional): number of words to read from file_path (default=None, all words are taken). | |||
| special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"] | |||
| (default=None, no special tokens will be added). | |||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens | |||
| is specified and special_first is set to None, special_tokens will be prepended. (default=None). | |||
| """ | |||
| return super().from_file(file_path, delimiter, vocab_size) | |||
| return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first) | |||
| @classmethod | |||
| @check_from_dict | |||
| @@ -88,7 +102,8 @@ class Vocab(cde.Vocab): | |||
| """ | |||
| build a vocab object from a dict. | |||
| Args: | |||
| word_dict(dict): dict contains word, id pairs. id should start from 2 and be continuous. | |||
| word_dict(dict): dict contains word, id pairs where word should be str and id int. id is recommended to | |||
| start from 0 and be continuous. ValueError will be raised if id is negative. | |||
| """ | |||
| return super().from_dict(word_dict) | |||
| @@ -23,6 +23,21 @@ import mindspore._c_dataengine as cde | |||
| from ..transforms.validators import check_uint32, check_pos_int64 | |||
| def check_unique_list_of_words(words, arg_name): | |||
| """Check that words is a list and each element is a str without any duplication""" | |||
| if not isinstance(words, list): | |||
| raise ValueError(arg_name + " needs to be a list of words of type string.") | |||
| words_set = set() | |||
| for word in words: | |||
| if not isinstance(word, str): | |||
| raise ValueError("each word in " + arg_name + " needs to be type str.") | |||
| if word in words_set: | |||
| raise ValueError(arg_name + " contains duplicate word: " + word + ".") | |||
| words_set.add(word) | |||
| return words_set | |||
| def check_lookup(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| @@ -52,13 +67,17 @@ def check_from_file(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3] | |||
| file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5] | |||
| if "file_path" in kwargs: | |||
| file_path = kwargs.get("file_path") | |||
| if "delimiter" in kwargs: | |||
| delimiter = kwargs.get("delimiter") | |||
| if "vocab_size" in kwargs: | |||
| vocab_size = kwargs.get("vocab_size") | |||
| if "special_tokens" in kwargs: | |||
| special_tokens = kwargs.get("special_tokens") | |||
| if "special_first" in kwargs: | |||
| special_first = kwargs.get("special_first") | |||
| if not isinstance(file_path, str): | |||
| raise ValueError("file_path needs to be str.") | |||
| @@ -73,9 +92,24 @@ def check_from_file(method): | |||
| raise ValueError("vocab size needs to be a positive integer.") | |||
| else: | |||
| vocab_size = -1 | |||
| if special_first is None: | |||
| special_first = True | |||
| if not isinstance(special_first, bool): | |||
| raise ValueError("special_first needs to be a boolean value") | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||
| kwargs["file_path"] = file_path | |||
| kwargs["delimiter"] = delimiter | |||
| kwargs["vocab_size"] = vocab_size | |||
| kwargs["special_tokens"] = special_tokens | |||
| kwargs["special_first"] = special_first | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -86,16 +120,32 @@ def check_from_list(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word_list, = (list(args) + [None])[:1] | |||
| word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3] | |||
| if "word_list" in kwargs: | |||
| word_list = kwargs.get("word_list") | |||
| if not isinstance(word_list, list): | |||
| raise ValueError("word_list needs to be a list of words.") | |||
| for word in word_list: | |||
| if not isinstance(word, str): | |||
| raise ValueError("each word in word list needs to be type str.") | |||
| if "special_tokens" in kwargs: | |||
| special_tokens = kwargs.get("special_tokens") | |||
| if "special_first" in kwargs: | |||
| special_first = kwargs.get("special_first") | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| word_set = check_unique_list_of_words(word_list, "word_list") | |||
| token_set = check_unique_list_of_words(special_tokens, "special_tokens") | |||
| intersect = word_set.intersection(token_set) | |||
| if intersect != set(): | |||
| raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") | |||
| if special_first is None: | |||
| special_first = True | |||
| if not isinstance(special_first, bool): | |||
| raise ValueError("special_first needs to be a boolean value.") | |||
| kwargs["word_list"] = word_list | |||
| kwargs["special_tokens"] = special_tokens | |||
| kwargs["special_first"] = special_first | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -113,9 +163,9 @@ def check_from_dict(method): | |||
| raise ValueError("word_dict needs to be a list of word,id pairs.") | |||
| for word, word_id in word_dict.items(): | |||
| if not isinstance(word, str): | |||
| raise ValueError("each word in word_dict needs to be type str.") | |||
| raise ValueError("Each word in word_dict needs to be type string.") | |||
| if not (isinstance(word_id, int) and word_id >= 0): | |||
| raise ValueError("each word id needs to be positive integer.") | |||
| raise ValueError("Each word id needs to be positive integer.") | |||
| kwargs["word_dict"] = word_dict | |||
| return method(self, **kwargs) | |||
| @@ -135,11 +185,11 @@ def check_jieba_init(method): | |||
| mp_path = kwargs.get("mp_path") | |||
| if hmm_path is None: | |||
| raise ValueError( | |||
| "the dict of HMMSegment in cppjieba is not provided.") | |||
| "The dict of HMMSegment in cppjieba is not provided.") | |||
| kwargs["hmm_path"] = hmm_path | |||
| if mp_path is None: | |||
| raise ValueError( | |||
| "the dict of MPSegment in cppjieba is not provided.") | |||
| "The dict of MPSegment in cppjieba is not provided.") | |||
| kwargs["mp_path"] = mp_path | |||
| if model is not None: | |||
| kwargs["model"] = model | |||
| @@ -171,7 +221,7 @@ def check_jieba_add_word(method): | |||
| def check_jieba_add_dict(method): | |||
| """Wrapper method to check the parameters of add dict""" | |||
| """Wrapper method to check the parameters of add dict.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -189,10 +239,10 @@ def check_jieba_add_dict(method): | |||
| def check_from_dataset(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| # def from_dataset(cls, dataset, columns, freq_range=None, top_k=None): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| dataset, columns, freq_range, top_k = (list(args) + 4 * [None])[:4] | |||
| dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6] | |||
| if "dataset" in kwargs: | |||
| dataset = kwargs.get("dataset") | |||
| if "columns" in kwargs: | |||
| @@ -201,6 +251,10 @@ def check_from_dataset(method): | |||
| freq_range = kwargs.get("freq_range") | |||
| if "top_k" in kwargs: | |||
| top_k = kwargs.get("top_k") | |||
| if "special_tokens" in kwargs: | |||
| special_tokens = kwargs.get("special_tokens") | |||
| if "special_first" in kwargs: | |||
| special_first = kwargs.get("special_first") | |||
| if columns is None: | |||
| columns = [] | |||
| @@ -232,10 +286,23 @@ def check_from_dataset(method): | |||
| if isinstance(top_k, int) and top_k <= 0: | |||
| raise ValueError("top_k needs to be a positive integer.") | |||
| if special_first is None: | |||
| special_first = True | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| if not isinstance(special_first, bool): | |||
| raise ValueError("special_first needs to be a boolean value.") | |||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||
| kwargs["dataset"] = dataset | |||
| kwargs["columns"] = columns | |||
| kwargs["freq_range"] = freq_range | |||
| kwargs["top_k"] = top_k | |||
| kwargs["special_tokens"] = special_tokens | |||
| kwargs["special_first"] = special_first | |||
| return method(self, **kwargs) | |||
| @@ -0,0 +1,8 @@ | |||
| w1 | |||
| w2 | |||
| w3 | |||
| w4 | |||
| w5 | |||
| w6 | |||
| w7 | |||
| w8 | |||
| @@ -23,19 +23,21 @@ import mindspore.dataset.text as text | |||
| def test_demo_basic_from_dataset(): | |||
| """ this is a tutorial on how from_dataset should be used in a normal use case""" | |||
| data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | |||
| vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None) | |||
| vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"], | |||
| special_first=True) | |||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | |||
| res = [] | |||
| for d in data.create_dict_iterator(): | |||
| res.append(d["text"].item()) | |||
| assert res == [4, 5, 3, 6, 7, 2] | |||
| assert res == [4, 5, 3, 6, 7, 2], res | |||
| def test_demo_basic_from_dataset_with_tokenizer(): | |||
| """ this is a tutorial on how from_dataset should be used in a normal use case with tokenizer""" | |||
| data = ds.TextFileDataset("../data/dataset/testTokenizerData/1.txt", shuffle=False) | |||
| data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer()) | |||
| vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None) | |||
| vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"], | |||
| special_first=True) | |||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | |||
| res = [] | |||
| for d in data.create_dict_iterator(): | |||
| @@ -55,7 +57,8 @@ def test_from_dataset(): | |||
| def test_config(freq_range, top_k): | |||
| corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"]) | |||
| vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k) | |||
| vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k, special_tokens=["<pad>", "<unk>"], | |||
| special_first=True) | |||
| corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab)) | |||
| res = [] | |||
| for d in corpus_dataset.create_dict_iterator(): | |||
| @@ -87,6 +90,35 @@ def test_from_dataset(): | |||
| assert test6_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test6_res) | |||
| def test_from_dataset_special_token(): | |||
| """ test build vocab with generator dataset """ | |||
| def gen_corpus(): | |||
| # key: word, value: number of occurrences, reason for using letters is so their order is apparent | |||
| corpus = {"D": 1, "C": 1, "B": 1, "A": 1} | |||
| for k, v in corpus.items(): | |||
| yield (np.array([k] * v, dtype='S'),) | |||
| def gen_input(texts): | |||
| for word in texts.split(" "): | |||
| yield (np.array(word, dtype='S'),) | |||
| def test_config(texts, top_k, special_tokens, special_first): | |||
| corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"]) | |||
| vocab = text.Vocab.from_dataset(corpus_dataset, None, None, top_k, special_tokens, special_first) | |||
| data = ds.GeneratorDataset(gen_input(texts), column_names=["text"]) | |||
| data = data.map(input_columns="text", operations=text.Lookup(vocab)) | |||
| res = [] | |||
| for d in data.create_dict_iterator(): | |||
| res.append(d["text"].item()) | |||
| return res | |||
| # test special tokens are inserted before | |||
| assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], True) == [2, 3, 4, 5, 0, 1] | |||
| # test special tokens are inserted after | |||
| assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], False) == [0, 1, 2, 3, 4, 5] | |||
| def test_from_dataset_exceptions(): | |||
| """ test various exceptions during that are checked in validator """ | |||
| @@ -105,8 +137,10 @@ def test_from_dataset_exceptions(): | |||
| test_config("text", (2, 3), 0, "top_k needs to be a positive integer") | |||
| test_config([123], (2, 3), 0, "columns need to be a list of strings") | |||
| if __name__ == '__main__': | |||
| test_demo_basic_from_dataset() | |||
| test_from_dataset() | |||
| test_from_dataset_exceptions() | |||
| test_demo_basic_from_dataset_with_tokenizer() | |||
| test_from_dataset_special_token() | |||
| @@ -33,7 +33,7 @@ def test_on_tokenized_line(): | |||
| word = line.split(',')[0] | |||
| jieba_op.add_word(word) | |||
| data = data.map(input_columns=["text"], operations=jieba_op) | |||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",") | |||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",", special_tokens=["<pad>", "<unk>"]) | |||
| lookup = text.Lookup(vocab) | |||
| data = data.map(input_columns=["text"], operations=lookup) | |||
| res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], | |||
| @@ -1,13 +1,31 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| # this file contains "home is behind the world head" each word is 1 line | |||
| DATA_FILE = "../data/dataset/testVocab/words.txt" | |||
| VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt" | |||
| SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt" | |||
| def test_from_list(): | |||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" ")) | |||
| def test_from_list_tutorial(): | |||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True) | |||
| lookup = text.Lookup(vocab) | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.map(input_columns=["text"], operations=lookup) | |||
| @@ -18,8 +36,8 @@ def test_from_list(): | |||
| ind += 1 | |||
| def test_from_file(): | |||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",") | |||
| def test_from_file_tutorial(): | |||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",", None, ["<pad>", "<unk>"], True) | |||
| lookup = text.Lookup(vocab) | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.map(input_columns=["text"], operations=lookup) | |||
| @@ -30,7 +48,7 @@ def test_from_file(): | |||
| ind += 1 | |||
| def test_from_dict(): | |||
| def test_from_dict_tutorial(): | |||
| vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6}) | |||
| lookup = text.Lookup(vocab, 6) # default value is -1 | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| @@ -41,7 +59,61 @@ def test_from_dict(): | |||
| assert d["text"] == res[ind], ind | |||
| ind += 1 | |||
| def test_from_list(): | |||
| def gen(texts): | |||
| for word in texts.split(" "): | |||
| yield (np.array(word, dtype='S'),) | |||
| def test_config(lookup_str, vocab_input, special_tokens, special_first): | |||
| try: | |||
| vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first) | |||
| data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) | |||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | |||
| res = [] | |||
| for d in data.create_dict_iterator(): | |||
| res.append(d["text"].item()) | |||
| return res | |||
| except ValueError as e: | |||
| return str(e) | |||
| # test normal operations | |||
| assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], True) == [2, 3, 4, 0, 1] | |||
| assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False) == [0, 1, 2, 3, 4] | |||
| assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True) == [2, 1, 0] | |||
| assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False) == [2, 1, 0] | |||
| # test exceptions | |||
| assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True) | |||
| assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True) | |||
| assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True) | |||
| def test_from_file(): | |||
| def gen(texts): | |||
| for word in texts.split(" "): | |||
| yield (np.array(word, dtype='S'),) | |||
| def test_config(lookup_str, special_tokens, special_first): | |||
| try: | |||
| vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, special_tokens=special_tokens, special_first=special_first) | |||
| data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) | |||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | |||
| res = [] | |||
| for d in data.create_dict_iterator(): | |||
| res.append(d["text"].item()) | |||
| return res | |||
| except ValueError as e: | |||
| return str(e) | |||
| assert test_config("w1 w2 w3", ["s1", "s2", "s3"], True) == [3, 4, 5] | |||
| assert test_config("w1 w2 w3", ["s1", "s2", "s3"], False) == [0, 1, 2] | |||
| assert "special_tokens contains duplicate" in test_config("w1", ["s1", "s1"], True) | |||
| if __name__ == '__main__': | |||
| test_from_list_tutorial() | |||
| test_from_file_tutorial() | |||
| test_from_dict_tutorial() | |||
| test_from_list() | |||
| test_from_file() | |||
| test_from_dict() | |||