| @@ -13,7 +13,8 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||
| __all__ = ["Instance", "FieldArray", "Batch", "Vocabulary", "DataSet", | |||
| "Trainer", "Tester", "Callback", | |||
| "Padder", "AutoPadder", "EngChar2DPadder", | |||
| "AccuracyMetric", "Optimizer", "SGD", "Adam", | |||
| "AccuracyMetric", "BMESF1PreRecMetric", "SpanFPreRecMetric", "SQuADMetric", | |||
| "Optimizer", "SGD", "Adam", | |||
| "Sampler", "SequentialSampler", "BucketSampler", "RandomSampler", | |||
| "LossFunc", "CrossEntropyLoss", "L1Loss", "BCELoss", "NLLLoss", "LossInForward", | |||
| "cache_results"] | |||
| @@ -17,7 +17,7 @@ from .dataset import DataSet | |||
| from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
| from .instance import Instance | |||
| from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
| from .metrics import AccuracyMetric | |||
| from .metrics import AccuracyMetric, BMESF1PreRecMetric, SpanFPreRecMetric, SQuADMetric | |||
| from .optimizer import Optimizer, SGD, Adam | |||
| from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
| from .tester import Tester | |||
| @@ -236,6 +236,7 @@ class CallbackManager(Callback): | |||
| for env_name, env_val in env.items(): | |||
| for callback in self.callbacks: | |||
| print(callback, env_name, env_val ) | |||
| setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||
| @_transfer | |||
| @@ -425,19 +426,25 @@ class LRFinder(Callback): | |||
| super(LRFinder, self).__init__() | |||
| self.start_lr, self.end_lr = start_lr, end_lr | |||
| self.num_it = self.batch_per_epoch | |||
| self.stop = False | |||
| self.best_loss = 0. | |||
| self.best_lr = None | |||
| self.loss_history = [] | |||
| self.smooth_value = SmoothValue(0.8) | |||
| self.opt = None | |||
| scale = (self.end_lr - self.start_lr) / self.num_it | |||
| self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) | |||
| self.find = None | |||
| self.loader = ModelLoader() | |||
| @property | |||
| def lr_gen(self): | |||
| scale = (self.end_lr - self.start_lr) / self.batch_per_epoch | |||
| return (self.start_lr + scale * (step + 1) for step in range(self.batch_per_epoch)) | |||
| @property | |||
| def num_it(self): | |||
| return self.batch_per_epoch | |||
| def on_epoch_begin(self): | |||
| if self.epoch == 1: # first epoch | |||
| self.opt = self.trainer.optimizer # pytorch optimizer | |||
| @@ -418,6 +418,7 @@ class AutoPadder(Padder): | |||
| return False | |||
| def __call__(self, contents, field_name, field_ele_dtype): | |||
| if not _is_iterable(contents[0]): | |||
| array = np.array([content for content in contents], dtype=field_ele_dtype) | |||
| elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||
| @@ -430,6 +430,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
| class SpanFPreRecMetric(MetricBase): | |||
| """ | |||
| 别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
| 在序列标注问题中,以span的方式计算F, pre, rec. | |||
| 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | |||
| @@ -619,6 +620,8 @@ class SpanFPreRecMetric(MetricBase): | |||
| class BMESF1PreRecMetric(MetricBase): | |||
| """ | |||
| 别名::class:`fastNLP.BMESF1PreRecMetric` :class:`fastNLP.core.metrics.BMESF1PreRecMetric` | |||
| 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
| next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
| @@ -826,6 +829,8 @@ def _pred_topk(y_prob, k=1): | |||
| class SQuADMetric(MetricBase): | |||
| """ | |||
| 别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||
| SQuAD数据集metric | |||
| :param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | |||
| @@ -350,7 +350,7 @@ class Trainer(object): | |||
| :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
| :param nn.modules model: 待训练的模型 | |||
| :param torch.optim.Optimizer optimizer: 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
| :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
| :param int batch_size: 训练和验证的时候的batch大小。 | |||
| :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
| :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
| @@ -403,7 +403,6 @@ class Trainer(object): | |||
| callbacks=None, | |||
| check_code_level=0): | |||
| super(Trainer, self).__init__() | |||
| if not isinstance(train_data, DataSet): | |||
| raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||
| if not isinstance(model, nn.Module): | |||
| @@ -468,7 +467,7 @@ class Trainer(object): | |||
| len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
| self.model = _move_model_to_device(self.model, device=device) | |||
| if isinstance(optimizer, torch.optim.Optimizer): | |||
| self.optimizer = optimizer | |||
| elif isinstance(optimizer, Optimizer): | |||
| @@ -493,6 +492,7 @@ class Trainer(object): | |||
| self.step = 0 | |||
| self.start_time = None # start timestamp | |||
| print("callback_manager") | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, | |||
| callbacks=callbacks) | |||
| @@ -616,7 +616,7 @@ def seq_lens_to_masks(seq_lens, float=False): | |||
| assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||
| batch_size = seq_lens.size(0) | |||
| max_len = seq_lens.max() | |||
| indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||
| indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long() | |||
| masks = indexes.lt(seq_lens.unsqueeze(1)) | |||
| if float: | |||
| @@ -2,16 +2,18 @@ from functools import wraps | |||
| from collections import Counter | |||
| from .dataset import DataSet | |||
| def _check_build_vocab(func): | |||
| """A decorator to make sure the indexing is built before used. | |||
| """ | |||
| @wraps(func) # to solve missing docstring | |||
| @wraps(func) # to solve missing docstring | |||
| def _wrapper(self, *args, **kwargs): | |||
| if self.word2idx is None or self.rebuild is True: | |||
| self.build_vocab() | |||
| return func(self, *args, **kwargs) | |||
| return _wrapper | |||
| @@ -19,7 +21,8 @@ def _check_build_status(func): | |||
| """A decorator to check whether the vocabulary updates after the last build. | |||
| """ | |||
| @wraps(func) # to solve missing docstring | |||
| @wraps(func) # to solve missing docstring | |||
| def _wrapper(self, *args, **kwargs): | |||
| if self.rebuild is False: | |||
| self.rebuild = True | |||
| @@ -28,7 +31,7 @@ def _check_build_status(func): | |||
| "Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
| self.max_size, func.__name__)) | |||
| return func(self, *args, **kwargs) | |||
| return _wrapper | |||
| @@ -50,15 +53,15 @@ class Vocabulary(object): | |||
| 若为 ``None`` , 则不限制大小. Default: ``None`` | |||
| :param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||
| 若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||
| :param str padding: padding的字符. 如果设置为 ``None`` , | |||
| :param str optional padding: padding的字符. 如果设置为 ``None`` , | |||
| 则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
| Default: '<pad>' | |||
| :param str unknow: unknow的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
| :param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
| 如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||
| 为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
| Default: '<unk>' | |||
| """ | |||
| def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
| self.max_size = max_size | |||
| self.min_freq = min_freq | |||
| @@ -68,7 +71,7 @@ class Vocabulary(object): | |||
| self.word2idx = None | |||
| self.idx2word = None | |||
| self.rebuild = True | |||
| @_check_build_status | |||
| def update(self, word_lst): | |||
| """依次增加序列中词在词典中的出现频率 | |||
| @@ -76,7 +79,7 @@ class Vocabulary(object): | |||
| :param list word_lst: a list of strings | |||
| """ | |||
| self.word_count.update(word_lst) | |||
| @_check_build_status | |||
| def add(self, word): | |||
| """ | |||
| @@ -85,7 +88,7 @@ class Vocabulary(object): | |||
| :param str word: 新词 | |||
| """ | |||
| self.word_count[word] += 1 | |||
| @_check_build_status | |||
| def add_word(self, word): | |||
| """ | |||
| @@ -94,7 +97,7 @@ class Vocabulary(object): | |||
| :param str word: 新词 | |||
| """ | |||
| self.add(word) | |||
| @_check_build_status | |||
| def add_word_lst(self, word_lst): | |||
| """ | |||
| @@ -103,7 +106,7 @@ class Vocabulary(object): | |||
| :param list[str] word_lst: 词的序列 | |||
| """ | |||
| self.update(word_lst) | |||
| def build_vocab(self): | |||
| """ | |||
| 根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | |||
| @@ -116,7 +119,7 @@ class Vocabulary(object): | |||
| self.word2idx[self.padding] = len(self.word2idx) | |||
| if self.unknown is not None: | |||
| self.word2idx[self.unknown] = len(self.word2idx) | |||
| max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
| words = self.word_count.most_common(max_size) | |||
| if self.min_freq is not None: | |||
| @@ -127,18 +130,18 @@ class Vocabulary(object): | |||
| self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
| self.build_reverse_vocab() | |||
| self.rebuild = False | |||
| def build_reverse_vocab(self): | |||
| """ | |||
| 基于 "word to index" dict, 构建 "index to word" dict. | |||
| """ | |||
| self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
| @_check_build_vocab | |||
| def __len__(self): | |||
| return len(self.word2idx) | |||
| @_check_build_vocab | |||
| def __contains__(self, item): | |||
| """ | |||
| @@ -148,7 +151,7 @@ class Vocabulary(object): | |||
| :return: True or False | |||
| """ | |||
| return item in self.word2idx | |||
| def has_word(self, w): | |||
| """ | |||
| 检查词是否被记录 | |||
| @@ -163,7 +166,7 @@ class Vocabulary(object): | |||
| :return: ``True`` or ``False`` | |||
| """ | |||
| return self.__contains__(w) | |||
| @_check_build_vocab | |||
| def __getitem__(self, w): | |||
| """ | |||
| @@ -177,7 +180,7 @@ class Vocabulary(object): | |||
| return self.word2idx[self.unknown] | |||
| else: | |||
| raise ValueError("word {} not in vocabulary".format(w)) | |||
| @_check_build_vocab | |||
| def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
| """ | |||
| @@ -194,6 +197,7 @@ class Vocabulary(object): | |||
| :param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||
| Default: ``None`` | |||
| """ | |||
| def index_instance(ins): | |||
| """ | |||
| 有几种情况, str, 1d-list, 2d-list | |||
| @@ -209,8 +213,8 @@ class Vocabulary(object): | |||
| else: | |||
| if isinstance(field[0][0], list): | |||
| raise RuntimeError("Only support field with 2 dimensions.") | |||
| return[[self.to_index(c) for c in w] for w in field] | |||
| return [[self.to_index(c) for c in w] for w in field] | |||
| if new_field_name is None: | |||
| new_field_name = field_name | |||
| for idx, dataset in enumerate(datasets): | |||
| @@ -222,7 +226,7 @@ class Vocabulary(object): | |||
| raise e | |||
| else: | |||
| raise RuntimeError("Only DataSet type is allowed.") | |||
| def from_dataset(self, *datasets, field_name): | |||
| """ | |||
| 使用dataset的对应field中词构建词典 | |||
| @@ -243,7 +247,7 @@ class Vocabulary(object): | |||
| field_name = [field_name] | |||
| elif not isinstance(field_name, list): | |||
| raise TypeError('invalid argument field_name: {}'.format(field_name)) | |||
| def construct_vocab(ins): | |||
| for fn in field_name: | |||
| field = ins[fn] | |||
| @@ -256,6 +260,7 @@ class Vocabulary(object): | |||
| if isinstance(field[0][0], list): | |||
| raise RuntimeError("Only support field with 2 dimensions.") | |||
| [self.add_word_lst(w) for w in field] | |||
| for idx, dataset in enumerate(datasets): | |||
| if isinstance(dataset, DataSet): | |||
| try: | |||
| @@ -266,7 +271,7 @@ class Vocabulary(object): | |||
| else: | |||
| raise RuntimeError("Only DataSet type is allowed.") | |||
| return self | |||
| def to_index(self, w): | |||
| """ | |||
| 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 | |||
| @@ -282,7 +287,7 @@ class Vocabulary(object): | |||
| :return int index: the number | |||
| """ | |||
| return self.__getitem__(w) | |||
| @property | |||
| @_check_build_vocab | |||
| def unknown_idx(self): | |||
| @@ -292,7 +297,7 @@ class Vocabulary(object): | |||
| if self.unknown is None: | |||
| return None | |||
| return self.word2idx[self.unknown] | |||
| @property | |||
| @_check_build_vocab | |||
| def padding_idx(self): | |||
| @@ -302,7 +307,7 @@ class Vocabulary(object): | |||
| if self.padding is None: | |||
| return None | |||
| return self.word2idx[self.padding] | |||
| @_check_build_vocab | |||
| def to_word(self, idx): | |||
| """ | |||
| @@ -312,26 +317,26 @@ class Vocabulary(object): | |||
| :return str word: the word | |||
| """ | |||
| return self.idx2word[idx] | |||
| def __getstate__(self): | |||
| """Use to prepare data for pickle. | |||
| """ | |||
| len(self) # make sure vocab has been built | |||
| len(self) # make sure vocab has been built | |||
| state = self.__dict__.copy() | |||
| # no need to pickle idx2word as it can be constructed from word2idx | |||
| del state['idx2word'] | |||
| return state | |||
| def __setstate__(self, state): | |||
| """Use to restore state from pickle. | |||
| """ | |||
| self.__dict__.update(state) | |||
| self.build_reverse_vocab() | |||
| def __repr__(self): | |||
| return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | |||
| def __iter__(self): | |||
| return iter(list(self.word_count.keys())) | |||
| @@ -1,13 +1,12 @@ | |||
| import time | |||
| import unittest | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.sampler import SequentialSampler | |||
| from fastNLP import Batch | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import SequentialSampler | |||
| def generate_fake_dataset(num_samples=1000): | |||
| @@ -16,11 +15,11 @@ def generate_fake_dataset(num_samples=1000): | |||
| :param num_samples: sample的数量 | |||
| :return: | |||
| """ | |||
| max_len = 50 | |||
| min_len = 10 | |||
| num_features = 4 | |||
| data_dict = {} | |||
| for i in range(num_features): | |||
| data = [] | |||
| @@ -28,9 +27,9 @@ def generate_fake_dataset(num_samples=1000): | |||
| for length in lengths: | |||
| data.append(np.random.randint(100, size=length)) | |||
| data_dict[str(i)] = data | |||
| dataset = DataSet(data_dict) | |||
| for i in range(num_features): | |||
| if np.random.randint(2) == 0: | |||
| dataset.set_input(str(i)) | |||
| @@ -38,6 +37,7 @@ def generate_fake_dataset(num_samples=1000): | |||
| dataset.set_target(str(i)) | |||
| return dataset | |||
| def construct_dataset(sentences): | |||
| """Construct a data set from a list of sentences. | |||
| @@ -51,18 +51,19 @@ def construct_dataset(sentences): | |||
| dataset.append(instance) | |||
| return dataset | |||
| class TestCase1(unittest.TestCase): | |||
| def test_simple(self): | |||
| dataset = construct_dataset( | |||
| [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
| dataset.set_target() | |||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| cnt = 0 | |||
| for _, _ in batch: | |||
| cnt += 1 | |||
| self.assertEqual(cnt, 10) | |||
| def test_dataset_batching(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ds.set_input("x") | |||
| @@ -74,7 +75,7 @@ class TestCase1(unittest.TestCase): | |||
| self.assertEqual(len(y["y"]), 4) | |||
| self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | |||
| self.assertListEqual(list(y["y"][-1]), [5, 6]) | |||
| def test_list_padding(self): | |||
| ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||
| "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
| @@ -84,7 +85,7 @@ class TestCase1(unittest.TestCase): | |||
| for x, y in iter: | |||
| self.assertEqual(x["x"].shape, (4, 4)) | |||
| self.assertEqual(y["y"].shape, (4, 4)) | |||
| def test_numpy_padding(self): | |||
| ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||
| "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
| @@ -94,7 +95,7 @@ class TestCase1(unittest.TestCase): | |||
| for x, y in iter: | |||
| self.assertEqual(x["x"].shape, (4, 4)) | |||
| self.assertEqual(y["y"].shape, (4, 4)) | |||
| def test_list_to_tensor(self): | |||
| ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||
| "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
| @@ -106,7 +107,7 @@ class TestCase1(unittest.TestCase): | |||
| self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
| self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
| self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
| def test_numpy_to_tensor(self): | |||
| ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||
| "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
| @@ -118,7 +119,7 @@ class TestCase1(unittest.TestCase): | |||
| self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
| self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
| self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
| def test_list_of_list_to_tensor(self): | |||
| ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] + | |||
| [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | |||
| @@ -130,7 +131,7 @@ class TestCase1(unittest.TestCase): | |||
| self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
| self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
| self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
| def test_list_of_numpy_to_tensor(self): | |||
| ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + | |||
| [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||
| @@ -139,16 +140,16 @@ class TestCase1(unittest.TestCase): | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| for x, y in iter: | |||
| print(x, y) | |||
| def test_sequential_batch(self): | |||
| batch_size = 32 | |||
| num_samples = 1000 | |||
| dataset = generate_fake_dataset(num_samples) | |||
| batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
| for batch_x, batch_y in batch: | |||
| pass | |||
| """ | |||
| def test_multi_workers_batch(self): | |||
| batch_size = 32 | |||
| @@ -4,14 +4,13 @@ import numpy as np | |||
| import torch | |||
| from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
| LRFinder, \ | |||
| TensorboardCallback | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.losses import BCELoss | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| from fastNLP.core.optimizer import SGD | |||
| from fastNLP.core.trainer import Trainer | |||
| LRFinder, TensorboardCallback | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import BCELoss | |||
| from fastNLP import AccuracyMetric | |||
| from fastNLP import SGD | |||
| from fastNLP import Trainer | |||
| from fastNLP.models.base_model import NaiveClassifier | |||
| @@ -20,15 +19,15 @@ def prepare_env(): | |||
| mean = np.array([-3, -3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| mean = np.array([3, 3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
| [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
| return data_set | |||
| data_set = prepare_fake_dataset() | |||
| data_set.set_input("x") | |||
| data_set.set_target("y") | |||
| @@ -37,19 +36,7 @@ def prepare_env(): | |||
| class TestCallback(unittest.TestCase): | |||
| def test_echo_callback(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=2, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| callbacks=[EchoCallback()]) | |||
| trainer.train() | |||
| def test_gradient_clip(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| @@ -64,7 +51,7 @@ class TestCallback(unittest.TestCase): | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||
| trainer.train() | |||
| def test_early_stop(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| @@ -79,7 +66,7 @@ class TestCallback(unittest.TestCase): | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[EarlyStopCallback(5)]) | |||
| trainer.train() | |||
| def test_lr_scheduler(self): | |||
| data_set, model = prepare_env() | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||
| @@ -95,7 +82,7 @@ class TestCallback(unittest.TestCase): | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||
| trainer.train() | |||
| def test_KeyBoardInterrupt(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| @@ -108,7 +95,7 @@ class TestCallback(unittest.TestCase): | |||
| use_tqdm=False, | |||
| callbacks=[ControlC(False)]) | |||
| trainer.train() | |||
| def test_LRFinder(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| @@ -121,7 +108,7 @@ class TestCallback(unittest.TestCase): | |||
| use_tqdm=False, | |||
| callbacks=[LRFinder(len(data_set) // 32)]) | |||
| trainer.train() | |||
| def test_TensorboardCallback(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| @@ -136,21 +123,22 @@ class TestCallback(unittest.TestCase): | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[TensorboardCallback("loss", "metric")]) | |||
| trainer.train() | |||
| def test_readonly_property(self): | |||
| from fastNLP.core.callback import Callback | |||
| passed_epochs = [] | |||
| total_epochs = 5 | |||
| class MyCallback(Callback): | |||
| def __init__(self): | |||
| super(MyCallback, self).__init__() | |||
| def on_epoch_begin(self): | |||
| passed_epochs.append(self.epoch) | |||
| print(self.n_epochs, self.n_steps, self.batch_size) | |||
| print(self.model) | |||
| print(self.optimizer) | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| @@ -164,4 +152,4 @@ class TestCallback(unittest.TestCase): | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[MyCallback()]) | |||
| trainer.train() | |||
| assert passed_epochs == list(range(1, total_epochs+1)) | |||
| assert passed_epochs == list(range(1, total_epochs + 1)) | |||
| @@ -1,9 +1,10 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP import DataSet | |||
| from fastNLP import FieldArray | |||
| from fastNLP import Instance | |||
| from fastNLP.io import CSVLoader | |||
| class TestDataSetInit(unittest.TestCase): | |||
| @@ -167,13 +168,11 @@ class TestDataSetMethods(unittest.TestCase): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
| d1, d2 = ds.split(0.1) | |||
| def test_apply2(self): | |||
| def split_sent(ins): | |||
| return ins['raw_sentence'].split() | |||
| dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | |||
| sep='\t') | |||
| csv_loader = CSVLoader(headers=['raw_sentence', 'label'],sep='\t') | |||
| dataset = csv_loader.load('../data_for_tests/tutorial_sample_dataset.csv') | |||
| dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | |||
| dataset.apply(split_sent, new_field_name='words', is_input=True) | |||
| # print(dataset) | |||
| @@ -208,7 +207,7 @@ class TestDataSetMethods(unittest.TestCase): | |||
| self.assertEqual(ans.content, [[5, 6]] * 10) | |||
| def test_add_null(self): | |||
| # TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||
| # TODO test failed because 'fastNLP\core\field.py:143: RuntimeError' | |||
| ds = DataSet() | |||
| with self.assertRaises(RuntimeError) as RE: | |||
| ds.add_field('test', []) | |||
| @@ -2,7 +2,7 @@ import unittest | |||
| import numpy as np | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP import FieldArray | |||
| class TestFieldArrayInit(unittest.TestCase): | |||
| @@ -170,7 +170,7 @@ class TestPadder(unittest.TestCase): | |||
| 测试AutoPadder能否正常工作 | |||
| :return: | |||
| """ | |||
| from fastNLP.core.fieldarray import AutoPadder | |||
| from fastNLP import AutoPadder | |||
| padder = AutoPadder() | |||
| content = ['This is a str', 'this is another str'] | |||
| self.assertListEqual(content, padder(content, None, np.str).tolist()) | |||
| @@ -194,7 +194,7 @@ class TestPadder(unittest.TestCase): | |||
| 测试EngChar2DPadder能不能正确使用 | |||
| :return: | |||
| """ | |||
| from fastNLP.core.fieldarray import EngChar2DPadder | |||
| from fastNLP import EngChar2DPadder | |||
| padder = EngChar2DPadder(pad_length=0) | |||
| contents = [1, 2] | |||
| @@ -225,11 +225,11 @@ class TestPadder(unittest.TestCase): | |||
| ) | |||
| def test_None_dtype(self): | |||
| from fastNLP.core.fieldarray import AutoPadder | |||
| from fastNLP import AutoPadder | |||
| padder = AutoPadder() | |||
| content = [ | |||
| [[1, 2, 3], [4, 5], [7, 8, 9, 10]], | |||
| [[1]] | |||
| ] | |||
| ans = padder(content, None, None) | |||
| ans = padder(content, None, None).tolist() | |||
| self.assertListEqual(content, ans) | |||
| @@ -1,33 +1,33 @@ | |||
| import unittest | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP import Instance | |||
| class TestCase(unittest.TestCase): | |||
| def test_init(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
| ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||
| self.assertTrue(isinstance(ins.fields, dict)) | |||
| self.assertEqual(ins.fields, fields) | |||
| ins = Instance(**fields) | |||
| self.assertEqual(ins.fields, fields) | |||
| def test_add_field(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
| ins = Instance(**fields) | |||
| ins.add_field("z", [1, 1, 1]) | |||
| fields.update({"z": [1, 1, 1]}) | |||
| self.assertEqual(ins.fields, fields) | |||
| def test_get_item(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
| ins = Instance(**fields) | |||
| self.assertEqual(ins["x"], [1, 2, 3]) | |||
| self.assertEqual(ins["y"], [4, 5, 6]) | |||
| self.assertEqual(ins["z"], [1, 1, 1]) | |||
| def test_repr(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
| ins = Instance(**fields) | |||
| @@ -3,7 +3,7 @@ import unittest | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import fastNLP.core.losses as loss | |||
| import fastNLP as loss | |||
| from fastNLP.core.losses import squash, unpad | |||
| @@ -14,21 +14,21 @@ class TestLoss(unittest.TestCase): | |||
| b = torch.empty(3, dtype=torch.long).random_(5) | |||
| ans = ce({"my_predict": a}, {"my_truth": b}) | |||
| self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | |||
| def test_BCELoss(self): | |||
| bce = loss.BCELoss(pred="my_predict", target="my_truth") | |||
| a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | |||
| b = torch.randn((3, 5), requires_grad=False) | |||
| ans = bce({"my_predict": a}, {"my_truth": b}) | |||
| self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | |||
| def test_L1Loss(self): | |||
| l1 = loss.L1Loss(pred="my_predict", target="my_truth") | |||
| a = torch.randn(3, 5, requires_grad=False) | |||
| b = torch.randn(3, 5) | |||
| ans = l1({"my_predict": a}, {"my_truth": b}) | |||
| self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | |||
| def test_NLLLoss(self): | |||
| l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||
| a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||
| @@ -43,34 +43,34 @@ class TestLosserError(unittest.TestCase): | |||
| pred_dict = {"pred": torch.zeros(4, 3)} | |||
| target_dict = {'target': torch.zeros(4).long()} | |||
| los = loss.CrossEntropyLoss() | |||
| print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
| # | |||
| def test_losser2(self): | |||
| # (2) with corrupted size | |||
| pred_dict = {"pred": torch.zeros(16, 3)} | |||
| target_dict = {'target': torch.zeros(16, 3).long()} | |||
| los = loss.CrossEntropyLoss() | |||
| with self.assertRaises(RuntimeError): | |||
| print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
| def test_losser3(self): | |||
| # (2) with corrupted size | |||
| pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} | |||
| target_dict = {'target': torch.zeros(16).long()} | |||
| los = loss.CrossEntropyLoss() | |||
| print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
| def test_check_error(self): | |||
| l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||
| a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||
| b = torch.tensor([1, 0, 4]) | |||
| with self.assertRaises(Exception): | |||
| ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||
| with self.assertRaises(Exception): | |||
| ans = l1({"my_predict": a}, {"truth": b, "my": a}) | |||
| @@ -80,7 +80,7 @@ class TestLossUtils(unittest.TestCase): | |||
| a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) | |||
| self.assertEqual(tuple(a.size()), (3, 5)) | |||
| self.assertEqual(tuple(b.size()), (15,)) | |||
| def test_unpad(self): | |||
| a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) | |||
| self.assertEqual(tuple(a.size()), (5, 8, 3)) | |||
| @@ -3,8 +3,8 @@ import unittest | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| from fastNLP.core.metrics import BMESF1PreRecMetric | |||
| from fastNLP import AccuracyMetric | |||
| from fastNLP import BMESF1PreRecMetric | |||
| from fastNLP.core.metrics import _pred_topk, _accuracy_topk | |||
| @@ -14,24 +14,24 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| pred_dict = {"pred": torch.zeros(4, 3)} | |||
| target_dict = {'target': torch.zeros(4)} | |||
| metric = AccuracyMetric() | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| print(metric.get_metric()) | |||
| def test_AccuracyMetric2(self): | |||
| # (2) with corrupted size | |||
| try: | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4)} | |||
| metric = AccuracyMetric() | |||
| metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
| print(metric.get_metric()) | |||
| except Exception as e: | |||
| print(e) | |||
| return | |||
| print("No exception catches.") | |||
| def test_AccuracyMetric3(self): | |||
| # (3) the second batch is corrupted size | |||
| try: | |||
| @@ -39,17 +39,17 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| print(metric.get_metric()) | |||
| except Exception as e: | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_AccuaryMetric4(self): | |||
| # (5) check reset | |||
| metric = AccuracyMetric() | |||
| @@ -61,7 +61,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| self.assertTrue(isinstance(res, dict)) | |||
| self.assertTrue("acc" in res) | |||
| self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) | |||
| def test_AccuaryMetric5(self): | |||
| # (5) check reset | |||
| metric = AccuracyMetric() | |||
| @@ -71,7 +71,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| res = metric.get_metric(reset=False) | |||
| ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() | |||
| self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||
| def test_AccuaryMetric6(self): | |||
| # (6) check numpy array is not acceptable | |||
| try: | |||
| @@ -83,7 +83,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_AccuaryMetric7(self): | |||
| # (7) check map, match | |||
| metric = AccuracyMetric(pred='predictions', target='targets') | |||
| @@ -93,7 +93,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| res = metric.get_metric() | |||
| ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() | |||
| self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||
| def test_AccuaryMetric8(self): | |||
| try: | |||
| metric = AccuracyMetric(pred='predictions', target='targets') | |||
| @@ -105,7 +105,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_AccuaryMetric9(self): | |||
| # (9) check map, include unused | |||
| try: | |||
| @@ -118,12 +118,12 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_AccuaryMetric10(self): | |||
| # (10) check _fast_metric | |||
| try: | |||
| metric = AccuracyMetric() | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3)*3} | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | |||
| target_dict = {'targets': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| @@ -131,7 +131,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_seq_len(self): | |||
| N = 256 | |||
| seq_len = torch.zeros(N).long() | |||
| @@ -145,20 +145,21 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| metric(pred_dict=pred, target_dict=target) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | |||
| class SpanF1PreRecMetric(unittest.TestCase): | |||
| def test_case1(self): | |||
| from fastNLP.core.metrics import _bmes_tag_to_spans | |||
| from fastNLP.core.metrics import _bio_tag_to_spans | |||
| bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | |||
| bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | |||
| expect_bmes_res = set() | |||
| expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)), | |||
| ('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) | |||
| ('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) | |||
| expect_bio_res = set() | |||
| expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), | |||
| ('6', (4, 5)), ('7', (6, 7))]) | |||
| self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) | |||
| ('6', (4, 5)), ('7', (6, 7))]) | |||
| self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) | |||
| self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | |||
| # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
| # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
| @@ -171,19 +172,19 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| # bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | |||
| # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
| # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
| def test_case2(self): | |||
| # 测试不带label的 | |||
| from fastNLP.core.metrics import _bmes_tag_to_spans | |||
| from fastNLP.core.metrics import _bio_tag_to_spans | |||
| bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | |||
| bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | |||
| expect_bmes_res = set() | |||
| expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) | |||
| expect_bio_res = set() | |||
| expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) | |||
| self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) | |||
| self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) | |||
| self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | |||
| # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
| # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
| @@ -195,7 +196,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| # bio_strs = np.random.choice(bio, size=100) | |||
| # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
| # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
| def tese_case3(self): | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from collections import Counter | |||
| @@ -213,7 +214,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| continue | |||
| vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | |||
| return vocab | |||
| number_labels = 4 | |||
| # bio tag | |||
| fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
| @@ -221,26 +222,26 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
| bio_sequence = torch.FloatTensor( | |||
| [[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | |||
| 0.0470, 0.0971], | |||
| [-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||
| 0.7987, -0.3970], | |||
| [0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||
| 0.6880, 1.4348], | |||
| [-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||
| -1.6876, -0.8917], | |||
| [-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||
| 1.4217, 0.2622]], | |||
| [[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||
| 1.3592, -0.8973], | |||
| [0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||
| -0.4025, -0.3417], | |||
| [-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||
| 0.2861, -0.3966], | |||
| [-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||
| 0.0213, 1.4777], | |||
| [-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||
| 1.3024, 0.2001]]] | |||
| 0.0470, 0.0971], | |||
| [-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||
| 0.7987, -0.3970], | |||
| [0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||
| 0.6880, 1.4348], | |||
| [-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||
| -1.6876, -0.8917], | |||
| [-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||
| 1.4217, 0.2622]], | |||
| [[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||
| 1.3592, -0.8973], | |||
| [0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||
| -0.4025, -0.3417], | |||
| [-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||
| 0.2861, -0.3966], | |||
| [-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||
| 0.0213, 1.4777], | |||
| [-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||
| 1.3024, 0.2001]]] | |||
| ) | |||
| bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | |||
| [5., 6., 8., 6., 0.]]) | |||
| @@ -250,8 +251,8 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | |||
| 'f': 0.12499999999994846} | |||
| self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | |||
| #bmes tag | |||
| # bmes tag | |||
| bmes_sequence = torch.FloatTensor( | |||
| [[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | |||
| -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | |||
| @@ -268,7 +269,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| [0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | |||
| 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | |||
| -0.3779, -0.3195]], | |||
| [[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | |||
| 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | |||
| -0.1103, 0.4417], | |||
| @@ -285,22 +286,22 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | |||
| -0.7344, -1.2046]]] | |||
| ) | |||
| bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], | |||
| [ 6., 15., 6., 15., 5.]]) | |||
| bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.], | |||
| [6., 15., 6., 15., 5.]]) | |||
| fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||
| fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||
| fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||
| fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||
| expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | |||
| 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | |||
| 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | |||
| 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | |||
| 'pre': 0.499999999999995, 'rec': 0.499999999999995} | |||
| self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | |||
| # 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | |||
| # from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | |||
| # from allennlp.training.metrics import SpanBasedF1Measure | |||
| @@ -349,6 +350,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | |||
| # fastnlp_bmes_metric.get_metric()) | |||
| class TestBMESF1PreRecMetric(unittest.TestCase): | |||
| def test_case1(self): | |||
| seq_lens = torch.LongTensor([4, 2]) | |||
| @@ -356,20 +358,20 @@ class TestBMESF1PreRecMetric(unittest.TestCase): | |||
| target = torch.LongTensor([[0, 1, 2, 3], | |||
| [3, 3, 0, 0]]) | |||
| pred_dict = {'pred': pred} | |||
| target_dict = {'target': target, 'seq_lens': seq_lens} | |||
| target_dict = {'target': target, 'seq_len': seq_lens} | |||
| metric = BMESF1PreRecMetric() | |||
| metric(pred_dict, target_dict) | |||
| metric.get_metric() | |||
| def test_case2(self): | |||
| # 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | |||
| seq_lens = torch.LongTensor([4, 2]) | |||
| target = torch.LongTensor([[0, 1, 2, 3], | |||
| [3, 3, 0, 0]]) | |||
| pred_dict = {'pred': target} | |||
| target_dict = {'target': target, 'seq_lens': seq_lens} | |||
| target_dict = {'target': target, 'seq_len': seq_lens} | |||
| metric = BMESF1PreRecMetric() | |||
| metric(pred_dict, target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) | |||
| @@ -381,5 +383,5 @@ class TestUsefulFunctions(unittest.TestCase): | |||
| # multi-class | |||
| _ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) | |||
| _ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) | |||
| # 跑通即可 | |||
| @@ -2,7 +2,7 @@ import unittest | |||
| import torch | |||
| from fastNLP.core.optimizer import SGD, Adam | |||
| from fastNLP import SGD, Adam | |||
| class TestOptim(unittest.TestCase): | |||
| @@ -12,42 +12,42 @@ class TestOptim(unittest.TestCase): | |||
| self.assertTrue("momentum" in optim.__dict__["settings"]) | |||
| res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue(isinstance(res, torch.optim.SGD)) | |||
| optim = SGD(lr=0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue(isinstance(res, torch.optim.SGD)) | |||
| optim = SGD(lr=0.002, momentum=0.989) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
| self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | |||
| optim = SGD(0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue(isinstance(res, torch.optim.SGD)) | |||
| with self.assertRaises(TypeError): | |||
| _ = SGD("???") | |||
| with self.assertRaises(TypeError): | |||
| _ = SGD(0.001, lr=0.002) | |||
| def test_Adam(self): | |||
| optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue("lr" in optim.__dict__["settings"]) | |||
| self.assertTrue("weight_decay" in optim.__dict__["settings"]) | |||
| res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue(isinstance(res, torch.optim.Adam)) | |||
| optim = Adam(lr=0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue(isinstance(res, torch.optim.Adam)) | |||
| optim = Adam(lr=0.002, weight_decay=0.989) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
| self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | |||
| optim = Adam(0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
| @@ -3,9 +3,9 @@ import unittest | |||
| import torch | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.sampler import SequentialSampler, RandomSampler, \ | |||
| k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler | |||
| from fastNLP import DataSet | |||
| from fastNLP import SequentialSampler, RandomSampler, BucketSampler | |||
| from fastNLP.core.sampler import k_means_1d, k_means_bucketing, simple_sort_bucketing | |||
| class TestSampler(unittest.TestCase): | |||
| @@ -1,32 +1,25 @@ | |||
| import unittest | |||
| import numpy as np | |||
| from torch import nn | |||
| import time | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import AccuracyMetric | |||
| from fastNLP import Tester | |||
| data_name = "pku_training.utf8" | |||
| pickle_path = "data_for_tests" | |||
| import numpy as np | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| import time | |||
| from fastNLP.core.utils import _CheckError | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.losses import BCELoss | |||
| from fastNLP.core.losses import CrossEntropyLoss | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| from fastNLP.core.optimizer import SGD | |||
| from fastNLP.core.tester import Tester | |||
| from fastNLP.models.base_model import NaiveClassifier | |||
| def prepare_fake_dataset(): | |||
| mean = np.array([-3, -3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| mean = np.array([3, 3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
| [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
| return data_set | |||
| @@ -39,6 +32,7 @@ def prepare_fake_dataset2(*args, size=100): | |||
| data[arg] = np.random.randn(size, 5) | |||
| return DataSet(data=data) | |||
| class TestTester(unittest.TestCase): | |||
| def test_case_1(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| @@ -46,10 +40,12 @@ class TestTester(unittest.TestCase): | |||
| dataset.rename_field('x_unused', 'x2') | |||
| dataset.set_input('x1', 'x2') | |||
| dataset.set_target('y', 'x1') | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.fc = nn.Linear(5, 4) | |||
| def forward(self, x1, x2): | |||
| x1 = self.fc(x1) | |||
| x2 = self.fc(x2) | |||
| @@ -57,7 +53,7 @@ class TestTester(unittest.TestCase): | |||
| time.sleep(0.1) | |||
| # loss = F.cross_entropy(x, y) | |||
| return {'preds': x} | |||
| model = Model() | |||
| with self.assertRaises(NameError): | |||
| tester = Tester( | |||
| @@ -5,25 +5,24 @@ import numpy as np | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.losses import BCELoss | |||
| from fastNLP.core.losses import CrossEntropyLoss | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| from fastNLP.core.optimizer import SGD | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import BCELoss | |||
| from fastNLP import CrossEntropyLoss | |||
| from fastNLP import AccuracyMetric | |||
| from fastNLP import SGD | |||
| from fastNLP import Trainer | |||
| from fastNLP.models.base_model import NaiveClassifier | |||
| def prepare_fake_dataset(): | |||
| mean = np.array([-3, -3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| mean = np.array([3, 3]) | |||
| cov = np.array([[1, 0], [0, 1]]) | |||
| class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
| data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
| [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
| return data_set | |||
| @@ -42,11 +41,11 @@ class TrainerTestGround(unittest.TestCase): | |||
| data_set = prepare_fake_dataset() | |||
| data_set.set_input("x", flag=True) | |||
| data_set.set_target("y", flag=True) | |||
| train_set, dev_set = data_set.split(0.3) | |||
| model = NaiveClassifier(2, 1) | |||
| trainer = Trainer(train_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| @@ -63,26 +62,26 @@ class TrainerTestGround(unittest.TestCase): | |||
| """ | |||
| # 应该正确运行 | |||
| """ | |||
| def test_trainer_suggestion1(self): | |||
| # 检查报错提示能否正确提醒用户。 | |||
| # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | |||
| dataset = prepare_fake_dataset2('x') | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.fc = nn.Linear(5, 4) | |||
| def forward(self, x1, x2, y): | |||
| x1 = self.fc(x1) | |||
| x2 = self.fc(x2) | |||
| x = x1 + x2 | |||
| loss = F.cross_entropy(x, y) | |||
| return {'loss': loss} | |||
| model = Model() | |||
| with self.assertRaises(RuntimeError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| @@ -97,25 +96,25 @@ class TrainerTestGround(unittest.TestCase): | |||
| (2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | |||
| """ | |||
| def test_trainer_suggestion2(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| # 这里传入forward需要的数据,看是否可以运行 | |||
| dataset = prepare_fake_dataset2('x1', 'x2') | |||
| dataset.set_input('x1', 'x2', 'y', flag=True) | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.fc = nn.Linear(5, 4) | |||
| def forward(self, x1, x2, y): | |||
| x1 = self.fc(x1) | |||
| x2 = self.fc(x2) | |||
| x = x1 + x2 | |||
| loss = F.cross_entropy(x, y) | |||
| return {'loss': loss} | |||
| model = Model() | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| @@ -127,25 +126,25 @@ class TrainerTestGround(unittest.TestCase): | |||
| """ | |||
| # 应该正确运行 | |||
| """ | |||
| def test_trainer_suggestion3(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| # 这里传入forward需要的数据,但是forward没有返回loss这个key | |||
| dataset = prepare_fake_dataset2('x1', 'x2') | |||
| dataset.set_input('x1', 'x2', 'y', flag=True) | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.fc = nn.Linear(5, 4) | |||
| def forward(self, x1, x2, y): | |||
| x1 = self.fc(x1) | |||
| x2 = self.fc(x2) | |||
| x = x1 + x2 | |||
| loss = F.cross_entropy(x, y) | |||
| return {'wrong_loss_key': loss} | |||
| model = Model() | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| @@ -155,23 +154,25 @@ class TrainerTestGround(unittest.TestCase): | |||
| print_every=2 | |||
| ) | |||
| trainer.train() | |||
| def test_trainer_suggestion4(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| # 这里传入forward需要的数据,是否可以正确提示unused | |||
| dataset = prepare_fake_dataset2('x1', 'x2') | |||
| dataset.set_input('x1', 'x2', 'y', flag=True) | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.fc = nn.Linear(5, 4) | |||
| def forward(self, x1, x2, y): | |||
| x1 = self.fc(x1) | |||
| x2 = self.fc(x2) | |||
| x = x1 + x2 | |||
| loss = F.cross_entropy(x, y) | |||
| return {'losses': loss} | |||
| model = Model() | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| @@ -180,7 +181,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| def test_trainer_suggestion5(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| # 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | |||
| @@ -188,17 +189,19 @@ class TrainerTestGround(unittest.TestCase): | |||
| dataset.rename_field('x_unused', 'x2') | |||
| dataset.set_input('x1', 'x2', 'y') | |||
| dataset.set_target('y') | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.fc = nn.Linear(5, 4) | |||
| def forward(self, x1, x2, y): | |||
| x1 = self.fc(x1) | |||
| x2 = self.fc(x2) | |||
| x = x1 + x2 | |||
| loss = F.cross_entropy(x, y) | |||
| return {'loss': loss} | |||
| model = Model() | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| @@ -206,7 +209,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| def test_trainer_suggestion6(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| # 这里传入多余参数,让其duplicate | |||
| @@ -214,10 +217,12 @@ class TrainerTestGround(unittest.TestCase): | |||
| dataset.rename_field('x_unused', 'x2') | |||
| dataset.set_input('x1', 'x2') | |||
| dataset.set_target('y', 'x1') | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.fc = nn.Linear(5, 4) | |||
| def forward(self, x1, x2): | |||
| x1 = self.fc(x1) | |||
| x2 = self.fc(x2) | |||
| @@ -225,7 +230,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| time.sleep(0.1) | |||
| # loss = F.cross_entropy(x, y) | |||
| return {'preds': x} | |||
| model = Model() | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| @@ -236,7 +241,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| metrics=AccuracyMetric(), | |||
| use_tqdm=False, | |||
| print_every=2) | |||
| """ | |||
| def test_trainer_multiprocess(self): | |||
| dataset = prepare_fake_dataset2('x1', 'x2') | |||
| @@ -1,8 +1,7 @@ | |||
| import unittest | |||
| import _pickle | |||
| from fastNLP import cache_results | |||
| from fastNLP.io.embed_loader import EmbedLoader | |||
| from fastNLP.io import EmbedLoader | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| import time | |||
| @@ -11,11 +10,13 @@ import torch | |||
| from torch import nn | |||
| from fastNLP.core.utils import _move_model_to_device, _get_model_device | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param = nn.Parameter(torch.zeros(0)) | |||
| class TestMoveModelDeivce(unittest.TestCase): | |||
| def test_case1(self): | |||
| # 测试str | |||
| @@ -35,36 +36,36 @@ class TestMoveModelDeivce(unittest.TestCase): | |||
| _move_model_to_device(model, 'cuda:1000') | |||
| # 测试None | |||
| model = _move_model_to_device(model, None) | |||
| def test_case2(self): | |||
| # 测试使用int初始化 | |||
| model = Model() | |||
| if torch.cuda.is_available(): | |||
| model = _move_model_to_device(model, 0) | |||
| assert model.param.device == torch.device('cuda:0') | |||
| assert model.param.device==torch.device('cuda:0'), "The model should be in " | |||
| assert model.param.device == torch.device('cuda:0'), "The model should be in " | |||
| with self.assertRaises(Exception): | |||
| _move_model_to_device(model, 100) | |||
| with self.assertRaises(Exception): | |||
| _move_model_to_device(model, -1) | |||
| def test_case3(self): | |||
| # 测试None | |||
| model = Model() | |||
| device = _get_model_device(model) | |||
| model = _move_model_to_device(model, None) | |||
| assert device==_get_model_device(model), "The device should not change." | |||
| assert device == _get_model_device(model), "The device should not change." | |||
| if torch.cuda.is_available(): | |||
| model.cuda() | |||
| device = _get_model_device(model) | |||
| model = _move_model_to_device(model, None) | |||
| assert device==_get_model_device(model), "The device should not change." | |||
| assert device == _get_model_device(model), "The device should not change." | |||
| model = nn.DataParallel(model, device_ids=[0]) | |||
| _move_model_to_device(model, None) | |||
| with self.assertRaises(Exception): | |||
| _move_model_to_device(model, 'cpu') | |||
| def test_case4(self): | |||
| # 测试传入list的内容 | |||
| model = Model() | |||
| @@ -78,15 +79,17 @@ class TestMoveModelDeivce(unittest.TestCase): | |||
| device = [torch.device('cuda:0'), torch.device('cuda:0')] | |||
| with self.assertRaises(Exception): | |||
| _model = _move_model_to_device(model, device) | |||
| if torch.cuda.device_count()>1: | |||
| if torch.cuda.device_count() > 1: | |||
| device = [0, 1] | |||
| _model = _move_model_to_device(model, device) | |||
| assert isinstance(_model, nn.DataParallel) | |||
| device = ['cuda', 'cuda:1'] | |||
| with self.assertRaises(Exception): | |||
| _move_model_to_device(model, device) | |||
| def test_case5(self): | |||
| if not torch.cuda.is_available(): | |||
| return | |||
| # torch.device() | |||
| device = torch.device('cpu') | |||
| model = Model() | |||
| @@ -106,10 +109,11 @@ def process_data_1(embed_file, cws_train): | |||
| d = DataSet() | |||
| for line in f: | |||
| line = line.strip() | |||
| if len(line)>0: | |||
| if len(line) > 0: | |||
| d.append(Instance(raw=line)) | |||
| return embed, vocab, d | |||
| class TestCache(unittest.TestCase): | |||
| def test_cache_save(self): | |||
| try: | |||
| @@ -127,10 +131,10 @@ class TestCache(unittest.TestCase): | |||
| end_time = time.time() | |||
| read_time = end_time - start_time | |||
| print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
| self.assertGreater(pre_time-0.5, read_time) | |||
| self.assertGreater(pre_time - 0.5, read_time) | |||
| finally: | |||
| os.remove('test/demo1.pkl') | |||
| def test_cache_save_overwrite_path(self): | |||
| try: | |||
| start_time = time.time() | |||
| @@ -149,10 +153,10 @@ class TestCache(unittest.TestCase): | |||
| end_time = time.time() | |||
| read_time = end_time - start_time | |||
| print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
| self.assertGreater(pre_time-0.5, read_time) | |||
| self.assertGreater(pre_time - 0.5, read_time) | |||
| finally: | |||
| os.remove('test/demo_overwrite.pkl') | |||
| def test_cache_refresh(self): | |||
| try: | |||
| start_time = time.time() | |||
| @@ -171,34 +175,38 @@ class TestCache(unittest.TestCase): | |||
| end_time = time.time() | |||
| read_time = end_time - start_time | |||
| print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
| self.assertGreater(0.1, pre_time-read_time) | |||
| self.assertGreater(0.1, pre_time - read_time) | |||
| finally: | |||
| os.remove('test/demo1.pkl') | |||
| def test_duplicate_keyword(self): | |||
| with self.assertRaises(RuntimeError): | |||
| @cache_results(None) | |||
| def func_verbose(a, _verbose): | |||
| pass | |||
| func_verbose(0, 1) | |||
| with self.assertRaises(RuntimeError): | |||
| @cache_results(None) | |||
| def func_cache(a, _cache_fp): | |||
| pass | |||
| func_cache(1, 2) | |||
| with self.assertRaises(RuntimeError): | |||
| @cache_results(None) | |||
| def func_refresh(a, _refresh): | |||
| pass | |||
| func_refresh(1, 2) | |||
| def test_create_cache_dir(self): | |||
| @cache_results('test/demo1/demo.pkl') | |||
| def cache(): | |||
| return 1, 2 | |||
| try: | |||
| results = cache() | |||
| print(results) | |||
| finally: | |||
| os.remove('test/demo1/demo.pkl') | |||
| os.rmdir('test/demo1') | |||
| os.rmdir('test/demo1') | |||
| @@ -1,9 +1,9 @@ | |||
| import unittest | |||
| from collections import Counter | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP import Vocabulary | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
| "works", "well", "in", "most", "cases", "scales", "well"] | |||
| @@ -12,92 +12,93 @@ counter = Counter(text) | |||
| class TestAdd(unittest.TestCase): | |||
| def test_add(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab = Vocabulary() | |||
| for word in text: | |||
| vocab.add(word) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_add_word(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab = Vocabulary() | |||
| for word in text: | |||
| vocab.add_word(word) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_add_word_lst(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab = Vocabulary() | |||
| vocab.add_word_lst(text) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_update(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab = Vocabulary() | |||
| vocab.update(text) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_from_dataset(self): | |||
| start_char = 65 | |||
| num_samples = 10 | |||
| # 0 dim | |||
| dataset = DataSet() | |||
| for i in range(num_samples): | |||
| ins = Instance(char=chr(start_char+i)) | |||
| ins = Instance(char=chr(start_char + i)) | |||
| dataset.append(ins) | |||
| vocab = Vocabulary() | |||
| vocab.from_dataset(dataset, field_name='char') | |||
| for i in range(num_samples): | |||
| self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
| self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||
| vocab.index_dataset(dataset, field_name='char') | |||
| # 1 dim | |||
| dataset = DataSet() | |||
| for i in range(num_samples): | |||
| ins = Instance(char=[chr(start_char+i)]*6) | |||
| ins = Instance(char=[chr(start_char + i)] * 6) | |||
| dataset.append(ins) | |||
| vocab = Vocabulary() | |||
| vocab.from_dataset(dataset, field_name='char') | |||
| for i in range(num_samples): | |||
| self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
| self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||
| vocab.index_dataset(dataset, field_name='char') | |||
| # 2 dim | |||
| dataset = DataSet() | |||
| for i in range(num_samples): | |||
| ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)]) | |||
| ins = Instance(char=[[chr(start_char + i) for _ in range(6)] for _ in range(6)]) | |||
| dataset.append(ins) | |||
| vocab = Vocabulary() | |||
| vocab.from_dataset(dataset, field_name='char') | |||
| for i in range(num_samples): | |||
| self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
| self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||
| vocab.index_dataset(dataset, field_name='char') | |||
| class TestIndexing(unittest.TestCase): | |||
| def test_len(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
| vocab = Vocabulary(unknown=None, padding=None) | |||
| vocab.update(text) | |||
| self.assertEqual(len(vocab), len(counter)) | |||
| def test_contains(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
| vocab = Vocabulary(unknown=None) | |||
| vocab.update(text) | |||
| self.assertTrue(text[-1] in vocab) | |||
| self.assertFalse("~!@#" in vocab) | |||
| self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | |||
| self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||
| def test_index(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab = Vocabulary() | |||
| vocab.update(text) | |||
| res = [vocab[w] for w in set(text)] | |||
| self.assertEqual(len(res), len(set(res))) | |||
| res = [vocab.to_index(w) for w in set(text)] | |||
| self.assertEqual(len(res), len(set(res))) | |||
| def test_to_word(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab = Vocabulary() | |||
| vocab.update(text) | |||
| self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||
| def test_iteration(self): | |||
| vocab = Vocabulary() | |||
| text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
| @@ -110,26 +111,26 @@ class TestIndexing(unittest.TestCase): | |||
| class TestOther(unittest.TestCase): | |||
| def test_additional_update(self): | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab = Vocabulary() | |||
| vocab.update(text) | |||
| _ = vocab["well"] | |||
| self.assertEqual(vocab.rebuild, False) | |||
| vocab.add("hahaha") | |||
| self.assertEqual(vocab.rebuild, True) | |||
| _ = vocab["hahaha"] | |||
| self.assertEqual(vocab.rebuild, False) | |||
| self.assertTrue("hahaha" in vocab) | |||
| def test_warning(self): | |||
| vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||
| vocab = Vocabulary(max_size=len(set(text))) | |||
| vocab.update(text) | |||
| self.assertEqual(vocab.rebuild, True) | |||
| print(len(vocab)) | |||
| self.assertEqual(vocab.rebuild, False) | |||
| vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | |||
| # this will print a warning | |||
| self.assertEqual(vocab.rebuild, True) | |||