You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from time import time
  2. from datetime import timedelta
  3. from fastNLP.io.dataset_loader import JsonLoader
  4. from fastNLP.modules.encoder._bert import BertTokenizer
  5. from fastNLP.io.base_loader import DataBundle
  6. from fastNLP.core.const import Const
  7. class BertData(JsonLoader):
  8. def __init__(self, max_nsents=60, max_ntokens=100, max_len=512):
  9. fields = {'article': 'article',
  10. 'label': 'label'}
  11. super(BertData, self).__init__(fields=fields)
  12. self.max_nsents = max_nsents
  13. self.max_ntokens = max_ntokens
  14. self.max_len = max_len
  15. self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
  16. self.cls_id = self.tokenizer.vocab['[CLS]']
  17. self.sep_id = self.tokenizer.vocab['[SEP]']
  18. self.pad_id = self.tokenizer.vocab['[PAD]']
  19. def _load(self, paths):
  20. dataset = super(BertData, self)._load(paths)
  21. return dataset
  22. def process(self, paths):
  23. def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens):
  24. article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']]
  25. return article[:max_nsents]
  26. def truncate_labels(instance):
  27. label = list(filter(lambda x: x < len(instance['article']), instance['label']))
  28. return label
  29. def bert_tokenize(instance, tokenizer, max_len, pad_value):
  30. article = instance['article']
  31. article = ' [SEP] [CLS] '.join(article)
  32. word_pieces = tokenizer.tokenize(article)[:(max_len - 2)]
  33. word_pieces = ['[CLS]'] + word_pieces + ['[SEP]']
  34. token_ids = tokenizer.convert_tokens_to_ids(word_pieces)
  35. while len(token_ids) < max_len:
  36. token_ids.append(pad_value)
  37. assert len(token_ids) == max_len
  38. return token_ids
  39. def get_seg_id(instance, max_len, sep_id):
  40. _segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id]
  41. segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
  42. segment_id = []
  43. for i, length in enumerate(segs):
  44. if i % 2 == 0:
  45. segment_id += length * [0]
  46. else:
  47. segment_id += length * [1]
  48. while len(segment_id) < max_len:
  49. segment_id.append(0)
  50. return segment_id
  51. def get_cls_id(instance, cls_id):
  52. classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id]
  53. return classification_id
  54. def get_labels(instance):
  55. labels = [0] * len(instance['cls_id'])
  56. label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label']))
  57. for idx in label_idx:
  58. labels[idx] = 1
  59. return labels
  60. datasets = {}
  61. for name in paths:
  62. datasets[name] = self._load(paths[name])
  63. # remove empty samples
  64. datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0)
  65. # truncate articles
  66. datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article')
  67. # truncate labels
  68. datasets[name].apply(truncate_labels, new_field_name='label')
  69. # tokenize and convert tokens to id
  70. datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article')
  71. # get segment id
  72. datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id')
  73. # get classification id
  74. datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id')
  75. # get label
  76. datasets[name].apply(get_labels, new_field_name='label')
  77. # rename filed
  78. datasets[name].rename_field('article', Const.INPUTS(0))
  79. datasets[name].rename_field('segment_id', Const.INPUTS(1))
  80. datasets[name].rename_field('cls_id', Const.INPUTS(2))
  81. datasets[name].rename_field('lbael', Const.TARGET)
  82. # set input and target
  83. datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2))
  84. datasets[name].set_target(Const.TARGET)
  85. # set paddding value
  86. datasets[name].set_pad_val('article', 0)
  87. return DataBundle(datasets=datasets)
  88. class BertSumLoader(JsonLoader):
  89. def __init__(self):
  90. fields = {'article': 'article',
  91. 'segment_id': 'segment_id',
  92. 'cls_id': 'cls_id',
  93. 'label': Const.TARGET
  94. }
  95. super(BertSumLoader, self).__init__(fields=fields)
  96. def _load(self, paths):
  97. dataset = super(BertSumLoader, self)._load(paths)
  98. return dataset
  99. def process(self, paths):
  100. def get_seq_len(instance):
  101. return len(instance['article'])
  102. print('Start loading datasets !!!')
  103. start = time()
  104. # load datasets
  105. datasets = {}
  106. for name in paths:
  107. datasets[name] = self._load(paths[name])
  108. datasets[name].apply(get_seq_len, new_field_name='seq_len')
  109. # set input and target
  110. datasets[name].set_input('article', 'segment_id', 'cls_id')
  111. datasets[name].set_target(Const.TARGET)
  112. # set padding value
  113. datasets[name].set_pad_val('article', 0)
  114. datasets[name].set_pad_val('segment_id', 0)
  115. datasets[name].set_pad_val('cls_id', -1)
  116. datasets[name].set_pad_val(Const.TARGET, 0)
  117. print('Finished in {}'.format(timedelta(seconds=time()-start)))
  118. return DataBundle(datasets=datasets)