| @@ -62,7 +62,7 @@ class ExtCNNDMPipe(Pipe): | |||
| db.set_input(Const.INPUT, Const.INPUT_LEN) | |||
| db.set_target(Const.TARGET, Const.INPUT_LEN) | |||
| print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||
| # print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||
| word_list = [] | |||
| with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | |||
| cnt = 2 # pad and unk | |||
| @@ -1,188 +0,0 @@ | |||
| import pickle | |||
| import numpy as np | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| from fastNLP.io.dataset_loader import JsonLoader | |||
| from fastNLP.core.const import Const | |||
| from tools.logger import * | |||
| WORD_PAD = "[PAD]" | |||
| WORD_UNK = "[UNK]" | |||
| DOMAIN_UNK = "X" | |||
| TAG_UNK = "X" | |||
| class SummarizationLoader(JsonLoader): | |||
| """ | |||
| 读取summarization数据集,读取的DataSet包含fields:: | |||
| text: list(str),document | |||
| summary: list(str), summary | |||
| text_wd: list(list(str)),tokenized document | |||
| summary_wd: list(list(str)), tokenized summary | |||
| labels: list(int), | |||
| flatten_label: list(int), 0 or 1, flatten labels | |||
| domain: str, optional | |||
| tag: list(str), optional | |||
| 数据来源: CNN_DailyMail Newsroom DUC | |||
| """ | |||
| def __init__(self): | |||
| super(SummarizationLoader, self).__init__() | |||
| def _load(self, path): | |||
| ds = super(SummarizationLoader, self)._load(path) | |||
| def _lower_text(text_list): | |||
| return [text.lower() for text in text_list] | |||
| def _split_list(text_list): | |||
| return [text.split() for text in text_list] | |||
| def _convert_label(label, sent_len): | |||
| np_label = np.zeros(sent_len, dtype=int) | |||
| if label != []: | |||
| np_label[np.array(label)] = 1 | |||
| return np_label.tolist() | |||
| ds.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
| ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
| ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd') | |||
| ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd') | |||
| ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label") | |||
| return ds | |||
| def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||
| """ | |||
| :param paths: dict path for each dataset | |||
| :param vocab_size: int max_size for vocab | |||
| :param vocab_path: str vocab path | |||
| :param sent_max_len: int max token number of the sentence | |||
| :param doc_max_timesteps: int max sentence number of the document | |||
| :param domain: bool build vocab for publication, use 'X' for unknown | |||
| :param tag: bool build vocab for tag, use 'X' for unknown | |||
| :param load_vocab_file: bool build vocab (False) or load vocab (True) | |||
| :return: DataBundle | |||
| datasets: dict keys correspond to the paths dict | |||
| vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | |||
| embeddings: optional | |||
| """ | |||
| def _pad_sent(text_wd): | |||
| pad_text_wd = [] | |||
| for sent_wd in text_wd: | |||
| if len(sent_wd) < sent_max_len: | |||
| pad_num = sent_max_len - len(sent_wd) | |||
| sent_wd.extend([WORD_PAD] * pad_num) | |||
| else: | |||
| sent_wd = sent_wd[:sent_max_len] | |||
| pad_text_wd.append(sent_wd) | |||
| return pad_text_wd | |||
| def _token_mask(text_wd): | |||
| token_mask_list = [] | |||
| for sent_wd in text_wd: | |||
| token_num = len(sent_wd) | |||
| if token_num < sent_max_len: | |||
| mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
| else: | |||
| mask = [1] * sent_max_len | |||
| token_mask_list.append(mask) | |||
| return token_mask_list | |||
| def _pad_label(label): | |||
| text_len = len(label) | |||
| if text_len < doc_max_timesteps: | |||
| pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
| else: | |||
| pad_label = label[:doc_max_timesteps] | |||
| return pad_label | |||
| def _pad_doc(text_wd): | |||
| text_len = len(text_wd) | |||
| if text_len < doc_max_timesteps: | |||
| padding = [WORD_PAD] * sent_max_len | |||
| pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
| else: | |||
| pad_text = text_wd[:doc_max_timesteps] | |||
| return pad_text | |||
| def _sent_mask(text_wd): | |||
| text_len = len(text_wd) | |||
| if text_len < doc_max_timesteps: | |||
| sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
| else: | |||
| sent_mask = [1] * doc_max_timesteps | |||
| return sent_mask | |||
| datasets = {} | |||
| train_ds = None | |||
| for key, value in paths.items(): | |||
| ds = self.load(value) | |||
| # pad sent | |||
| ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd") | |||
| ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask") | |||
| # pad document | |||
| ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text") | |||
| ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len") | |||
| ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label") | |||
| # rename field | |||
| ds.rename_field("pad_text", Const.INPUT) | |||
| ds.rename_field("seq_len", Const.INPUT_LEN) | |||
| ds.rename_field("pad_label", Const.TARGET) | |||
| # set input and target | |||
| ds.set_input(Const.INPUT, Const.INPUT_LEN) | |||
| ds.set_target(Const.TARGET, Const.INPUT_LEN) | |||
| datasets[key] = ds | |||
| if "train" in key: | |||
| train_ds = datasets[key] | |||
| vocab_dict = {} | |||
| if load_vocab_file == False: | |||
| logger.info("[INFO] Build new vocab from training dataset!") | |||
| if train_ds == None: | |||
| raise ValueError("Lack train file to build vocabulary!") | |||
| vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
| vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"]) | |||
| vocab_dict["vocab"] = vocabs | |||
| else: | |||
| logger.info("[INFO] Load existing vocab from %s!" % vocab_path) | |||
| word_list = [] | |||
| with open(vocab_path, 'r', encoding='utf8') as vocab_f: | |||
| cnt = 2 # pad and unk | |||
| for line in vocab_f: | |||
| pieces = line.split("\t") | |||
| word_list.append(pieces[0]) | |||
| cnt += 1 | |||
| if cnt > vocab_size: | |||
| break | |||
| vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
| vocabs.add_word_lst(word_list) | |||
| vocabs.build_vocab() | |||
| vocab_dict["vocab"] = vocabs | |||
| if domain == True: | |||
| domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
| domaindict.from_dataset(train_ds, field_name="publication") | |||
| vocab_dict["domain"] = domaindict | |||
| if tag == True: | |||
| tagdict = Vocabulary(padding=None, unknown=TAG_UNK) | |||
| tagdict.from_dataset(train_ds, field_name="tag") | |||
| vocab_dict["tag"] = tagdict | |||
| for ds in datasets.values(): | |||
| vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
| return DataBundle(vocabs=vocab_dict, datasets=datasets) | |||
| @@ -94,6 +94,8 @@ class Encoder(nn.Module): | |||
| if self._hps.cuda: | |||
| input_pos = input_pos.cuda() | |||
| enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
| # print(enc_embed_input.size()) | |||
| # print(enc_pos_embed_input.size()) | |||
| enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
| enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
| enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
| @@ -17,11 +17,12 @@ class SummarizationModel(nn.Module): | |||
| """ | |||
| :param hps: hyperparameters for the model | |||
| :param vocab: vocab object | |||
| :param embed: word embedding | |||
| """ | |||
| super(SummarizationModel, self).__init__() | |||
| self._hps = hps | |||
| self.Train = (hps.mode == 'train') | |||
| # sentence encoder | |||
| self.encoder = Encoder(hps, embed) | |||
| @@ -45,18 +46,19 @@ class SummarizationModel(nn.Module): | |||
| self.wh = nn.Linear(self.d_v, 2) | |||
| def forward(self, input, input_len, Train): | |||
| def forward(self, words, seq_len): | |||
| """ | |||
| :param input: [batch_size, N, seq_len], word idx long tensor | |||
| :param input_len: [batch_size, N], 1 for sentence and 0 for padding | |||
| :param Train: True for train and False for eval and test | |||
| :param return_atten: True or False to return multi-head attention output self.output_slf_attn | |||
| :return: | |||
| p_sent: [batch_size, N, 2] | |||
| output_slf_attn: (option) [n_head, batch_size, N, N] | |||
| """ | |||
| input = words | |||
| input_len = seq_len | |||
| # -- Sentence Encoder | |||
| self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes] | |||
| @@ -67,7 +69,7 @@ class SummarizationModel(nn.Module): | |||
| self.inputs[0] = self.sent_embedding.permute(1, 0, 2) # [N, batch, Co * kernel_sizes] | |||
| self.input_masks[0] = input_len.permute(1, 0).unsqueeze(2) | |||
| self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train) # [batch, N, hidden_size] | |||
| self.lstm_output_state = self.deep_lstm(self.inputs, self.input_masks, Train=self.train) # [batch, N, hidden_size] | |||
| # -- Prepare masks | |||
| batch_size, N = input_len.size() | |||
| @@ -21,7 +21,7 @@ import torch | |||
| import torch.nn.functional as F | |||
| from fastNLP.core.losses import LossBase | |||
| from tools.logger import * | |||
| from fastNLP.core._logger import logger | |||
| class MyCrossEntropyLoss(LossBase): | |||
| def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | |||
| @@ -20,14 +20,60 @@ from __future__ import division | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from rouge import Rouge | |||
| from fastNLP.core.const import Const | |||
| from fastNLP.core.metrics import MetricBase | |||
| from tools.logger import * | |||
| # from tools.logger import * | |||
| from fastNLP.core._logger import logger | |||
| from tools.utils import pyrouge_score_all, pyrouge_score_all_multi | |||
| class LossMetric(MetricBase): | |||
| def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'): | |||
| super().__init__() | |||
| self._init_param_map(pred=pred, target=target, mask=mask) | |||
| self.padding_idx = padding_idx | |||
| self.reduce = reduce | |||
| self.loss = 0.0 | |||
| self.iteration = 0 | |||
| def evaluate(self, pred, target, mask): | |||
| """ | |||
| :param pred: [batch, N, 2] | |||
| :param target: [batch, N] | |||
| :param input_mask: [batch, N] | |||
| :return: | |||
| """ | |||
| batch, N, _ = pred.size() | |||
| pred = pred.view(-1, 2) | |||
| target = target.view(-1) | |||
| loss = F.cross_entropy(input=pred, target=target, | |||
| ignore_index=self.padding_idx, reduction=self.reduce) | |||
| loss = loss.view(batch, -1) | |||
| loss = loss.masked_fill(mask.eq(0), 0) | |||
| loss = loss.sum(1).mean() | |||
| self.loss += loss | |||
| self.iteration += 1 | |||
| def get_metric(self, reset=True): | |||
| epoch_avg_loss = self.loss / self.iteration | |||
| if reset: | |||
| self.loss = 0.0 | |||
| self.iteration = 0 | |||
| metric = {"loss": -epoch_avg_loss} | |||
| logger.info(metric) | |||
| return metric | |||
| class LabelFMetric(MetricBase): | |||
| def __init__(self, pred=None, target=None): | |||
| super().__init__() | |||
| @@ -51,7 +51,7 @@ class TransformerModel(nn.Module): | |||
| ffn_inner_hidden_size: FFN hiddens size | |||
| atten_dropout_prob: dropout size | |||
| doc_max_timesteps: max sentence number of the document | |||
| :param vocab: | |||
| :param embed: word embedding | |||
| """ | |||
| super(TransformerModel, self).__init__() | |||
| @@ -28,7 +28,7 @@ from fastNLP.core.const import Const | |||
| from fastNLP.io.model_io import ModelSaver | |||
| from fastNLP.core.callback import Callback, EarlyStopError | |||
| from tools.logger import * | |||
| from fastNLP.core._logger import logger | |||
| class TrainCallback(Callback): | |||
| def __init__(self, hps, patience=3, quit_all=True): | |||
| @@ -36,6 +36,9 @@ class TrainCallback(Callback): | |||
| self._hps = hps | |||
| self.patience = patience | |||
| self.wait = 0 | |||
| self.train_loss = 0.0 | |||
| self.prev_train_avg_loss = 1000.0 | |||
| self.train_dir = os.path.join(self._hps.save_root, "train") | |||
| if type(quit_all) != bool: | |||
| raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||
| @@ -43,20 +46,7 @@ class TrainCallback(Callback): | |||
| def on_epoch_begin(self): | |||
| self.epoch_start_time = time.time() | |||
| # def on_loss_begin(self, batch_y, predict_y): | |||
| # """ | |||
| # | |||
| # :param batch_y: dict | |||
| # input_len: [batch, N] | |||
| # :param predict_y: dict | |||
| # p_sent: [batch, N, 2] | |||
| # :return: | |||
| # """ | |||
| # input_len = batch_y[Const.INPUT_LEN] | |||
| # batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100) | |||
| # # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1) | |||
| # # logger.debug(predict_y["p_sent"][0:5,:,:]) | |||
| self.model.Train = True | |||
| def on_backward_begin(self, loss): | |||
| """ | |||
| @@ -72,19 +62,34 @@ class TrainCallback(Callback): | |||
| logger.info(name) | |||
| logger.info(param.grad.data.sum()) | |||
| raise Exception("train Loss is not finite. Stopping.") | |||
| self.train_loss += loss.data | |||
| def on_backward_end(self): | |||
| if self._hps.grad_clip: | |||
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm) | |||
| torch.cuda.empty_cache() | |||
| def on_epoch_end(self): | |||
| logger.info(' | end of epoch {:3d} | time: {:5.2f}s | ' | |||
| .format(self.epoch, (time.time() - self.epoch_start_time))) | |||
| epoch_avg_loss = self.train_loss / self.n_steps | |||
| logger.info(' | end of epoch {:3d} | time: {:5.2f}s | train loss: {:5.6f}' | |||
| .format(self.epoch, (time.time() - self.epoch_start_time), epoch_avg_loss)) | |||
| if self.prev_train_avg_loss < epoch_avg_loss: | |||
| save_file = os.path.join(self.train_dir, "earlystop.pkl") | |||
| self.save_model(save_file) | |||
| else: | |||
| self.prev_train_avg_loss = epoch_avg_loss | |||
| self.train_loss = 0.0 | |||
| # save epoch | |||
| save_file = os.path.join(self.train_dir, "epoch_%d.pkl" % self.epoch) | |||
| self.save_model(save_file) | |||
| def on_valid_begin(self): | |||
| self.valid_start_time = time.time() | |||
| self.model.Train = False | |||
| def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
| logger.info(' | end of valid {:3d} | time: {:5.2f}s | ' | |||
| @@ -95,9 +100,7 @@ class TrainCallback(Callback): | |||
| if self.wait == self.patience: | |||
| train_dir = os.path.join(self._hps.save_root, "train") | |||
| save_file = os.path.join(train_dir, "earlystop.pkl") | |||
| saver = ModelSaver(save_file) | |||
| saver.save_pytorch(self.model) | |||
| logger.info('[INFO] Saving early stop model to %s', save_file) | |||
| self.save_model(save_file) | |||
| raise EarlyStopError("Early stopping raised.") | |||
| else: | |||
| self.wait += 1 | |||
| @@ -111,14 +114,12 @@ class TrainCallback(Callback): | |||
| param_group['lr'] = new_lr | |||
| logger.info("[INFO] The learning rate now is %f", new_lr) | |||
| def on_exception(self, exception): | |||
| if isinstance(exception, KeyboardInterrupt): | |||
| logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||
| train_dir = os.path.join(self._hps.save_root, "train") | |||
| save_file = os.path.join(train_dir, "earlystop.pkl") | |||
| saver = ModelSaver(save_file) | |||
| saver.save_pytorch(self.model) | |||
| logger.info('[INFO] Saving early stop model to %s', save_file) | |||
| save_file = os.path.join(self.train_dir, "earlystop.pkl") | |||
| self.save_model(save_file) | |||
| if self.quit_all is True: | |||
| sys.exit(0) # 直接退出程序 | |||
| @@ -127,6 +128,11 @@ class TrainCallback(Callback): | |||
| else: | |||
| raise exception # 抛出陌生Error | |||
| def save_model(self, save_file): | |||
| saver = ModelSaver(save_file) | |||
| saver.save_pytorch(self.model) | |||
| logger.info('[INFO] Saving model to %s', save_file) | |||
| @@ -1,562 +0,0 @@ | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import print_function | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.autograd import * | |||
| import torch.nn.init as init | |||
| import data | |||
| from tools.logger import * | |||
| from transformer.Models import get_sinusoid_encoding_table | |||
| class Encoder(nn.Module): | |||
| def __init__(self, hps, vocab): | |||
| super(Encoder, self).__init__() | |||
| self._hps = hps | |||
| self._vocab = vocab | |||
| self.sent_max_len = hps.sent_max_len | |||
| vocab_size = len(vocab) | |||
| logger.info("[INFO] Vocabulary size is %d", vocab_size) | |||
| embed_size = hps.word_emb_dim | |||
| sent_max_len = hps.sent_max_len | |||
| input_channels = 1 | |||
| out_channels = hps.output_channel | |||
| min_kernel_size = hps.min_kernel_size | |||
| max_kernel_size = hps.max_kernel_size | |||
| width = embed_size | |||
| # word embedding | |||
| self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=vocab.word2id('[PAD]')) | |||
| if hps.word_embedding: | |||
| word2vec = data.Word_Embedding(hps.embedding_path, vocab) | |||
| word_vecs = word2vec.load_my_vecs(embed_size) | |||
| # pretrained_weight = word2vec.add_unknown_words_by_zero(word_vecs, embed_size) | |||
| pretrained_weight = word2vec.add_unknown_words_by_avg(word_vecs, embed_size) | |||
| pretrained_weight = np.array(pretrained_weight) | |||
| self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight)) | |||
| self.embed.weight.requires_grad = hps.embed_train | |||
| # position embedding | |||
| self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
| # cnn | |||
| self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
| logger.info("[INFO] Initing W for CNN.......") | |||
| for conv in self.convs: | |||
| init_weight_value = 6.0 | |||
| init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
| fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
| std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
| def calculate_fan_in_and_fan_out(tensor): | |||
| dimensions = tensor.ndimension() | |||
| if dimensions < 2: | |||
| logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| if dimensions == 2: # Linear | |||
| fan_in = tensor.size(1) | |||
| fan_out = tensor.size(0) | |||
| else: | |||
| num_input_fmaps = tensor.size(1) | |||
| num_output_fmaps = tensor.size(0) | |||
| receptive_field_size = 1 | |||
| if tensor.dim() > 2: | |||
| receptive_field_size = tensor[0][0].numel() | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def forward(self, input): | |||
| # input: a batch of Example object [batch_size, N, seq_len] | |||
| vocab = self._vocab | |||
| batch_size, N, _ = input.size() | |||
| input = input.view(-1, input.size(2)) # [batch_size*N, L] | |||
| input_sent_len = ((input!=vocab.word2id('[PAD]')).sum(dim=1)).int() # [batch_size*N, 1] | |||
| enc_embed_input = self.embed(input) # [batch_size*N, L, D] | |||
| input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
| if self._hps.cuda: | |||
| input_pos = input_pos.cuda() | |||
| enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
| enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
| enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
| enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
| enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
| sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
| sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
| return sent_embedding | |||
| class DomainEncoder(Encoder): | |||
| def __init__(self, hps, vocab, domaindict): | |||
| super(DomainEncoder, self).__init__(hps, vocab) | |||
| # domain embedding | |||
| self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||
| self.domain_embedding.weight.requires_grad = True | |||
| def forward(self, input, domain): | |||
| """ | |||
| :param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||
| :param domain: [batch_size] | |||
| :return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||
| """ | |||
| batch_size, N, _ = input.size() | |||
| sent_embedding = super().forward(input) | |||
| enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||
| enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
| sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
| return sent_embedding | |||
| class MultiDomainEncoder(Encoder): | |||
| def __init__(self, hps, vocab, domaindict): | |||
| super(MultiDomainEncoder, self).__init__(hps, vocab) | |||
| self.domain_size = domaindict.size() | |||
| # domain embedding | |||
| self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim) | |||
| self.domain_embedding.weight.requires_grad = True | |||
| def forward(self, input, domain): | |||
| """ | |||
| :param input: [batch_size, N, seq_len], N sentence number, seq_len token number | |||
| :param domain: [batch_size, domain_size] | |||
| :return: sent_embedding: [batch_size, N, Co * kernel_sizes] | |||
| """ | |||
| batch_size, N, _ = input.size() | |||
| # logger.info(domain[:5, :]) | |||
| sent_embedding = super().forward(input) | |||
| domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1) | |||
| domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size] | |||
| enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D] | |||
| enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D] | |||
| # logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D] | |||
| enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D] | |||
| enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
| sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
| return sent_embedding | |||
| class BertEncoder(nn.Module): | |||
| def __init__(self, hps): | |||
| super(BertEncoder, self).__init__() | |||
| from pytorch_pretrained_bert.modeling import BertModel | |||
| self._hps = hps | |||
| self.sent_max_len = hps.sent_max_len | |||
| self._cuda = hps.cuda | |||
| embed_size = hps.word_emb_dim | |||
| sent_max_len = hps.sent_max_len | |||
| input_channels = 1 | |||
| out_channels = hps.output_channel | |||
| min_kernel_size = hps.min_kernel_size | |||
| max_kernel_size = hps.max_kernel_size | |||
| width = embed_size | |||
| # word embedding | |||
| self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16") | |||
| self._bert.eval() | |||
| for p in self._bert.parameters(): | |||
| p.requires_grad = False | |||
| self.word_embedding_proj = nn.Linear(4096, embed_size) | |||
| # position embedding | |||
| self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
| # cnn | |||
| self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
| logger.info("[INFO] Initing W for CNN.......") | |||
| for conv in self.convs: | |||
| init_weight_value = 6.0 | |||
| init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
| fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
| std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
| def calculate_fan_in_and_fan_out(tensor): | |||
| dimensions = tensor.ndimension() | |||
| if dimensions < 2: | |||
| logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| if dimensions == 2: # Linear | |||
| fan_in = tensor.size(1) | |||
| fan_out = tensor.size(0) | |||
| else: | |||
| num_input_fmaps = tensor.size(1) | |||
| num_output_fmaps = tensor.size(0) | |||
| receptive_field_size = 1 | |||
| if tensor.dim() > 2: | |||
| receptive_field_size = tensor[0][0].numel() | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def pad_encoder_input(self, input_list): | |||
| """ | |||
| :param input_list: N [seq_len, hidden_state] | |||
| :return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||
| """ | |||
| max_len = self.sent_max_len | |||
| enc_sent_input_pad = [] | |||
| _, hidden_size = input_list[0].size() | |||
| for i in range(len(input_list)): | |||
| article_words = input_list[i] # [seq_len, hidden_size] | |||
| seq_len = article_words.size(0) | |||
| if seq_len > max_len: | |||
| pad_words = article_words[:max_len, :] | |||
| else: | |||
| pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||
| pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||
| enc_sent_input_pad.append(pad_words) | |||
| return enc_sent_input_pad | |||
| def forward(self, inputs, input_masks, enc_sent_len): | |||
| """ | |||
| :param inputs: a batch of Example object [batch_size, doc_len=512] | |||
| :param input_masks: 0 or 1, [batch, doc_len=512] | |||
| :param enc_sent_len: sentence original length [batch, N] | |||
| :return: | |||
| """ | |||
| # Use Bert to get word embedding | |||
| batch_size, N = enc_sent_len.size() | |||
| input_pad_list = [] | |||
| for i in range(batch_size): | |||
| tokens_id = inputs[i] | |||
| input_mask = input_masks[i] | |||
| sent_len = enc_sent_len[i] | |||
| input_ids = tokens_id.unsqueeze(0) | |||
| input_mask = input_mask.unsqueeze(0) | |||
| out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) | |||
| out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096] | |||
| _, hidden_size = out.size() | |||
| # restore the sentence | |||
| last_end = 1 | |||
| enc_sent_input = [] | |||
| for length in sent_len: | |||
| if length != 0 and last_end < 511: | |||
| enc_sent_input.append(out[last_end: min(511, last_end + length), :]) | |||
| last_end += length | |||
| else: | |||
| pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||
| enc_sent_input.append(pad_tensor) | |||
| # pad the sentence | |||
| enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||
| input_pad_list.append(torch.stack(enc_sent_input_pad)) | |||
| input_pad = torch.stack(input_pad_list) | |||
| input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1) | |||
| enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||
| enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D] | |||
| sent_pos_list = [] | |||
| for sentlen in enc_sent_len: | |||
| sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
| for k in range(self.sent_max_len - sentlen): | |||
| sent_pos.append(0) | |||
| sent_pos_list.append(sent_pos) | |||
| input_pos = torch.Tensor(sent_pos_list).long() | |||
| if self._hps.cuda: | |||
| input_pos = input_pos.cuda() | |||
| enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
| enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
| enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
| enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
| enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
| sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
| sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
| return sent_embedding | |||
| class BertTagEncoder(BertEncoder): | |||
| def __init__(self, hps, domaindict): | |||
| super(BertTagEncoder, self).__init__(hps) | |||
| # domain embedding | |||
| self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim) | |||
| self.domain_embedding.weight.requires_grad = True | |||
| def forward(self, inputs, input_masks, enc_sent_len, domain): | |||
| sent_embedding = super().forward(inputs, input_masks, enc_sent_len) | |||
| batch_size, N = enc_sent_len.size() | |||
| enc_domain_input = self.domain_embedding(domain) # [batch, D] | |||
| enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D] | |||
| sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2) | |||
| return sent_embedding | |||
| class ELMoEndoer(nn.Module): | |||
| def __init__(self, hps): | |||
| super(ELMoEndoer, self).__init__() | |||
| self._hps = hps | |||
| self.sent_max_len = hps.sent_max_len | |||
| from allennlp.modules.elmo import Elmo | |||
| elmo_dim = 1024 | |||
| options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||
| weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||
| # elmo_dim = 512 | |||
| # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||
| # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||
| embed_size = hps.word_emb_dim | |||
| sent_max_len = hps.sent_max_len | |||
| input_channels = 1 | |||
| out_channels = hps.output_channel | |||
| min_kernel_size = hps.min_kernel_size | |||
| max_kernel_size = hps.max_kernel_size | |||
| width = embed_size | |||
| # elmo embedding | |||
| self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||
| self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||
| # position embedding | |||
| self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
| # cnn | |||
| self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
| logger.info("[INFO] Initing W for CNN.......") | |||
| for conv in self.convs: | |||
| init_weight_value = 6.0 | |||
| init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
| fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
| std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
| def calculate_fan_in_and_fan_out(tensor): | |||
| dimensions = tensor.ndimension() | |||
| if dimensions < 2: | |||
| logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| if dimensions == 2: # Linear | |||
| fan_in = tensor.size(1) | |||
| fan_out = tensor.size(0) | |||
| else: | |||
| num_input_fmaps = tensor.size(1) | |||
| num_output_fmaps = tensor.size(0) | |||
| receptive_field_size = 1 | |||
| if tensor.dim() > 2: | |||
| receptive_field_size = tensor[0][0].numel() | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def forward(self, input): | |||
| # input: a batch of Example object [batch_size, N, seq_len, character_len] | |||
| batch_size, N, seq_len, _ = input.size() | |||
| input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len] | |||
| input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1] | |||
| logger.debug(input_sent_len.view(batch_size, -1)) | |||
| enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D] | |||
| enc_embed_input = self.embed_proj(enc_embed_input) | |||
| # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
| sent_pos_list = [] | |||
| for sentlen in input_sent_len: | |||
| sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
| for k in range(self.sent_max_len - sentlen): | |||
| sent_pos.append(0) | |||
| sent_pos_list.append(sent_pos) | |||
| input_pos = torch.Tensor(sent_pos_list).long() | |||
| if self._hps.cuda: | |||
| input_pos = input_pos.cuda() | |||
| enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
| enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
| enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
| enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
| enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
| sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
| sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
| return sent_embedding | |||
| class ELMoEndoer2(nn.Module): | |||
| def __init__(self, hps): | |||
| super(ELMoEndoer2, self).__init__() | |||
| self._hps = hps | |||
| self._cuda = hps.cuda | |||
| self.sent_max_len = hps.sent_max_len | |||
| from allennlp.modules.elmo import Elmo | |||
| elmo_dim = 1024 | |||
| options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" | |||
| weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" | |||
| # elmo_dim = 512 | |||
| # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json" | |||
| # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" | |||
| embed_size = hps.word_emb_dim | |||
| sent_max_len = hps.sent_max_len | |||
| input_channels = 1 | |||
| out_channels = hps.output_channel | |||
| min_kernel_size = hps.min_kernel_size | |||
| max_kernel_size = hps.max_kernel_size | |||
| width = embed_size | |||
| # elmo embedding | |||
| self.elmo = Elmo(options_file, weight_file, 1, dropout=0) | |||
| self.embed_proj = nn.Linear(elmo_dim, embed_size) | |||
| # position embedding | |||
| self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True) | |||
| # cnn | |||
| self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)]) | |||
| logger.info("[INFO] Initing W for CNN.......") | |||
| for conv in self.convs: | |||
| init_weight_value = 6.0 | |||
| init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value)) | |||
| fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data) | |||
| std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out)) | |||
| def calculate_fan_in_and_fan_out(tensor): | |||
| dimensions = tensor.ndimension() | |||
| if dimensions < 2: | |||
| logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions") | |||
| if dimensions == 2: # Linear | |||
| fan_in = tensor.size(1) | |||
| fan_out = tensor.size(0) | |||
| else: | |||
| num_input_fmaps = tensor.size(1) | |||
| num_output_fmaps = tensor.size(0) | |||
| receptive_field_size = 1 | |||
| if tensor.dim() > 2: | |||
| receptive_field_size = tensor[0][0].numel() | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def pad_encoder_input(self, input_list): | |||
| """ | |||
| :param input_list: N [seq_len, hidden_state] | |||
| :return: enc_sent_input_pad: list, N [max_len, hidden_state] | |||
| """ | |||
| max_len = self.sent_max_len | |||
| enc_sent_input_pad = [] | |||
| _, hidden_size = input_list[0].size() | |||
| for i in range(len(input_list)): | |||
| article_words = input_list[i] # [seq_len, hidden_size] | |||
| seq_len = article_words.size(0) | |||
| if seq_len > max_len: | |||
| pad_words = article_words[:max_len, :] | |||
| else: | |||
| pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size) | |||
| pad_words = torch.cat([article_words, pad_tensor], dim=0) | |||
| enc_sent_input_pad.append(pad_words) | |||
| return enc_sent_input_pad | |||
| def forward(self, inputs, input_masks, enc_sent_len): | |||
| """ | |||
| :param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50] | |||
| :param input_masks: 0 or 1, [batch, doc_len=512] | |||
| :param enc_sent_len: sentence original length [batch, N] | |||
| :return: | |||
| sent_embedding: [batch, N, D] | |||
| """ | |||
| # Use Bert to get word embedding | |||
| batch_size, N = enc_sent_len.size() | |||
| input_pad_list = [] | |||
| elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D] | |||
| elmo_output = elmo_output * input_masks.unsqueeze(-1).float() | |||
| # print("END elmo") | |||
| for i in range(batch_size): | |||
| sent_len = enc_sent_len[i] # [1, N] | |||
| out = elmo_output[i] | |||
| _, hidden_size = out.size() | |||
| # restore the sentence | |||
| last_end = 0 | |||
| enc_sent_input = [] | |||
| for length in sent_len: | |||
| if length != 0 and last_end < 512: | |||
| enc_sent_input.append(out[last_end : min(512, last_end + length), :]) | |||
| last_end += length | |||
| else: | |||
| pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size) | |||
| enc_sent_input.append(pad_tensor) | |||
| # pad the sentence | |||
| enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096] | |||
| input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state] | |||
| input_pad = torch.stack(input_pad_list) | |||
| input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1) | |||
| enc_sent_len = enc_sent_len.view(-1) # [batch_size*N] | |||
| enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D] | |||
| # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len]) | |||
| sent_pos_list = [] | |||
| for sentlen in enc_sent_len: | |||
| sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1)) | |||
| for k in range(self.sent_max_len - sentlen): | |||
| sent_pos.append(0) | |||
| sent_pos_list.append(sent_pos) | |||
| input_pos = torch.Tensor(sent_pos_list).long() | |||
| if self._hps.cuda: | |||
| input_pos = input_pos.cuda() | |||
| enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D] | |||
| enc_conv_input = enc_embed_input + enc_pos_embed_input | |||
| enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D) | |||
| enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W) | |||
| enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co) | |||
| sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes) | |||
| sent_embedding = sent_embedding.view(batch_size, N, -1) | |||
| return sent_embedding | |||
| @@ -21,6 +21,7 @@ | |||
| import os | |||
| import sys | |||
| import json | |||
| import shutil | |||
| import argparse | |||
| import datetime | |||
| @@ -32,20 +33,25 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||
| sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/') | |||
| from fastNLP.core._logger import logger | |||
| # from fastNLP.core._logger import _init_logger | |||
| from fastNLP.core.const import Const | |||
| from fastNLP.core.trainer import Trainer, Tester | |||
| from fastNLP.io.pipe.summarization import ExtCNNDMPipe | |||
| from fastNLP.io.model_io import ModelLoader, ModelSaver | |||
| from fastNLP.io.embed_loader import EmbedLoader | |||
| from tools.logger import * | |||
| # from tools.logger import * | |||
| # from model.TransformerModel import TransformerModel | |||
| from model.TForiginal import TransformerModel | |||
| from model.Metric import LabelFMetric, FastRougeMetric, PyRougeMetric | |||
| from model.LSTMModel import SummarizationModel | |||
| from model.Metric import LossMetric, LabelFMetric, FastRougeMetric, PyRougeMetric | |||
| from model.Loss import MyCrossEntropyLoss | |||
| from tools.Callback import TrainCallback | |||
| def setup_training(model, train_loader, valid_loader, hps): | |||
| """Does setup before starting training (run_training)""" | |||
| @@ -60,32 +66,23 @@ def setup_training(model, train_loader, valid_loader, hps): | |||
| else: | |||
| logger.info("[INFO] Create new model for training...") | |||
| try: | |||
| run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||
| except KeyboardInterrupt: | |||
| logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...") | |||
| save_file = os.path.join(train_dir, "earlystop.pkl") | |||
| saver = ModelSaver(save_file) | |||
| saver.save_pytorch(model) | |||
| logger.info('[INFO] Saving early stop model to %s', save_file) | |||
| run_training(model, train_loader, valid_loader, hps) # this is an infinite loop until interrupted | |||
| def run_training(model, train_loader, valid_loader, hps): | |||
| """Repeatedly runs training iterations, logging loss to screen and writing summaries""" | |||
| logger.info("[INFO] Starting run_training") | |||
| train_dir = os.path.join(hps.save_root, "train") | |||
| if not os.path.exists(train_dir): os.makedirs(train_dir) | |||
| if os.path.exists(train_dir): shutil.rmtree(train_dir) | |||
| os.makedirs(train_dir) | |||
| eval_dir = os.path.join(hps.save_root, "eval") # make a subdir of the root dir for eval data | |||
| if not os.path.exists(eval_dir): os.makedirs(eval_dir) | |||
| lr = hps.lr | |||
| optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |||
| optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=hps.lr) | |||
| criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none') | |||
| # criterion = torch.nn.CrossEntropyLoss(reduce="none") | |||
| trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion, | |||
| n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||
| metric_key="f", validate_every=-1, save_path=eval_dir, | |||
| n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LossMetric(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none'), LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")], | |||
| metric_key="loss", validate_every=-1, save_path=eval_dir, | |||
| callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False) | |||
| train_info = trainer.train(load_best_model=True) | |||
| @@ -98,8 +95,8 @@ def run_training(model, train_loader, valid_loader, hps): | |||
| saver.save_pytorch(model) | |||
| logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path) | |||
| def run_test(model, loader, hps, limited=False): | |||
| """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" | |||
| def run_test(model, loader, hps): | |||
| test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data | |||
| eval_dir = os.path.join(hps.save_root, "eval") | |||
| if not os.path.exists(test_dir) : os.makedirs(test_dir) | |||
| @@ -113,8 +110,8 @@ def run_test(model, loader, hps, limited=False): | |||
| train_dir = os.path.join(hps.save_root, "train") | |||
| bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl') | |||
| else: | |||
| logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
| raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop") | |||
| logger.error("None of such model! Must be one of evalbestmodel/earlystop") | |||
| raise ValueError("None of such model! Must be one of evalbestmodel/earlystop") | |||
| logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path) | |||
| modelloader = ModelLoader() | |||
| @@ -174,13 +171,11 @@ def main(): | |||
| # Training | |||
| parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | |||
| parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent') | |||
| parser.add_argument('--warmup_steps', type=int, default=4000, help='warmup_steps') | |||
| parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping') | |||
| parser.add_argument('--max_grad_norm', type=float, default=10, help='for gradient clipping max gradient normalization') | |||
| # test | |||
| parser.add_argument('-m', type=int, default=3, help='decode summary length') | |||
| parser.add_argument('--limited', action='store_true', default=False, help='limited decode summary length') | |||
| parser.add_argument('--test_model', type=str, default='evalbestmodel', help='choose different model to test [evalbestmodel/evalbestFmodel/trainbestmodel/trainbestFmodel/earlystop]') | |||
| parser.add_argument('--use_pyrouge', action='store_true', default=False, help='use_pyrouge') | |||
| @@ -195,21 +190,22 @@ def main(): | |||
| VOCAL_FILE = args.vocab_path | |||
| LOG_PATH = args.log_root | |||
| # train_log setting | |||
| # # train_log setting | |||
| if not os.path.exists(LOG_PATH): | |||
| if args.mode == "train": | |||
| os.makedirs(LOG_PATH) | |||
| else: | |||
| logger.exception("[Error] Logdir %s doesn't exist. Run in train mode to create it.", LOG_PATH) | |||
| raise Exception("[Error] Logdir %s doesn't exist. Run in train mode to create it." % (LOG_PATH)) | |||
| nowTime=datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | |||
| log_path = os.path.join(LOG_PATH, args.mode + "_" + nowTime) | |||
| file_handler = logging.FileHandler(log_path) | |||
| file_handler.setFormatter(formatter) | |||
| logger.addHandler(file_handler) | |||
| # logger = _init_logger(path=log_path) | |||
| # file_handler = logging.FileHandler(log_path) | |||
| # file_handler.setFormatter(formatter) | |||
| # logger.addHandler(file_handler) | |||
| logger.info("Pytorch %s", torch.__version__) | |||
| # dataset | |||
| hps = args | |||
| dbPipe = ExtCNNDMPipe(vocab_size=hps.vocab_size, | |||
| vocab_path=VOCAL_FILE, | |||
| @@ -225,6 +221,8 @@ def main(): | |||
| paths = {"train": DATA_FILE, "valid": VALID_FILE} | |||
| db = dbPipe.process_from_file(paths) | |||
| # embedding | |||
| if args.embedding == "glove": | |||
| vocab = db.get_vocab("vocab") | |||
| embed = torch.nn.Embedding(len(vocab), hps.word_emb_dim) | |||
| @@ -237,19 +235,24 @@ def main(): | |||
| logger.error("[ERROR] embedding To Be Continued!") | |||
| sys.exit(1) | |||
| # model | |||
| if args.sentence_encoder == "transformer" and args.sentence_decoder == "SeqLab": | |||
| model_param = json.load(open("config/transformer.config", "rb")) | |||
| hps.__dict__.update(model_param) | |||
| model = TransformerModel(hps, embed) | |||
| elif args.sentence_encoder == "deeplstm" and args.sentence_decoder == "SeqLab": | |||
| model_param = json.load(open("config/deeplstm.config", "rb")) | |||
| hps.__dict__.update(model_param) | |||
| model = SummarizationModel(hps, embed) | |||
| else: | |||
| logger.error("[ERROR] Model To Be Continued!") | |||
| sys.exit(1) | |||
| logger.info(hps) | |||
| if hps.cuda: | |||
| model = model.cuda() | |||
| logger.info("[INFO] Use cuda") | |||
| logger.info(hps) | |||
| if hps.mode == 'train': | |||
| db.get_dataset("valid").set_target("text", "summary") | |||
| setup_training(model, db.get_dataset("train"), db.get_dataset("valid"), hps) | |||