* remove and fix other unit tests * add more code commentstags/v0.2.0
| @@ -5,7 +5,8 @@ class Batch(object): | |||
| """Batch is an iterable object which iterates over mini-batches. | |||
| :: | |||
| for batch_x, batch_y in Batch(data_set): | |||
| for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||
| """ | |||
| @@ -15,6 +16,8 @@ class Batch(object): | |||
| :param dataset: a DataSet object | |||
| :param batch_size: int, the size of the batch | |||
| :param sampler: a Sampler object | |||
| :param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. | |||
| """ | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| @@ -30,17 +33,6 @@ class Batch(object): | |||
| return self | |||
| def __next__(self): | |||
| """ | |||
| :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||
| E.g. | |||
| :: | |||
| {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | |||
| batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||
| All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | |||
| """ | |||
| if self.curidx >= len(self.idx_list): | |||
| raise StopIteration | |||
| else: | |||
| @@ -117,22 +117,20 @@ class DataSet(object): | |||
| assert name in self.field_arrays | |||
| self.field_arrays[name].append(field) | |||
| def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): | |||
| def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||
| """ | |||
| :param name: | |||
| :param str name: | |||
| :param fields: | |||
| :param padding_val: | |||
| :param need_tensor: | |||
| :param is_target: | |||
| :param int padding_val: | |||
| :param bool is_input: | |||
| :param bool is_target: | |||
| :return: | |||
| """ | |||
| if len(self.field_arrays) != 0: | |||
| assert len(self) == len(fields) | |||
| self.field_arrays[name] = FieldArray(name, fields, | |||
| padding_val=padding_val, | |||
| need_tensor=need_tensor, | |||
| is_target=is_target) | |||
| self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | |||
| is_input=is_input) | |||
| def delete_field(self, name): | |||
| self.field_arrays.pop(name) | |||
| @@ -2,7 +2,19 @@ import numpy as np | |||
| class FieldArray(object): | |||
| """FieldArray is the collection of Instances of the same Field. | |||
| It is the basic element of DataSet class. | |||
| """ | |||
| def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||
| """ | |||
| :param str name: the name of the FieldArray | |||
| :param list content: a list of int, float, or other objects. | |||
| :param int padding_val: the integer for padding. Default: 0. | |||
| :param bool is_target: If True, this FieldArray is used to compute loss. | |||
| :param bool is_input: If True, this FieldArray is used to the model input. | |||
| """ | |||
| self.name = name | |||
| self.content = content | |||
| self.padding_val = padding_val | |||
| @@ -24,23 +36,28 @@ class FieldArray(object): | |||
| assert isinstance(name, int) | |||
| self.content[name] = val | |||
| def get(self, idxes): | |||
| if isinstance(idxes, int): | |||
| return self.content[idxes] | |||
| def get(self, indices): | |||
| """Fetch instances based on indices. | |||
| :param indices: an int, or a list of int. | |||
| :return: | |||
| """ | |||
| if isinstance(indices, int): | |||
| return self.content[indices] | |||
| assert self.is_input is True or self.is_target is True | |||
| batch_size = len(idxes) | |||
| batch_size = len(indices) | |||
| # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||
| if isinstance(self.content[0], int) or isinstance(self.content[0], float): | |||
| if self.dtype is None: | |||
| self.dtype = np.int64 if isinstance(self.content[0], int) else np.double | |||
| array = np.array([self.content[i] for i in idxes], dtype=self.dtype) | |||
| array = np.array([self.content[i] for i in indices], dtype=self.dtype) | |||
| else: | |||
| if self.dtype is None: | |||
| self.dtype = np.int64 | |||
| max_len = max([len(self.content[i]) for i in idxes]) | |||
| max_len = max([len(self.content[i]) for i in indices]) | |||
| array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | |||
| for i, idx in enumerate(idxes): | |||
| for i, idx in enumerate(indices): | |||
| array[i][:len(self.content[idx])] = self.content[idx] | |||
| return array | |||
| @@ -1,16 +1,27 @@ | |||
| class Instance(object): | |||
| """An instance which consists of Fields is an example in the DataSet. | |||
| """An Instance is an example of data. It is the collection of Fields. | |||
| :: | |||
| Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
| """ | |||
| def __init__(self, **fields): | |||
| """ | |||
| :param fields: a dict of (field name: field) | |||
| """ | |||
| self.fields = fields | |||
| def add_field(self, field_name, field): | |||
| """Add a new field to the instance. | |||
| :param field_name: str, the name of the field. | |||
| :param field: | |||
| """ | |||
| self.fields[field_name] = field | |||
| return self | |||
| def __getitem__(self, name): | |||
| if name in self.fields: | |||
| @@ -21,17 +32,5 @@ class Instance(object): | |||
| def __setitem__(self, name, field): | |||
| return self.add_field(name, field) | |||
| def __getattr__(self, item): | |||
| if hasattr(self, 'fields') and item in self.fields: | |||
| return self.fields[item] | |||
| else: | |||
| raise AttributeError('{} does not exist.'.format(item)) | |||
| def __setattr__(self, key, value): | |||
| if hasattr(self, 'fields'): | |||
| self.__setitem__(key, value) | |||
| else: | |||
| super().__setattr__(key, value) | |||
| def __repr__(self): | |||
| return self.fields.__repr__() | |||
| @@ -1,5 +1,5 @@ | |||
| from copy import deepcopy | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
| @@ -20,6 +20,7 @@ def check_build_vocab(func): | |||
| if self.word2idx is None: | |||
| self.build_vocab() | |||
| return func(self, *args, **kwargs) | |||
| return _wrapper | |||
| @@ -34,6 +35,7 @@ class Vocabulary(object): | |||
| vocab["word"] | |||
| vocab.to_word(5) | |||
| """ | |||
| def __init__(self, need_default=True, max_size=None, min_freq=None): | |||
| """ | |||
| :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||
| @@ -54,24 +56,36 @@ class Vocabulary(object): | |||
| self.idx2word = None | |||
| def update(self, word_lst): | |||
| """add word or list of words into Vocabulary | |||
| """Add a list of words into the vocabulary. | |||
| :param word: a list of string or a single string | |||
| :param list word_lst: a list of strings | |||
| """ | |||
| self.word_count.update(word_lst) | |||
| def add(self, word): | |||
| """Add a single word into the vocabulary. | |||
| :param str word: a word or token. | |||
| """ | |||
| self.word_count[word] += 1 | |||
| def add_word(self, word): | |||
| """Add a single word into the vocabulary. | |||
| :param str word: a word or token. | |||
| """ | |||
| self.add(word) | |||
| def add_word_lst(self, word_lst): | |||
| self.update(word_lst) | |||
| """Add a list of words into the vocabulary. | |||
| :param list word_lst: a list of strings | |||
| """ | |||
| self.update(word_lst) | |||
| def build_vocab(self): | |||
| """build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||
| """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | |||
| """ | |||
| if self.has_default: | |||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||
| @@ -85,11 +99,12 @@ class Vocabulary(object): | |||
| if self.min_freq is not None: | |||
| words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
| start_idx = len(self.word2idx) | |||
| self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)}) | |||
| self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
| self.build_reverse_vocab() | |||
| def build_reverse_vocab(self): | |||
| """build 'index to word' dict based on 'word to index' dict | |||
| """Build 'index to word' dict based on 'word to index' dict. | |||
| """ | |||
| self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
| @@ -97,6 +112,15 @@ class Vocabulary(object): | |||
| def __len__(self): | |||
| return len(self.word2idx) | |||
| @check_build_vocab | |||
| def __contains__(self, item): | |||
| """Check if a word in vocabulary. | |||
| :param item: the word | |||
| :return: True or False | |||
| """ | |||
| return item in self.word2idx | |||
| def has_word(self, w): | |||
| return self.__contains__(w) | |||
| @@ -114,8 +138,8 @@ class Vocabulary(object): | |||
| raise ValueError("word {} not in vocabulary".format(w)) | |||
| def to_index(self, w): | |||
| """ like to_index(w) function, turn a word to the index | |||
| if w is not in Vocabulary, return the unknown label | |||
| """ Turn a word to an index. | |||
| If w is not in Vocabulary, return the unknown label. | |||
| :param str w: | |||
| """ | |||
| @@ -144,12 +168,14 @@ class Vocabulary(object): | |||
| def to_word(self, idx): | |||
| """given a word's index, return the word itself | |||
| :param int idx: | |||
| :param int idx: the index | |||
| :return str word: the indexed word | |||
| """ | |||
| return self.idx2word[idx] | |||
| def __getstate__(self): | |||
| """use to prepare data for pickle | |||
| """Use to prepare data for pickle. | |||
| """ | |||
| state = self.__dict__.copy() | |||
| # no need to pickle idx2word as it can be constructed from word2idx | |||
| @@ -157,16 +183,9 @@ class Vocabulary(object): | |||
| return state | |||
| def __setstate__(self, state): | |||
| """use to restore state from pickle | |||
| """Use to restore state from pickle. | |||
| """ | |||
| self.__dict__.update(state) | |||
| self.build_reverse_vocab() | |||
| @check_build_vocab | |||
| def __contains__(self, item): | |||
| """Check if a word in vocabulary. | |||
| :param item: the word | |||
| :return: True or False | |||
| """ | |||
| return item in self.word2idx | |||
| @@ -1,17 +1,18 @@ | |||
| import unittest | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.dataset import construct_dataset | |||
| from fastNLP.core.sampler import SequentialSampler | |||
| class TestCase1(unittest.TestCase): | |||
| def test(self): | |||
| dataset = DataSet([Instance(x=["I", "am", "here"])] * 40) | |||
| def test_simple(self): | |||
| dataset = construct_dataset( | |||
| [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
| dataset.set_target() | |||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) | |||
| for batch_x, batch_y in batch: | |||
| print(batch_x, batch_y) | |||
| # TODO: weird due to change in dataset.py | |||
| cnt = 0 | |||
| for _, _ in batch: | |||
| cnt += 1 | |||
| self.assertEqual(cnt, 10) | |||
| @@ -1,20 +1,20 @@ | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| class TestDataSet(unittest.TestCase): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| ] | |||
| unlabeled_data_list = [ | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"] | |||
| ] | |||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||
| def test_case_1(self): | |||
| # TODO: | |||
| pass | |||
| ds = DataSet() | |||
| ds.add_field(name="xx", fields=["a", "b", "e", "d"]) | |||
| self.assertTrue("xx" in ds.field_arrays) | |||
| self.assertEqual(len(ds.field_arrays["xx"]), 4) | |||
| self.assertEqual(ds.get_length(), 4) | |||
| self.assertEqual(ds.get_fields(), ds.field_arrays) | |||
| try: | |||
| ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"]) | |||
| except BaseException as e: | |||
| self.assertTrue(isinstance(e, AssertionError)) | |||
| @@ -1,42 +0,0 @@ | |||
| import unittest | |||
| from fastNLP.core.field import CharTextField, LabelField, SeqLabelField | |||
| class TestField(unittest.TestCase): | |||
| def test_char_field(self): | |||
| text = "PhD applicants must submit a Research Plan and a resume " \ | |||
| "specify your class ranking written in English and a list of research" \ | |||
| " publications if any".split() | |||
| max_word_len = max([len(w) for w in text]) | |||
| field = CharTextField(text, max_word_len, is_target=False) | |||
| all_char = set() | |||
| for word in text: | |||
| all_char.update([ch for ch in word]) | |||
| char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)} | |||
| self.assertEqual(field.index(char_vocab), | |||
| [[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text]) | |||
| self.assertEqual(field.get_length(), len(text)) | |||
| self.assertEqual(field.contents(), text) | |||
| tensor = field.to_tensor(50) | |||
| self.assertEqual(tuple(tensor.shape), (50, max_word_len)) | |||
| def test_label_field(self): | |||
| label = LabelField("A", is_target=True) | |||
| self.assertEqual(label.get_length(), 1) | |||
| self.assertEqual(label.index({"A": 10}), 10) | |||
| label = LabelField(30, is_target=True) | |||
| self.assertEqual(label.get_length(), 1) | |||
| tensor = label.to_tensor(0) | |||
| self.assertEqual(tensor.shape, ()) | |||
| self.assertEqual(int(tensor), 30) | |||
| def test_seq_label_field(self): | |||
| seq = ["a", "b", "c", "d", "a", "c", "a", "b"] | |||
| field = SeqLabelField(seq) | |||
| vocab = {"a": 10, "b": 20, "c": 30, "d": 40} | |||
| self.assertEqual(field.index(vocab), [vocab[x] for x in seq]) | |||
| tensor = field.to_tensor(10) | |||
| self.assertEqual(tuple(tensor.shape), (10,)) | |||
| @@ -0,0 +1,6 @@ | |||
| import unittest | |||
| class TestFieldArray(unittest.TestCase): | |||
| def test(self): | |||
| pass | |||
| @@ -0,0 +1,29 @@ | |||
| import unittest | |||
| from fastNLP.core.instance import Instance | |||
| class TestCase(unittest.TestCase): | |||
| def test_init(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
| ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||
| self.assertTrue(isinstance(ins.fields, dict)) | |||
| self.assertEqual(ins.fields, fields) | |||
| ins = Instance(**fields) | |||
| self.assertEqual(ins.fields, fields) | |||
| def test_add_field(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
| ins = Instance(**fields) | |||
| ins.add_field("z", [1, 1, 1]) | |||
| fields.update({"z": [1, 1, 1]}) | |||
| self.assertEqual(ins.fields, fields) | |||
| def test_get_item(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
| ins = Instance(**fields) | |||
| self.assertEqual(ins["x"], [1, 2, 3]) | |||
| self.assertEqual(ins["y"], [4, 5, 6]) | |||
| self.assertEqual(ins["z"], [1, 1, 1]) | |||
| @@ -1,44 +1,42 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | |||
| k_means_1d, k_means_bucketing, simple_sort_bucketing | |||
| def test_convert_to_torch_tensor(): | |||
| data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||
| ans = convert_to_torch_tensor(data, False) | |||
| assert isinstance(ans, torch.Tensor) | |||
| assert tuple(ans.shape) == (3, 5) | |||
| def test_sequential_sampler(): | |||
| sampler = SequentialSampler() | |||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
| for idx, i in enumerate(sampler(data)): | |||
| assert idx == i | |||
| def test_random_sampler(): | |||
| sampler = RandomSampler() | |||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
| ans = [data[i] for i in sampler(data)] | |||
| assert len(ans) == len(data) | |||
| for d in ans: | |||
| assert d in data | |||
| def test_k_means(): | |||
| centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||
| centroids, assign = list(centroids), list(assign) | |||
| assert len(centroids) == 2 | |||
| assert len(assign) == 10 | |||
| def test_k_means_bucketing(): | |||
| res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||
| assert len(res) == 2 | |||
| def test_simple_sort_bucketing(): | |||
| _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||
| assert len(_) == 10 | |||
| class TestSampler(unittest.TestCase): | |||
| def test_convert_to_torch_tensor(self): | |||
| data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||
| ans = convert_to_torch_tensor(data, False) | |||
| assert isinstance(ans, torch.Tensor) | |||
| assert tuple(ans.shape) == (3, 5) | |||
| def test_sequential_sampler(self): | |||
| sampler = SequentialSampler() | |||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
| for idx, i in enumerate(sampler(data)): | |||
| assert idx == i | |||
| def test_random_sampler(self): | |||
| sampler = RandomSampler() | |||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
| ans = [data[i] for i in sampler(data)] | |||
| assert len(ans) == len(data) | |||
| for d in ans: | |||
| assert d in data | |||
| def test_k_means(self): | |||
| centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||
| centroids, assign = list(centroids), list(assign) | |||
| assert len(centroids) == 2 | |||
| assert len(assign) == 10 | |||
| def test_k_means_bucketing(self): | |||
| res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||
| assert len(res) == 2 | |||
| def test_simple_sort_bucketing(self): | |||
| _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||
| assert len(_) == 10 | |||
| @@ -1,31 +0,0 @@ | |||
| import unittest | |||
| from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||
| class TestVocabulary(unittest.TestCase): | |||
| def test_vocab(self): | |||
| import _pickle as pickle | |||
| import os | |||
| vocab = Vocabulary() | |||
| filename = 'vocab' | |||
| vocab.update(filename) | |||
| vocab.update([filename, ['a'], [['b']], ['c']]) | |||
| idx = vocab[filename] | |||
| before_pic = (vocab.to_word(idx), vocab[filename]) | |||
| with open(filename, 'wb') as f: | |||
| pickle.dump(vocab, f) | |||
| with open(filename, 'rb') as f: | |||
| vocab = pickle.load(f) | |||
| os.remove(filename) | |||
| vocab.build_reverse_vocab() | |||
| after_pic = (vocab.to_word(idx), vocab[filename]) | |||
| TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||
| TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||
| TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||
| self.assertEqual(before_pic, after_pic) | |||
| self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||
| self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,61 @@ | |||
| import unittest | |||
| from collections import Counter | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
| "works", "well", "in", "most", "cases", "scales", "well"] | |||
| counter = Counter(text) | |||
| class TestAdd(unittest.TestCase): | |||
| def test_add(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| for word in text: | |||
| vocab.add(word) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_add_word(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| for word in text: | |||
| vocab.add_word(word) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_add_word_lst(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab.add_word_lst(text) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_update(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| class TestIndexing(unittest.TestCase): | |||
| def test_len(self): | |||
| vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| self.assertEqual(len(vocab), len(counter)) | |||
| def test_contains(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| self.assertTrue(text[-1] in vocab) | |||
| self.assertFalse("~!@#" in vocab) | |||
| self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | |||
| self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||
| def test_index(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| res = [vocab[w] for w in set(text)] | |||
| self.assertEqual(len(res), len(set(res))) | |||
| res = [vocab.to_index(w) for w in set(text)] | |||
| self.assertEqual(len(res), len(set(res))) | |||
| def test_to_word(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||