| @@ -37,7 +37,7 @@ __all__ = [ | |||
| "AccuracyMetric", | |||
| "SpanFPreRecMetric", | |||
| "SQuADMetric", | |||
| "ExtractiveQAMetric", | |||
| "Optimizer", | |||
| "SGD", | |||
| @@ -61,3 +61,4 @@ __version__ = '0.4.0' | |||
| from .core import * | |||
| from . import models | |||
| from . import modules | |||
| from .io import data_loader | |||
| @@ -21,7 +21,7 @@ from .dataset import DataSet | |||
| from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
| from .instance import Instance | |||
| from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
| from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric | |||
| from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||
| from .optimizer import Optimizer, SGD, Adam | |||
| from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
| from .tester import Tester | |||
| @@ -6,7 +6,7 @@ __all__ = [ | |||
| "MetricBase", | |||
| "AccuracyMetric", | |||
| "SpanFPreRecMetric", | |||
| "SQuADMetric" | |||
| "ExtractiveQAMetric" | |||
| ] | |||
| import inspect | |||
| @@ -24,6 +24,7 @@ from .utils import seq_len_to_mask | |||
| from .vocabulary import Vocabulary | |||
| from abc import abstractmethod | |||
| class MetricBase(object): | |||
| """ | |||
| 所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | |||
| @@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1): | |||
| return y_pred_topk, y_prob_topk | |||
| class SQuADMetric(MetricBase): | |||
| r""" | |||
| 别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||
| class ExtractiveQAMetric(MetricBase): | |||
| """ | |||
| 别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` | |||
| SQuAD数据集metric | |||
| 抽取式QA(如SQuAD)的metric. | |||
| :param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
| :param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
| @@ -755,7 +756,7 @@ class SQuADMetric(MetricBase): | |||
| def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | |||
| beta=1, right_open=True, print_predict_stat=False): | |||
| super(SQuADMetric, self).__init__() | |||
| super(ExtractiveQAMetric, self).__init__() | |||
| self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | |||
| @@ -4,16 +4,26 @@ | |||
| 这些模块的使用方法如下: | |||
| """ | |||
| __all__ = [ | |||
| 'SSTLoader', | |||
| 'IMDBLoader', | |||
| 'MatchingLoader', | |||
| 'SNLILoader', | |||
| 'MNLILoader', | |||
| 'MTL16Loader', | |||
| 'QNLILoader', | |||
| 'QuoraLoader', | |||
| 'RTELoader', | |||
| 'SSTLoader', | |||
| 'SNLILoader', | |||
| 'YelpLoader', | |||
| ] | |||
| from .imdb import IMDBLoader | |||
| from .matching import MatchingLoader | |||
| from .mnli import MNLILoader | |||
| from .mtl import MTL16Loader | |||
| from .qnli import QNLILoader | |||
| from .quora import QuoraLoader | |||
| from .rte import RTELoader | |||
| from .snli import SNLILoader | |||
| from .sst import SSTLoader | |||
| from .matching import MatchingLoader, SNLILoader, \ | |||
| MNLILoader, QNLILoader, QuoraLoader, RTELoader | |||
| from .yelp import YelpLoader | |||
| @@ -0,0 +1,96 @@ | |||
| from typing import Union, Dict | |||
| from ..embed_loader import EmbeddingOption, EmbedLoader | |||
| from ..base_loader import DataSetLoader, DataInfo | |||
| from ...core.vocabulary import VocabularyOption, Vocabulary | |||
| from ...core.dataset import DataSet | |||
| from ...core.instance import Instance | |||
| from ...core.const import Const | |||
| from ..utils import get_tokenizer | |||
| class IMDBLoader(DataSetLoader): | |||
| """ | |||
| 读取IMDB数据集,DataSet包含以下fields: | |||
| words: list(str), 需要分类的文本 | |||
| target: str, 文本的标签 | |||
| """ | |||
| def __init__(self): | |||
| super(IMDBLoader, self).__init__() | |||
| self.tokenizer = get_tokenizer() | |||
| def _load(self, path): | |||
| dataset = DataSet() | |||
| with open(path, 'r', encoding="utf-8") as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if not line: | |||
| continue | |||
| parts = line.split('\t') | |||
| target = parts[0] | |||
| words = self.tokenizer(parts[1].lower()) | |||
| dataset.append(Instance(words=words, target=target)) | |||
| if len(dataset) == 0: | |||
| raise RuntimeError(f"{path} has no valid data.") | |||
| return dataset | |||
| def process(self, | |||
| paths: Union[str, Dict[str, str]], | |||
| src_vocab_opt: VocabularyOption = None, | |||
| tgt_vocab_opt: VocabularyOption = None, | |||
| char_level_op=False): | |||
| datasets = {} | |||
| info = DataInfo() | |||
| for name, path in paths.items(): | |||
| dataset = self.load(path) | |||
| datasets[name] = dataset | |||
| def wordtochar(words): | |||
| chars = [] | |||
| for word in words: | |||
| word = word.lower() | |||
| for char in word: | |||
| chars.append(char) | |||
| chars.append('') | |||
| chars.pop() | |||
| return chars | |||
| if char_level_op: | |||
| for dataset in datasets.values(): | |||
| dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||
| datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
| src_vocab.from_dataset(datasets['train'], field_name='words') | |||
| src_vocab.index_dataset(*datasets.values(), field_name='words') | |||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
| tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||
| tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||
| info.vocabs = { | |||
| Const.INPUT: src_vocab, | |||
| Const.TARGET: tgt_vocab | |||
| } | |||
| info.datasets = datasets | |||
| for name, dataset in info.datasets.items(): | |||
| dataset.set_input(Const.INPUT) | |||
| dataset.set_target(Const.TARGET) | |||
| return info | |||
| @@ -5,14 +5,13 @@ from typing import Union, Dict | |||
| from ...core.const import Const | |||
| from ...core.vocabulary import Vocabulary | |||
| from ..base_loader import DataInfo, DataSetLoader | |||
| from ..dataset_loader import JsonLoader, CSVLoader | |||
| from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
| from ...modules.encoder._bert import BertTokenizer | |||
| class MatchingLoader(DataSetLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
| 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||
| 读取Matching任务的数据集 | |||
| @@ -227,204 +226,3 @@ class MatchingLoader(DataSetLoader): | |||
| data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||
| return data_info | |||
| class SNLILoader(MatchingLoader, JsonLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
| 读取SNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| fields = { | |||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||
| 'gold_label': Const.TARGET, | |||
| } | |||
| paths = paths if paths is not None else { | |||
| 'train': 'snli_1.0_train.jsonl', | |||
| 'dev': 'snli_1.0_dev.jsonl', | |||
| 'test': 'snli_1.0_test.jsonl'} | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| JsonLoader.__init__(self, fields=fields) | |||
| def _load(self, path): | |||
| ds = JsonLoader._load(self, path) | |||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(0)) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(1)) | |||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||
| return ds | |||
| class RTELoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||
| 读取RTE数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev': 'dev.tsv', | |||
| 'test': 'test.tsv' # test set has not label | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| self.fields = { | |||
| 'sentence1': Const.INPUTS(0), | |||
| 'sentence2': Const.INPUTS(1), | |||
| 'label': Const.TARGET, | |||
| } | |||
| CSVLoader.__init__(self, sep='\t') | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| for k, v in self.fields.items(): | |||
| if v in ds.get_field_names(): | |||
| ds.rename_field(k, v) | |||
| for fields in ds.get_all_fields(): | |||
| if Const.INPUT in fields: | |||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
| return ds | |||
| class QNLILoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||
| 读取QNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev': 'dev.tsv', | |||
| 'test': 'test.tsv' # test set has not label | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| self.fields = { | |||
| 'question': Const.INPUTS(0), | |||
| 'sentence': Const.INPUTS(1), | |||
| 'label': Const.TARGET, | |||
| } | |||
| CSVLoader.__init__(self, sep='\t') | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| for k, v in self.fields.items(): | |||
| if v in ds.get_field_names(): | |||
| ds.rename_field(k, v) | |||
| for fields in ds.get_all_fields(): | |||
| if Const.INPUT in fields: | |||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
| return ds | |||
| class MNLILoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev_matched': 'dev_matched.tsv', | |||
| 'dev_mismatched': 'dev_mismatched.tsv', | |||
| 'test_matched': 'test_matched.tsv', | |||
| 'test_mismatched': 'test_mismatched.tsv', | |||
| # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
| # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
| # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| CSVLoader.__init__(self, sep='\t') | |||
| self.fields = { | |||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||
| 'gold_label': Const.TARGET, | |||
| } | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| for k, v in self.fields.items(): | |||
| if k in ds.get_field_names(): | |||
| ds.rename_field(k, v) | |||
| if Const.TARGET in ds.get_field_names(): | |||
| if ds[0][Const.TARGET] == 'hidden': | |||
| ds.delete_field(Const.TARGET) | |||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(0)) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(1)) | |||
| if Const.TARGET in ds.get_field_names(): | |||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||
| return ds | |||
| class QuoraLoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev': 'dev.tsv', | |||
| 'test': 'test.tsv', | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| return ds | |||
| @@ -0,0 +1,60 @@ | |||
| from ...core import Const | |||
| from .matching import MatchingLoader | |||
| from ..dataset_loader import CSVLoader | |||
| class MNLILoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev_matched': 'dev_matched.tsv', | |||
| 'dev_mismatched': 'dev_mismatched.tsv', | |||
| 'test_matched': 'test_matched.tsv', | |||
| 'test_mismatched': 'test_mismatched.tsv', | |||
| # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
| # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
| # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| CSVLoader.__init__(self, sep='\t') | |||
| self.fields = { | |||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||
| 'gold_label': Const.TARGET, | |||
| } | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| for k, v in self.fields.items(): | |||
| if k in ds.get_field_names(): | |||
| ds.rename_field(k, v) | |||
| if Const.TARGET in ds.get_field_names(): | |||
| if ds[0][Const.TARGET] == 'hidden': | |||
| ds.delete_field(Const.TARGET) | |||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(0)) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(1)) | |||
| if Const.TARGET in ds.get_field_names(): | |||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||
| return ds | |||
| @@ -0,0 +1,65 @@ | |||
| from typing import Union, Dict | |||
| from ..base_loader import DataInfo | |||
| from ..dataset_loader import CSVLoader | |||
| from ...core.vocabulary import Vocabulary, VocabularyOption | |||
| from ...core.const import Const | |||
| from ..utils import check_dataloader_paths | |||
| class MTL16Loader(CSVLoader): | |||
| """ | |||
| 读取MTL16数据集,DataSet包含以下fields: | |||
| words: list(str), 需要分类的文本 | |||
| target: str, 文本的标签 | |||
| 数据来源:https://pan.baidu.com/s/1c2L6vdA | |||
| """ | |||
| def __init__(self): | |||
| super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') | |||
| def _load(self, path): | |||
| dataset = super(MTL16Loader, self)._load(path) | |||
| dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) | |||
| if len(dataset) == 0: | |||
| raise RuntimeError(f"{path} has no valid data.") | |||
| return dataset | |||
| def process(self, | |||
| paths: Union[str, Dict[str, str]], | |||
| src_vocab_opt: VocabularyOption = None, | |||
| tgt_vocab_opt: VocabularyOption = None,): | |||
| paths = check_dataloader_paths(paths) | |||
| datasets = {} | |||
| info = DataInfo() | |||
| for name, path in paths.items(): | |||
| dataset = self.load(path) | |||
| datasets[name] = dataset | |||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
| src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||
| src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
| tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||
| tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||
| info.vocabs = { | |||
| Const.INPUT: src_vocab, | |||
| Const.TARGET: tgt_vocab | |||
| } | |||
| info.datasets = datasets | |||
| for name, dataset in info.datasets.items(): | |||
| dataset.set_input(Const.INPUT) | |||
| dataset.set_target(Const.TARGET) | |||
| return info | |||
| @@ -0,0 +1,45 @@ | |||
| from ...core import Const | |||
| from .matching import MatchingLoader | |||
| from ..dataset_loader import CSVLoader | |||
| class QNLILoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` | |||
| 读取QNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev': 'dev.tsv', | |||
| 'test': 'test.tsv' # test set has not label | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| self.fields = { | |||
| 'question': Const.INPUTS(0), | |||
| 'sentence': Const.INPUTS(1), | |||
| 'label': Const.TARGET, | |||
| } | |||
| CSVLoader.__init__(self, sep='\t') | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| for k, v in self.fields.items(): | |||
| if k in ds.get_field_names(): | |||
| ds.rename_field(k, v) | |||
| for fields in ds.get_all_fields(): | |||
| if Const.INPUT in fields: | |||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
| return ds | |||
| @@ -0,0 +1,32 @@ | |||
| from ...core import Const | |||
| from .matching import MatchingLoader | |||
| from ..dataset_loader import CSVLoader | |||
| class QuoraLoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` | |||
| 读取MNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev': 'dev.tsv', | |||
| 'test': 'test.tsv', | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| return ds | |||
| @@ -0,0 +1,45 @@ | |||
| from ...core import Const | |||
| from .matching import MatchingLoader | |||
| from ..dataset_loader import CSVLoader | |||
| class RTELoader(MatchingLoader, CSVLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` | |||
| 读取RTE数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| paths = paths if paths is not None else { | |||
| 'train': 'train.tsv', | |||
| 'dev': 'dev.tsv', | |||
| 'test': 'test.tsv' # test set has not label | |||
| } | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| self.fields = { | |||
| 'sentence1': Const.INPUTS(0), | |||
| 'sentence2': Const.INPUTS(1), | |||
| 'label': Const.TARGET, | |||
| } | |||
| CSVLoader.__init__(self, sep='\t') | |||
| def _load(self, path): | |||
| ds = CSVLoader._load(self, path) | |||
| for k, v in self.fields.items(): | |||
| if k in ds.get_field_names(): | |||
| ds.rename_field(k, v) | |||
| for fields in ds.get_all_fields(): | |||
| if Const.INPUT in fields: | |||
| ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
| return ds | |||
| @@ -0,0 +1,44 @@ | |||
| from ...core import Const | |||
| from .matching import MatchingLoader | |||
| from ..dataset_loader import JsonLoader | |||
| class SNLILoader(MatchingLoader, JsonLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` | |||
| 读取SNLI数据集,读取的DataSet包含fields:: | |||
| words1: list(str),第一句文本, premise | |||
| words2: list(str), 第二句文本, hypothesis | |||
| target: str, 真实标签 | |||
| 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
| """ | |||
| def __init__(self, paths: dict=None): | |||
| fields = { | |||
| 'sentence1_binary_parse': Const.INPUTS(0), | |||
| 'sentence2_binary_parse': Const.INPUTS(1), | |||
| 'gold_label': Const.TARGET, | |||
| } | |||
| paths = paths if paths is not None else { | |||
| 'train': 'snli_1.0_train.jsonl', | |||
| 'dev': 'snli_1.0_dev.jsonl', | |||
| 'test': 'snli_1.0_test.jsonl'} | |||
| MatchingLoader.__init__(self, paths=paths) | |||
| JsonLoader.__init__(self, fields=fields) | |||
| def _load(self, path): | |||
| ds = JsonLoader._load(self, path) | |||
| parentheses_table = str.maketrans({'(': None, ')': None}) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(0)) | |||
| ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
| new_field_name=Const.INPUTS(1)) | |||
| ds.drop(lambda x: x[Const.TARGET] == '-') | |||
| return ds | |||
| @@ -1,19 +1,19 @@ | |||
| from typing import Iterable | |||
| from typing import Union, Dict | |||
| from nltk import Tree | |||
| import spacy | |||
| from ..base_loader import DataInfo, DataSetLoader | |||
| from ..dataset_loader import CSVLoader | |||
| from ...core.vocabulary import VocabularyOption, Vocabulary | |||
| from ...core.dataset import DataSet | |||
| from ...core.const import Const | |||
| from ...core.instance import Instance | |||
| from ..utils import check_dataloader_paths, get_tokenizer | |||
| class SSTLoader(DataSetLoader): | |||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
| DATA_DIR = 'sst/' | |||
| """ | |||
| 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||
| 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` | |||
| 读取SST数据集, DataSet包含fields:: | |||
| @@ -26,6 +26,9 @@ class SSTLoader(DataSetLoader): | |||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
| """ | |||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
| DATA_DIR = 'sst/' | |||
| def __init__(self, subtree=False, fine_grained=False): | |||
| self.subtree = subtree | |||
| @@ -98,3 +101,72 @@ class SSTLoader(DataSetLoader): | |||
| return info | |||
| class SST2Loader(CSVLoader): | |||
| """ | |||
| 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||
| """ | |||
| def __init__(self): | |||
| super(SST2Loader, self).__init__(sep='\t') | |||
| self.tokenizer = get_tokenizer() | |||
| self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} | |||
| def _load(self, path: str) -> DataSet: | |||
| ds = super(SST2Loader, self)._load(path) | |||
| ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) | |||
| print("all count:", len(ds)) | |||
| return ds | |||
| def process(self, | |||
| paths: Union[str, Dict[str, str]], | |||
| src_vocab_opt: VocabularyOption = None, | |||
| tgt_vocab_opt: VocabularyOption = None, | |||
| char_level_op=False): | |||
| paths = check_dataloader_paths(paths) | |||
| datasets = {} | |||
| info = DataInfo() | |||
| for name, path in paths.items(): | |||
| dataset = self.load(path) | |||
| datasets[name] = dataset | |||
| def wordtochar(words): | |||
| chars = [] | |||
| for word in words: | |||
| word = word.lower() | |||
| for char in word: | |||
| chars.append(char) | |||
| chars.append('') | |||
| chars.pop() | |||
| return chars | |||
| input_name, target_name = Const.INPUT, Const.TARGET | |||
| info.vocabs={} | |||
| # 就分隔为char形式 | |||
| if char_level_op: | |||
| for dataset in datasets.values(): | |||
| dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
| src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||
| src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
| tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||
| tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||
| info.vocabs = { | |||
| Const.INPUT: src_vocab, | |||
| Const.TARGET: tgt_vocab | |||
| } | |||
| info.datasets = datasets | |||
| for name, dataset in info.datasets.items(): | |||
| dataset.set_input(Const.INPUT) | |||
| dataset.set_target(Const.TARGET) | |||
| return info | |||
| @@ -0,0 +1,126 @@ | |||
| import csv | |||
| from typing import Iterable | |||
| from ...core.const import Const | |||
| from ...core import DataSet, Instance, Vocabulary | |||
| from ...core.vocabulary import VocabularyOption | |||
| from ..base_loader import DataInfo,DataSetLoader | |||
| from typing import Union, Dict | |||
| from ..utils import check_dataloader_paths, get_tokenizer | |||
| class YelpLoader(DataSetLoader): | |||
| """ | |||
| 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||
| words: list(str), 需要分类的文本 | |||
| target: str, 文本的标签 | |||
| chars:list(str),未index的字符列表 | |||
| 数据集:yelp_full/yelp_polarity | |||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
| :param lower: 是否需要自动转小写,默认为False。 | |||
| """ | |||
| def __init__(self, fine_grained=False, lower=False): | |||
| super(YelpLoader, self).__init__() | |||
| tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||
| '4.0': 'positive', '5.0': 'very positive'} | |||
| if not fine_grained: | |||
| tag_v['1.0'] = tag_v['2.0'] | |||
| tag_v['5.0'] = tag_v['4.0'] | |||
| self.fine_grained = fine_grained | |||
| self.tag_v = tag_v | |||
| self.lower = lower | |||
| self.tokenizer = get_tokenizer() | |||
| def _load(self, path): | |||
| ds = DataSet() | |||
| csv_reader = csv.reader(open(path, encoding='utf-8')) | |||
| all_count = 0 | |||
| real_count = 0 | |||
| for row in csv_reader: | |||
| all_count += 1 | |||
| if len(row) == 2: | |||
| target = self.tag_v[row[0] + ".0"] | |||
| words = clean_str(row[1], self.tokenizer, self.lower) | |||
| if len(words) != 0: | |||
| ds.append(Instance(words=words, target=target)) | |||
| real_count += 1 | |||
| print("all count:", all_count) | |||
| print("real count:", real_count) | |||
| return ds | |||
| def process(self, paths: Union[str, Dict[str, str]], | |||
| train_ds: Iterable[str] = None, | |||
| src_vocab_op: VocabularyOption = None, | |||
| tgt_vocab_op: VocabularyOption = None, | |||
| char_level_op=False): | |||
| paths = check_dataloader_paths(paths) | |||
| info = DataInfo(datasets=self.load(paths)) | |||
| src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
| if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
| _train_ds = [info.datasets[name] | |||
| for name in train_ds] if train_ds else info.datasets.values() | |||
| def wordtochar(words): | |||
| chars = [] | |||
| for word in words: | |||
| word = word.lower() | |||
| for char in word: | |||
| chars.append(char) | |||
| chars.append('') | |||
| chars.pop() | |||
| return chars | |||
| input_name, target_name = Const.INPUT, Const.TARGET | |||
| info.vocabs = {} | |||
| # 就分隔为char形式 | |||
| if char_level_op: | |||
| for dataset in info.datasets.values(): | |||
| dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||
| else: | |||
| src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||
| src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) | |||
| info.vocabs[input_name] = src_vocab | |||
| tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||
| tgt_vocab.index_dataset( | |||
| *info.datasets.values(), | |||
| field_name=target_name, new_field_name=target_name) | |||
| info.vocabs[target_name] = tgt_vocab | |||
| info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) | |||
| for name, dataset in info.datasets.items(): | |||
| dataset.set_input(Const.INPUT) | |||
| dataset.set_target(Const.TARGET) | |||
| return info | |||
| def clean_str(sentence, tokenizer, char_lower=False): | |||
| """ | |||
| heavily borrowed from github | |||
| https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||
| :param sentence: is a str | |||
| :return: | |||
| """ | |||
| if char_lower: | |||
| sentence = sentence.lower() | |||
| import re | |||
| nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
| words = tokenizer(sentence) | |||
| words_collection = [] | |||
| for word in words: | |||
| if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||
| continue | |||
| tt = nonalpnum.split(word) | |||
| t = ''.join(tt) | |||
| if t != '': | |||
| words_collection.append(t) | |||
| return words_collection | |||
| @@ -1,14 +0,0 @@ | |||
| __all__ = [ | |||
| "MaxPool", | |||
| "MaxPoolWithMask", | |||
| "AvgPool", | |||
| "MultiHeadAttention", | |||
| ] | |||
| from .pooling import MaxPool | |||
| from .pooling import MaxPoolWithMask | |||
| from .pooling import AvgPool | |||
| from .pooling import AvgPoolWithMask | |||
| from .attention import MultiHeadAttention | |||
| @@ -8,9 +8,9 @@ import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from ..dropout import TimestepDropout | |||
| from fastNLP.modules.dropout import TimestepDropout | |||
| from ..utils import initial_parameter | |||
| from fastNLP.modules.utils import initial_parameter | |||
| class DotAttention(nn.Module): | |||
| @@ -3,7 +3,7 @@ __all__ = [ | |||
| ] | |||
| from torch import nn | |||
| from ..aggregator.attention import MultiHeadAttention | |||
| from fastNLP.modules.encoder.attention import MultiHeadAttention | |||
| from ..dropout import TimestepDropout | |||
| @@ -11,7 +11,7 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||
| 由于版权问题,本文无法提供数据集的下载,请自行下载。 | |||
| 原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | |||
| 代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||
| 代码实现采用了论文作者Lee的预处理方法,具体细节参见[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||
| 处理之后的数据集为json格式,例子: | |||
| ``` | |||
| { | |||
| @@ -25,12 +25,12 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||
| ### embedding 数据集下载 | |||
| [turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | |||
| [glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||
| [glove embedding](https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||
| ## 运行 | |||
| ```python | |||
| ```shell | |||
| # 训练代码 | |||
| CUDA_VISIBLE_DEVICES=0 python train.py | |||
| # 测试代码 | |||
| @@ -39,9 +39,9 @@ CUDA_VISIBLE_DEVICES=0 python valid.py | |||
| ## 结果 | |||
| 原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | |||
| 其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||
| 其中AllenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||
| 在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||
| 在与AllenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||
| ## 问题 | |||
| @@ -2,7 +2,7 @@ | |||
| 这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | |||
| 复现的模型有(按论文发表时间顺序排序): | |||
| - CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
| - CNTN:[模型代码](model/cntn.py); [训练代码](matching_cntn.py). | |||
| 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | |||
| - ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | |||
| 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | |||
| @@ -21,7 +21,7 @@ | |||
| model name | SNLI | MNLI | RTE | QNLI | Quora | |||
| :---: | :---: | :---: | :---: | :---: | :---: | |||
| CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||
| CNTN [代码](model/cntn.py); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 77.79 vs - | 63.29/63.16(dev) vs - | 57.04(dev) vs - | 62.38(dev) vs - | - | | |||
| ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | |||
| DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||
| MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | |||
| @@ -44,7 +44,7 @@ Performance on Test set: | |||
| model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | |||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | |||
| __performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||
| __performance__ | 77.79 | 88.13 | - | 87.9 | 90.6 | 91.16 | |||
| ## MNLI | |||
| [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | |||
| @@ -60,7 +60,7 @@ Performance on Test set(matched/mismatched): | |||
| model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | |||
| :---: | :---: | :---: | :---: | :---: | :---: | | |||
| __performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||
| __performance__ | 63.29/63.16(dev) | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||
| ## RTE | |||
| @@ -92,7 +92,7 @@ Performance on __Dev__ set: | |||
| model name | CNTN | ESIM | DIIN | MwAN | BERT | |||
| :---: | :---: | :---: | :---: | :---: | :---: | |||
| __performance__ | - | 76.97 | - | 74.6 | - | |||
| __performance__ | 62.38 | 76.97 | - | 74.6 | - | |||
| ## Quora | |||
| @@ -3,3 +3,5 @@ torch>=1.0.0 | |||
| tqdm>=4.28.1 | |||
| nltk>=3.4.1 | |||
| requests | |||
| spacy | |||
| h5py | |||