| @@ -71,7 +71,9 @@ __all__ = [ | |||||
| "QuoraLoader", | "QuoraLoader", | ||||
| "SNLILoader", | "SNLILoader", | ||||
| "QNLILoader", | "QNLILoader", | ||||
| "RTELoader" | |||||
| "RTELoader", | |||||
| "CRLoader" | |||||
| ] | ] | ||||
| from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader | from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader | ||||
| from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | ||||
| @@ -81,3 +83,4 @@ from .json import JsonLoader | |||||
| from .loader import Loader | from .loader import Loader | ||||
| from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | ||||
| from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | ||||
| from .coreference import CRLoader | |||||
| @@ -0,0 +1,24 @@ | |||||
| from ...core.dataset import DataSet | |||||
| from ..file_reader import _read_json | |||||
| from ...core.instance import Instance | |||||
| from .json import JsonLoader | |||||
| class CRLoader(JsonLoader): | |||||
| def __init__(self, fields=None, dropna=False): | |||||
| super().__init__(fields, dropna) | |||||
| def _load(self, path): | |||||
| """ | |||||
| 加载数据 | |||||
| :param path: | |||||
| :return: | |||||
| """ | |||||
| dataset = DataSet() | |||||
| for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
| if self.fields: | |||||
| ins = {self.fields[k]: v for k, v in d.items()} | |||||
| else: | |||||
| ins = d | |||||
| dataset.append(Instance(**ins)) | |||||
| return dataset | |||||
| @@ -37,6 +37,8 @@ __all__ = [ | |||||
| "QuoraPipe", | "QuoraPipe", | ||||
| "QNLIPipe", | "QNLIPipe", | ||||
| "MNLIPipe", | "MNLIPipe", | ||||
| "CoreferencePipe" | |||||
| ] | ] | ||||
| from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe | from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe | ||||
| @@ -46,3 +48,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe | |||||
| from .pipe import Pipe | from .pipe import Pipe | ||||
| from .conll import Conll2003Pipe | from .conll import Conll2003Pipe | ||||
| from .cws import CWSPipe | from .cws import CWSPipe | ||||
| from .coreference import CoreferencePipe | |||||
| @@ -0,0 +1,115 @@ | |||||
| __all__ = [ | |||||
| "CoreferencePipe" | |||||
| ] | |||||
| from .pipe import Pipe | |||||
| from ..data_bundle import DataBundle | |||||
| from ..loader.coreference import CRLoader | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| import numpy as np | |||||
| import collections | |||||
| class CoreferencePipe(Pipe): | |||||
| def __init__(self,config): | |||||
| super().__init__() | |||||
| self.config = config | |||||
| def process(self, data_bundle: DataBundle): | |||||
| genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||||
| vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name='sentences') | |||||
| vocab.build_vocab() | |||||
| word2id = vocab.word2idx | |||||
| char_dict = get_char_dict(self.config.char_path) | |||||
| for name, ds in data_bundle.datasets.items(): | |||||
| ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter), | |||||
| self.config.max_sentences, is_train=name == 'train')[0], | |||||
| new_field_name='doc_np') | |||||
| ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter), | |||||
| self.config.max_sentences, is_train=name == 'train')[1], | |||||
| new_field_name='char_index') | |||||
| ds.apply(lambda x: doc2numpy(x['sentences'], word2id, char_dict, max(self.config.filter), | |||||
| self.config.max_sentences, is_train=name == 'train')[2], | |||||
| new_field_name='seq_len') | |||||
| ds.apply(lambda x: speaker2numpy(x["speakers"], self.config.max_sentences, is_train=name == 'train'), | |||||
| new_field_name='speaker_ids_np') | |||||
| ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') | |||||
| ds.set_ignore_type('clusters') | |||||
| ds.set_padder('clusters', None) | |||||
| ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") | |||||
| ds.set_target("clusters") | |||||
| return data_bundle | |||||
| def process_from_file(self, paths): | |||||
| bundle = CRLoader().load(paths) | |||||
| return self.process(bundle) | |||||
| # helper | |||||
| def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train): | |||||
| docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train) | |||||
| assert max(length) == max_len | |||||
| assert char_index.shape[0] == len(length) | |||||
| assert char_index.shape[1] == max_len | |||||
| doc_np = np.zeros((len(docvec), max_len), int) | |||||
| for i in range(len(docvec)): | |||||
| for j in range(len(docvec[i])): | |||||
| doc_np[i][j] = docvec[i][j] | |||||
| return doc_np, char_index, length | |||||
| def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train): | |||||
| max_len = 0 | |||||
| max_word_length = 0 | |||||
| docvex = [] | |||||
| length = [] | |||||
| if is_train: | |||||
| sent_num = min(max_sentences,len(doc)) | |||||
| else: | |||||
| sent_num = len(doc) | |||||
| for i in range(sent_num): | |||||
| sent = doc[i] | |||||
| length.append(len(sent)) | |||||
| if (len(sent) > max_len): | |||||
| max_len = len(sent) | |||||
| sent_vec =[] | |||||
| for j,word in enumerate(sent): | |||||
| if len(word)>max_word_length: | |||||
| max_word_length = len(word) | |||||
| if word in word2id: | |||||
| sent_vec.append(word2id[word]) | |||||
| else: | |||||
| sent_vec.append(word2id["UNK"]) | |||||
| docvex.append(sent_vec) | |||||
| char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int) | |||||
| for i in range(sent_num): | |||||
| sent = doc[i] | |||||
| for j,word in enumerate(sent): | |||||
| char_index[i, j, :len(word)] = [char_dict[c] for c in word] | |||||
| return docvex,char_index,length,max_len | |||||
| def speaker2numpy(speakers_raw,max_sentences,is_train): | |||||
| if is_train and len(speakers_raw)> max_sentences: | |||||
| speakers_raw = speakers_raw[0:max_sentences] | |||||
| speakers = flatten(speakers_raw) | |||||
| speaker_dict = {s: i for i, s in enumerate(set(speakers))} | |||||
| speaker_ids = np.array([speaker_dict[s] for s in speakers]) | |||||
| return speaker_ids | |||||
| # 展平 | |||||
| def flatten(l): | |||||
| return [item for sublist in l for item in sublist] | |||||
| def get_char_dict(path): | |||||
| vocab = ["<UNK>"] | |||||
| with open(path) as f: | |||||
| vocab.extend(c.strip() for c in f.readlines()) | |||||
| char_dict = collections.defaultdict(int) | |||||
| char_dict.update({c: i for i, c in enumerate(vocab)}) | |||||
| return char_dict | |||||
| @@ -1,4 +1,4 @@ | |||||
| # 共指消解复现 | |||||
| # 指代消解复现 | |||||
| ## 介绍 | ## 介绍 | ||||
| Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 | Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 | ||||
| 对于涉及自然语言理解的许多更高级别的NLP任务来说, | 对于涉及自然语言理解的许多更高级别的NLP任务来说, | ||||
| @@ -1,68 +0,0 @@ | |||||
| from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance | |||||
| from fastNLP.io.file_reader import _read_json | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.io.data_bundle import DataBundle | |||||
| from reproduction.coreference_resolution.model.config import Config | |||||
| import reproduction.coreference_resolution.model.preprocess as preprocess | |||||
| class CRLoader(JsonLoader): | |||||
| def __init__(self, fields=None, dropna=False): | |||||
| super().__init__(fields, dropna) | |||||
| def _load(self, path): | |||||
| """ | |||||
| 加载数据 | |||||
| :param path: | |||||
| :return: | |||||
| """ | |||||
| dataset = DataSet() | |||||
| for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
| if self.fields: | |||||
| ins = {self.fields[k]: v for k, v in d.items()} | |||||
| else: | |||||
| ins = d | |||||
| dataset.append(Instance(**ins)) | |||||
| return dataset | |||||
| def process(self, paths, **kwargs): | |||||
| data_info = DataBundle() | |||||
| for name in ['train', 'test', 'dev']: | |||||
| data_info.datasets[name] = self.load(paths[name]) | |||||
| config = Config() | |||||
| vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences') | |||||
| vocab.build_vocab() | |||||
| word2id = vocab.word2idx | |||||
| char_dict = preprocess.get_char_dict(config.char_path) | |||||
| data_info.vocabs = vocab | |||||
| genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||||
| for name, ds in data_info.datasets.items(): | |||||
| ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
| config.max_sentences, is_train=name=='train')[0], | |||||
| new_field_name='doc_np') | |||||
| ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
| config.max_sentences, is_train=name=='train')[1], | |||||
| new_field_name='char_index') | |||||
| ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), | |||||
| config.max_sentences, is_train=name=='train')[2], | |||||
| new_field_name='seq_len') | |||||
| ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'), | |||||
| new_field_name='speaker_ids_np') | |||||
| ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') | |||||
| ds.set_ignore_type('clusters') | |||||
| ds.set_padder('clusters', None) | |||||
| ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") | |||||
| ds.set_target("clusters") | |||||
| # train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False) | |||||
| # train, dev = train_dev.split(343 / (2802 + 343), shuffle=False) | |||||
| return data_info | |||||
| @@ -1,14 +1,14 @@ | |||||
| import unittest | import unittest | ||||
| from ..data_load.cr_loader import CRLoader | |||||
| from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
| from reproduction.coreference_resolution.model.config import Config | |||||
| class Test_CRLoader(unittest.TestCase): | class Test_CRLoader(unittest.TestCase): | ||||
| def test_cr_loader(self): | def test_cr_loader(self): | ||||
| train_path = 'data/train.english.jsonlines.mini' | |||||
| dev_path = 'data/dev.english.jsonlines.minid' | |||||
| test_path = 'data/test.english.jsonlines' | |||||
| cr = CRLoader() | |||||
| data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path}) | |||||
| print(data_info.datasets['train'][0]) | |||||
| print(data_info.datasets['dev'][0]) | |||||
| print(data_info.datasets['test'][0]) | |||||
| config = Config() | |||||
| bundle = CoreferencePipe(config).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) | |||||
| print(bundle.datasets['train'][0]) | |||||
| print(bundle.datasets['dev'][0]) | |||||
| print(bundle.datasets['test'][0]) | |||||
| @@ -7,7 +7,8 @@ from torch.optim import Adam | |||||
| from fastNLP.core.callback import Callback, GradientClipCallback | from fastNLP.core.callback import Callback, GradientClipCallback | ||||
| from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
| from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||||
| from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
| from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
| from reproduction.coreference_resolution.model.model_re import Model | from reproduction.coreference_resolution.model.model_re import Model | ||||
| from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss | ||||
| @@ -38,11 +39,8 @@ if __name__ == "__main__": | |||||
| @cache_results('cache.pkl') | @cache_results('cache.pkl') | ||||
| def cache(): | def cache(): | ||||
| cr_train_dev_test = CRLoader() | |||||
| data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path, | |||||
| 'test': config.test_path}) | |||||
| return data_info | |||||
| bundle = CoreferencePipe(Config()).process_from_file({'train': config.train_path, 'dev': config.dev_path,'test': config.test_path}) | |||||
| return bundle | |||||
| data_info = cache() | data_info = cache() | ||||
| print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), | print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), | ||||
| "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) | "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) | ||||
| @@ -1,7 +1,8 @@ | |||||
| import torch | import torch | ||||
| from reproduction.coreference_resolution.model.config import Config | from reproduction.coreference_resolution.model.config import Config | ||||
| from reproduction.coreference_resolution.model.metric import CRMetric | from reproduction.coreference_resolution.model.metric import CRMetric | ||||
| from reproduction.coreference_resolution.data_load.cr_loader import CRLoader | |||||
| from fastNLP.io.pipe.coreference import CoreferencePipe | |||||
| from fastNLP import Tester | from fastNLP import Tester | ||||
| import argparse | import argparse | ||||
| @@ -11,13 +12,12 @@ if __name__=='__main__': | |||||
| parser.add_argument('--path') | parser.add_argument('--path') | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| cr_loader = CRLoader() | |||||
| config = Config() | config = Config() | ||||
| data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path, | |||||
| 'test': config.test_path}) | |||||
| bundle = CoreferencePipe(Config()).process_from_file( | |||||
| {'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path}) | |||||
| metirc = CRMetric() | metirc = CRMetric() | ||||
| model = torch.load(args.path) | model = torch.load(args.path) | ||||
| tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0") | |||||
| tester = Tester(bundle.datasets['test'],model,metirc,batch_size=1,device="cuda:0") | |||||
| tester.test() | tester.test() | ||||
| print('test over') | print('test over') | ||||