| @@ -37,7 +37,7 @@ __all__ = [ | |||||
| "AccuracyMetric", | "AccuracyMetric", | ||||
| "SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
| "SQuADMetric", | |||||
| "ExtractiveQAMetric", | |||||
| "Optimizer", | "Optimizer", | ||||
| "SGD", | "SGD", | ||||
| @@ -61,3 +61,4 @@ __version__ = '0.4.0' | |||||
| from .core import * | from .core import * | ||||
| from . import models | from . import models | ||||
| from . import modules | from . import modules | ||||
| from .io import data_loader | |||||
| @@ -21,7 +21,7 @@ from .dataset import DataSet | |||||
| from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
| from .instance import Instance | from .instance import Instance | ||||
| from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
| from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric | |||||
| from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||||
| from .optimizer import Optimizer, SGD, Adam | from .optimizer import Optimizer, SGD, Adam | ||||
| from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
| from .tester import Tester | from .tester import Tester | ||||
| @@ -6,7 +6,7 @@ __all__ = [ | |||||
| "MetricBase", | "MetricBase", | ||||
| "AccuracyMetric", | "AccuracyMetric", | ||||
| "SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
| "SQuADMetric" | |||||
| "ExtractiveQAMetric" | |||||
| ] | ] | ||||
| import inspect | import inspect | ||||
| @@ -24,6 +24,7 @@ from .utils import seq_len_to_mask | |||||
| from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| class MetricBase(object): | class MetricBase(object): | ||||
| """ | """ | ||||
| 所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | 所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | ||||
| @@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1): | |||||
| return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
| class SQuADMetric(MetricBase): | |||||
| r""" | |||||
| 别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||||
| class ExtractiveQAMetric(MetricBase): | |||||
| """ | |||||
| 别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` | |||||
| SQuAD数据集metric | |||||
| 抽取式QA(如SQuAD)的metric. | |||||
| :param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | :param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | ||||
| :param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | :param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | ||||
| @@ -755,7 +756,7 @@ class SQuADMetric(MetricBase): | |||||
| def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | ||||
| beta=1, right_open=True, print_predict_stat=False): | beta=1, right_open=True, print_predict_stat=False): | ||||
| super(SQuADMetric, self).__init__() | |||||
| super(ExtractiveQAMetric, self).__init__() | |||||
| self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | ||||
| @@ -4,16 +4,26 @@ | |||||
| 这些模块的使用方法如下: | 这些模块的使用方法如下: | ||||
| """ | """ | ||||
| __all__ = [ | __all__ = [ | ||||
| 'SSTLoader', | |||||
| 'IMDBLoader', | |||||
| 'MatchingLoader', | 'MatchingLoader', | ||||
| 'SNLILoader', | |||||
| 'MNLILoader', | 'MNLILoader', | ||||
| 'MTL16Loader', | |||||
| 'QNLILoader', | 'QNLILoader', | ||||
| 'QuoraLoader', | 'QuoraLoader', | ||||
| 'RTELoader', | 'RTELoader', | ||||
| 'SSTLoader', | |||||
| 'SNLILoader', | |||||
| 'YelpLoader', | |||||
| ] | ] | ||||
| from .imdb import IMDBLoader | |||||
| from .matching import MatchingLoader | |||||
| from .mnli import MNLILoader | |||||
| from .mtl import MTL16Loader | |||||
| from .qnli import QNLILoader | |||||
| from .quora import QuoraLoader | |||||
| from .rte import RTELoader | |||||
| from .snli import SNLILoader | |||||
| from .sst import SSTLoader | from .sst import SSTLoader | ||||
| from .matching import MatchingLoader, SNLILoader, \ | |||||
| MNLILoader, QNLILoader, QuoraLoader, RTELoader | |||||
| from .yelp import YelpLoader | |||||
| @@ -0,0 +1,96 @@ | |||||
| from typing import Union, Dict | |||||
| from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
| from ..base_loader import DataSetLoader, DataInfo | |||||
| from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
| from ...core.dataset import DataSet | |||||
| from ...core.instance import Instance | |||||
| from ...core.const import Const | |||||
| from ..utils import get_tokenizer | |||||
| class IMDBLoader(DataSetLoader): | |||||
| """ | |||||
| 读取IMDB数据集,DataSet包含以下fields: | |||||
| words: list(str), 需要分类的文本 | |||||
| target: str, 文本的标签 | |||||
| """ | |||||
| def __init__(self): | |||||
| super(IMDBLoader, self).__init__() | |||||
| self.tokenizer = get_tokenizer() | |||||
| def _load(self, path): | |||||
| dataset = DataSet() | |||||
| with open(path, 'r', encoding="utf-8") as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if not line: | |||||
| continue | |||||
| parts = line.split('\t') | |||||
| target = parts[0] | |||||
| words = self.tokenizer(parts[1].lower()) | |||||
| dataset.append(Instance(words=words, target=target)) | |||||
| if len(dataset) == 0: | |||||
| raise RuntimeError(f"{path} has no valid data.") | |||||
| return dataset | |||||
| def process(self, | |||||
| paths: Union[str, Dict[str, str]], | |||||
| src_vocab_opt: VocabularyOption = None, | |||||
| tgt_vocab_opt: VocabularyOption = None, | |||||
| char_level_op=False): | |||||
| datasets = {} | |||||
| info = DataInfo() | |||||
| for name, path in paths.items(): | |||||
| dataset = self.load(path) | |||||
| datasets[name] = dataset | |||||
| def wordtochar(words): | |||||
| chars = [] | |||||
| for word in words: | |||||
| word = word.lower() | |||||
| for char in word: | |||||
| chars.append(char) | |||||
| chars.append('') | |||||
| chars.pop() | |||||
| return chars | |||||
| if char_level_op: | |||||
| for dataset in datasets.values(): | |||||
| dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
| datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
| src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
| src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
| tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
| tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
| info.vocabs = { | |||||
| Const.INPUT: src_vocab, | |||||
| Const.TARGET: tgt_vocab | |||||
| } | |||||
| info.datasets = datasets | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input(Const.INPUT) | |||||
| dataset.set_target(Const.TARGET) | |||||
| return info | |||||
| @@ -5,14 +5,13 @@ from typing import Union, Dict | |||||
| from ...core.const import Const | from ...core.const import Const | ||||
| from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
| from ..base_loader import DataInfo, DataSetLoader | from ..base_loader import DataInfo, DataSetLoader | ||||
| from ..dataset_loader import JsonLoader, CSVLoader | |||||
| from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | ||||
| from ...modules.encoder._bert import BertTokenizer | from ...modules.encoder._bert import BertTokenizer | ||||
| class MatchingLoader(DataSetLoader): | class MatchingLoader(DataSetLoader): | ||||
| """ | """ | ||||
| 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
| 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||||
| 读取Matching任务的数据集 | 读取Matching任务的数据集 | ||||
| @@ -227,204 +226,3 @@ class MatchingLoader(DataSetLoader): | |||||
| data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | ||||
| return data_info | return data_info | ||||
| class SNLILoader(MatchingLoader, JsonLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
| 读取SNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| fields = { | |||||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||||
| 'gold_label': Const.TARGET, | |||||
| } | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'snli_1.0_train.jsonl', | |||||
| 'dev': 'snli_1.0_dev.jsonl', | |||||
| 'test': 'snli_1.0_test.jsonl'} | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| JsonLoader.__init__(self, fields=fields) | |||||
| def _load(self, path): | |||||
| ds = JsonLoader._load(self, path) | |||||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(0)) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(1)) | |||||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
| return ds | |||||
| class RTELoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||||
| 读取RTE数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev': 'dev.tsv', | |||||
| 'test': 'test.tsv' # test set has not label | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| self.fields = { | |||||
| 'sentence1': Const.INPUTS(0), | |||||
| 'sentence2': Const.INPUTS(1), | |||||
| 'label': Const.TARGET, | |||||
| } | |||||
| CSVLoader.__init__(self, sep='\t') | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| for k, v in self.fields.items(): | |||||
| if v in ds.get_field_names(): | |||||
| ds.rename_field(k, v) | |||||
| for fields in ds.get_all_fields(): | |||||
| if Const.INPUT in fields: | |||||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
| return ds | |||||
| class QNLILoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||||
| 读取QNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev': 'dev.tsv', | |||||
| 'test': 'test.tsv' # test set has not label | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| self.fields = { | |||||
| 'question': Const.INPUTS(0), | |||||
| 'sentence': Const.INPUTS(1), | |||||
| 'label': Const.TARGET, | |||||
| } | |||||
| CSVLoader.__init__(self, sep='\t') | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| for k, v in self.fields.items(): | |||||
| if v in ds.get_field_names(): | |||||
| ds.rename_field(k, v) | |||||
| for fields in ds.get_all_fields(): | |||||
| if Const.INPUT in fields: | |||||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
| return ds | |||||
| class MNLILoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev_matched': 'dev_matched.tsv', | |||||
| 'dev_mismatched': 'dev_mismatched.tsv', | |||||
| 'test_matched': 'test_matched.tsv', | |||||
| 'test_mismatched': 'test_mismatched.tsv', | |||||
| # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
| # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
| # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| CSVLoader.__init__(self, sep='\t') | |||||
| self.fields = { | |||||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||||
| 'gold_label': Const.TARGET, | |||||
| } | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| for k, v in self.fields.items(): | |||||
| if k in ds.get_field_names(): | |||||
| ds.rename_field(k, v) | |||||
| if Const.TARGET in ds.get_field_names(): | |||||
| if ds[0][Const.TARGET] == 'hidden': | |||||
| ds.delete_field(Const.TARGET) | |||||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(0)) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(1)) | |||||
| if Const.TARGET in ds.get_field_names(): | |||||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
| return ds | |||||
| class QuoraLoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev': 'dev.tsv', | |||||
| 'test': 'test.tsv', | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| return ds | |||||
| @@ -0,0 +1,60 @@ | |||||
| from ...core import Const | |||||
| from .matching import MatchingLoader | |||||
| from ..dataset_loader import CSVLoader | |||||
| class MNLILoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev_matched': 'dev_matched.tsv', | |||||
| 'dev_mismatched': 'dev_mismatched.tsv', | |||||
| 'test_matched': 'test_matched.tsv', | |||||
| 'test_mismatched': 'test_mismatched.tsv', | |||||
| # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
| # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
| # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| CSVLoader.__init__(self, sep='\t') | |||||
| self.fields = { | |||||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||||
| 'gold_label': Const.TARGET, | |||||
| } | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| for k, v in self.fields.items(): | |||||
| if k in ds.get_field_names(): | |||||
| ds.rename_field(k, v) | |||||
| if Const.TARGET in ds.get_field_names(): | |||||
| if ds[0][Const.TARGET] == 'hidden': | |||||
| ds.delete_field(Const.TARGET) | |||||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(0)) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(1)) | |||||
| if Const.TARGET in ds.get_field_names(): | |||||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
| return ds | |||||
| @@ -0,0 +1,65 @@ | |||||
| from typing import Union, Dict | |||||
| from ..base_loader import DataInfo | |||||
| from ..dataset_loader import CSVLoader | |||||
| from ...core.vocabulary import Vocabulary, VocabularyOption | |||||
| from ...core.const import Const | |||||
| from ..utils import check_dataloader_paths | |||||
| class MTL16Loader(CSVLoader): | |||||
| """ | |||||
| 读取MTL16数据集,DataSet包含以下fields: | |||||
| words: list(str), 需要分类的文本 | |||||
| target: str, 文本的标签 | |||||
| 数据来源:https://pan.baidu.com/s/1c2L6vdA | |||||
| """ | |||||
| def __init__(self): | |||||
| super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') | |||||
| def _load(self, path): | |||||
| dataset = super(MTL16Loader, self)._load(path) | |||||
| dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) | |||||
| if len(dataset) == 0: | |||||
| raise RuntimeError(f"{path} has no valid data.") | |||||
| return dataset | |||||
| def process(self, | |||||
| paths: Union[str, Dict[str, str]], | |||||
| src_vocab_opt: VocabularyOption = None, | |||||
| tgt_vocab_opt: VocabularyOption = None,): | |||||
| paths = check_dataloader_paths(paths) | |||||
| datasets = {} | |||||
| info = DataInfo() | |||||
| for name, path in paths.items(): | |||||
| dataset = self.load(path) | |||||
| datasets[name] = dataset | |||||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
| src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
| src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
| tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
| tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
| info.vocabs = { | |||||
| Const.INPUT: src_vocab, | |||||
| Const.TARGET: tgt_vocab | |||||
| } | |||||
| info.datasets = datasets | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input(Const.INPUT) | |||||
| dataset.set_target(Const.TARGET) | |||||
| return info | |||||
| @@ -0,0 +1,45 @@ | |||||
| from ...core import Const | |||||
| from .matching import MatchingLoader | |||||
| from ..dataset_loader import CSVLoader | |||||
| class QNLILoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` | |||||
| 读取QNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev': 'dev.tsv', | |||||
| 'test': 'test.tsv' # test set has not label | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| self.fields = { | |||||
| 'question': Const.INPUTS(0), | |||||
| 'sentence': Const.INPUTS(1), | |||||
| 'label': Const.TARGET, | |||||
| } | |||||
| CSVLoader.__init__(self, sep='\t') | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| for k, v in self.fields.items(): | |||||
| if k in ds.get_field_names(): | |||||
| ds.rename_field(k, v) | |||||
| for fields in ds.get_all_fields(): | |||||
| if Const.INPUT in fields: | |||||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
| return ds | |||||
| @@ -0,0 +1,32 @@ | |||||
| from ...core import Const | |||||
| from .matching import MatchingLoader | |||||
| from ..dataset_loader import CSVLoader | |||||
| class QuoraLoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` | |||||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev': 'dev.tsv', | |||||
| 'test': 'test.tsv', | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| return ds | |||||
| @@ -0,0 +1,45 @@ | |||||
| from ...core import Const | |||||
| from .matching import MatchingLoader | |||||
| from ..dataset_loader import CSVLoader | |||||
| class RTELoader(MatchingLoader, CSVLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` | |||||
| 读取RTE数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'train.tsv', | |||||
| 'dev': 'dev.tsv', | |||||
| 'test': 'test.tsv' # test set has not label | |||||
| } | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| self.fields = { | |||||
| 'sentence1': Const.INPUTS(0), | |||||
| 'sentence2': Const.INPUTS(1), | |||||
| 'label': Const.TARGET, | |||||
| } | |||||
| CSVLoader.__init__(self, sep='\t') | |||||
| def _load(self, path): | |||||
| ds = CSVLoader._load(self, path) | |||||
| for k, v in self.fields.items(): | |||||
| if k in ds.get_field_names(): | |||||
| ds.rename_field(k, v) | |||||
| for fields in ds.get_all_fields(): | |||||
| if Const.INPUT in fields: | |||||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
| return ds | |||||
| @@ -0,0 +1,44 @@ | |||||
| from ...core import Const | |||||
| from .matching import MatchingLoader | |||||
| from ..dataset_loader import JsonLoader | |||||
| class SNLILoader(MatchingLoader, JsonLoader): | |||||
| """ | |||||
| 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` | |||||
| 读取SNLI数据集,读取的DataSet包含fields:: | |||||
| words1: list(str),第一句文本, premise | |||||
| words2: list(str), 第二句文本, hypothesis | |||||
| target: str, 真实标签 | |||||
| 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
| """ | |||||
| def __init__(self, paths: dict=None): | |||||
| fields = { | |||||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||||
| 'gold_label': Const.TARGET, | |||||
| } | |||||
| paths = paths if paths is not None else { | |||||
| 'train': 'snli_1.0_train.jsonl', | |||||
| 'dev': 'snli_1.0_dev.jsonl', | |||||
| 'test': 'snli_1.0_test.jsonl'} | |||||
| MatchingLoader.__init__(self, paths=paths) | |||||
| JsonLoader.__init__(self, fields=fields) | |||||
| def _load(self, path): | |||||
| ds = JsonLoader._load(self, path) | |||||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(0)) | |||||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
| new_field_name=Const.INPUTS(1)) | |||||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
| return ds | |||||
| @@ -1,19 +1,19 @@ | |||||
| from typing import Iterable | |||||
| from typing import Union, Dict | |||||
| from nltk import Tree | from nltk import Tree | ||||
| import spacy | |||||
| from ..base_loader import DataInfo, DataSetLoader | from ..base_loader import DataInfo, DataSetLoader | ||||
| from ..dataset_loader import CSVLoader | |||||
| from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
| from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
| from ...core.const import Const | |||||
| from ...core.instance import Instance | from ...core.instance import Instance | ||||
| from ..utils import check_dataloader_paths, get_tokenizer | from ..utils import check_dataloader_paths, get_tokenizer | ||||
| class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
| DATA_DIR = 'sst/' | |||||
| """ | """ | ||||
| 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||||
| 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` | |||||
| 读取SST数据集, DataSet包含fields:: | 读取SST数据集, DataSet包含fields:: | ||||
| @@ -26,6 +26,9 @@ class SSTLoader(DataSetLoader): | |||||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | ||||
| """ | """ | ||||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
| DATA_DIR = 'sst/' | |||||
| def __init__(self, subtree=False, fine_grained=False): | def __init__(self, subtree=False, fine_grained=False): | ||||
| self.subtree = subtree | self.subtree = subtree | ||||
| @@ -98,3 +101,72 @@ class SSTLoader(DataSetLoader): | |||||
| return info | return info | ||||
| class SST2Loader(CSVLoader): | |||||
| """ | |||||
| 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||||
| """ | |||||
| def __init__(self): | |||||
| super(SST2Loader, self).__init__(sep='\t') | |||||
| self.tokenizer = get_tokenizer() | |||||
| self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} | |||||
| def _load(self, path: str) -> DataSet: | |||||
| ds = super(SST2Loader, self)._load(path) | |||||
| ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) | |||||
| print("all count:", len(ds)) | |||||
| return ds | |||||
| def process(self, | |||||
| paths: Union[str, Dict[str, str]], | |||||
| src_vocab_opt: VocabularyOption = None, | |||||
| tgt_vocab_opt: VocabularyOption = None, | |||||
| char_level_op=False): | |||||
| paths = check_dataloader_paths(paths) | |||||
| datasets = {} | |||||
| info = DataInfo() | |||||
| for name, path in paths.items(): | |||||
| dataset = self.load(path) | |||||
| datasets[name] = dataset | |||||
| def wordtochar(words): | |||||
| chars = [] | |||||
| for word in words: | |||||
| word = word.lower() | |||||
| for char in word: | |||||
| chars.append(char) | |||||
| chars.append('') | |||||
| chars.pop() | |||||
| return chars | |||||
| input_name, target_name = Const.INPUT, Const.TARGET | |||||
| info.vocabs={} | |||||
| # 就分隔为char形式 | |||||
| if char_level_op: | |||||
| for dataset in datasets.values(): | |||||
| dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
| src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
| src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
| tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
| tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
| info.vocabs = { | |||||
| Const.INPUT: src_vocab, | |||||
| Const.TARGET: tgt_vocab | |||||
| } | |||||
| info.datasets = datasets | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input(Const.INPUT) | |||||
| dataset.set_target(Const.TARGET) | |||||
| return info | |||||
| @@ -0,0 +1,126 @@ | |||||
| import csv | |||||
| from typing import Iterable | |||||
| from ...core.const import Const | |||||
| from ...core import DataSet, Instance, Vocabulary | |||||
| from ...core.vocabulary import VocabularyOption | |||||
| from ..base_loader import DataInfo,DataSetLoader | |||||
| from typing import Union, Dict | |||||
| from ..utils import check_dataloader_paths, get_tokenizer | |||||
| class YelpLoader(DataSetLoader): | |||||
| """ | |||||
| 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||||
| words: list(str), 需要分类的文本 | |||||
| target: str, 文本的标签 | |||||
| chars:list(str),未index的字符列表 | |||||
| 数据集:yelp_full/yelp_polarity | |||||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
| :param lower: 是否需要自动转小写,默认为False。 | |||||
| """ | |||||
| def __init__(self, fine_grained=False, lower=False): | |||||
| super(YelpLoader, self).__init__() | |||||
| tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||||
| '4.0': 'positive', '5.0': 'very positive'} | |||||
| if not fine_grained: | |||||
| tag_v['1.0'] = tag_v['2.0'] | |||||
| tag_v['5.0'] = tag_v['4.0'] | |||||
| self.fine_grained = fine_grained | |||||
| self.tag_v = tag_v | |||||
| self.lower = lower | |||||
| self.tokenizer = get_tokenizer() | |||||
| def _load(self, path): | |||||
| ds = DataSet() | |||||
| csv_reader = csv.reader(open(path, encoding='utf-8')) | |||||
| all_count = 0 | |||||
| real_count = 0 | |||||
| for row in csv_reader: | |||||
| all_count += 1 | |||||
| if len(row) == 2: | |||||
| target = self.tag_v[row[0] + ".0"] | |||||
| words = clean_str(row[1], self.tokenizer, self.lower) | |||||
| if len(words) != 0: | |||||
| ds.append(Instance(words=words, target=target)) | |||||
| real_count += 1 | |||||
| print("all count:", all_count) | |||||
| print("real count:", real_count) | |||||
| return ds | |||||
| def process(self, paths: Union[str, Dict[str, str]], | |||||
| train_ds: Iterable[str] = None, | |||||
| src_vocab_op: VocabularyOption = None, | |||||
| tgt_vocab_op: VocabularyOption = None, | |||||
| char_level_op=False): | |||||
| paths = check_dataloader_paths(paths) | |||||
| info = DataInfo(datasets=self.load(paths)) | |||||
| src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
| _train_ds = [info.datasets[name] | |||||
| for name in train_ds] if train_ds else info.datasets.values() | |||||
| def wordtochar(words): | |||||
| chars = [] | |||||
| for word in words: | |||||
| word = word.lower() | |||||
| for char in word: | |||||
| chars.append(char) | |||||
| chars.append('') | |||||
| chars.pop() | |||||
| return chars | |||||
| input_name, target_name = Const.INPUT, Const.TARGET | |||||
| info.vocabs = {} | |||||
| # 就分隔为char形式 | |||||
| if char_level_op: | |||||
| for dataset in info.datasets.values(): | |||||
| dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
| else: | |||||
| src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
| src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) | |||||
| info.vocabs[input_name] = src_vocab | |||||
| tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
| tgt_vocab.index_dataset( | |||||
| *info.datasets.values(), | |||||
| field_name=target_name, new_field_name=target_name) | |||||
| info.vocabs[target_name] = tgt_vocab | |||||
| info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input(Const.INPUT) | |||||
| dataset.set_target(Const.TARGET) | |||||
| return info | |||||
| def clean_str(sentence, tokenizer, char_lower=False): | |||||
| """ | |||||
| heavily borrowed from github | |||||
| https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||||
| :param sentence: is a str | |||||
| :return: | |||||
| """ | |||||
| if char_lower: | |||||
| sentence = sentence.lower() | |||||
| import re | |||||
| nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||||
| words = tokenizer(sentence) | |||||
| words_collection = [] | |||||
| for word in words: | |||||
| if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||||
| continue | |||||
| tt = nonalpnum.split(word) | |||||
| t = ''.join(tt) | |||||
| if t != '': | |||||
| words_collection.append(t) | |||||
| return words_collection | |||||
| @@ -1,14 +0,0 @@ | |||||
| __all__ = [ | |||||
| "MaxPool", | |||||
| "MaxPoolWithMask", | |||||
| "AvgPool", | |||||
| "MultiHeadAttention", | |||||
| ] | |||||
| from .pooling import MaxPool | |||||
| from .pooling import MaxPoolWithMask | |||||
| from .pooling import AvgPool | |||||
| from .pooling import AvgPoolWithMask | |||||
| from .attention import MultiHeadAttention | |||||
| @@ -8,9 +8,9 @@ import torch | |||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch import nn | from torch import nn | ||||
| from ..dropout import TimestepDropout | |||||
| from fastNLP.modules.dropout import TimestepDropout | |||||
| from ..utils import initial_parameter | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class DotAttention(nn.Module): | class DotAttention(nn.Module): | ||||
| @@ -3,7 +3,7 @@ __all__ = [ | |||||
| ] | ] | ||||
| from torch import nn | from torch import nn | ||||
| from ..aggregator.attention import MultiHeadAttention | |||||
| from fastNLP.modules.encoder.attention import MultiHeadAttention | |||||
| from ..dropout import TimestepDropout | from ..dropout import TimestepDropout | ||||
| @@ -11,7 +11,7 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||||
| 由于版权问题,本文无法提供数据集的下载,请自行下载。 | 由于版权问题,本文无法提供数据集的下载,请自行下载。 | ||||
| 原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | 原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | ||||
| 代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||||
| 代码实现采用了论文作者Lee的预处理方法,具体细节参见[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||||
| 处理之后的数据集为json格式,例子: | 处理之后的数据集为json格式,例子: | ||||
| ``` | ``` | ||||
| { | { | ||||
| @@ -25,12 +25,12 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||||
| ### embedding 数据集下载 | ### embedding 数据集下载 | ||||
| [turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | [turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | ||||
| [glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||||
| [glove embedding](https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||||
| ## 运行 | ## 运行 | ||||
| ```python | |||||
| ```shell | |||||
| # 训练代码 | # 训练代码 | ||||
| CUDA_VISIBLE_DEVICES=0 python train.py | CUDA_VISIBLE_DEVICES=0 python train.py | ||||
| # 测试代码 | # 测试代码 | ||||
| @@ -39,9 +39,9 @@ CUDA_VISIBLE_DEVICES=0 python valid.py | |||||
| ## 结果 | ## 结果 | ||||
| 原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | 原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | ||||
| 其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||||
| 其中AllenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||||
| 在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||||
| 在与AllenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||||
| ## 问题 | ## 问题 | ||||
| @@ -2,7 +2,7 @@ | |||||
| 这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | 这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | ||||
| 复现的模型有(按论文发表时间顺序排序): | 复现的模型有(按论文发表时间顺序排序): | ||||
| - CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
| - CNTN:[模型代码](model/cntn.py); [训练代码](matching_cntn.py). | |||||
| 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | ||||
| - ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | - ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | ||||
| 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | ||||
| @@ -21,7 +21,7 @@ | |||||
| model name | SNLI | MNLI | RTE | QNLI | Quora | model name | SNLI | MNLI | RTE | QNLI | Quora | ||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
| CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||||
| CNTN [代码](model/cntn.py); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 77.79 vs - | 63.29/63.16(dev) vs - | 57.04(dev) vs - | 62.38(dev) vs - | - | | |||||
| ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | ||||
| DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | ||||
| MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | ||||
| @@ -44,7 +44,7 @@ Performance on Test set: | |||||
| model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | ||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
| __performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||||
| __performance__ | 77.79 | 88.13 | - | 87.9 | 90.6 | 91.16 | |||||
| ## MNLI | ## MNLI | ||||
| [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | ||||
| @@ -60,7 +60,7 @@ Performance on Test set(matched/mismatched): | |||||
| model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | ||||
| :---: | :---: | :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: | :---: | :---: | | ||||
| __performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||||
| __performance__ | 63.29/63.16(dev) | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||||
| ## RTE | ## RTE | ||||
| @@ -92,7 +92,7 @@ Performance on __Dev__ set: | |||||
| model name | CNTN | ESIM | DIIN | MwAN | BERT | model name | CNTN | ESIM | DIIN | MwAN | BERT | ||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
| __performance__ | - | 76.97 | - | 74.6 | - | |||||
| __performance__ | 62.38 | 76.97 | - | 74.6 | - | |||||
| ## Quora | ## Quora | ||||
| @@ -3,3 +3,5 @@ torch>=1.0.0 | |||||
| tqdm>=4.28.1 | tqdm>=4.28.1 | ||||
| nltk>=3.4.1 | nltk>=3.4.1 | ||||
| requests | requests | ||||
| spacy | |||||
| h5py | |||||