You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ChineseNER.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from fastNLP.io.base_loader import DataSetLoader, DataBundle
  2. from fastNLP.io import ConllLoader
  3. from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
  4. from fastNLP import Const
  5. from reproduction.utils import check_dataloader_paths
  6. from fastNLP import Vocabulary
  7. class ChineseNERLoader(DataSetLoader):
  8. """
  9. 读取中文命名实体数据集,包括PeopleDaily, MSRA-NER, Weibo。数据在这里可以找到https://github.com/OYE93/Chinese-NLP-Corpus/tree/master/NER
  10. 请确保输入数据的格式如下, 共两列,第一列为字,第二列为标签,不同句子以空行隔开
  11. 我 O
  12. 们 O
  13. 变 O
  14. 而 O
  15. 以 O
  16. 书 O
  17. 会 O
  18. ...
  19. """
  20. def __init__(self, encoding_type:str='bioes'):
  21. """
  22. :param str encoding_type: 支持bio和bioes格式
  23. """
  24. super().__init__()
  25. self._loader = ConllLoader(headers=['raw_chars', 'target'], indexes=[0, 1])
  26. assert encoding_type in ('bio', 'bioes')
  27. self._tag_converters = [iob2]
  28. if encoding_type == 'bioes':
  29. self._tag_converters.append(iob2bioes)
  30. def load(self, path:str):
  31. dataset = self._loader.load(path)
  32. def convert_tag_schema(tags):
  33. for converter in self._tag_converters:
  34. tags = converter(tags)
  35. return tags
  36. if self._tag_converters:
  37. dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
  38. return dataset
  39. def process(self, paths, bigrams=False, trigrams=False):
  40. """
  41. :param paths:
  42. :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>]
  43. :param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>]
  44. :return: ~fastNLP.io.DataBundle
  45. 包含以下的fields
  46. raw_chars: List[str]
  47. chars: List[int]
  48. seq_len: int, 字的长度
  49. bigrams: List[int], optional
  50. trigrams: List[int], optional
  51. target: List[int]
  52. """
  53. paths = check_dataloader_paths(paths)
  54. data = DataBundle()
  55. input_fields = [Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]
  56. target_fields = [Const.TARGET, Const.INPUT_LEN]
  57. for name, path in paths.items():
  58. dataset = self.load(path)
  59. if bigrams:
  60. dataset.apply_field(lambda raw_chars: [c1+c2 for c1, c2 in zip(raw_chars, raw_chars[1:]+['<eos>'])],
  61. field_name='raw_chars', new_field_name='bigrams')
  62. if trigrams:
  63. dataset.apply_field(lambda raw_chars: [c1+c2+c3 for c1, c2, c3 in zip(raw_chars,
  64. raw_chars[1:]+['<eos>'],
  65. raw_chars[2:]+['<eos>']*2)],
  66. field_name='raw_chars', new_field_name='trigrams')
  67. data.datasets[name] = dataset
  68. char_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='raw_chars',
  69. no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
  70. char_vocab.index_dataset(*data.datasets.values(), field_name='raw_chars', new_field_name=Const.CHAR_INPUT)
  71. data.vocabs[Const.CHAR_INPUT] = char_vocab
  72. target_vocab = Vocabulary(unknown=None, padding=None).from_dataset(data.datasets['train'], field_name=Const.TARGET)
  73. target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
  74. data.vocabs[Const.TARGET] = target_vocab
  75. if bigrams:
  76. bigram_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='bigrams',
  77. no_create_entry_dataset=[dataset for name, dataset in
  78. data.datasets.items() if name != 'train'])
  79. bigram_vocab.index_dataset(*data.datasets.values(), field_name='bigrams', new_field_name='bigrams')
  80. data.vocabs['bigrams'] = bigram_vocab
  81. input_fields.append('bigrams')
  82. if trigrams:
  83. trigram_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='trigrams',
  84. no_create_entry_dataset=[dataset for name, dataset in
  85. data.datasets.items() if name != 'train'])
  86. trigram_vocab.index_dataset(*data.datasets.values(), field_name='trigrams', new_field_name='trigrams')
  87. data.vocabs['trigrams'] = trigram_vocab
  88. input_fields.append('trigrams')
  89. for name, dataset in data.datasets.items():
  90. dataset.add_seq_len(Const.CHAR_INPUT)
  91. dataset.set_input(*input_fields)
  92. dataset.set_target(*target_fields)
  93. return data