You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CWSDataLoader.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
  2. from fastNLP.core.vocabulary import VocabularyOption
  3. from fastNLP.io.base_loader import DataSetLoader, DataBundle
  4. from typing import Union, Dict, List, Iterator
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP import Vocabulary
  8. from fastNLP import Const
  9. from reproduction.utils import check_dataloader_paths
  10. from functools import partial
  11. class SigHanLoader(DataSetLoader):
  12. """
  13. 任务相关的说明可以在这里找到http://sighan.cs.uchicago.edu/
  14. 支持的数据格式为,一行一句,不同的word用空格隔开。如下例
  15. 共同 创造 美好 的 新 世纪 —— 二○○一年 新年
  16. 女士 们 , 先生 们 , 同志 们 , 朋友 们 :
  17. 读取sighan中的数据集,返回的DataSet将包含以下的内容fields:
  18. raw_chars: list(str), 每个元素是一个汉字
  19. chars: list(str), 每个元素是一个index(汉字对应的index)
  20. target: list(int), 根据不同的encoding_type会有不同的变化
  21. :param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay"
  22. """
  23. def __init__(self, target_type:str):
  24. super().__init__()
  25. if target_type.lower() not in ('bmes', 'shift_relay'):
  26. raise ValueError("target_type only supports 'bmes', 'shift_relay'.")
  27. self.target_type = target_type
  28. if target_type=='bmes':
  29. self._word_len_to_target = self._word_len_to_bems
  30. elif target_type=='shift_relay':
  31. self._word_len_to_target = self._word_lens_to_relay
  32. @staticmethod
  33. def _word_lens_to_relay(word_lens: Iterator[int]):
  34. """
  35. [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长);
  36. :param word_lens:
  37. :return: {'target': , 'end_seg_mask':, 'start_seg_mask':}
  38. """
  39. tags = []
  40. end_seg_mask = []
  41. start_seg_mask = []
  42. for word_len in word_lens:
  43. tags.extend([idx for idx in range(word_len - 1, -1, -1)])
  44. end_seg_mask.extend([0] * (word_len - 1) + [1])
  45. start_seg_mask.extend([1] + [0] * (word_len - 1))
  46. return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask}
  47. @staticmethod
  48. def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]:
  49. """
  50. :param word_lens: 每个word的长度
  51. :return:
  52. """
  53. tags = []
  54. for word_len in word_lens:
  55. if word_len==1:
  56. tags.append('S')
  57. else:
  58. tags.append('B')
  59. for _ in range(word_len-2):
  60. tags.append('M')
  61. tags.append('E')
  62. return {'target':tags}
  63. @staticmethod
  64. def _gen_bigram(chars:List[str])->List[str]:
  65. """
  66. :param chars:
  67. :return:
  68. """
  69. return [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])]
  70. def load(self, path:str, bigram:bool=False)->DataSet:
  71. """
  72. :param path: str
  73. :param bigram: 是否使用bigram feature
  74. :return:
  75. """
  76. dataset = DataSet()
  77. with open(path, 'r', encoding='utf-8') as f:
  78. for line in f:
  79. line = line.strip()
  80. if not line: # 去掉空行
  81. continue
  82. parts = line.split()
  83. word_lens = map(len, parts)
  84. chars = list(''.join(parts))
  85. tags = self._word_len_to_target(word_lens)
  86. assert len(chars)==len(tags['target'])
  87. dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars)))
  88. if len(dataset)==0:
  89. raise RuntimeError(f"{path} has no valid data.")
  90. if bigram:
  91. dataset.apply_field(self._gen_bigram, field_name='raw_chars', new_field_name='bigrams')
  92. return dataset
  93. def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None,
  94. char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None,
  95. bigram_embed_opt:EmbeddingOption=None, L:int=4):
  96. """
  97. 支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如
  98. Option::
  99. 共同 创造 美好 的 新 世纪 —— 二○○一年 新年 贺词
  100. ( 二○○○年 十二月 三十一日 ) ( 附 图片 1 张 )
  101. 女士 们 , 先生 们 , 同志 们 , 朋友 们 :
  102. paths支持两种格式,第一种是str,第二种是Dict[str, str].
  103. Option::
  104. # 1. str类型
  105. # 1.1 传入具体的文件路径
  106. data = SigHanLoader('bmes').process('/path/to/cws/data.txt') # 将读取data.txt的内容
  107. # 包含以下的内容data.vocabs['chars']:Vocabulary对象,
  108. # data.vocabs['target']: Vocabulary对象,根据encoding_type可能会没有该值
  109. # data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项
  110. # data.datasets['train']: DataSet对象
  111. # 包含的field有:
  112. # raw_chars: list[str], 每个元素是一个汉字
  113. # chars: list[int], 每个元素是汉字对应的index
  114. # target: list[int], 根据encoding_type有对应的变化
  115. # 1.2 传入一个目录, 里面必须包含train.txt文件
  116. data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt
  117. # 包含以下的内容data.vocabs['chars']: Vocabulary对象
  118. # data.vocabs['target']:Vocabulary对象
  119. # data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象;
  120. # data.datasets['train']: DataSet对象
  121. # 包含的field有:
  122. # raw_chars: list[str], 每个元素是一个汉字
  123. # chars: list[int], 每个元素是汉字对应的index
  124. # target: list[int], 根据encoding_type有对应的变化
  125. # data.datasets['dev']: DataSet对象,如果文件夹下包含了dev.txt;内容与data.datasets['train']一样
  126. # 2. dict类型, key是文件的名称,value是对应的读取路径. 必须包含'train'这个key
  127. paths = {'train': '/path/to/train/train.txt', 'test':'/path/to/test/test.txt', 'dev':'/path/to/dev/dev.txt'}
  128. data = SigHanLoader(paths).process(paths)
  129. # 结果与传入目录时是一致的,但是可以传入多个数据集。data.datasets中的key将与这里传入的一致
  130. :param paths: 支持传入目录,文件路径,以及dict。
  131. :param char_vocab_opt: 用于构建chars的vocabulary参数,默认为min_freq=2
  132. :param char_embed_opt: 用于读取chars的Embedding的参数,默认不读取pretrained的embedding
  133. :param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。
  134. 为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos>
  135. :param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效
  136. :param L: 当target_type为shift_relay时传入的segment长度
  137. :return:
  138. """
  139. # 推荐大家使用这个check_data_loader_paths进行paths的验证
  140. paths = check_dataloader_paths(paths)
  141. datasets = {}
  142. data = DataBundle()
  143. bigram = bigram_vocab_opt is not None
  144. for name, path in paths.items():
  145. dataset = self.load(path, bigram=bigram)
  146. datasets[name] = dataset
  147. input_fields = []
  148. target_fields = []
  149. # 创建vocab
  150. char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt)
  151. char_vocab.from_dataset(datasets['train'], field_name='raw_chars')
  152. char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars')
  153. data.vocabs[Const.CHAR_INPUT] = char_vocab
  154. input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET])
  155. target_fields.append(Const.TARGET)
  156. # 创建target
  157. if self.target_type == 'bmes':
  158. target_vocab = Vocabulary(unknown=None, padding=None)
  159. target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S'])
  160. target_vocab.index_dataset(*datasets.values(), field_name='target')
  161. data.vocabs[Const.TARGET] = target_vocab
  162. if char_embed_opt is not None:
  163. char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab)
  164. data.embeddings['chars'] = char_embed
  165. if bigram:
  166. bigram_vocab = Vocabulary(**bigram_vocab_opt)
  167. bigram_vocab.from_dataset(datasets['train'], field_name='bigrams')
  168. bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams')
  169. data.vocabs['bigrams'] = bigram_vocab
  170. if bigram_embed_opt is not None:
  171. bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab)
  172. data.embeddings['bigrams'] = bigram_embed
  173. input_fields.append('bigrams')
  174. if self.target_type == 'shift_relay':
  175. func = partial(self._clip_target, L=L)
  176. for name, dataset in datasets.items():
  177. res = dataset.apply_field(func, field_name='target')
  178. relay_target = [res_i[0] for res_i in res]
  179. relay_mask = [res_i[1] for res_i in res]
  180. dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False)
  181. dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False)
  182. if self.target_type == 'shift_relay':
  183. input_fields.extend(['end_seg_mask'])
  184. target_fields.append('start_seg_mask')
  185. # 将dataset加入DataInfo
  186. for name, dataset in datasets.items():
  187. dataset.set_input(*input_fields)
  188. dataset.set_target(*target_fields)
  189. data.datasets[name] = dataset
  190. return data
  191. @staticmethod
  192. def _clip_target(target:List[int], L:int):
  193. """
  194. 只有在target_type为shift_relay的使用
  195. :param target: List[int]
  196. :param L:
  197. :return:
  198. """
  199. relay_target_i = []
  200. tmp = []
  201. for j in range(len(target) - 1):
  202. tmp.append(target[j])
  203. if target[j] > target[j + 1]:
  204. pass
  205. else:
  206. relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
  207. tmp = []
  208. # 处理未结束的部分
  209. if len(tmp) == 0:
  210. relay_target_i.append(0)
  211. else:
  212. tmp.append(target[-1])
  213. relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
  214. relay_mask_i = []
  215. j = 0
  216. while j < len(target):
  217. seg_len = target[j] + 1
  218. if target[j] < L:
  219. relay_mask_i.extend([0] * (seg_len))
  220. else:
  221. relay_mask_i.extend([1] * (seg_len - L) + [0] * L)
  222. j = seg_len + j
  223. return relay_target_i, relay_mask_i