You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OntoNoteLoader.py 5.9 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from fastNLP.core.vocabulary import VocabularyOption
  2. from fastNLP.io.base_loader import DataSetLoader, DataBundle
  3. from typing import Union, Dict
  4. from fastNLP import DataSet
  5. from fastNLP import Vocabulary
  6. from fastNLP import Const
  7. from reproduction.utils import check_dataloader_paths
  8. from fastNLP.io import ConllLoader
  9. from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
  10. class OntoNoteNERDataLoader(DataSetLoader):
  11. """
  12. 用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。
  13. """
  14. def __init__(self, encoding_type:str='bioes'):
  15. assert encoding_type in ('bioes', 'bio')
  16. self.encoding_type = encoding_type
  17. if encoding_type=='bioes':
  18. self.encoding_method = iob2bioes
  19. else:
  20. self.encoding_method = iob2
  21. def load(self, path:str)->DataSet:
  22. """
  23. 给定一个文件路径,读取数据。返回的DataSet包含以下的field
  24. raw_words: List[str]
  25. target: List[str]
  26. :param path:
  27. :return:
  28. """
  29. dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path)
  30. def convert_to_bio(tags):
  31. bio_tags = []
  32. flag = None
  33. for tag in tags:
  34. label = tag.strip("()*")
  35. if '(' in tag:
  36. bio_label = 'B-' + label
  37. flag = label
  38. elif flag:
  39. bio_label = 'I-' + flag
  40. else:
  41. bio_label = 'O'
  42. if ')' in tag:
  43. flag = None
  44. bio_tags.append(bio_label)
  45. return self.encoding_method(bio_tags)
  46. def convert_word(words):
  47. converted_words = []
  48. for word in words:
  49. word = word.replace('/.', '.') # 有些结尾的.是/.形式的
  50. if not word.startswith('-'):
  51. converted_words.append(word)
  52. continue
  53. # 以下是由于这些符号被转义了,再转回来
  54. tfrs = {'-LRB-':'(',
  55. '-RRB-': ')',
  56. '-LSB-': '[',
  57. '-RSB-': ']',
  58. '-LCB-': '{',
  59. '-RCB-': '}'
  60. }
  61. if word in tfrs:
  62. converted_words.append(tfrs[word])
  63. else:
  64. converted_words.append(word)
  65. return converted_words
  66. dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words')
  67. dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target')
  68. return dataset
  69. def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
  70. lower:bool=True)->DataBundle:
  71. """
  72. 读取并处理数据。返回的DataInfo包含以下的内容
  73. vocabs:
  74. word: Vocabulary
  75. target: Vocabulary
  76. datasets:
  77. train: DataSet
  78. words: List[int], 被设置为input
  79. target: int. label,被同时设置为input和target
  80. seq_len: int. 句子的长度,被同时设置为input和target
  81. raw_words: List[str]
  82. xxx(根据传入的paths可能有所变化)
  83. :param paths:
  84. :param word_vocab_opt: vocabulary的初始化值
  85. :param lower: 是否使用小写
  86. :return:
  87. """
  88. paths = check_dataloader_paths(paths)
  89. data = DataBundle()
  90. input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
  91. target_fields = [Const.TARGET, Const.INPUT_LEN]
  92. for name, path in paths.items():
  93. dataset = self.load(path)
  94. dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
  95. if lower:
  96. dataset.words.lower()
  97. data.datasets[name] = dataset
  98. # 对construct vocab
  99. word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
  100. word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
  101. no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
  102. word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
  103. data.vocabs[Const.INPUT] = word_vocab
  104. # cap words
  105. cap_word_vocab = Vocabulary()
  106. cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words')
  107. cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
  108. input_fields.append('cap_words')
  109. data.vocabs['cap_words'] = cap_word_vocab
  110. # 对target建vocab
  111. target_vocab = Vocabulary(unknown=None, padding=None)
  112. target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
  113. target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
  114. data.vocabs[Const.TARGET] = target_vocab
  115. for name, dataset in data.datasets.items():
  116. dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
  117. dataset.set_input(*input_fields)
  118. dataset.set_target(*target_fields)
  119. return data
  120. if __name__ == '__main__':
  121. loader = OntoNoteNERDataLoader()
  122. dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt')
  123. print(dataset.target.value_count())
  124. print(dataset[:4])
  125. """
  126. train 115812 2200752
  127. development 15680 304684
  128. test 12217 230111
  129. train 92403 1901772
  130. valid 13606 279180
  131. test 10258 204135
  132. """