from fastNLP.io.base_loader import DataSetLoader, DataBundle from fastNLP.io.data_loader import ConllLoader import numpy as np from itertools import chain from fastNLP import DataSet, Vocabulary from functools import partial import os from typing import Union, Dict from reproduction.utils import check_dataloader_paths class CTBxJointLoader(DataSetLoader): """ 文件夹下应该具有以下的文件结构 -train.conllx -dev.conllx -test.conllx 每个文件中的内容如下(空格隔开不同的句子, 共有) 1 费孝通 _ NR NR _ 3 nsubjpass _ _ 2 被 _ SB SB _ 3 pass _ _ 3 授予 _ VV VV _ 0 root _ _ 4 麦格赛赛 _ NR NR _ 5 nn _ _ 5 奖 _ NN NN _ 3 dobj _ _ 1 新华社 _ NR NR _ 7 dep _ _ 2 马尼拉 _ NR NR _ 7 dep _ _ 3 8月 _ NT NT _ 7 dep _ _ 4 31日 _ NT NT _ 7 dep _ _ ... """ def __init__(self): self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) def load(self, path:str): """ 给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 words: list[str] pos_tags: list[str] heads: list[int] labels: list[str] :param path: :return: """ dataset = self._loader.load(path) dataset.heads.int() return dataset def process(self, paths): """ :param paths: :return: Dataset包含以下的field chars: bigrams: trigrams: pre_chars: pre_bigrams: pre_trigrams: seg_targets: seg_masks: seq_lens: char_labels: char_heads: gold_word_pairs: seg_targets: seg_masks: char_labels: char_heads: pun_masks: gold_label_word_pairs: """ paths = check_dataloader_paths(paths) data = DataBundle() for name, path in paths.items(): dataset = self.load(path) data.datasets[name] = dataset char_labels_vocab = Vocabulary(padding=None, unknown=None) def process(dataset, char_label_vocab): dataset.apply(add_word_lst, new_field_name='word_lst') dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') dataset.apply(add_char_heads, new_field_name='char_heads') dataset.apply(add_char_labels, new_field_name='char_labels') dataset.apply(add_segs, new_field_name='seg_targets') dataset.apply(add_mask, new_field_name='seg_masks') dataset.add_seq_len('chars', new_field_name='seq_lens') dataset.apply(add_pun_masks, new_field_name='pun_masks') if len(char_label_vocab.word_count)==0: char_label_vocab.from_dataset(dataset, field_name='char_labels') char_label_vocab.index_dataset(dataset, field_name='char_labels') new_dataset = add_root(dataset) new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) global add_label_word_pairs add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) new_dataset.set_pad_val('char_labels', -1) new_dataset.set_pad_val('char_heads', -1) return new_dataset for name in list(paths.keys()): dataset = data.datasets[name] dataset = process(dataset, char_labels_vocab) data.datasets[name] = dataset data.vocabs['char_labels'] = char_labels_vocab char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') for name in ['chars', 'bigrams', 'trigrams']: vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) data.vocabs['pre_{}'.format(name)] = vocab for name, vocab in zip(['chars', 'bigrams', 'trigrams'], [char_vocab, bigram_vocab, trigram_vocab]): vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) data.vocabs[name] = vocab for name, dataset in data.datasets.items(): dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', 'pre_bigrams', 'pre_trigrams') dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', 'char_heads', 'pun_masks', 'gold_label_word_pairs') return data def add_label_word_pairs(instance, label_vocab): # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] word_end_indexes = np.array(list(map(len, instance['word_lst']))) word_end_indexes = np.cumsum(word_end_indexes).tolist() word_end_indexes.insert(0, 0) word_pairs = [] labels = instance['labels'] pos_tags = instance['pos_tags'] for idx, head in enumerate(instance['heads']): if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 continue label = label_vocab.to_index(labels[idx]) if head==0: word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) else: word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, (word_end_indexes[idx], word_end_indexes[idx + 1]))) return word_pairs def add_word_pairs(instance): # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] word_end_indexes = np.array(list(map(len, instance['word_lst']))) word_end_indexes = np.cumsum(word_end_indexes).tolist() word_end_indexes.insert(0, 0) word_pairs = [] pos_tags = instance['pos_tags'] for idx, head in enumerate(instance['heads']): if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 continue if head==0: word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) else: word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), (word_end_indexes[idx], word_end_indexes[idx + 1]))) return word_pairs def add_root(dataset): new_dataset = DataSet() for sample in dataset: chars = ['char_root'] + sample['chars'] bigrams = ['bigram_root'] + sample['bigrams'] trigrams = ['trigram_root'] + sample['trigrams'] seq_lens = sample['seq_lens']+1 char_labels = [0] + sample['char_labels'] char_heads = [0] + sample['char_heads'] sample['chars'] = chars sample['bigrams'] = bigrams sample['trigrams'] = trigrams sample['seq_lens'] = seq_lens sample['char_labels'] = char_labels sample['char_heads'] = char_heads new_dataset.append(sample) return new_dataset def add_pun_masks(instance): tags = instance['pos_tags'] pun_masks = [] for word, tag in zip(instance['words'], tags): if tag=='PU': pun_masks.extend([1]*len(word)) else: pun_masks.extend([0]*len(word)) return pun_masks def add_word_lst(instance): words = instance['words'] word_lst = [list(word) for word in words] return word_lst def add_bigram(instance): chars = instance['chars'] length = len(chars) chars = chars + [''] bigrams = [] for i in range(length): bigrams.append(''.join(chars[i:i + 2])) return bigrams def add_trigram(instance): chars = instance['chars'] length = len(chars) chars = chars + [''] * 2 trigrams = [] for i in range(length): trigrams.append(''.join(chars[i:i + 3])) return trigrams def add_char_heads(instance): words = instance['word_lst'] heads = instance['heads'] char_heads = [] char_index = 1 # 因此存在root节点所以需要从1开始 head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 for word, head in zip(words, heads): char_head = [] if len(word)>1: char_head.append(char_index+1) char_index += 1 for _ in range(len(word)-2): char_index += 1 char_head.append(char_index) char_index += 1 char_head.append(head_end_indexes[head-1]) char_heads.extend(char_head) return char_heads def add_char_labels(instance): """ 将word_lst中的数据按照下面的方式设置label 比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. 对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label :param instance: :return: """ words = instance['word_lst'] labels = instance['labels'] char_labels = [] for word, label in zip(words, labels): for _ in range(len(word)-1): char_labels.append('APP') char_labels.append(label) return char_labels # add seg_targets def add_segs(instance): words = instance['word_lst'] segs = [0]*len(instance['chars']) index = 0 for word in words: index = index + len(word) - 1 segs[index] = len(word)-1 index = index + 1 return segs # add target_masks def add_mask(instance): words = instance['word_lst'] mask = [] for word in words: mask.extend([0] * (len(word) - 1)) mask.append(1) return mask