You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. from fastNLP.io.base_loader import DataSetLoader, DataBundle
  2. from fastNLP.io.data_loader import ConllLoader
  3. import numpy as np
  4. from itertools import chain
  5. from fastNLP import DataSet, Vocabulary
  6. from functools import partial
  7. import os
  8. from typing import Union, Dict
  9. from reproduction.utils import check_dataloader_paths
  10. class CTBxJointLoader(DataSetLoader):
  11. """
  12. 文件夹下应该具有以下的文件结构
  13. -train.conllx
  14. -dev.conllx
  15. -test.conllx
  16. 每个文件中的内容如下(空格隔开不同的句子, 共有)
  17. 1 费孝通 _ NR NR _ 3 nsubjpass _ _
  18. 2 被 _ SB SB _ 3 pass _ _
  19. 3 授予 _ VV VV _ 0 root _ _
  20. 4 麦格赛赛 _ NR NR _ 5 nn _ _
  21. 5 奖 _ NN NN _ 3 dobj _ _
  22. 1 新华社 _ NR NR _ 7 dep _ _
  23. 2 马尼拉 _ NR NR _ 7 dep _ _
  24. 3 8月 _ NT NT _ 7 dep _ _
  25. 4 31日 _ NT NT _ 7 dep _ _
  26. ...
  27. """
  28. def __init__(self):
  29. self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7])
  30. def load(self, path:str):
  31. """
  32. 给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容
  33. words: list[str]
  34. pos_tags: list[str]
  35. heads: list[int]
  36. labels: list[str]
  37. :param path:
  38. :return:
  39. """
  40. dataset = self._loader.load(path)
  41. dataset.heads.int()
  42. return dataset
  43. def process(self, paths):
  44. """
  45. :param paths:
  46. :return:
  47. Dataset包含以下的field
  48. chars:
  49. bigrams:
  50. trigrams:
  51. pre_chars:
  52. pre_bigrams:
  53. pre_trigrams:
  54. seg_targets:
  55. seg_masks:
  56. seq_lens:
  57. char_labels:
  58. char_heads:
  59. gold_word_pairs:
  60. seg_targets:
  61. seg_masks:
  62. char_labels:
  63. char_heads:
  64. pun_masks:
  65. gold_label_word_pairs:
  66. """
  67. paths = check_dataloader_paths(paths)
  68. data = DataBundle()
  69. for name, path in paths.items():
  70. dataset = self.load(path)
  71. data.datasets[name] = dataset
  72. char_labels_vocab = Vocabulary(padding=None, unknown=None)
  73. def process(dataset, char_label_vocab):
  74. dataset.apply(add_word_lst, new_field_name='word_lst')
  75. dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars')
  76. dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams')
  77. dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams')
  78. dataset.apply(add_char_heads, new_field_name='char_heads')
  79. dataset.apply(add_char_labels, new_field_name='char_labels')
  80. dataset.apply(add_segs, new_field_name='seg_targets')
  81. dataset.apply(add_mask, new_field_name='seg_masks')
  82. dataset.add_seq_len('chars', new_field_name='seq_lens')
  83. dataset.apply(add_pun_masks, new_field_name='pun_masks')
  84. if len(char_label_vocab.word_count)==0:
  85. char_label_vocab.from_dataset(dataset, field_name='char_labels')
  86. char_label_vocab.index_dataset(dataset, field_name='char_labels')
  87. new_dataset = add_root(dataset)
  88. new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True)
  89. global add_label_word_pairs
  90. add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab)
  91. new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True)
  92. new_dataset.set_pad_val('char_labels', -1)
  93. new_dataset.set_pad_val('char_heads', -1)
  94. return new_dataset
  95. for name in list(paths.keys()):
  96. dataset = data.datasets[name]
  97. dataset = process(dataset, char_labels_vocab)
  98. data.datasets[name] = dataset
  99. data.vocabs['char_labels'] = char_labels_vocab
  100. char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars')
  101. bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams')
  102. trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams')
  103. for name in ['chars', 'bigrams', 'trigrams']:
  104. vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values()))
  105. vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name)
  106. data.vocabs['pre_{}'.format(name)] = vocab
  107. for name, vocab in zip(['chars', 'bigrams', 'trigrams'],
  108. [char_vocab, bigram_vocab, trigram_vocab]):
  109. vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name)
  110. data.vocabs[name] = vocab
  111. for name, dataset in data.datasets.items():
  112. dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars',
  113. 'pre_bigrams', 'pre_trigrams')
  114. dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels',
  115. 'char_heads',
  116. 'pun_masks', 'gold_label_word_pairs')
  117. return data
  118. def add_label_word_pairs(instance, label_vocab):
  119. # List[List[((head_start, head_end], (dep_start, dep_end]), ...]]
  120. word_end_indexes = np.array(list(map(len, instance['word_lst'])))
  121. word_end_indexes = np.cumsum(word_end_indexes).tolist()
  122. word_end_indexes.insert(0, 0)
  123. word_pairs = []
  124. labels = instance['labels']
  125. pos_tags = instance['pos_tags']
  126. for idx, head in enumerate(instance['heads']):
  127. if pos_tags[idx]=='PU': # 如果是标点符号,就不记录
  128. continue
  129. label = label_vocab.to_index(labels[idx])
  130. if head==0:
  131. word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1]))))
  132. else:
  133. word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label,
  134. (word_end_indexes[idx], word_end_indexes[idx + 1])))
  135. return word_pairs
  136. def add_word_pairs(instance):
  137. # List[List[((head_start, head_end], (dep_start, dep_end]), ...]]
  138. word_end_indexes = np.array(list(map(len, instance['word_lst'])))
  139. word_end_indexes = np.cumsum(word_end_indexes).tolist()
  140. word_end_indexes.insert(0, 0)
  141. word_pairs = []
  142. pos_tags = instance['pos_tags']
  143. for idx, head in enumerate(instance['heads']):
  144. if pos_tags[idx]=='PU': # 如果是标点符号,就不记录
  145. continue
  146. if head==0:
  147. word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1]))))
  148. else:
  149. word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]),
  150. (word_end_indexes[idx], word_end_indexes[idx + 1])))
  151. return word_pairs
  152. def add_root(dataset):
  153. new_dataset = DataSet()
  154. for sample in dataset:
  155. chars = ['char_root'] + sample['chars']
  156. bigrams = ['bigram_root'] + sample['bigrams']
  157. trigrams = ['trigram_root'] + sample['trigrams']
  158. seq_lens = sample['seq_lens']+1
  159. char_labels = [0] + sample['char_labels']
  160. char_heads = [0] + sample['char_heads']
  161. sample['chars'] = chars
  162. sample['bigrams'] = bigrams
  163. sample['trigrams'] = trigrams
  164. sample['seq_lens'] = seq_lens
  165. sample['char_labels'] = char_labels
  166. sample['char_heads'] = char_heads
  167. new_dataset.append(sample)
  168. return new_dataset
  169. def add_pun_masks(instance):
  170. tags = instance['pos_tags']
  171. pun_masks = []
  172. for word, tag in zip(instance['words'], tags):
  173. if tag=='PU':
  174. pun_masks.extend([1]*len(word))
  175. else:
  176. pun_masks.extend([0]*len(word))
  177. return pun_masks
  178. def add_word_lst(instance):
  179. words = instance['words']
  180. word_lst = [list(word) for word in words]
  181. return word_lst
  182. def add_bigram(instance):
  183. chars = instance['chars']
  184. length = len(chars)
  185. chars = chars + ['<eos>']
  186. bigrams = []
  187. for i in range(length):
  188. bigrams.append(''.join(chars[i:i + 2]))
  189. return bigrams
  190. def add_trigram(instance):
  191. chars = instance['chars']
  192. length = len(chars)
  193. chars = chars + ['<eos>'] * 2
  194. trigrams = []
  195. for i in range(length):
  196. trigrams.append(''.join(chars[i:i + 3]))
  197. return trigrams
  198. def add_char_heads(instance):
  199. words = instance['word_lst']
  200. heads = instance['heads']
  201. char_heads = []
  202. char_index = 1 # 因此存在root节点所以需要从1开始
  203. head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1
  204. for word, head in zip(words, heads):
  205. char_head = []
  206. if len(word)>1:
  207. char_head.append(char_index+1)
  208. char_index += 1
  209. for _ in range(len(word)-2):
  210. char_index += 1
  211. char_head.append(char_index)
  212. char_index += 1
  213. char_head.append(head_end_indexes[head-1])
  214. char_heads.extend(char_head)
  215. return char_heads
  216. def add_char_labels(instance):
  217. """
  218. 将word_lst中的数据按照下面的方式设置label
  219. 比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)"..
  220. 对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label
  221. :param instance:
  222. :return:
  223. """
  224. words = instance['word_lst']
  225. labels = instance['labels']
  226. char_labels = []
  227. for word, label in zip(words, labels):
  228. for _ in range(len(word)-1):
  229. char_labels.append('APP')
  230. char_labels.append(label)
  231. return char_labels
  232. # add seg_targets
  233. def add_segs(instance):
  234. words = instance['word_lst']
  235. segs = [0]*len(instance['chars'])
  236. index = 0
  237. for word in words:
  238. index = index + len(word) - 1
  239. segs[index] = len(word)-1
  240. index = index + 1
  241. return segs
  242. # add target_masks
  243. def add_mask(instance):
  244. words = instance['word_lst']
  245. mask = []
  246. for word in words:
  247. mask.extend([0] * (len(word) - 1))
  248. mask.append(1)
  249. return mask