| @@ -0,0 +1,284 @@ | |||
| from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
| from fastNLP.io.dataset_loader import ConllLoader | |||
| import numpy as np | |||
| from itertools import chain | |||
| from fastNLP import DataSet, Vocabulary | |||
| from functools import partial | |||
| import os | |||
| from typing import Union, Dict | |||
| from reproduction.utils import check_dataloader_paths | |||
| class CTBxJointLoader(DataSetLoader): | |||
| """ | |||
| 文件夹下应该具有以下的文件结构 | |||
| -train.conllx | |||
| -dev.conllx | |||
| -test.conllx | |||
| 每个文件中的内容如下(空格隔开不同的句子, 共有) | |||
| 1 费孝通 _ NR NR _ 3 nsubjpass _ _ | |||
| 2 被 _ SB SB _ 3 pass _ _ | |||
| 3 授予 _ VV VV _ 0 root _ _ | |||
| 4 麦格赛赛 _ NR NR _ 5 nn _ _ | |||
| 5 奖 _ NN NN _ 3 dobj _ _ | |||
| 1 新华社 _ NR NR _ 7 dep _ _ | |||
| 2 马尼拉 _ NR NR _ 7 dep _ _ | |||
| 3 8月 _ NT NT _ 7 dep _ _ | |||
| 4 31日 _ NT NT _ 7 dep _ _ | |||
| ... | |||
| """ | |||
| def __init__(self): | |||
| self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) | |||
| def load(self, path:str): | |||
| """ | |||
| 给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 | |||
| words: list[str] | |||
| pos_tags: list[str] | |||
| heads: list[int] | |||
| labels: list[str] | |||
| :param path: | |||
| :return: | |||
| """ | |||
| dataset = self._loader.load(path) | |||
| dataset.heads.int() | |||
| return dataset | |||
| def process(self, paths): | |||
| """ | |||
| :param paths: | |||
| :return: | |||
| Dataset包含以下的field | |||
| chars: | |||
| bigrams: | |||
| trigrams: | |||
| pre_chars: | |||
| pre_bigrams: | |||
| pre_trigrams: | |||
| seg_targets: | |||
| seg_masks: | |||
| seq_lens: | |||
| char_labels: | |||
| char_heads: | |||
| gold_word_pairs: | |||
| seg_targets: | |||
| seg_masks: | |||
| char_labels: | |||
| char_heads: | |||
| pun_masks: | |||
| gold_label_word_pairs: | |||
| """ | |||
| paths = check_dataloader_paths(paths) | |||
| data = DataInfo() | |||
| for name, path in paths.items(): | |||
| dataset = self.load(path) | |||
| data.datasets[name] = dataset | |||
| char_labels_vocab = Vocabulary(padding=None, unknown=None) | |||
| def process(dataset, char_label_vocab): | |||
| dataset.apply(add_word_lst, new_field_name='word_lst') | |||
| dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') | |||
| dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') | |||
| dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') | |||
| dataset.apply(add_char_heads, new_field_name='char_heads') | |||
| dataset.apply(add_char_labels, new_field_name='char_labels') | |||
| dataset.apply(add_segs, new_field_name='seg_targets') | |||
| dataset.apply(add_mask, new_field_name='seg_masks') | |||
| dataset.add_seq_len('chars', new_field_name='seq_lens') | |||
| dataset.apply(add_pun_masks, new_field_name='pun_masks') | |||
| if len(char_label_vocab.word_count)==0: | |||
| char_label_vocab.from_dataset(dataset, field_name='char_labels') | |||
| char_label_vocab.index_dataset(dataset, field_name='char_labels') | |||
| new_dataset = add_root(dataset) | |||
| new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) | |||
| global add_label_word_pairs | |||
| add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) | |||
| new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) | |||
| new_dataset.set_pad_val('char_labels', -1) | |||
| new_dataset.set_pad_val('char_heads', -1) | |||
| return new_dataset | |||
| for name in list(paths.keys()): | |||
| dataset = data.datasets[name] | |||
| dataset = process(dataset, char_labels_vocab) | |||
| data.datasets[name] = dataset | |||
| data.vocabs['char_labels'] = char_labels_vocab | |||
| char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') | |||
| bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') | |||
| trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') | |||
| for name in ['chars', 'bigrams', 'trigrams']: | |||
| vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) | |||
| vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) | |||
| data.vocabs['pre_{}'.format(name)] = vocab | |||
| for name, vocab in zip(['chars', 'bigrams', 'trigrams'], | |||
| [char_vocab, bigram_vocab, trigram_vocab]): | |||
| vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) | |||
| data.vocabs[name] = vocab | |||
| for name, dataset in data.datasets.items(): | |||
| dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', | |||
| 'pre_bigrams', 'pre_trigrams') | |||
| dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', | |||
| 'char_heads', | |||
| 'pun_masks', 'gold_label_word_pairs') | |||
| return data | |||
| def add_label_word_pairs(instance, label_vocab): | |||
| # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||
| word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||
| word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||
| word_end_indexes.insert(0, 0) | |||
| word_pairs = [] | |||
| labels = instance['labels'] | |||
| pos_tags = instance['pos_tags'] | |||
| for idx, head in enumerate(instance['heads']): | |||
| if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||
| continue | |||
| label = label_vocab.to_index(labels[idx]) | |||
| if head==0: | |||
| word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||
| else: | |||
| word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, | |||
| (word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||
| return word_pairs | |||
| def add_word_pairs(instance): | |||
| # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||
| word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||
| word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||
| word_end_indexes.insert(0, 0) | |||
| word_pairs = [] | |||
| pos_tags = instance['pos_tags'] | |||
| for idx, head in enumerate(instance['heads']): | |||
| if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||
| continue | |||
| if head==0: | |||
| word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||
| else: | |||
| word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), | |||
| (word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||
| return word_pairs | |||
| def add_root(dataset): | |||
| new_dataset = DataSet() | |||
| for sample in dataset: | |||
| chars = ['char_root'] + sample['chars'] | |||
| bigrams = ['bigram_root'] + sample['bigrams'] | |||
| trigrams = ['trigram_root'] + sample['trigrams'] | |||
| seq_lens = sample['seq_lens']+1 | |||
| char_labels = [0] + sample['char_labels'] | |||
| char_heads = [0] + sample['char_heads'] | |||
| sample['chars'] = chars | |||
| sample['bigrams'] = bigrams | |||
| sample['trigrams'] = trigrams | |||
| sample['seq_lens'] = seq_lens | |||
| sample['char_labels'] = char_labels | |||
| sample['char_heads'] = char_heads | |||
| new_dataset.append(sample) | |||
| return new_dataset | |||
| def add_pun_masks(instance): | |||
| tags = instance['pos_tags'] | |||
| pun_masks = [] | |||
| for word, tag in zip(instance['words'], tags): | |||
| if tag=='PU': | |||
| pun_masks.extend([1]*len(word)) | |||
| else: | |||
| pun_masks.extend([0]*len(word)) | |||
| return pun_masks | |||
| def add_word_lst(instance): | |||
| words = instance['words'] | |||
| word_lst = [list(word) for word in words] | |||
| return word_lst | |||
| def add_bigram(instance): | |||
| chars = instance['chars'] | |||
| length = len(chars) | |||
| chars = chars + ['<eos>'] | |||
| bigrams = [] | |||
| for i in range(length): | |||
| bigrams.append(''.join(chars[i:i + 2])) | |||
| return bigrams | |||
| def add_trigram(instance): | |||
| chars = instance['chars'] | |||
| length = len(chars) | |||
| chars = chars + ['<eos>'] * 2 | |||
| trigrams = [] | |||
| for i in range(length): | |||
| trigrams.append(''.join(chars[i:i + 3])) | |||
| return trigrams | |||
| def add_char_heads(instance): | |||
| words = instance['word_lst'] | |||
| heads = instance['heads'] | |||
| char_heads = [] | |||
| char_index = 1 # 因此存在root节点所以需要从1开始 | |||
| head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 | |||
| for word, head in zip(words, heads): | |||
| char_head = [] | |||
| if len(word)>1: | |||
| char_head.append(char_index+1) | |||
| char_index += 1 | |||
| for _ in range(len(word)-2): | |||
| char_index += 1 | |||
| char_head.append(char_index) | |||
| char_index += 1 | |||
| char_head.append(head_end_indexes[head-1]) | |||
| char_heads.extend(char_head) | |||
| return char_heads | |||
| def add_char_labels(instance): | |||
| """ | |||
| 将word_lst中的数据按照下面的方式设置label | |||
| 比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. | |||
| 对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label | |||
| :param instance: | |||
| :return: | |||
| """ | |||
| words = instance['word_lst'] | |||
| labels = instance['labels'] | |||
| char_labels = [] | |||
| for word, label in zip(words, labels): | |||
| for _ in range(len(word)-1): | |||
| char_labels.append('APP') | |||
| char_labels.append(label) | |||
| return char_labels | |||
| # add seg_targets | |||
| def add_segs(instance): | |||
| words = instance['word_lst'] | |||
| segs = [0]*len(instance['chars']) | |||
| index = 0 | |||
| for word in words: | |||
| index = index + len(word) - 1 | |||
| segs[index] = len(word)-1 | |||
| index = index + 1 | |||
| return segs | |||
| # add target_masks | |||
| def add_mask(instance): | |||
| words = instance['word_lst'] | |||
| mask = [] | |||
| for word in words: | |||
| mask.extend([0] * (len(word) - 1)) | |||
| mask.append(1) | |||
| return mask | |||
| @@ -0,0 +1,311 @@ | |||
| from fastNLP.models.biaffine_parser import BiaffineParser | |||
| from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn import functional as F | |||
| from fastNLP.modules.dropout import TimestepDropout | |||
| from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||
| from fastNLP import seq_len_to_mask | |||
| from fastNLP.modules import Embedding | |||
| def drop_input_independent(word_embeddings, dropout_emb): | |||
| batch_size, seq_length, _ = word_embeddings.size() | |||
| word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) | |||
| word_masks = torch.bernoulli(word_masks) | |||
| word_masks = word_masks.unsqueeze(dim=2) | |||
| word_embeddings = word_embeddings * word_masks | |||
| return word_embeddings | |||
| class CharBiaffineParser(BiaffineParser): | |||
| def __init__(self, char_vocab_size, | |||
| emb_dim, | |||
| bigram_vocab_size, | |||
| trigram_vocab_size, | |||
| num_label, | |||
| rnn_layers=3, | |||
| rnn_hidden_size=800, #单向的数量 | |||
| arc_mlp_size=500, | |||
| label_mlp_size=100, | |||
| dropout=0.3, | |||
| encoder='lstm', | |||
| use_greedy_infer=False, | |||
| app_index = 0, | |||
| pre_chars_embed=None, | |||
| pre_bigrams_embed=None, | |||
| pre_trigrams_embed=None): | |||
| super(BiaffineParser, self).__init__() | |||
| rnn_out_size = 2 * rnn_hidden_size | |||
| self.char_embed = Embedding((char_vocab_size, emb_dim)) | |||
| self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) | |||
| self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) | |||
| if pre_chars_embed: | |||
| self.pre_char_embed = Embedding(pre_chars_embed) | |||
| self.pre_char_embed.requires_grad = False | |||
| if pre_bigrams_embed: | |||
| self.pre_bigram_embed = Embedding(pre_bigrams_embed) | |||
| self.pre_bigram_embed.requires_grad = False | |||
| if pre_trigrams_embed: | |||
| self.pre_trigram_embed = Embedding(pre_trigrams_embed) | |||
| self.pre_trigram_embed.requires_grad = False | |||
| self.timestep_drop = TimestepDropout(dropout) | |||
| self.encoder_name = encoder | |||
| if encoder == 'var-lstm': | |||
| self.encoder = VarLSTM(input_size=emb_dim*3, | |||
| hidden_size=rnn_hidden_size, | |||
| num_layers=rnn_layers, | |||
| bias=True, | |||
| batch_first=True, | |||
| input_dropout=dropout, | |||
| hidden_dropout=dropout, | |||
| bidirectional=True) | |||
| elif encoder == 'lstm': | |||
| self.encoder = nn.LSTM(input_size=emb_dim*3, | |||
| hidden_size=rnn_hidden_size, | |||
| num_layers=rnn_layers, | |||
| bias=True, | |||
| batch_first=True, | |||
| dropout=dropout, | |||
| bidirectional=True) | |||
| else: | |||
| raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||
| self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||
| nn.LeakyReLU(0.1), | |||
| TimestepDropout(p=dropout),) | |||
| self.arc_mlp_size = arc_mlp_size | |||
| self.label_mlp_size = label_mlp_size | |||
| self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||
| self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||
| self.use_greedy_infer = use_greedy_infer | |||
| self.reset_parameters() | |||
| self.dropout = dropout | |||
| self.app_index = app_index | |||
| self.num_label = num_label | |||
| if self.app_index != 0: | |||
| raise ValueError("现在app_index必须等于0") | |||
| def reset_parameters(self): | |||
| for name, m in self.named_modules(): | |||
| if 'embed' in name: | |||
| pass | |||
| elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): | |||
| pass | |||
| else: | |||
| for p in m.parameters(): | |||
| if len(p.size())>1: | |||
| nn.init.xavier_normal_(p, gain=0.1) | |||
| else: | |||
| nn.init.uniform_(p, -0.1, 0.1) | |||
| def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, | |||
| pre_trigrams=None): | |||
| """ | |||
| max_len是包含root的 | |||
| :param chars: batch_size x max_len | |||
| :param ngrams: batch_size x max_len*ngram_per_char | |||
| :param seq_lens: batch_size | |||
| :param gold_heads: batch_size x max_len | |||
| :param pre_chars: batch_size x max_len | |||
| :param pre_ngrams: batch_size x max_len*ngram_per_char | |||
| :return dict: parsing results | |||
| arc_pred: [batch_size, seq_len, seq_len] | |||
| label_pred: [batch_size, seq_len, seq_len] | |||
| mask: [batch_size, seq_len] | |||
| head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||
| """ | |||
| # prepare embeddings | |||
| batch_size, seq_len = chars.shape | |||
| # print('forward {} {}'.format(batch_size, seq_len)) | |||
| # get sequence mask | |||
| mask = seq_len_to_mask(seq_lens).long() | |||
| chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] | |||
| bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] | |||
| trigrams = self.trigram_embed(trigrams) | |||
| if pre_chars is not None: | |||
| pre_chars = self.pre_char_embed(pre_chars) | |||
| # pre_chars = self.pre_char_fc(pre_chars) | |||
| chars = pre_chars + chars | |||
| if pre_bigrams is not None: | |||
| pre_bigrams = self.pre_bigram_embed(pre_bigrams) | |||
| # pre_bigrams = self.pre_bigram_fc(pre_bigrams) | |||
| bigrams = bigrams + pre_bigrams | |||
| if pre_trigrams is not None: | |||
| pre_trigrams = self.pre_trigram_embed(pre_trigrams) | |||
| # pre_trigrams = self.pre_trigram_fc(pre_trigrams) | |||
| trigrams = trigrams + pre_trigrams | |||
| x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] | |||
| # encoder, extract features | |||
| if self.training: | |||
| x = drop_input_independent(x, self.dropout) | |||
| sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||
| x = x[sort_idx] | |||
| x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||
| feat, _ = self.encoder(x) # -> [N,L,C] | |||
| feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
| _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
| feat = feat[unsort_idx] | |||
| feat = self.timestep_drop(feat) | |||
| # for arc biaffine | |||
| # mlp, reduce dim | |||
| feat = self.mlp(feat) | |||
| arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||
| arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||
| label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||
| # biaffine arc classifier | |||
| arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
| # use gold or predicted arc to predict label | |||
| if gold_heads is None or not self.training: | |||
| # use greedy decoding in training | |||
| if self.training or self.use_greedy_infer: | |||
| heads = self.greedy_decoder(arc_pred, mask) | |||
| else: | |||
| heads = self.mst_decoder(arc_pred, mask) | |||
| head_pred = heads | |||
| else: | |||
| assert self.training # must be training mode | |||
| if gold_heads is None: | |||
| heads = self.greedy_decoder(arc_pred, mask) | |||
| head_pred = heads | |||
| else: | |||
| head_pred = None | |||
| heads = gold_heads | |||
| # heads: batch_size x max_len | |||
| batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) | |||
| label_head = label_head[batch_range, heads].contiguous() | |||
| label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] | |||
| # 这里限制一下,只有当head为下一个时,才能预测app这个label | |||
| arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ | |||
| .repeat(batch_size, 1) # batch_size x max_len | |||
| app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app | |||
| app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) | |||
| app_masks[:, :, 1:] = 0 | |||
| label_pred = label_pred.masked_fill(app_masks, -np.inf) | |||
| res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||
| if head_pred is not None: | |||
| res_dict['head_pred'] = head_pred | |||
| return res_dict | |||
| @staticmethod | |||
| def loss(arc_pred, label_pred, arc_true, label_true, mask): | |||
| """ | |||
| Compute loss. | |||
| :param arc_pred: [batch_size, seq_len, seq_len] | |||
| :param label_pred: [batch_size, seq_len, n_tags] | |||
| :param arc_true: [batch_size, seq_len] | |||
| :param label_true: [batch_size, seq_len] | |||
| :param mask: [batch_size, seq_len] | |||
| :return: loss value | |||
| """ | |||
| batch_size, seq_len, _ = arc_pred.shape | |||
| flip_mask = (mask == 0) | |||
| _arc_pred = arc_pred.clone() | |||
| _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||
| arc_true[:, 0].fill_(-1) | |||
| label_true[:, 0].fill_(-1) | |||
| arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) | |||
| label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) | |||
| return arc_nll + label_nll | |||
| def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): | |||
| """ | |||
| max_len是包含root的 | |||
| :param chars: batch_size x max_len | |||
| :param ngrams: batch_size x max_len*ngram_per_char | |||
| :param seq_lens: batch_size | |||
| :param pre_chars: batch_size x max_len | |||
| :param pre_ngrams: batch_size x max_len*ngram_per_cha | |||
| :return: | |||
| """ | |||
| res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, | |||
| pre_trigrams=pre_trigrams, gold_heads=None) | |||
| output = {} | |||
| output['arc_pred'] = res.pop('head_pred') | |||
| _, label_pred = res.pop('label_pred').max(2) | |||
| output['label_pred'] = label_pred | |||
| return output | |||
| class CharParser(nn.Module): | |||
| def __init__(self, char_vocab_size, | |||
| emb_dim, | |||
| bigram_vocab_size, | |||
| trigram_vocab_size, | |||
| num_label, | |||
| rnn_layers=3, | |||
| rnn_hidden_size=400, #单向的数量 | |||
| arc_mlp_size=500, | |||
| label_mlp_size=100, | |||
| dropout=0.3, | |||
| encoder='var-lstm', | |||
| use_greedy_infer=False, | |||
| app_index = 0, | |||
| pre_chars_embed=None, | |||
| pre_bigrams_embed=None, | |||
| pre_trigrams_embed=None): | |||
| super().__init__() | |||
| self.parser = CharBiaffineParser(char_vocab_size, | |||
| emb_dim, | |||
| bigram_vocab_size, | |||
| trigram_vocab_size, | |||
| num_label, | |||
| rnn_layers, | |||
| rnn_hidden_size, #单向的数量 | |||
| arc_mlp_size, | |||
| label_mlp_size, | |||
| dropout, | |||
| encoder, | |||
| use_greedy_infer, | |||
| app_index, | |||
| pre_chars_embed=pre_chars_embed, | |||
| pre_bigrams_embed=pre_bigrams_embed, | |||
| pre_trigrams_embed=pre_trigrams_embed) | |||
| def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, | |||
| pre_trigrams=None): | |||
| res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, | |||
| pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||
| arc_pred = res_dict['arc_pred'] | |||
| label_pred = res_dict['label_pred'] | |||
| masks = res_dict['mask'] | |||
| loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) | |||
| return {'loss': loss} | |||
| def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): | |||
| res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, | |||
| pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||
| output = {} | |||
| output['head_preds'] = res.pop('head_pred') | |||
| _, label_pred = res.pop('label_pred').max(2) | |||
| output['label_preds'] = label_pred | |||
| return output | |||
| @@ -0,0 +1,65 @@ | |||
| from fastNLP.core.callback import Callback | |||
| import torch | |||
| from torch import nn | |||
| class OptimizerCallback(Callback): | |||
| def __init__(self, optimizer, scheduler, update_every=4): | |||
| super().__init__() | |||
| self._optimizer = optimizer | |||
| self.scheduler = scheduler | |||
| self._update_every = update_every | |||
| def on_backward_end(self): | |||
| if self.step % self._update_every==0: | |||
| # nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) | |||
| # self._optimizer.step() | |||
| self.scheduler.step() | |||
| # self.model.zero_grad() | |||
| class DevCallback(Callback): | |||
| def __init__(self, tester, metric_key='u_f1'): | |||
| super().__init__() | |||
| self.tester = tester | |||
| setattr(tester, 'verbose', 0) | |||
| self.metric_key = metric_key | |||
| self.record_best = False | |||
| self.best_eval_value = 0 | |||
| self.best_eval_res = None | |||
| self.best_dev_res = None # 存取dev的表现 | |||
| def on_valid_begin(self): | |||
| eval_res = self.tester.test() | |||
| metric_name = self.tester.metrics[0].__class__.__name__ | |||
| metric_value = eval_res[metric_name][self.metric_key] | |||
| if metric_value>self.best_eval_value: | |||
| self.best_eval_value = metric_value | |||
| self.best_epoch = self.trainer.epoch | |||
| self.record_best = True | |||
| self.best_eval_res = eval_res | |||
| self.test_eval_res = eval_res | |||
| eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ | |||
| self.tester._format_eval_results(eval_res) | |||
| self.pbar.write(eval_str) | |||
| def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
| if self.record_best: | |||
| self.best_dev_res = eval_result | |||
| self.record_best = False | |||
| if is_better_eval: | |||
| self.best_dev_res_on_dev = eval_result | |||
| self.best_test_res_on_dev = self.test_eval_res | |||
| self.dev_epoch = self.epoch | |||
| def on_train_end(self): | |||
| print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, | |||
| self.tester._format_eval_results(self.best_eval_res), | |||
| self.tester._format_eval_results(self.best_dev_res))) | |||
| print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, | |||
| self.tester._format_eval_results(self.best_test_res_on_dev), | |||
| self.tester._format_eval_results(self.best_dev_res_on_dev))) | |||
| @@ -0,0 +1,184 @@ | |||
| from fastNLP.core.metrics import MetricBase | |||
| from fastNLP.core.utils import seq_len_to_mask | |||
| import torch | |||
| class SegAppCharParseF1Metric(MetricBase): | |||
| # | |||
| def __init__(self, app_index): | |||
| super().__init__() | |||
| self.app_index = app_index | |||
| self.parse_head_tp = 0 | |||
| self.parse_label_tp = 0 | |||
| self.rec_tol = 0 | |||
| self.pre_tol = 0 | |||
| def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, | |||
| pun_masks): | |||
| """ | |||
| max_len是不包含root的character的长度 | |||
| :param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size | |||
| :param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size | |||
| :param head_preds: batch_size x max_len | |||
| :param label_preds: batch_size x max_len | |||
| :param seq_lens: | |||
| :param pun_masks: batch_size x | |||
| :return: | |||
| """ | |||
| # 去掉root | |||
| head_preds = head_preds[:, 1:].tolist() | |||
| label_preds = label_preds[:, 1:].tolist() | |||
| seq_lens = (seq_lens - 1).tolist() | |||
| # 先解码出words,POS,heads, labels, 对应的character范围 | |||
| for b in range(len(head_preds)): | |||
| seq_len = seq_lens[b] | |||
| head_pred = head_preds[b][:seq_len] | |||
| label_pred = label_preds[b][:seq_len] | |||
| words = [] # 存放[word_start, word_end),相对起始位置,不考虑root | |||
| heads = [] | |||
| labels = [] | |||
| ranges = [] # 对应该char是第几个word,长度是seq_len+1 | |||
| word_idx = 0 | |||
| word_start_idx = 0 | |||
| for idx, (label, head) in enumerate(zip(label_pred, head_pred)): | |||
| ranges.append(word_idx) | |||
| if label == self.app_index: | |||
| pass | |||
| else: | |||
| labels.append(label) | |||
| heads.append(head) | |||
| words.append((word_start_idx, idx+1)) | |||
| word_start_idx = idx+1 | |||
| word_idx += 1 | |||
| head_dep_tuple = [] # head在前面 | |||
| head_label_dep_tuple = [] | |||
| for idx, head in enumerate(heads): | |||
| span = words[idx] | |||
| if span[0]==span[1]-1 and pun_masks[b, span[0]]: | |||
| continue # exclude punctuations | |||
| if head == 0: | |||
| head_dep_tuple.append((('root', words[idx]))) | |||
| head_label_dep_tuple.append(('root', labels[idx], words[idx])) | |||
| else: | |||
| head_word_idx = ranges[head-1] | |||
| head_word_span = words[head_word_idx] | |||
| head_dep_tuple.append(((head_word_span, words[idx]))) | |||
| head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) | |||
| gold_head_dep_tuple = set(gold_word_pairs[b]) | |||
| gold_head_label_dep_tuple = set(gold_label_word_pairs[b]) | |||
| for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): | |||
| if head_dep in gold_head_dep_tuple: | |||
| self.parse_head_tp += 1 | |||
| if head_label_dep in gold_head_label_dep_tuple: | |||
| self.parse_label_tp += 1 | |||
| self.pre_tol += len(head_dep_tuple) | |||
| self.rec_tol += len(gold_head_dep_tuple) | |||
| def get_metric(self, reset=True): | |||
| u_p = self.parse_head_tp / self.pre_tol | |||
| u_r = self.parse_head_tp / self.rec_tol | |||
| u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) | |||
| l_p = self.parse_label_tp / self.pre_tol | |||
| l_r = self.parse_label_tp / self.rec_tol | |||
| l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) | |||
| if reset: | |||
| self.parse_head_tp = 0 | |||
| self.parse_label_tp = 0 | |||
| self.rec_tol = 0 | |||
| self.pre_tol = 0 | |||
| return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), | |||
| 'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} | |||
| class CWSMetric(MetricBase): | |||
| def __init__(self, app_index): | |||
| super().__init__() | |||
| self.app_index = app_index | |||
| self.pre = 0 | |||
| self.rec = 0 | |||
| self.tp = 0 | |||
| def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): | |||
| """ | |||
| :param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 | |||
| :param seg_masks: batch_size x max_len,只有在word结束的地方为1 | |||
| :param label_preds: batch_size x max_len | |||
| :param seq_lens: batch_size | |||
| :return: | |||
| """ | |||
| pred_masks = torch.zeros_like(seg_masks) | |||
| pred_segs = torch.zeros_like(seg_targets) | |||
| seq_lens = (seq_lens - 1).tolist() | |||
| for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): | |||
| seq_len = seq_lens[idx] | |||
| label_pred = label_pred[:seq_len] | |||
| word_len = 0 | |||
| for l_i, label in enumerate(label_pred): | |||
| if label==self.app_index and l_i!=len(label_pred)-1: | |||
| word_len += 1 | |||
| else: | |||
| pred_segs[idx, l_i] = word_len # 这个词的长度为word_len | |||
| pred_masks[idx, l_i] = 1 | |||
| word_len = 0 | |||
| right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 | |||
| self.rec += seg_masks.sum().item() | |||
| self.pre += pred_masks.sum().item() | |||
| # 且pred和target在同一个地方有值 | |||
| self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item() | |||
| def get_metric(self, reset=True): | |||
| res = {} | |||
| res['rec'] = round(self.tp/(self.rec+1e-6), 4) | |||
| res['pre'] = round(self.tp/(self.pre+1e-6), 4) | |||
| res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) | |||
| if reset: | |||
| self.pre = 0 | |||
| self.rec = 0 | |||
| self.tp = 0 | |||
| return res | |||
| class ParserMetric(MetricBase): | |||
| def __init__(self, ): | |||
| super().__init__() | |||
| self.num_arc = 0 | |||
| self.num_label = 0 | |||
| self.num_sample = 0 | |||
| def get_metric(self, reset=True): | |||
| res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), | |||
| 'LAS': round(self.num_label*1.0 / self.num_sample, 4)} | |||
| if reset: | |||
| self.num_sample = self.num_label = self.num_arc = 0 | |||
| return res | |||
| def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): | |||
| """Evaluate the performance of prediction. | |||
| """ | |||
| if seq_lens is None: | |||
| seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) | |||
| else: | |||
| seq_mask = seq_len_to_mask(seq_lens.long(), float=False) | |||
| # mask out <root> tag | |||
| seq_mask[:, 0] = 0 | |||
| head_pred_correct = (head_preds == heads).__and__(seq_mask) | |||
| label_pred_correct = (label_preds == labels).__and__(head_pred_correct) | |||
| self.num_arc += head_pred_correct.float().sum().item() | |||
| self.num_label += label_pred_correct.float().sum().item() | |||
| self.num_sample += seq_mask.sum().item() | |||
| @@ -0,0 +1,16 @@ | |||
| Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) | |||
| ### 准备数据 | |||
| 1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'. | |||
| 2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。 | |||
| 3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。 | |||
| ### 运行代码 | |||
| ``` | |||
| python train.py | |||
| ``` | |||
| ### 其它 | |||
| ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节 | |||
| learning rate scheduler. | |||
| @@ -0,0 +1,124 @@ | |||
| import sys | |||
| sys.path.append('../..') | |||
| from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader | |||
| from fastNLP.modules.encoder.embedding import StaticEmbedding | |||
| from torch import nn | |||
| from functools import partial | |||
| from reproduction.joint_cws_parse.models.CharParser import CharParser | |||
| from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric | |||
| from fastNLP import cache_results, BucketSampler, Trainer | |||
| from torch import optim | |||
| from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback | |||
| from torch.optim.lr_scheduler import LambdaLR, StepLR | |||
| from fastNLP import Tester | |||
| from fastNLP import GradientClipCallback, LRScheduler | |||
| import os | |||
| def set_random_seed(random_seed=666): | |||
| import random, numpy, torch | |||
| random.seed(random_seed) | |||
| numpy.random.seed(random_seed) | |||
| torch.cuda.manual_seed(random_seed) | |||
| torch.random.manual_seed(random_seed) | |||
| uniform_init = partial(nn.init.normal_, std=0.02) | |||
| ################################################### | |||
| # 需要变动的超参放到这里 | |||
| lr = 0.002 # 0.01~0.001 | |||
| dropout = 0.33 # 0.3~0.6 | |||
| weight_decay = 0 # 1e-5, 1e-6, 0 | |||
| arc_mlp_size = 500 # 200, 300 | |||
| rnn_hidden_size = 400 # 200, 300, 400 | |||
| rnn_layers = 3 # 2, 3 | |||
| encoder = 'var-lstm' # var-lstm, lstm | |||
| emb_size = 100 # 64 , 100 | |||
| label_mlp_size = 100 | |||
| batch_size = 32 | |||
| update_every = 4 | |||
| n_epochs = 100 | |||
| data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 | |||
| vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt | |||
| #################################################### | |||
| set_random_seed(1234) | |||
| device = 0 | |||
| # @cache_results('caches/{}.pkl'.format(data_name)) | |||
| # def get_data(): | |||
| data = CTBxJointLoader().process(data_folder) | |||
| char_labels_vocab = data.vocabs['char_labels'] | |||
| pre_chars_vocab = data.vocabs['pre_chars'] | |||
| pre_bigrams_vocab = data.vocabs['pre_bigrams'] | |||
| pre_trigrams_vocab = data.vocabs['pre_trigrams'] | |||
| chars_vocab = data.vocabs['chars'] | |||
| bigrams_vocab = data.vocabs['bigrams'] | |||
| trigrams_vocab = data.vocabs['trigrams'] | |||
| pre_chars_embed = StaticEmbedding(pre_chars_vocab, | |||
| model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), | |||
| init_method=uniform_init, normalize=False) | |||
| pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() | |||
| pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, | |||
| model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), | |||
| init_method=uniform_init, normalize=False) | |||
| pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() | |||
| pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, | |||
| model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), | |||
| init_method=uniform_init, normalize=False) | |||
| pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() | |||
| # return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data | |||
| # chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() | |||
| print(data) | |||
| model = CharParser(char_vocab_size=len(chars_vocab), | |||
| emb_dim=emb_size, | |||
| bigram_vocab_size=len(bigrams_vocab), | |||
| trigram_vocab_size=len(trigrams_vocab), | |||
| num_label=len(char_labels_vocab), | |||
| rnn_layers=rnn_layers, | |||
| rnn_hidden_size=rnn_hidden_size, | |||
| arc_mlp_size=arc_mlp_size, | |||
| label_mlp_size=label_mlp_size, | |||
| dropout=dropout, | |||
| encoder=encoder, | |||
| use_greedy_infer=False, | |||
| app_index=char_labels_vocab['APP'], | |||
| pre_chars_embed=pre_chars_embed, | |||
| pre_bigrams_embed=pre_bigrams_embed, | |||
| pre_trigrams_embed=pre_trigrams_embed) | |||
| metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) | |||
| metric2 = CWSMetric(char_labels_vocab['APP']) | |||
| metrics = [metric1, metric2] | |||
| optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, | |||
| weight_decay=weight_decay, betas=[0.9, 0.9]) | |||
| sampler = BucketSampler(seq_len_field_name='seq_lens') | |||
| callbacks = [] | |||
| # scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) | |||
| scheduler = StepLR(optimizer, step_size=18, gamma=0.75) | |||
| # optim_callback = OptimizerCallback(optimizer, scheduler, update_every) | |||
| # callbacks.append(optim_callback) | |||
| scheduler_callback = LRScheduler(scheduler) | |||
| callbacks.append(scheduler_callback) | |||
| callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||
| tester = Tester(data=data.datasets['test'], model=model, metrics=metrics, | |||
| batch_size=64, device=device, verbose=0) | |||
| dev_callback = DevCallback(tester) | |||
| callbacks.append(dev_callback) | |||
| trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, | |||
| validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, | |||
| check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, | |||
| device=device, callbacks=callbacks, update_every=update_every) | |||
| trainer.train() | |||