You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it"""
  18. import os
  19. import re
  20. import glob
  21. import copy
  22. import random
  23. import json
  24. import collections
  25. from itertools import combinations
  26. import numpy as np
  27. from random import shuffle
  28. import torch.utils.data
  29. import time
  30. import pickle
  31. from nltk.tokenize import sent_tokenize
  32. import utils
  33. from logger import *
  34. # <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
  35. SENTENCE_START = '<s>'
  36. SENTENCE_END = '</s>'
  37. PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
  38. UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
  39. START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
  40. STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
  41. # Note: none of <s>, </s>, [PAD], [UNK], [START], [STOP] should appear in the vocab file.
  42. class Vocab(object):
  43. """Vocabulary class for mapping between words and ids (integers)"""
  44. def __init__(self, vocab_file, max_size):
  45. """
  46. Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file.
  47. :param vocab_file: string; path to the vocab file, which is assumed to contain "<word> <frequency>" on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though.
  48. :param max_size: int; The maximum size of the resulting Vocabulary.
  49. """
  50. self._word_to_id = {}
  51. self._id_to_word = {}
  52. self._count = 0 # keeps track of total number of words in the Vocab
  53. # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
  54. for w in [PAD_TOKEN, UNKNOWN_TOKEN, START_DECODING, STOP_DECODING]:
  55. self._word_to_id[w] = self._count
  56. self._id_to_word[self._count] = w
  57. self._count += 1
  58. # Read the vocab file and add words up to max_size
  59. with open(vocab_file, 'r', encoding='utf8') as vocab_f: #New : add the utf8 encoding to prevent error
  60. cnt = 0
  61. for line in vocab_f:
  62. cnt += 1
  63. pieces = line.split("\t")
  64. # pieces = line.split()
  65. w = pieces[0]
  66. # print(w)
  67. if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
  68. raise Exception('<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
  69. if w in self._word_to_id:
  70. logger.error('Duplicated word in vocabulary file Line %d : %s' % (cnt, w))
  71. continue
  72. self._word_to_id[w] = self._count
  73. self._id_to_word[self._count] = w
  74. self._count += 1
  75. if max_size != 0 and self._count >= max_size:
  76. logger.info("[INFO] max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
  77. break
  78. logger.info("[INFO] Finished constructing vocabulary of %i total words. Last word added: %s", self._count, self._id_to_word[self._count-1])
  79. def word2id(self, word):
  80. """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
  81. if word not in self._word_to_id:
  82. return self._word_to_id[UNKNOWN_TOKEN]
  83. return self._word_to_id[word]
  84. def id2word(self, word_id):
  85. """Returns the word (string) corresponding to an id (integer)."""
  86. if word_id not in self._id_to_word:
  87. raise ValueError('Id not found in vocab: %d' % word_id)
  88. return self._id_to_word[word_id]
  89. def size(self):
  90. """Returns the total size of the vocabulary"""
  91. return self._count
  92. def word_list(self):
  93. """Return the word list of the vocabulary"""
  94. return self._word_to_id.keys()
  95. class Word_Embedding(object):
  96. def __init__(self, path, vocab):
  97. """
  98. :param path: string; the path of word embedding
  99. :param vocab: object;
  100. """
  101. logger.info("[INFO] Loading external word embedding...")
  102. self._path = path
  103. self._vocablist = vocab.word_list()
  104. self._vocab = vocab
  105. def load_my_vecs(self, k=200):
  106. """Load word embedding"""
  107. word_vecs = {}
  108. with open(self._path, encoding="utf-8") as f:
  109. count = 0
  110. lines = f.readlines()[1:]
  111. for line in lines:
  112. values = line.split(" ")
  113. word = values[0]
  114. count += 1
  115. if word in self._vocablist: # whether to judge if in vocab
  116. vector = []
  117. for count, val in enumerate(values):
  118. if count == 0:
  119. continue
  120. if count <= k:
  121. vector.append(float(val))
  122. word_vecs[word] = vector
  123. return word_vecs
  124. def add_unknown_words_by_zero(self, word_vecs, k=200):
  125. """Solve unknown by zeros"""
  126. zero = [0.0] * k
  127. list_word2vec = []
  128. oov = 0
  129. iov = 0
  130. for i in range(self._vocab.size()):
  131. word = self._vocab.id2word(i)
  132. if word not in word_vecs:
  133. oov += 1
  134. word_vecs[word] = zero
  135. list_word2vec.append(word_vecs[word])
  136. else:
  137. iov += 1
  138. list_word2vec.append(word_vecs[word])
  139. logger.info("[INFO] oov count %d, iov count %d", oov, iov)
  140. return list_word2vec
  141. def add_unknown_words_by_avg(self, word_vecs, k=200):
  142. """Solve unknown by avg word embedding"""
  143. # solve unknown words inplaced by zero list
  144. word_vecs_numpy = []
  145. for word in self._vocablist:
  146. if word in word_vecs:
  147. word_vecs_numpy.append(word_vecs[word])
  148. col = []
  149. for i in range(k):
  150. sum = 0.0
  151. for j in range(int(len(word_vecs_numpy))):
  152. sum += word_vecs_numpy[j][i]
  153. sum = round(sum, 6)
  154. col.append(sum)
  155. zero = []
  156. for m in range(k):
  157. avg = col[m] / int(len(word_vecs_numpy))
  158. avg = round(avg, 6)
  159. zero.append(float(avg))
  160. list_word2vec = []
  161. oov = 0
  162. iov = 0
  163. for i in range(self._vocab.size()):
  164. word = self._vocab.id2word(i)
  165. if word not in word_vecs:
  166. oov += 1
  167. word_vecs[word] = zero
  168. list_word2vec.append(word_vecs[word])
  169. else:
  170. iov += 1
  171. list_word2vec.append(word_vecs[word])
  172. logger.info("[INFO] External Word Embedding iov count: %d, oov count: %d", iov, oov)
  173. return list_word2vec
  174. def add_unknown_words_by_uniform(self, word_vecs, uniform=0.25, k=200):
  175. """Solve unknown word by uniform(-0.25,0.25)"""
  176. list_word2vec = []
  177. oov = 0
  178. iov = 0
  179. for i in range(self._vocab.size()):
  180. word = self._vocab.id2word(i)
  181. if word not in word_vecs:
  182. oov += 1
  183. word_vecs[word] = np.random.uniform(-1 * uniform, uniform, k).round(6).tolist()
  184. list_word2vec.append(word_vecs[word])
  185. else:
  186. iov += 1
  187. list_word2vec.append(word_vecs[word])
  188. logger.info("[INFO] oov count %d, iov count %d", oov, iov)
  189. return list_word2vec
  190. # load word embedding
  191. def load_my_vecs_freq1(self, freqs, pro):
  192. word_vecs = {}
  193. with open(self._path, encoding="utf-8") as f:
  194. freq = 0
  195. lines = f.readlines()[1:]
  196. for line in lines:
  197. values = line.split(" ")
  198. word = values[0]
  199. if word in self._vocablist: # whehter to judge if in vocab
  200. if freqs[word] == 1:
  201. a = np.random.uniform(0, 1, 1).round(2)
  202. if pro < a:
  203. continue
  204. vector = []
  205. for count, val in enumerate(values):
  206. if count == 0:
  207. continue
  208. vector.append(float(val))
  209. word_vecs[word] = vector
  210. return word_vecs
  211. class DomainDict(object):
  212. """Domain embedding for Newsroom"""
  213. def __init__(self, path):
  214. self.domain_list = self.readDomainlist(path)
  215. # self.domain_list = ["foxnews.com", "cnn.com", "mashable.com", "nytimes.com", "washingtonpost.com"]
  216. self.domain_number = len(self.domain_list)
  217. self._domain_to_id = {}
  218. self._id_to_domain = {}
  219. self._cnt = 0
  220. self._domain_to_id["X"] = self._cnt
  221. self._id_to_domain[self._cnt] = "X"
  222. self._cnt += 1
  223. for i in range(self.domain_number):
  224. domain = self.domain_list[i]
  225. self._domain_to_id[domain] = self._cnt
  226. self._id_to_domain[self._cnt] = domain
  227. self._cnt += 1
  228. def readDomainlist(self, path):
  229. domain_list = []
  230. with open(path) as f:
  231. for line in f:
  232. domain_list.append(line.split("\t")[0].strip())
  233. logger.info(domain_list)
  234. return domain_list
  235. def domain2id(self, domain):
  236. """ Returns the id (integer) of a domain (string). Returns "X" for unknow domain.
  237. :param domain: string
  238. :return: id; int
  239. """
  240. if domain in self.domain_list:
  241. return self._domain_to_id[domain]
  242. else:
  243. logger.info(domain)
  244. return self._domain_to_id["X"]
  245. def id2domain(self, domain_id):
  246. """ Returns the domain (string) corresponding to an id (integer).
  247. :param id: int;
  248. :return: domain: string
  249. """
  250. if domain_id not in self._id_to_domain:
  251. raise ValueError('Id not found in DomainDict: %d' % domain_id)
  252. return self._id_to_domain[id]
  253. def size(self):
  254. return self._cnt
  255. class Example(object):
  256. """Class representing a train/val/test example for text summarization."""
  257. def __init__(self, article_sents, abstract_sents, vocab, sent_max_len, label, domainid=None):
  258. """ Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.
  259. :param article_sents: list of strings; one per article sentence. each token is separated by a single space.
  260. :param abstract_sents: list of strings; one per abstract sentence. In each sentence, each token is separated by a single space.
  261. :param domainid: int; publication of the example
  262. :param vocab: Vocabulary object
  263. :param sent_max_len: int; the maximum length of each sentence, padding all sentences to this length
  264. :param label: list of int; the index of selected sentences
  265. """
  266. self.sent_max_len = sent_max_len
  267. self.enc_sent_len = []
  268. self.enc_sent_input = []
  269. self.enc_sent_input_pad = []
  270. # origin_cnt = len(article_sents)
  271. # article_sents = [re.sub(r"\n+\t+", " ", sent) for sent in article_sents]
  272. # assert origin_cnt == len(article_sents)
  273. # Process the article
  274. for sent in article_sents:
  275. article_words = sent.split()
  276. self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding
  277. self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token
  278. self._pad_encoder_input(vocab.word2id('[PAD]'))
  279. # Store the original strings
  280. self.original_article = " ".join(article_sents)
  281. self.original_article_sents = article_sents
  282. if isinstance(abstract_sents[0], list):
  283. logger.debug("[INFO] Multi Reference summaries!")
  284. self.original_abstract_sents = []
  285. self.original_abstract = []
  286. for summary in abstract_sents:
  287. self.original_abstract_sents.append([sent.strip() for sent in summary])
  288. self.original_abstract.append("\n".join([sent.replace("\n", "") for sent in summary]))
  289. else:
  290. self.original_abstract_sents = [sent.replace("\n", "") for sent in abstract_sents]
  291. self.original_abstract = "\n".join(self.original_abstract_sents)
  292. # Store the label
  293. self.label = np.zeros(len(article_sents), dtype=int)
  294. if label != []:
  295. self.label[np.array(label)] = 1
  296. self.label = list(self.label)
  297. # Store the publication
  298. if domainid != None:
  299. if domainid == 0:
  300. logger.debug("domain id = 0!")
  301. self.domain = domainid
  302. def _pad_encoder_input(self, pad_id):
  303. """
  304. :param pad_id: int; token pad id
  305. :return:
  306. """
  307. max_len = self.sent_max_len
  308. for i in range(len(self.enc_sent_input)):
  309. article_words = self.enc_sent_input[i]
  310. if len(article_words) > max_len:
  311. article_words = article_words[:max_len]
  312. while len(article_words) < max_len:
  313. article_words.append(pad_id)
  314. self.enc_sent_input_pad.append(article_words)
  315. class ExampleSet(torch.utils.data.Dataset):
  316. """ Constructor: Dataset of example(object) """
  317. def __init__(self, data_path, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False):
  318. """ Initializes the ExampleSet with the path of data
  319. :param data_path: string; the path of data
  320. :param vocab: object;
  321. :param doc_max_timesteps: int; the maximum sentence number of a document, each example should pad sentences to this length
  322. :param sent_max_len: int; the maximum token number of a sentence, each sentence should pad tokens to this length
  323. :param domaindict: object; the domain dict to embed domain
  324. """
  325. self.domaindict = domaindict
  326. if domaindict:
  327. logger.info("[INFO] Use domain information in the dateset!")
  328. if randomX==True:
  329. logger.info("[INFO] Random some example to unknow domain X!")
  330. self.randomP = 0.1
  331. logger.info("[INFO] Start reading ExampleSet")
  332. start = time.time()
  333. self.example_list = []
  334. self.doc_max_timesteps = doc_max_timesteps
  335. cnt = 0
  336. with open(data_path, 'r') as reader:
  337. for line in reader:
  338. try:
  339. e = json.loads(line)
  340. article_sent = e['text']
  341. tag = e["tag"][0] if usetag else e['publication']
  342. # logger.info(tag)
  343. if "duc" in data_path:
  344. abstract_sent = e["summaryList"] if "summaryList" in e.keys() else [e['summary']]
  345. else:
  346. abstract_sent = e['summary']
  347. if domaindict:
  348. if randomX == True:
  349. p = np.random.rand()
  350. if p <= self.randomP:
  351. domainid = domaindict.domain2id("X")
  352. else:
  353. domainid = domaindict.domain2id(tag)
  354. else:
  355. domainid = domaindict.domain2id(tag)
  356. else:
  357. domainid = None
  358. logger.debug((tag, domainid))
  359. except (ValueError,EOFError) as e :
  360. logger.debug(e)
  361. break
  362. else:
  363. example = Example(article_sent, abstract_sent, vocab, sent_max_len, e["label"], domainid) # Process into an Example.
  364. self.example_list.append(example)
  365. cnt += 1
  366. # print(cnt)
  367. logger.info("[INFO] Finish reading ExampleSet. Total time is %f, Total size is %d", time.time() - start, len(self.example_list))
  368. self.size = len(self.example_list)
  369. # self.example_list.sort(key=lambda ex: ex.domain)
  370. def get_example(self, index):
  371. return self.example_list[index]
  372. def __getitem__(self, index):
  373. """
  374. :param index: int; the index of the example
  375. :return
  376. input_pad: [N, seq_len]
  377. label: [N]
  378. input_mask: [N]
  379. domain: [1]
  380. """
  381. item = self.example_list[index]
  382. input = np.array(item.enc_sent_input_pad)
  383. label = np.array(item.label, dtype=int)
  384. # pad input to doc_max_timesteps
  385. if len(input) < self.doc_max_timesteps:
  386. pad_number = self.doc_max_timesteps - len(input)
  387. pad_matrix = np.zeros((pad_number, len(input[0])))
  388. input_pad = np.vstack((input, pad_matrix))
  389. label = np.append(label, np.zeros(pad_number, dtype=int))
  390. input_mask = np.append(np.ones(len(input)), np.zeros(pad_number))
  391. else:
  392. input_pad = input[:self.doc_max_timesteps]
  393. label = label[:self.doc_max_timesteps]
  394. input_mask = np.ones(self.doc_max_timesteps)
  395. if self.domaindict:
  396. return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long(), item.domain
  397. return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long()
  398. def __len__(self):
  399. return self.size
  400. class MultiExampleSet():
  401. def __init__(self, data_dir, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False):
  402. self.datasets = [None] * (domaindict.size() - 1)
  403. data_path_list = [os.path.join(data_dir, s) for s in os.listdir(data_dir) if s.endswith("label.jsonl")]
  404. for data_path in data_path_list:
  405. fname = data_path.split("/")[-1] # cnn.com.label.json
  406. dataname = ".".join(fname.split(".")[:-2])
  407. domainid = domaindict.domain2id(dataname)
  408. logger.info("[INFO] domain name: %s, domain id: %d" % (dataname, domainid))
  409. self.datasets[domainid - 1] = ExampleSet(data_path, vocab, doc_max_timesteps, sent_max_len, domaindict, randomX, usetag)
  410. def get(self, id):
  411. return self.datasets[id]
  412. from torch.utils.data.dataloader import default_collate
  413. def my_collate_fn(batch):
  414. '''
  415. :param batch: (input_pad, label, input_mask, domain)
  416. :return:
  417. '''
  418. start_domain = batch[0][-1]
  419. # for i in range(len(batch)):
  420. # print(batch[i][-1], end=',')
  421. batch = list(filter(lambda x: x[-1] == start_domain, batch))
  422. print("start_domain %d" % start_domain)
  423. print("batch_len %d" % len(batch))
  424. if len(batch) == 0: return torch.Tensor()
  425. return default_collate(batch) # 用默认方式拼接过滤后的batch数据