You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Conll2003Loader.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from fastNLP.core.vocabulary import VocabularyOption
  2. from fastNLP.io.base_loader import DataSetLoader, DataBundle
  3. from typing import Union, Dict
  4. from fastNLP import Vocabulary
  5. from fastNLP import Const
  6. from reproduction.utils import check_dataloader_paths
  7. from fastNLP.io import ConllLoader
  8. from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
  9. class Conll2003DataLoader(DataSetLoader):
  10. def __init__(self, task:str='ner', encoding_type:str='bioes'):
  11. """
  12. 加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos
  13. 时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回
  14. 的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但
  15. 鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行
  16. ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。
  17. :param task: 指定需要标注任务。可选ner, pos, chunk
  18. """
  19. assert task in ('ner', 'pos', 'chunk')
  20. index = {'ner':3, 'pos':1, 'chunk':2}[task]
  21. self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index])
  22. self._tag_converters = []
  23. if task in ('ner', 'chunk'):
  24. self._tag_converters = [iob2]
  25. if encoding_type == 'bioes':
  26. self._tag_converters.append(iob2bioes)
  27. def load(self, path: str):
  28. dataset = self._loader.load(path)
  29. def convert_tag_schema(tags):
  30. for converter in self._tag_converters:
  31. tags = converter(tags)
  32. return tags
  33. if self._tag_converters:
  34. dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
  35. return dataset
  36. def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=False):
  37. """
  38. 读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略
  39. :param paths:
  40. :param word_vocab_opt: vocabulary的初始化值
  41. :param lower: 是否将所有字母转为小写。
  42. :return:
  43. """
  44. # 读取数据
  45. paths = check_dataloader_paths(paths)
  46. data = DataBundle()
  47. input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
  48. target_fields = [Const.TARGET, Const.INPUT_LEN]
  49. for name, path in paths.items():
  50. dataset = self.load(path)
  51. dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
  52. if lower:
  53. dataset.words.lower()
  54. data.datasets[name] = dataset
  55. # 对construct vocab
  56. word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
  57. word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
  58. no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
  59. word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
  60. data.vocabs[Const.INPUT] = word_vocab
  61. # cap words
  62. cap_word_vocab = Vocabulary()
  63. cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words',
  64. no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
  65. cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
  66. input_fields.append('cap_words')
  67. data.vocabs['cap_words'] = cap_word_vocab
  68. # 对target建vocab
  69. target_vocab = Vocabulary(unknown=None, padding=None)
  70. target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
  71. target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
  72. data.vocabs[Const.TARGET] = target_vocab
  73. for name, dataset in data.datasets.items():
  74. dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
  75. dataset.set_input(*input_fields)
  76. dataset.set_target(*target_fields)
  77. return data
  78. if __name__ == '__main__':
  79. pass