| @@ -1,11 +1,4 @@ | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
| DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1} | |||
| def isiterable(p_object): | |||
| try: | |||
| @@ -57,22 +50,16 @@ class Vocabulary(object): | |||
| vocab.to_word(5) | |||
| """ | |||
| def __init__(self, need_default=True, max_size=None, min_freq=None): | |||
| def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||
| """ | |||
| :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||
| :param int max_size: set the max number of words in Vocabulary. Default: None | |||
| :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | |||
| """ | |||
| self.max_size = max_size | |||
| self.min_freq = min_freq | |||
| self.word_count = Counter() | |||
| self.has_default = need_default | |||
| if self.has_default: | |||
| self.padding_label = DEFAULT_PADDING_LABEL | |||
| self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||
| else: | |||
| self.padding_label = None | |||
| self.unknown_label = None | |||
| self.unknown = unknown | |||
| self.padding = padding | |||
| self.word2idx = None | |||
| self.idx2word = None | |||
| self.rebuild = True | |||
| @@ -113,17 +100,18 @@ class Vocabulary(object): | |||
| """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | |||
| """ | |||
| if self.has_default: | |||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||
| self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||
| self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||
| else: | |||
| self.word2idx = {} | |||
| self.word2idx = {} | |||
| if self.padding is not None: | |||
| self.word2idx[self.padding] = 0 | |||
| if self.unknown is not None: | |||
| self.word2idx[self.unknown] = 1 | |||
| max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
| words = self.word_count.most_common(max_size) | |||
| if self.min_freq is not None: | |||
| words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
| if self.word2idx is not None: | |||
| words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||
| start_idx = len(self.word2idx) | |||
| self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
| self.build_reverse_vocab() | |||
| @@ -159,8 +147,8 @@ class Vocabulary(object): | |||
| """ | |||
| if w in self.word2idx: | |||
| return self.word2idx[w] | |||
| elif self.has_default: | |||
| return self.word2idx[self.unknown_label] | |||
| if self.unknown is not None: | |||
| return self.word2idx[self.unknown] | |||
| else: | |||
| raise ValueError("word {} not in vocabulary".format(w)) | |||
| @@ -175,21 +163,16 @@ class Vocabulary(object): | |||
| @property | |||
| @check_build_vocab | |||
| def unknown_idx(self): | |||
| if self.unknown_label is None: | |||
| if self.unknown is None: | |||
| return None | |||
| return self.word2idx[self.unknown_label] | |||
| def __setattr__(self, name, val): | |||
| self.__dict__[name] = val | |||
| if name in ["unknown_label", "padding_label"]: | |||
| self.word2idx = None | |||
| return self.word2idx[self.unknown] | |||
| @property | |||
| @check_build_vocab | |||
| def padding_idx(self): | |||
| if self.padding_label is None: | |||
| if self.padding is None: | |||
| return None | |||
| return self.word2idx[self.padding_label] | |||
| return self.word2idx[self.padding] | |||
| @check_build_vocab | |||
| def to_word(self, idx): | |||
| @@ -4,6 +4,7 @@ import numpy as np | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from fastNLP.core.utils import CheckError | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.losses import BCELoss | |||
| @@ -56,7 +57,8 @@ class TrainerTestGround(unittest.TestCase): | |||
| dev_data=dev_set, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=True) | |||
| use_tqdm=True, | |||
| save_path=None) | |||
| trainer.train() | |||
| """ | |||
| # 应该正确运行 | |||
| @@ -145,16 +147,14 @@ class TrainerTestGround(unittest.TestCase): | |||
| return {'wrong_loss_key': loss} | |||
| model = Model() | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| trainer.train() | |||
| """ | |||
| # 应该正确运行 | |||
| """ | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| trainer.train() | |||
| def test_trainer_suggestion4(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| @@ -173,12 +173,13 @@ class TrainerTestGround(unittest.TestCase): | |||
| return {'loss': loss} | |||
| model = Model() | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| def test_trainer_suggestion5(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| @@ -225,14 +226,15 @@ class TrainerTestGround(unittest.TestCase): | |||
| return {'pred': x} | |||
| model = Model() | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| dev_data=dataset, | |||
| metrics=AccuracyMetric(), | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| dev_data=dataset, | |||
| metrics=AccuracyMetric(), | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| def test_case2(self): | |||
| # check metrics Wrong | |||
| @@ -10,36 +10,36 @@ counter = Counter(text) | |||
| class TestAdd(unittest.TestCase): | |||
| def test_add(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| for word in text: | |||
| vocab.add(word) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_add_word(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| for word in text: | |||
| vocab.add_word(word) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_add_word_lst(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab.add_word_lst(text) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_update(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| class TestIndexing(unittest.TestCase): | |||
| def test_len(self): | |||
| vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
| vocab.update(text) | |||
| self.assertEqual(len(vocab), len(counter)) | |||
| def test_contains(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
| vocab.update(text) | |||
| self.assertTrue(text[-1] in vocab) | |||
| self.assertFalse("~!@#" in vocab) | |||
| @@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase): | |||
| self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||
| def test_index(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| res = [vocab[w] for w in set(text)] | |||
| self.assertEqual(len(res), len(set(res))) | |||
| @@ -56,14 +56,14 @@ class TestIndexing(unittest.TestCase): | |||
| self.assertEqual(len(res), len(set(res))) | |||
| def test_to_word(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||
| class TestOther(unittest.TestCase): | |||
| def test_additional_update(self): | |||
| vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
| vocab = Vocabulary(max_size=None, min_freq=None) | |||
| vocab.update(text) | |||
| _ = vocab["well"] | |||
| @@ -77,7 +77,7 @@ class TestOther(unittest.TestCase): | |||
| self.assertTrue("hahaha" in vocab) | |||
| def test_warning(self): | |||
| vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) | |||
| vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||
| vocab.update(text) | |||
| self.assertEqual(vocab.rebuild, True) | |||
| print(len(vocab)) | |||