* remove and fix other unit tests * add more code commentstags/v0.2.0
| @@ -5,7 +5,8 @@ class Batch(object): | |||||
| """Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
| :: | :: | ||||
| for batch_x, batch_y in Batch(data_set): | |||||
| for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||||
| """ | """ | ||||
| @@ -15,6 +16,8 @@ class Batch(object): | |||||
| :param dataset: a DataSet object | :param dataset: a DataSet object | ||||
| :param batch_size: int, the size of the batch | :param batch_size: int, the size of the batch | ||||
| :param sampler: a Sampler object | :param sampler: a Sampler object | ||||
| :param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. | |||||
| """ | """ | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| @@ -30,17 +33,6 @@ class Batch(object): | |||||
| return self | return self | ||||
| def __next__(self): | def __next__(self): | ||||
| """ | |||||
| :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
| E.g. | |||||
| :: | |||||
| {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | |||||
| batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
| All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | |||||
| """ | |||||
| if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
| raise StopIteration | raise StopIteration | ||||
| else: | else: | ||||
| @@ -117,22 +117,20 @@ class DataSet(object): | |||||
| assert name in self.field_arrays | assert name in self.field_arrays | ||||
| self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
| def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): | |||||
| def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||||
| """ | """ | ||||
| :param name: | |||||
| :param str name: | |||||
| :param fields: | :param fields: | ||||
| :param padding_val: | |||||
| :param need_tensor: | |||||
| :param is_target: | |||||
| :param int padding_val: | |||||
| :param bool is_input: | |||||
| :param bool is_target: | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
| assert len(self) == len(fields) | assert len(self) == len(fields) | ||||
| self.field_arrays[name] = FieldArray(name, fields, | |||||
| padding_val=padding_val, | |||||
| need_tensor=need_tensor, | |||||
| is_target=is_target) | |||||
| self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | |||||
| is_input=is_input) | |||||
| def delete_field(self, name): | def delete_field(self, name): | ||||
| self.field_arrays.pop(name) | self.field_arrays.pop(name) | ||||
| @@ -2,7 +2,19 @@ import numpy as np | |||||
| class FieldArray(object): | class FieldArray(object): | ||||
| """FieldArray is the collection of Instances of the same Field. | |||||
| It is the basic element of DataSet class. | |||||
| """ | |||||
| def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | ||||
| """ | |||||
| :param str name: the name of the FieldArray | |||||
| :param list content: a list of int, float, or other objects. | |||||
| :param int padding_val: the integer for padding. Default: 0. | |||||
| :param bool is_target: If True, this FieldArray is used to compute loss. | |||||
| :param bool is_input: If True, this FieldArray is used to the model input. | |||||
| """ | |||||
| self.name = name | self.name = name | ||||
| self.content = content | self.content = content | ||||
| self.padding_val = padding_val | self.padding_val = padding_val | ||||
| @@ -24,23 +36,28 @@ class FieldArray(object): | |||||
| assert isinstance(name, int) | assert isinstance(name, int) | ||||
| self.content[name] = val | self.content[name] = val | ||||
| def get(self, idxes): | |||||
| if isinstance(idxes, int): | |||||
| return self.content[idxes] | |||||
| def get(self, indices): | |||||
| """Fetch instances based on indices. | |||||
| :param indices: an int, or a list of int. | |||||
| :return: | |||||
| """ | |||||
| if isinstance(indices, int): | |||||
| return self.content[indices] | |||||
| assert self.is_input is True or self.is_target is True | assert self.is_input is True or self.is_target is True | ||||
| batch_size = len(idxes) | |||||
| batch_size = len(indices) | |||||
| # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | ||||
| if isinstance(self.content[0], int) or isinstance(self.content[0], float): | if isinstance(self.content[0], int) or isinstance(self.content[0], float): | ||||
| if self.dtype is None: | if self.dtype is None: | ||||
| self.dtype = np.int64 if isinstance(self.content[0], int) else np.double | self.dtype = np.int64 if isinstance(self.content[0], int) else np.double | ||||
| array = np.array([self.content[i] for i in idxes], dtype=self.dtype) | |||||
| array = np.array([self.content[i] for i in indices], dtype=self.dtype) | |||||
| else: | else: | ||||
| if self.dtype is None: | if self.dtype is None: | ||||
| self.dtype = np.int64 | self.dtype = np.int64 | ||||
| max_len = max([len(self.content[i]) for i in idxes]) | |||||
| max_len = max([len(self.content[i]) for i in indices]) | |||||
| array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | ||||
| for i, idx in enumerate(idxes): | |||||
| for i, idx in enumerate(indices): | |||||
| array[i][:len(self.content[idx])] = self.content[idx] | array[i][:len(self.content[idx])] = self.content[idx] | ||||
| return array | return array | ||||
| @@ -1,16 +1,27 @@ | |||||
| class Instance(object): | class Instance(object): | ||||
| """An instance which consists of Fields is an example in the DataSet. | |||||
| """An Instance is an example of data. It is the collection of Fields. | |||||
| :: | |||||
| Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
| """ | """ | ||||
| def __init__(self, **fields): | def __init__(self, **fields): | ||||
| """ | |||||
| :param fields: a dict of (field name: field) | |||||
| """ | |||||
| self.fields = fields | self.fields = fields | ||||
| def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
| """Add a new field to the instance. | |||||
| :param field_name: str, the name of the field. | |||||
| :param field: | |||||
| """ | |||||
| self.fields[field_name] = field | self.fields[field_name] = field | ||||
| return self | |||||
| def __getitem__(self, name): | def __getitem__(self, name): | ||||
| if name in self.fields: | if name in self.fields: | ||||
| @@ -21,17 +32,5 @@ class Instance(object): | |||||
| def __setitem__(self, name, field): | def __setitem__(self, name, field): | ||||
| return self.add_field(name, field) | return self.add_field(name, field) | ||||
| def __getattr__(self, item): | |||||
| if hasattr(self, 'fields') and item in self.fields: | |||||
| return self.fields[item] | |||||
| else: | |||||
| raise AttributeError('{} does not exist.'.format(item)) | |||||
| def __setattr__(self, key, value): | |||||
| if hasattr(self, 'fields'): | |||||
| self.__setitem__(key, value) | |||||
| else: | |||||
| super().__setattr__(key, value) | |||||
| def __repr__(self): | def __repr__(self): | ||||
| return self.fields.__repr__() | return self.fields.__repr__() | ||||
| @@ -1,5 +1,5 @@ | |||||
| from copy import deepcopy | |||||
| from collections import Counter | from collections import Counter | ||||
| from copy import deepcopy | |||||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | ||||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | ||||
| @@ -20,6 +20,7 @@ def check_build_vocab(func): | |||||
| if self.word2idx is None: | if self.word2idx is None: | ||||
| self.build_vocab() | self.build_vocab() | ||||
| return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
| return _wrapper | return _wrapper | ||||
| @@ -34,6 +35,7 @@ class Vocabulary(object): | |||||
| vocab["word"] | vocab["word"] | ||||
| vocab.to_word(5) | vocab.to_word(5) | ||||
| """ | """ | ||||
| def __init__(self, need_default=True, max_size=None, min_freq=None): | def __init__(self, need_default=True, max_size=None, min_freq=None): | ||||
| """ | """ | ||||
| :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | ||||
| @@ -54,24 +56,36 @@ class Vocabulary(object): | |||||
| self.idx2word = None | self.idx2word = None | ||||
| def update(self, word_lst): | def update(self, word_lst): | ||||
| """add word or list of words into Vocabulary | |||||
| """Add a list of words into the vocabulary. | |||||
| :param word: a list of string or a single string | |||||
| :param list word_lst: a list of strings | |||||
| """ | """ | ||||
| self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
| def add(self, word): | def add(self, word): | ||||
| """Add a single word into the vocabulary. | |||||
| :param str word: a word or token. | |||||
| """ | |||||
| self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
| def add_word(self, word): | def add_word(self, word): | ||||
| """Add a single word into the vocabulary. | |||||
| :param str word: a word or token. | |||||
| """ | |||||
| self.add(word) | self.add(word) | ||||
| def add_word_lst(self, word_lst): | def add_word_lst(self, word_lst): | ||||
| self.update(word_lst) | |||||
| """Add a list of words into the vocabulary. | |||||
| :param list word_lst: a list of strings | |||||
| """ | |||||
| self.update(word_lst) | |||||
| def build_vocab(self): | def build_vocab(self): | ||||
| """build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||||
| """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | |||||
| """ | """ | ||||
| if self.has_default: | if self.has_default: | ||||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | ||||
| @@ -85,11 +99,12 @@ class Vocabulary(object): | |||||
| if self.min_freq is not None: | if self.min_freq is not None: | ||||
| words = filter(lambda kv: kv[1] >= self.min_freq, words) | words = filter(lambda kv: kv[1] >= self.min_freq, words) | ||||
| start_idx = len(self.word2idx) | start_idx = len(self.word2idx) | ||||
| self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)}) | |||||
| self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||||
| self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
| def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
| """build 'index to word' dict based on 'word to index' dict | |||||
| """Build 'index to word' dict based on 'word to index' dict. | |||||
| """ | """ | ||||
| self.idx2word = {i: w for w, i in self.word2idx.items()} | self.idx2word = {i: w for w, i in self.word2idx.items()} | ||||
| @@ -97,6 +112,15 @@ class Vocabulary(object): | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.word2idx) | return len(self.word2idx) | ||||
| @check_build_vocab | |||||
| def __contains__(self, item): | |||||
| """Check if a word in vocabulary. | |||||
| :param item: the word | |||||
| :return: True or False | |||||
| """ | |||||
| return item in self.word2idx | |||||
| def has_word(self, w): | def has_word(self, w): | ||||
| return self.__contains__(w) | return self.__contains__(w) | ||||
| @@ -114,8 +138,8 @@ class Vocabulary(object): | |||||
| raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
| def to_index(self, w): | def to_index(self, w): | ||||
| """ like to_index(w) function, turn a word to the index | |||||
| if w is not in Vocabulary, return the unknown label | |||||
| """ Turn a word to an index. | |||||
| If w is not in Vocabulary, return the unknown label. | |||||
| :param str w: | :param str w: | ||||
| """ | """ | ||||
| @@ -144,12 +168,14 @@ class Vocabulary(object): | |||||
| def to_word(self, idx): | def to_word(self, idx): | ||||
| """given a word's index, return the word itself | """given a word's index, return the word itself | ||||
| :param int idx: | |||||
| :param int idx: the index | |||||
| :return str word: the indexed word | |||||
| """ | """ | ||||
| return self.idx2word[idx] | return self.idx2word[idx] | ||||
| def __getstate__(self): | def __getstate__(self): | ||||
| """use to prepare data for pickle | |||||
| """Use to prepare data for pickle. | |||||
| """ | """ | ||||
| state = self.__dict__.copy() | state = self.__dict__.copy() | ||||
| # no need to pickle idx2word as it can be constructed from word2idx | # no need to pickle idx2word as it can be constructed from word2idx | ||||
| @@ -157,16 +183,9 @@ class Vocabulary(object): | |||||
| return state | return state | ||||
| def __setstate__(self, state): | def __setstate__(self, state): | ||||
| """use to restore state from pickle | |||||
| """Use to restore state from pickle. | |||||
| """ | """ | ||||
| self.__dict__.update(state) | self.__dict__.update(state) | ||||
| self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
| @check_build_vocab | |||||
| def __contains__(self, item): | |||||
| """Check if a word in vocabulary. | |||||
| :param item: the word | |||||
| :return: True or False | |||||
| """ | |||||
| return item in self.word2idx | |||||
| @@ -1,17 +1,18 @@ | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.instance import Instance | |||||
| from fastNLP.core.dataset import construct_dataset | |||||
| from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
| class TestCase1(unittest.TestCase): | class TestCase1(unittest.TestCase): | ||||
| def test(self): | |||||
| dataset = DataSet([Instance(x=["I", "am", "here"])] * 40) | |||||
| def test_simple(self): | |||||
| dataset = construct_dataset( | |||||
| [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||||
| dataset.set_target() | |||||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) | batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) | ||||
| for batch_x, batch_y in batch: | |||||
| print(batch_x, batch_y) | |||||
| # TODO: weird due to change in dataset.py | |||||
| cnt = 0 | |||||
| for _, _ in batch: | |||||
| cnt += 1 | |||||
| self.assertEqual(cnt, 10) | |||||
| @@ -1,20 +1,20 @@ | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import DataSet | |||||
| class TestDataSet(unittest.TestCase): | class TestDataSet(unittest.TestCase): | ||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
| def test_case_1(self): | def test_case_1(self): | ||||
| # TODO: | |||||
| pass | |||||
| ds = DataSet() | |||||
| ds.add_field(name="xx", fields=["a", "b", "e", "d"]) | |||||
| self.assertTrue("xx" in ds.field_arrays) | |||||
| self.assertEqual(len(ds.field_arrays["xx"]), 4) | |||||
| self.assertEqual(ds.get_length(), 4) | |||||
| self.assertEqual(ds.get_fields(), ds.field_arrays) | |||||
| try: | |||||
| ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"]) | |||||
| except BaseException as e: | |||||
| self.assertTrue(isinstance(e, AssertionError)) | |||||
| @@ -1,42 +0,0 @@ | |||||
| import unittest | |||||
| from fastNLP.core.field import CharTextField, LabelField, SeqLabelField | |||||
| class TestField(unittest.TestCase): | |||||
| def test_char_field(self): | |||||
| text = "PhD applicants must submit a Research Plan and a resume " \ | |||||
| "specify your class ranking written in English and a list of research" \ | |||||
| " publications if any".split() | |||||
| max_word_len = max([len(w) for w in text]) | |||||
| field = CharTextField(text, max_word_len, is_target=False) | |||||
| all_char = set() | |||||
| for word in text: | |||||
| all_char.update([ch for ch in word]) | |||||
| char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)} | |||||
| self.assertEqual(field.index(char_vocab), | |||||
| [[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text]) | |||||
| self.assertEqual(field.get_length(), len(text)) | |||||
| self.assertEqual(field.contents(), text) | |||||
| tensor = field.to_tensor(50) | |||||
| self.assertEqual(tuple(tensor.shape), (50, max_word_len)) | |||||
| def test_label_field(self): | |||||
| label = LabelField("A", is_target=True) | |||||
| self.assertEqual(label.get_length(), 1) | |||||
| self.assertEqual(label.index({"A": 10}), 10) | |||||
| label = LabelField(30, is_target=True) | |||||
| self.assertEqual(label.get_length(), 1) | |||||
| tensor = label.to_tensor(0) | |||||
| self.assertEqual(tensor.shape, ()) | |||||
| self.assertEqual(int(tensor), 30) | |||||
| def test_seq_label_field(self): | |||||
| seq = ["a", "b", "c", "d", "a", "c", "a", "b"] | |||||
| field = SeqLabelField(seq) | |||||
| vocab = {"a": 10, "b": 20, "c": 30, "d": 40} | |||||
| self.assertEqual(field.index(vocab), [vocab[x] for x in seq]) | |||||
| tensor = field.to_tensor(10) | |||||
| self.assertEqual(tuple(tensor.shape), (10,)) | |||||
| @@ -0,0 +1,6 @@ | |||||
| import unittest | |||||
| class TestFieldArray(unittest.TestCase): | |||||
| def test(self): | |||||
| pass | |||||
| @@ -0,0 +1,29 @@ | |||||
| import unittest | |||||
| from fastNLP.core.instance import Instance | |||||
| class TestCase(unittest.TestCase): | |||||
| def test_init(self): | |||||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||||
| ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||||
| self.assertTrue(isinstance(ins.fields, dict)) | |||||
| self.assertEqual(ins.fields, fields) | |||||
| ins = Instance(**fields) | |||||
| self.assertEqual(ins.fields, fields) | |||||
| def test_add_field(self): | |||||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||||
| ins = Instance(**fields) | |||||
| ins.add_field("z", [1, 1, 1]) | |||||
| fields.update({"z": [1, 1, 1]}) | |||||
| self.assertEqual(ins.fields, fields) | |||||
| def test_get_item(self): | |||||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||||
| ins = Instance(**fields) | |||||
| self.assertEqual(ins["x"], [1, 2, 3]) | |||||
| self.assertEqual(ins["y"], [4, 5, 6]) | |||||
| self.assertEqual(ins["z"], [1, 1, 1]) | |||||
| @@ -1,44 +1,42 @@ | |||||
| import unittest | |||||
| import torch | import torch | ||||
| from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | ||||
| k_means_1d, k_means_bucketing, simple_sort_bucketing | k_means_1d, k_means_bucketing, simple_sort_bucketing | ||||
| def test_convert_to_torch_tensor(): | |||||
| data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||||
| ans = convert_to_torch_tensor(data, False) | |||||
| assert isinstance(ans, torch.Tensor) | |||||
| assert tuple(ans.shape) == (3, 5) | |||||
| def test_sequential_sampler(): | |||||
| sampler = SequentialSampler() | |||||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
| for idx, i in enumerate(sampler(data)): | |||||
| assert idx == i | |||||
| def test_random_sampler(): | |||||
| sampler = RandomSampler() | |||||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
| ans = [data[i] for i in sampler(data)] | |||||
| assert len(ans) == len(data) | |||||
| for d in ans: | |||||
| assert d in data | |||||
| def test_k_means(): | |||||
| centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||||
| centroids, assign = list(centroids), list(assign) | |||||
| assert len(centroids) == 2 | |||||
| assert len(assign) == 10 | |||||
| def test_k_means_bucketing(): | |||||
| res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||||
| assert len(res) == 2 | |||||
| def test_simple_sort_bucketing(): | |||||
| _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||||
| assert len(_) == 10 | |||||
| class TestSampler(unittest.TestCase): | |||||
| def test_convert_to_torch_tensor(self): | |||||
| data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||||
| ans = convert_to_torch_tensor(data, False) | |||||
| assert isinstance(ans, torch.Tensor) | |||||
| assert tuple(ans.shape) == (3, 5) | |||||
| def test_sequential_sampler(self): | |||||
| sampler = SequentialSampler() | |||||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
| for idx, i in enumerate(sampler(data)): | |||||
| assert idx == i | |||||
| def test_random_sampler(self): | |||||
| sampler = RandomSampler() | |||||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
| ans = [data[i] for i in sampler(data)] | |||||
| assert len(ans) == len(data) | |||||
| for d in ans: | |||||
| assert d in data | |||||
| def test_k_means(self): | |||||
| centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||||
| centroids, assign = list(centroids), list(assign) | |||||
| assert len(centroids) == 2 | |||||
| assert len(assign) == 10 | |||||
| def test_k_means_bucketing(self): | |||||
| res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||||
| assert len(res) == 2 | |||||
| def test_simple_sort_bucketing(self): | |||||
| _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||||
| assert len(_) == 10 | |||||
| @@ -1,31 +0,0 @@ | |||||
| import unittest | |||||
| from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||||
| class TestVocabulary(unittest.TestCase): | |||||
| def test_vocab(self): | |||||
| import _pickle as pickle | |||||
| import os | |||||
| vocab = Vocabulary() | |||||
| filename = 'vocab' | |||||
| vocab.update(filename) | |||||
| vocab.update([filename, ['a'], [['b']], ['c']]) | |||||
| idx = vocab[filename] | |||||
| before_pic = (vocab.to_word(idx), vocab[filename]) | |||||
| with open(filename, 'wb') as f: | |||||
| pickle.dump(vocab, f) | |||||
| with open(filename, 'rb') as f: | |||||
| vocab = pickle.load(f) | |||||
| os.remove(filename) | |||||
| vocab.build_reverse_vocab() | |||||
| after_pic = (vocab.to_word(idx), vocab[filename]) | |||||
| TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||||
| TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||||
| TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||||
| self.assertEqual(before_pic, after_pic) | |||||
| self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||||
| self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,61 @@ | |||||
| import unittest | |||||
| from collections import Counter | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||||
| "works", "well", "in", "most", "cases", "scales", "well"] | |||||
| counter = Counter(text) | |||||
| class TestAdd(unittest.TestCase): | |||||
| def test_add(self): | |||||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
| for word in text: | |||||
| vocab.add(word) | |||||
| self.assertEqual(vocab.word_count, counter) | |||||
| def test_add_word(self): | |||||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
| for word in text: | |||||
| vocab.add_word(word) | |||||
| self.assertEqual(vocab.word_count, counter) | |||||
| def test_add_word_lst(self): | |||||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
| vocab.add_word_lst(text) | |||||
| self.assertEqual(vocab.word_count, counter) | |||||
| def test_update(self): | |||||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
| vocab.update(text) | |||||
| self.assertEqual(vocab.word_count, counter) | |||||
| class TestIndexing(unittest.TestCase): | |||||
| def test_len(self): | |||||
| vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||||
| vocab.update(text) | |||||
| self.assertEqual(len(vocab), len(counter)) | |||||
| def test_contains(self): | |||||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
| vocab.update(text) | |||||
| self.assertTrue(text[-1] in vocab) | |||||
| self.assertFalse("~!@#" in vocab) | |||||
| self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | |||||
| self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||||
| def test_index(self): | |||||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
| vocab.update(text) | |||||
| res = [vocab[w] for w in set(text)] | |||||
| self.assertEqual(len(res), len(set(res))) | |||||
| res = [vocab.to_index(w) for w in set(text)] | |||||
| self.assertEqual(len(res), len(set(res))) | |||||
| def test_to_word(self): | |||||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
| vocab.update(text) | |||||
| self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||||