You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import pickle
  2. import numpy as np
  3. from fastNLP.core.vocabulary import Vocabulary
  4. from fastNLP.io.base_loader import DataBundle
  5. from fastNLP.io.dataset_loader import JsonLoader
  6. from fastNLP.core.const import Const
  7. from tools.logger import *
  8. WORD_PAD = "[PAD]"
  9. WORD_UNK = "[UNK]"
  10. DOMAIN_UNK = "X"
  11. TAG_UNK = "X"
  12. class SummarizationLoader(JsonLoader):
  13. """
  14. 读取summarization数据集,读取的DataSet包含fields::
  15. text: list(str),document
  16. summary: list(str), summary
  17. text_wd: list(list(str)),tokenized document
  18. summary_wd: list(list(str)), tokenized summary
  19. labels: list(int),
  20. flatten_label: list(int), 0 or 1, flatten labels
  21. domain: str, optional
  22. tag: list(str), optional
  23. 数据来源: CNN_DailyMail Newsroom DUC
  24. """
  25. def __init__(self):
  26. super(SummarizationLoader, self).__init__()
  27. def _load(self, path):
  28. ds = super(SummarizationLoader, self)._load(path)
  29. def _lower_text(text_list):
  30. return [text.lower() for text in text_list]
  31. def _split_list(text_list):
  32. return [text.split() for text in text_list]
  33. def _convert_label(label, sent_len):
  34. np_label = np.zeros(sent_len, dtype=int)
  35. if label != []:
  36. np_label[np.array(label)] = 1
  37. return np_label.tolist()
  38. ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
  39. ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
  40. ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
  41. ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
  42. ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
  43. return ds
  44. def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab=True):
  45. """
  46. :param paths: dict path for each dataset
  47. :param vocab_size: int max_size for vocab
  48. :param vocab_path: str vocab path
  49. :param sent_max_len: int max token number of the sentence
  50. :param doc_max_timesteps: int max sentence number of the document
  51. :param domain: bool build vocab for publication, use 'X' for unknown
  52. :param tag: bool build vocab for tag, use 'X' for unknown
  53. :param load_vocab: bool build vocab (False) or load vocab (True)
  54. :return: DataBundle
  55. datasets: dict keys correspond to the paths dict
  56. vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
  57. embeddings: optional
  58. """
  59. def _pad_sent(text_wd):
  60. pad_text_wd = []
  61. for sent_wd in text_wd:
  62. if len(sent_wd) < sent_max_len:
  63. pad_num = sent_max_len - len(sent_wd)
  64. sent_wd.extend([WORD_PAD] * pad_num)
  65. else:
  66. sent_wd = sent_wd[:sent_max_len]
  67. pad_text_wd.append(sent_wd)
  68. return pad_text_wd
  69. def _token_mask(text_wd):
  70. token_mask_list = []
  71. for sent_wd in text_wd:
  72. token_num = len(sent_wd)
  73. if token_num < sent_max_len:
  74. mask = [1] * token_num + [0] * (sent_max_len - token_num)
  75. else:
  76. mask = [1] * sent_max_len
  77. token_mask_list.append(mask)
  78. return token_mask_list
  79. def _pad_label(label):
  80. text_len = len(label)
  81. if text_len < doc_max_timesteps:
  82. pad_label = label + [0] * (doc_max_timesteps - text_len)
  83. else:
  84. pad_label = label[:doc_max_timesteps]
  85. return pad_label
  86. def _pad_doc(text_wd):
  87. text_len = len(text_wd)
  88. if text_len < doc_max_timesteps:
  89. padding = [WORD_PAD] * sent_max_len
  90. pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
  91. else:
  92. pad_text = text_wd[:doc_max_timesteps]
  93. return pad_text
  94. def _sent_mask(text_wd):
  95. text_len = len(text_wd)
  96. if text_len < doc_max_timesteps:
  97. sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
  98. else:
  99. sent_mask = [1] * doc_max_timesteps
  100. return sent_mask
  101. datasets = {}
  102. train_ds = None
  103. for key, value in paths.items():
  104. ds = self.load(value)
  105. # pad sent
  106. ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
  107. ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
  108. # pad document
  109. ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
  110. ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
  111. ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
  112. # rename field
  113. ds.rename_field("pad_text", Const.INPUT)
  114. ds.rename_field("seq_len", Const.INPUT_LEN)
  115. ds.rename_field("pad_label", Const.TARGET)
  116. # set input and target
  117. ds.set_input(Const.INPUT, Const.INPUT_LEN)
  118. ds.set_target(Const.TARGET, Const.INPUT_LEN)
  119. datasets[key] = ds
  120. if "train" in key:
  121. train_ds = datasets[key]
  122. vocab_dict = {}
  123. if load_vocab == False:
  124. logger.info("[INFO] Build new vocab from training dataset!")
  125. if train_ds == None:
  126. raise ValueError("Lack train file to build vocabulary!")
  127. vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
  128. vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
  129. vocab_dict["vocab"] = vocabs
  130. else:
  131. logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
  132. word_list = []
  133. with open(vocab_path, 'r', encoding='utf8') as vocab_f:
  134. cnt = 2 # pad and unk
  135. for line in vocab_f:
  136. pieces = line.split("\t")
  137. word_list.append(pieces[0])
  138. cnt += 1
  139. if cnt > vocab_size:
  140. break
  141. vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
  142. vocabs.add_word_lst(word_list)
  143. vocabs.build_vocab()
  144. vocab_dict["vocab"] = vocabs
  145. if domain == True:
  146. domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
  147. domaindict.from_dataset(train_ds, field_name="publication")
  148. vocab_dict["domain"] = domaindict
  149. if tag == True:
  150. tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
  151. tagdict.from_dataset(train_ds, field_name="tag")
  152. vocab_dict["tag"] = tagdict
  153. for ds in datasets.values():
  154. vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
  155. return DataBundle(vocabs=vocab_dict, datasets=datasets)