Merge pull request !2187 from ZiruiWu/vocab_reworktags/v0.5.0-beta
| @@ -56,10 +56,10 @@ Status NgramOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "n gram needs to be a positive number.\n"); | CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "n gram needs to be a positive number.\n"); | ||||
| int32_t start_ind = l_len_ - std::min(l_len_, n - 1); | int32_t start_ind = l_len_ - std::min(l_len_, n - 1); | ||||
| int32_t end_ind = offsets.size() - r_len_ + std::min(r_len_, n - 1); | int32_t end_ind = offsets.size() - r_len_ + std::min(r_len_, n - 1); | ||||
| if (end_ind - start_ind < n) { | |||||
| if (end_ind - start_ind <= n) { | |||||
| res.emplace_back(std::string()); // push back empty string | res.emplace_back(std::string()); // push back empty string | ||||
| } else { | } else { | ||||
| if (end_ind - n < 0) RETURN_STATUS_UNEXPECTED("loop condition error!"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(end_ind - n >= 0, "Incorrect loop condition"); | |||||
| for (int i = start_ind; i < end_ind - n; i++) { | for (int i = start_ind; i < end_ind - n; i++) { | ||||
| res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); | res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); | ||||
| @@ -4893,15 +4893,15 @@ class BuildVocabDataset(DatasetOp): | |||||
| text.Vocab.from_dataset() | text.Vocab.from_dataset() | ||||
| Args: | Args: | ||||
| vocab(Vocab): vocab object | |||||
| columns(str or list, optional): column names to get words from. It can be a list of column names. | |||||
| (Default is None where all columns will be used. If any column isn't string type, will return error) | |||||
| vocab(Vocab): vocab object. | |||||
| columns(str or list, optional): column names to get words from. It can be a list of column names (Default is | |||||
| None, all columns are used, return error if any column isn't string). | |||||
| freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency | freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency | ||||
| range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency | range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency | ||||
| can be None, which corresponds to 0/total_words separately (default is None, all words are included) | |||||
| can be None, which corresponds to 0/total_words separately (default is None, all words are included). | |||||
| top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are | top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are | ||||
| taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default is None | |||||
| all words are included) | |||||
| taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default is None | |||||
| all words are included). | |||||
| Returns: | Returns: | ||||
| BuildVocabDataset | BuildVocabDataset | ||||
| @@ -51,7 +51,7 @@ def test_simple_ngram(): | |||||
| """ test simple gram with only one n value""" | """ test simple gram with only one n value""" | ||||
| plates_mottos = ["Friendly Manitoba", "Yours to Discover", "Land of Living Skies", | plates_mottos = ["Friendly Manitoba", "Yours to Discover", "Land of Living Skies", | ||||
| "Birthplace of the Confederation"] | "Birthplace of the Confederation"] | ||||
| n_gram_mottos = [[]] | |||||
| n_gram_mottos = [[""]] | |||||
| n_gram_mottos.append(["Yours to Discover"]) | n_gram_mottos.append(["Yours to Discover"]) | ||||
| n_gram_mottos.append(['Land of Living', 'of Living Skies']) | n_gram_mottos.append(['Land of Living', 'of Living Skies']) | ||||
| n_gram_mottos.append(['Birthplace of the', 'of the Confederation']) | n_gram_mottos.append(['Birthplace of the', 'of the Confederation']) | ||||
| @@ -81,6 +81,8 @@ def test_corner_cases(): | |||||
| for data in dataset.create_dict_iterator(): | for data in dataset.create_dict_iterator(): | ||||
| assert [d.decode("utf8") for d in data["text"]] == output_line, output_line | assert [d.decode("utf8") for d in data["text"]] == output_line, output_line | ||||
| # test tensor length smaller than n | |||||
| test_config("Lone Star", ["Lone Star", "", "", ""], [2, 3, 4, 5]) | |||||
| # test empty separator | # test empty separator | ||||
| test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="") | test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="") | ||||
| # test separator with longer length | # test separator with longer length | ||||