#!/usr/bin/python
# -*- coding: utf-8 -*-
# __author__="Danqing Wang"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it"""
import os
import re
import glob
import copy
import random
import json
import collections
from itertools import combinations
import numpy as np
from random import shuffle
import torch.utils.data
import time
import pickle
from nltk.tokenize import sent_tokenize
import utils
from logger import *
# and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
SENTENCE_START = ''
SENTENCE_END = ''
PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
# Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file.
class Vocab(object):
"""Vocabulary class for mapping between words and ids (integers)"""
def __init__(self, vocab_file, max_size):
"""
Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file.
:param vocab_file: string; path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though.
:param max_size: int; The maximum size of the resulting Vocabulary.
"""
self._word_to_id = {}
self._id_to_word = {}
self._count = 0 # keeps track of total number of words in the Vocab
# [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
for w in [PAD_TOKEN, UNKNOWN_TOKEN, START_DECODING, STOP_DECODING]:
self._word_to_id[w] = self._count
self._id_to_word[self._count] = w
self._count += 1
# Read the vocab file and add words up to max_size
with open(vocab_file, 'r', encoding='utf8') as vocab_f: #New : add the utf8 encoding to prevent error
cnt = 0
for line in vocab_f:
cnt += 1
pieces = line.split("\t")
# pieces = line.split()
w = pieces[0]
# print(w)
if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
if w in self._word_to_id:
logger.error('Duplicated word in vocabulary file Line %d : %s' % (cnt, w))
continue
self._word_to_id[w] = self._count
self._id_to_word[self._count] = w
self._count += 1
if max_size != 0 and self._count >= max_size:
logger.info("[INFO] max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
break
logger.info("[INFO] Finished constructing vocabulary of %i total words. Last word added: %s", self._count, self._id_to_word[self._count-1])
def word2id(self, word):
"""Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
if word not in self._word_to_id:
return self._word_to_id[UNKNOWN_TOKEN]
return self._word_to_id[word]
def id2word(self, word_id):
"""Returns the word (string) corresponding to an id (integer)."""
if word_id not in self._id_to_word:
raise ValueError('Id not found in vocab: %d' % word_id)
return self._id_to_word[word_id]
def size(self):
"""Returns the total size of the vocabulary"""
return self._count
def word_list(self):
"""Return the word list of the vocabulary"""
return self._word_to_id.keys()
class Word_Embedding(object):
def __init__(self, path, vocab):
"""
:param path: string; the path of word embedding
:param vocab: object;
"""
logger.info("[INFO] Loading external word embedding...")
self._path = path
self._vocablist = vocab.word_list()
self._vocab = vocab
def load_my_vecs(self, k=200):
"""Load word embedding"""
word_vecs = {}
with open(self._path, encoding="utf-8") as f:
count = 0
lines = f.readlines()[1:]
for line in lines:
values = line.split(" ")
word = values[0]
count += 1
if word in self._vocablist: # whether to judge if in vocab
vector = []
for count, val in enumerate(values):
if count == 0:
continue
if count <= k:
vector.append(float(val))
word_vecs[word] = vector
return word_vecs
def add_unknown_words_by_zero(self, word_vecs, k=200):
"""Solve unknown by zeros"""
zero = [0.0] * k
list_word2vec = []
oov = 0
iov = 0
for i in range(self._vocab.size()):
word = self._vocab.id2word(i)
if word not in word_vecs:
oov += 1
word_vecs[word] = zero
list_word2vec.append(word_vecs[word])
else:
iov += 1
list_word2vec.append(word_vecs[word])
logger.info("[INFO] oov count %d, iov count %d", oov, iov)
return list_word2vec
def add_unknown_words_by_avg(self, word_vecs, k=200):
"""Solve unknown by avg word embedding"""
# solve unknown words inplaced by zero list
word_vecs_numpy = []
for word in self._vocablist:
if word in word_vecs:
word_vecs_numpy.append(word_vecs[word])
col = []
for i in range(k):
sum = 0.0
for j in range(int(len(word_vecs_numpy))):
sum += word_vecs_numpy[j][i]
sum = round(sum, 6)
col.append(sum)
zero = []
for m in range(k):
avg = col[m] / int(len(word_vecs_numpy))
avg = round(avg, 6)
zero.append(float(avg))
list_word2vec = []
oov = 0
iov = 0
for i in range(self._vocab.size()):
word = self._vocab.id2word(i)
if word not in word_vecs:
oov += 1
word_vecs[word] = zero
list_word2vec.append(word_vecs[word])
else:
iov += 1
list_word2vec.append(word_vecs[word])
logger.info("[INFO] External Word Embedding iov count: %d, oov count: %d", iov, oov)
return list_word2vec
def add_unknown_words_by_uniform(self, word_vecs, uniform=0.25, k=200):
"""Solve unknown word by uniform(-0.25,0.25)"""
list_word2vec = []
oov = 0
iov = 0
for i in range(self._vocab.size()):
word = self._vocab.id2word(i)
if word not in word_vecs:
oov += 1
word_vecs[word] = np.random.uniform(-1 * uniform, uniform, k).round(6).tolist()
list_word2vec.append(word_vecs[word])
else:
iov += 1
list_word2vec.append(word_vecs[word])
logger.info("[INFO] oov count %d, iov count %d", oov, iov)
return list_word2vec
# load word embedding
def load_my_vecs_freq1(self, freqs, pro):
word_vecs = {}
with open(self._path, encoding="utf-8") as f:
freq = 0
lines = f.readlines()[1:]
for line in lines:
values = line.split(" ")
word = values[0]
if word in self._vocablist: # whehter to judge if in vocab
if freqs[word] == 1:
a = np.random.uniform(0, 1, 1).round(2)
if pro < a:
continue
vector = []
for count, val in enumerate(values):
if count == 0:
continue
vector.append(float(val))
word_vecs[word] = vector
return word_vecs
class DomainDict(object):
"""Domain embedding for Newsroom"""
def __init__(self, path):
self.domain_list = self.readDomainlist(path)
# self.domain_list = ["foxnews.com", "cnn.com", "mashable.com", "nytimes.com", "washingtonpost.com"]
self.domain_number = len(self.domain_list)
self._domain_to_id = {}
self._id_to_domain = {}
self._cnt = 0
self._domain_to_id["X"] = self._cnt
self._id_to_domain[self._cnt] = "X"
self._cnt += 1
for i in range(self.domain_number):
domain = self.domain_list[i]
self._domain_to_id[domain] = self._cnt
self._id_to_domain[self._cnt] = domain
self._cnt += 1
def readDomainlist(self, path):
domain_list = []
with open(path) as f:
for line in f:
domain_list.append(line.split("\t")[0].strip())
logger.info(domain_list)
return domain_list
def domain2id(self, domain):
""" Returns the id (integer) of a domain (string). Returns "X" for unknow domain.
:param domain: string
:return: id; int
"""
if domain in self.domain_list:
return self._domain_to_id[domain]
else:
logger.info(domain)
return self._domain_to_id["X"]
def id2domain(self, domain_id):
""" Returns the domain (string) corresponding to an id (integer).
:param id: int;
:return: domain: string
"""
if domain_id not in self._id_to_domain:
raise ValueError('Id not found in DomainDict: %d' % domain_id)
return self._id_to_domain[id]
def size(self):
return self._cnt
class Example(object):
"""Class representing a train/val/test example for text summarization."""
def __init__(self, article_sents, abstract_sents, vocab, sent_max_len, label, domainid=None):
""" Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self.
:param article_sents: list of strings; one per article sentence. each token is separated by a single space.
:param abstract_sents: list of strings; one per abstract sentence. In each sentence, each token is separated by a single space.
:param domainid: int; publication of the example
:param vocab: Vocabulary object
:param sent_max_len: int; the maximum length of each sentence, padding all sentences to this length
:param label: list of int; the index of selected sentences
"""
self.sent_max_len = sent_max_len
self.enc_sent_len = []
self.enc_sent_input = []
self.enc_sent_input_pad = []
# origin_cnt = len(article_sents)
# article_sents = [re.sub(r"\n+\t+", " ", sent) for sent in article_sents]
# assert origin_cnt == len(article_sents)
# Process the article
for sent in article_sents:
article_words = sent.split()
self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding
self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token
self._pad_encoder_input(vocab.word2id('[PAD]'))
# Store the original strings
self.original_article = " ".join(article_sents)
self.original_article_sents = article_sents
if isinstance(abstract_sents[0], list):
logger.debug("[INFO] Multi Reference summaries!")
self.original_abstract_sents = []
self.original_abstract = []
for summary in abstract_sents:
self.original_abstract_sents.append([sent.strip() for sent in summary])
self.original_abstract.append("\n".join([sent.replace("\n", "") for sent in summary]))
else:
self.original_abstract_sents = [sent.replace("\n", "") for sent in abstract_sents]
self.original_abstract = "\n".join(self.original_abstract_sents)
# Store the label
self.label = np.zeros(len(article_sents), dtype=int)
if label != []:
self.label[np.array(label)] = 1
self.label = list(self.label)
# Store the publication
if domainid != None:
if domainid == 0:
logger.debug("domain id = 0!")
self.domain = domainid
def _pad_encoder_input(self, pad_id):
"""
:param pad_id: int; token pad id
:return:
"""
max_len = self.sent_max_len
for i in range(len(self.enc_sent_input)):
article_words = self.enc_sent_input[i]
if len(article_words) > max_len:
article_words = article_words[:max_len]
while len(article_words) < max_len:
article_words.append(pad_id)
self.enc_sent_input_pad.append(article_words)
class ExampleSet(torch.utils.data.Dataset):
""" Constructor: Dataset of example(object) """
def __init__(self, data_path, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False):
""" Initializes the ExampleSet with the path of data
:param data_path: string; the path of data
:param vocab: object;
:param doc_max_timesteps: int; the maximum sentence number of a document, each example should pad sentences to this length
:param sent_max_len: int; the maximum token number of a sentence, each sentence should pad tokens to this length
:param domaindict: object; the domain dict to embed domain
"""
self.domaindict = domaindict
if domaindict:
logger.info("[INFO] Use domain information in the dateset!")
if randomX==True:
logger.info("[INFO] Random some example to unknow domain X!")
self.randomP = 0.1
logger.info("[INFO] Start reading ExampleSet")
start = time.time()
self.example_list = []
self.doc_max_timesteps = doc_max_timesteps
cnt = 0
with open(data_path, 'r') as reader:
for line in reader:
try:
e = json.loads(line)
article_sent = e['text']
tag = e["tag"][0] if usetag else e['publication']
# logger.info(tag)
if "duc" in data_path:
abstract_sent = e["summaryList"] if "summaryList" in e.keys() else [e['summary']]
else:
abstract_sent = e['summary']
if domaindict:
if randomX == True:
p = np.random.rand()
if p <= self.randomP:
domainid = domaindict.domain2id("X")
else:
domainid = domaindict.domain2id(tag)
else:
domainid = domaindict.domain2id(tag)
else:
domainid = None
logger.debug((tag, domainid))
except (ValueError,EOFError) as e :
logger.debug(e)
break
else:
example = Example(article_sent, abstract_sent, vocab, sent_max_len, e["label"], domainid) # Process into an Example.
self.example_list.append(example)
cnt += 1
# print(cnt)
logger.info("[INFO] Finish reading ExampleSet. Total time is %f, Total size is %d", time.time() - start, len(self.example_list))
self.size = len(self.example_list)
# self.example_list.sort(key=lambda ex: ex.domain)
def get_example(self, index):
return self.example_list[index]
def __getitem__(self, index):
"""
:param index: int; the index of the example
:return
input_pad: [N, seq_len]
label: [N]
input_mask: [N]
domain: [1]
"""
item = self.example_list[index]
input = np.array(item.enc_sent_input_pad)
label = np.array(item.label, dtype=int)
# pad input to doc_max_timesteps
if len(input) < self.doc_max_timesteps:
pad_number = self.doc_max_timesteps - len(input)
pad_matrix = np.zeros((pad_number, len(input[0])))
input_pad = np.vstack((input, pad_matrix))
label = np.append(label, np.zeros(pad_number, dtype=int))
input_mask = np.append(np.ones(len(input)), np.zeros(pad_number))
else:
input_pad = input[:self.doc_max_timesteps]
label = label[:self.doc_max_timesteps]
input_mask = np.ones(self.doc_max_timesteps)
if self.domaindict:
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long(), item.domain
return torch.from_numpy(input_pad).long(), torch.from_numpy(label).long(), torch.from_numpy(input_mask).long()
def __len__(self):
return self.size
class MultiExampleSet():
def __init__(self, data_dir, vocab, doc_max_timesteps, sent_max_len, domaindict=None, randomX=False, usetag=False):
self.datasets = [None] * (domaindict.size() - 1)
data_path_list = [os.path.join(data_dir, s) for s in os.listdir(data_dir) if s.endswith("label.jsonl")]
for data_path in data_path_list:
fname = data_path.split("/")[-1] # cnn.com.label.json
dataname = ".".join(fname.split(".")[:-2])
domainid = domaindict.domain2id(dataname)
logger.info("[INFO] domain name: %s, domain id: %d" % (dataname, domainid))
self.datasets[domainid - 1] = ExampleSet(data_path, vocab, doc_max_timesteps, sent_max_len, domaindict, randomX, usetag)
def get(self, id):
return self.datasets[id]
from torch.utils.data.dataloader import default_collate
def my_collate_fn(batch):
'''
:param batch: (input_pad, label, input_mask, domain)
:return:
'''
start_domain = batch[0][-1]
# for i in range(len(batch)):
# print(batch[i][-1], end=',')
batch = list(filter(lambda x: x[-1] == start_domain, batch))
print("start_domain %d" % start_domain)
print("batch_len %d" % len(batch))
if len(batch) == 0: return torch.Tensor()
return default_collate(batch) # 用默认方式拼接过滤后的batch数据