| @@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
| class Predictor(object): | |||
| """ | |||
| An interface for predicting outputs based on trained models. | |||
| 一个根据训练模型预测输出的预测器(Predictor) | |||
| It does not care about evaluations of the model, which is different from Tester. | |||
| This is a high-level model wrapper to be called by FastNLP. | |||
| This class does not share any operations with Trainer and Tester. | |||
| Currently, Predictor does not support GPU. | |||
| 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||
| 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 | |||
| :param torch.nn.Module network: 用来完成预测任务的模型 | |||
| """ | |||
| def __init__(self, network): | |||
| @@ -30,18 +30,19 @@ class Predictor(object): | |||
| self.batch_size = 1 | |||
| self.batch_output = [] | |||
| def predict(self, data, seq_len_field_name=None): | |||
| """Perform inference using the trained model. | |||
| def predict(self, data: DataSet, seq_len_field_name=None): | |||
| """用已经训练好的模型进行inference. | |||
| :param data: a DataSet object. | |||
| :param str seq_len_field_name: field name indicating sequence lengths | |||
| :return: list of batch outputs | |||
| :param fastNLP.DataSet data: 待预测的数据集 | |||
| :param str seq_len_field_name: 表示序列长度信息的field名字 | |||
| :return: dict dict里面的内容为模型预测的结果 | |||
| """ | |||
| if not isinstance(data, DataSet): | |||
| raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||
| if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||
| raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
| prev_training = self.network.training | |||
| self.network.eval() | |||
| network_device = _get_model_device(self.network) | |||
| batch_output = defaultdict(list) | |||
| @@ -74,4 +75,5 @@ class Predictor(object): | |||
| else: | |||
| batch_output[key].append(value) | |||
| self.network.train(prev_training) | |||
| return batch_output | |||
| @@ -11,7 +11,7 @@ | |||
| ## Matching (自然语言推理/句子匹配) | |||
| - still in progress | |||
| - [Matching 任务复现](matching/) | |||
| ## Sequence Labeling (序列标注) | |||
| @@ -1,32 +1,34 @@ | |||
| # Matching任务模型复现 | |||
| 这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。 | |||
| 这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | |||
| 复现的模型有(按论文发表时间顺序排序): | |||
| - CNTN:复现代码(still in progress)[](). | |||
| - CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
| 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | |||
| - ESIM:[复现代码](model/esim.py). | |||
| - ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | |||
| 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | |||
| - DIIN:复现代码(still in progress)[](). | |||
| - DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
| 论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). | |||
| - MwAN:复现代码(still in progress)[](). | |||
| - MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
| 论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). | |||
| - BERT:[复现代码](model/bert.py). | |||
| - BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py). | |||
| 论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). | |||
| # 数据集及复现结果汇总 | |||
| 使用fastNLP复现的结果vs论文汇报结果 | |||
| 使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果 | |||
| '\-'表示我们仍未复现或者论文原文没有汇报 | |||
| model name | SNLI | MNLI | RTE | QNLI | Quora | |||
| :---: | :---: | :---: | :---: | :---: | :---: | |||
| CNTN ; [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | - | - | - | - | - | | |||
| ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs - | 57.04(dev) / - | 76.97(dev) / - | - | | |||
| DIIN [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||
| MwAN [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.5 vs 88.3 | - vs 78.5/77.7 | - | - | vs 89.12 | | |||
| CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||
| ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | |||
| DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||
| MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | |||
| BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | | |||
| *ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。 | |||
| # 数据集复现结果及其他主要模型对比 | |||
| ## SNLI | |||
| [Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) | |||
| @@ -42,7 +44,7 @@ Performance on Test set: | |||
| model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | |||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | |||
| __performance__ | - | 88.13 | - | - | 90.6 | 91.16 | |||
| __performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||
| ## MNLI | |||
| [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | |||
| @@ -58,11 +60,13 @@ Performance on Test set(matched/mismatched): | |||
| model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | |||
| :---: | :---: | :---: | :---: | :---: | :---: | | |||
| __performance__ | - | - | - | - | - | | |||
| __performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||
| ## RTE | |||
| Still in progress. | |||
| ## QNLI | |||
| ### From GLUE baselines | |||
| @@ -73,17 +77,24 @@ Performance on Test set: | |||
| model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo | |||
| :---: | :---: | :---: | :---: | :---: | | |||
| __performance__ | 74.6 | 74.3 | 75.5 | 79.8 | | |||
| *这些LSTM-based的baseline是由QNLI官方实现并测试的。 | |||
| #### Transformer-based | |||
| model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN | |||
| :---: | :---: | :---: | :---: | :---: | | |||
| __performance__ | 87.4 | 90.5 | 92.7 | 96.0 | | |||
| ### 基于fastNLP复现的结果 | |||
| Performance on Dev set: | |||
| Performance on __Dev__ set: | |||
| model name | CNTN | ESIM | DIIN | MwAN | BERT | |||
| :---: | :---: | :---: | :---: | :---: | :---: | |||
| __performance__ | - | 76.97 | - | - | - | |||
| __performance__ | - | 76.97 | - | 74.6 | - | |||
| ## Quora | |||
| Still in progress. | |||
| @@ -5,8 +5,8 @@ from typing import Union, Dict | |||
| from fastNLP.core.const import Const | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.base_loader import DataInfo | |||
| from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader | |||
| from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||
| from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||
| from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
| from fastNLP.modules.encoder._bert import BertTokenizer | |||
| @@ -348,6 +348,9 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
| 'dev_mismatched': 'dev_mismatched.tsv', | |||
| 'test_matched': 'test_matched.tsv', | |||
| 'test_mismatched': 'test_mismatched.tsv', | |||
| # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
| # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
| # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| CSVLoader.__init__(self, sep='\t') | |||
| @@ -364,6 +367,10 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
| if k in ds.get_field_names(): | |||
| ds.rename_field(k, v) | |||
| if Const.TARGET in ds.get_field_names(): | |||
| if ds[0][Const.TARGET] == 'hidden': | |||
| ds.delete_field(Const.TARGET) | |||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
| @@ -0,0 +1,102 @@ | |||
| import random | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam | |||
| from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||
| MNLILoader, QNLILoader, QuoraLoader | |||
| from reproduction.matching.model.bert import BertForNLI | |||
| # define hyper-parameters | |||
| class BERTConfig: | |||
| task = 'snli' | |||
| batch_size_per_gpu = 6 | |||
| n_epochs = 6 | |||
| lr = 2e-5 | |||
| seq_len_type = 'bert' | |||
| seed = 42 | |||
| train_dataset_name = 'train' | |||
| dev_dataset_name = 'dev' | |||
| test_dataset_name = 'test' | |||
| save_path = None # 模型存储的位置,None表示不存储模型。 | |||
| bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 | |||
| arg = BERTConfig() | |||
| # set random seed | |||
| random.seed(arg.seed) | |||
| np.random.seed(arg.seed) | |||
| torch.manual_seed(arg.seed) | |||
| n_gpu = torch.cuda.device_count() | |||
| if n_gpu > 0: | |||
| torch.cuda.manual_seed_all(arg.seed) | |||
| # load data set | |||
| if arg.task == 'snli': | |||
| data_info = SNLILoader().process( | |||
| paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
| bert_tokenizer=arg.bert_dir, cut_text=512, | |||
| get_index=True, concat='bert', | |||
| ) | |||
| elif arg.task == 'rte': | |||
| data_info = RTELoader().process( | |||
| paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
| bert_tokenizer=arg.bert_dir, cut_text=512, | |||
| get_index=True, concat='bert', | |||
| ) | |||
| elif arg.task == 'qnli': | |||
| data_info = QNLILoader().process( | |||
| paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
| bert_tokenizer=arg.bert_dir, cut_text=512, | |||
| get_index=True, concat='bert', | |||
| ) | |||
| elif arg.task == 'mnli': | |||
| data_info = MNLILoader().process( | |||
| paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
| bert_tokenizer=arg.bert_dir, cut_text=512, | |||
| get_index=True, concat='bert', | |||
| ) | |||
| elif arg.task == 'quora': | |||
| data_info = QuoraLoader().process( | |||
| paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
| bert_tokenizer=arg.bert_dir, cut_text=512, | |||
| get_index=True, concat='bert', | |||
| ) | |||
| else: | |||
| raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
| # define model | |||
| model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) | |||
| # define trainer | |||
| trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
| optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
| batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
| n_epochs=arg.n_epochs, print_every=-1, | |||
| dev_data=data_info.datasets[arg.dev_dataset_name], | |||
| metrics=AccuracyMetric(), metric_key='acc', | |||
| device=[i for i in range(torch.cuda.device_count())], | |||
| check_code_level=-1, | |||
| save_path=arg.save_path) | |||
| # train model | |||
| trainer.train(load_best_model=True) | |||
| # define tester | |||
| tester = Tester( | |||
| data=data_info.datasets[arg.test_dataset_name], | |||
| model=model, | |||
| metrics=AccuracyMetric(), | |||
| batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
| device=[i for i in range(torch.cuda.device_count())], | |||
| ) | |||
| # test model | |||
| tester.test() | |||
| @@ -1,47 +1,103 @@ | |||
| import argparse | |||
| import random | |||
| import numpy as np | |||
| import torch | |||
| from torch.optim import Adamax | |||
| from torch.optim.lr_scheduler import StepLR | |||
| from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
| from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||
| from fastNLP.core.callback import GradientClipCallback, LRScheduler | |||
| from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | |||
| from reproduction.matching.data.MatchingDataLoader import SNLILoader | |||
| from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||
| MNLILoader, QNLILoader, QuoraLoader | |||
| from reproduction.matching.model.esim import ESIMModel | |||
| argument = argparse.ArgumentParser() | |||
| argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') | |||
| argument.add_argument('--batch-size-per-gpu', type=int, default=128) | |||
| argument.add_argument('--n-epochs', type=int, default=100) | |||
| argument.add_argument('--lr', type=float, default=1e-4) | |||
| argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') | |||
| argument.add_argument('--save-dir', type=str, default=None) | |||
| arg = argument.parse_args() | |||
| bert_dirs = 'path/to/bert/dir' | |||
| # define hyper-parameters | |||
| class ESIMConfig: | |||
| task = 'snli' | |||
| embedding = 'glove' | |||
| batch_size_per_gpu = 196 | |||
| n_epochs = 30 | |||
| lr = 2e-3 | |||
| seq_len_type = 'seq_len' | |||
| # seq_len表示在process的时候用len(words)来表示长度信息; | |||
| # mask表示用0/1掩码矩阵来表示长度信息; | |||
| seed = 42 | |||
| train_dataset_name = 'train' | |||
| dev_dataset_name = 'dev' | |||
| test_dataset_name = 'test' | |||
| save_path = None # 模型存储的位置,None表示不存储模型。 | |||
| arg = ESIMConfig() | |||
| # set random seed | |||
| random.seed(arg.seed) | |||
| np.random.seed(arg.seed) | |||
| torch.manual_seed(arg.seed) | |||
| n_gpu = torch.cuda.device_count() | |||
| if n_gpu > 0: | |||
| torch.cuda.manual_seed_all(arg.seed) | |||
| # load data set | |||
| data_info = SNLILoader().process( | |||
| paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
| get_index=True, concat=False, | |||
| ) | |||
| if arg.task == 'snli': | |||
| data_info = SNLILoader().process( | |||
| paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
| get_index=True, concat=False, | |||
| ) | |||
| elif arg.task == 'rte': | |||
| data_info = RTELoader().process( | |||
| paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
| get_index=True, concat=False, | |||
| ) | |||
| elif arg.task == 'qnli': | |||
| data_info = QNLILoader().process( | |||
| paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
| get_index=True, concat=False, | |||
| ) | |||
| elif arg.task == 'mnli': | |||
| data_info = MNLILoader().process( | |||
| paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
| get_index=True, concat=False, | |||
| ) | |||
| elif arg.task == 'quora': | |||
| data_info = QuoraLoader().process( | |||
| paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
| get_index=True, concat=False, | |||
| ) | |||
| else: | |||
| raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
| # load embedding | |||
| if arg.embedding == 'elmo': | |||
| embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
| elif arg.embedding == 'glove': | |||
| embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
| embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) | |||
| else: | |||
| raise ValueError(f'now we only support elmo or glove embedding for esim model!') | |||
| raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | |||
| # define model | |||
| model = ESIMModel(embedding) | |||
| model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) | |||
| # define optimizer and callback | |||
| optimizer = Adamax(lr=arg.lr, params=model.parameters()) | |||
| scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍 | |||
| callbacks = [ | |||
| GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10) | |||
| LRScheduler(scheduler), | |||
| ] | |||
| # define trainer | |||
| trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
| optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
| trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
| optimizer=optimizer, | |||
| batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
| n_epochs=arg.n_epochs, print_every=-1, | |||
| dev_data=data_info.datasets['dev'], | |||
| dev_data=data_info.datasets[arg.dev_dataset_name], | |||
| metrics=AccuracyMetric(), metric_key='acc', | |||
| device=[i for i in range(torch.cuda.device_count())], | |||
| check_code_level=-1, | |||
| @@ -52,7 +108,7 @@ trainer.train(load_best_model=True) | |||
| # define tester | |||
| tester = Tester( | |||
| data=data_info.datasets['test'], | |||
| data=data_info.datasets[arg.test_dataset_name], | |||
| model=model, | |||
| metrics=AccuracyMetric(), | |||
| batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
| @@ -81,6 +81,7 @@ class ESIMModel(BaseModel): | |||
| out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | |||
| logits = torch.tanh(self.classifier(out)) | |||
| # logits = self.classifier(out) | |||
| if target is not None: | |||
| loss_fct = CrossEntropyLoss() | |||
| @@ -91,7 +92,8 @@ class ESIMModel(BaseModel): | |||
| return {Const.OUTPUT: logits} | |||
| def predict(self, **kwargs): | |||
| return self.forward(**kwargs) | |||
| pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||
| return {Const.OUTPUT: pred} | |||
| # input [batch_size, len , hidden] | |||
| # mask [batch_size, len] (111...00) | |||
| @@ -127,7 +129,7 @@ class BiRNN(nn.Module): | |||
| def forward(self, x, x_mask): | |||
| # Sort x | |||
| lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||
| lengths = x_mask.data.eq(1).long().sum(1) | |||
| _, idx_sort = torch.sort(lengths, dim=0, descending=True) | |||
| _, idx_unsort = torch.sort(idx_sort, dim=0) | |||
| lengths = list(lengths[idx_sort]) | |||