|
|
|
@@ -4,20 +4,19 @@ import math |
|
|
|
import os |
|
|
|
import subprocess |
|
|
|
from dataclasses import dataclass |
|
|
|
from typing import Dict, List, Optional, Union |
|
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
|
|
import json |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch import nn |
|
|
|
from torch import Tensor, nn |
|
|
|
from torch.nn.init import xavier_uniform_ |
|
|
|
from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig, |
|
|
|
RobertaModel, RobertaTokenizer) |
|
|
|
from transformers.activations import ACT2FN |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
|
|
from modelscope.outputs import OutputKeys |
|
|
|
from modelscope.utils import logger as logging |
|
|
|
from .configuration_palm import PalmConfig |
|
|
|
from .dureader_eval import compute_bleu_rouge, normalize |
|
|
|
@@ -142,35 +141,27 @@ class MultiHeadedAttention(nn.Module): # SelfAttention |
|
|
|
key = shape(key) |
|
|
|
value = shape(value) |
|
|
|
|
|
|
|
if layer_cache is not None: |
|
|
|
device = key.device |
|
|
|
if layer_cache['self_keys'] is not None: |
|
|
|
key = torch.cat( |
|
|
|
(layer_cache['self_keys'].to(device), key), dim=2) |
|
|
|
if layer_cache['self_values'] is not None: |
|
|
|
value = torch.cat( |
|
|
|
(layer_cache['self_values'].to(device), value), |
|
|
|
dim=2) |
|
|
|
layer_cache['self_keys'] = key |
|
|
|
layer_cache['self_values'] = value |
|
|
|
device = key.device |
|
|
|
if layer_cache['self_keys'] is not None: |
|
|
|
key = torch.cat((layer_cache['self_keys'].to(device), key), |
|
|
|
dim=2) |
|
|
|
if layer_cache['self_values'] is not None: |
|
|
|
value = torch.cat( |
|
|
|
(layer_cache['self_values'].to(device), value), dim=2) |
|
|
|
layer_cache['self_keys'] = key |
|
|
|
layer_cache['self_values'] = value |
|
|
|
elif type == 'context': |
|
|
|
query = self.linear_query(query) |
|
|
|
if layer_cache is not None: |
|
|
|
if layer_cache['memory_keys'] is None: |
|
|
|
key, value = self.linear_keys(key), self.linear_values( |
|
|
|
value) |
|
|
|
key = shape(key) |
|
|
|
value = shape(value) |
|
|
|
else: |
|
|
|
key, value = layer_cache['memory_keys'], layer_cache[ |
|
|
|
'memory_values'] |
|
|
|
layer_cache['memory_keys'] = key |
|
|
|
layer_cache['memory_values'] = value |
|
|
|
else: |
|
|
|
if layer_cache['memory_keys'] is None: |
|
|
|
key, value = self.linear_keys(key), self.linear_values( |
|
|
|
value) |
|
|
|
key = shape(key) |
|
|
|
value = shape(value) |
|
|
|
else: |
|
|
|
key, value = layer_cache['memory_keys'], layer_cache[ |
|
|
|
'memory_values'] |
|
|
|
layer_cache['memory_keys'] = key |
|
|
|
layer_cache['memory_values'] = value |
|
|
|
else: |
|
|
|
key = self.linear_keys(key) |
|
|
|
value = self.linear_values(value) |
|
|
|
@@ -372,6 +363,44 @@ class PositionalEncoding(nn.Module): |
|
|
|
return self.pe[:, :emb.size(1)] |
|
|
|
|
|
|
|
|
|
|
|
class TransformerDecoderState: |
|
|
|
|
|
|
|
def __init__(self, src: Tensor, cache_num_layers: int = -1): |
|
|
|
self.src: Tensor = src |
|
|
|
self.previous_input: Tensor = None |
|
|
|
self.previous_layer_inputs: Tensor = None |
|
|
|
self.cache: Optional[Dict[str, Any]] = None |
|
|
|
if cache_num_layers != -1: |
|
|
|
self._init_cache(cache_num_layers) |
|
|
|
|
|
|
|
def update_state(self, new_input, previous_layer_inputs): |
|
|
|
self.previous_input = new_input |
|
|
|
self.previous_layer_inputs = previous_layer_inputs |
|
|
|
self.cache = None |
|
|
|
|
|
|
|
def _init_cache(self, num_layers): |
|
|
|
self.cache = {} |
|
|
|
for num in range(num_layers): |
|
|
|
layer_cache = {'memory_keys': None, 'memory_values': None} |
|
|
|
layer_cache['self_keys'] = None |
|
|
|
layer_cache['self_values'] = None |
|
|
|
self.cache['layer_{}'.format(num)] = layer_cache |
|
|
|
|
|
|
|
def map_batch_fn(self, fn): |
|
|
|
|
|
|
|
def _recursive_map(struct, batch_dim=0): |
|
|
|
for k, v in struct.items(): |
|
|
|
if v is not None: |
|
|
|
if isinstance(v, dict): |
|
|
|
_recursive_map(v) |
|
|
|
else: |
|
|
|
struct[k] = fn(v, batch_dim) |
|
|
|
|
|
|
|
self.src = fn(self.src, 0) |
|
|
|
if self.cache is not None: |
|
|
|
_recursive_map(self.cache) |
|
|
|
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Module): # Decoder |
|
|
|
""" |
|
|
|
The Transformer decoder from "Attention is All You Need". |
|
|
|
@@ -403,44 +432,6 @@ class TransformerDecoder(nn.Module): # Decoder |
|
|
|
""" |
|
|
|
decoder_type = 'transformer' |
|
|
|
|
|
|
|
class TransformerDecoderState: |
|
|
|
|
|
|
|
def __init__(self, src): |
|
|
|
self.src = src |
|
|
|
self.previous_input = None |
|
|
|
self.previous_layer_inputs = None |
|
|
|
self.cache = None |
|
|
|
|
|
|
|
def update_state(self, new_input, previous_layer_inputs): |
|
|
|
self.previous_input = new_input |
|
|
|
self.previous_layer_inputs = previous_layer_inputs |
|
|
|
self.cache = None |
|
|
|
|
|
|
|
def _init_cache(self, num_layers): |
|
|
|
self.cache = {} |
|
|
|
for num in range(num_layers): |
|
|
|
layer_cache = { |
|
|
|
'memory_keys': None, |
|
|
|
'memory_values': None, |
|
|
|
'self_keys': None, |
|
|
|
'self_values': None |
|
|
|
} |
|
|
|
self.cache['layer_{}'.format(num)] = layer_cache |
|
|
|
|
|
|
|
def map_batch_fn(self, fn): |
|
|
|
|
|
|
|
def _recursive_map(struct, batch_dim=0): |
|
|
|
for k, v in struct.items(): |
|
|
|
if v is not None: |
|
|
|
if isinstance(v, dict): |
|
|
|
_recursive_map(v) |
|
|
|
else: |
|
|
|
struct[k] = fn(v, batch_dim) |
|
|
|
|
|
|
|
self.src = fn(self.src, 0) |
|
|
|
if self.cache is not None: |
|
|
|
_recursive_map(self.cache) |
|
|
|
|
|
|
|
def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
@@ -458,13 +449,13 @@ class TransformerDecoder(nn.Module): # Decoder |
|
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) |
|
|
|
self.state = None |
|
|
|
|
|
|
|
def init_state(self, src, with_cache=False): |
|
|
|
self.state = self.TransformerDecoderState(src) |
|
|
|
if with_cache: |
|
|
|
self.state._init_cache(self.num_layers) |
|
|
|
|
|
|
|
def forward(self, tgt, memory_bank, step=None, memory_masks=None): |
|
|
|
src_words = self.state.src |
|
|
|
def forward(self, |
|
|
|
state: TransformerDecoderState, |
|
|
|
tgt: Tensor, |
|
|
|
memory_bank: Tensor, |
|
|
|
step: int = None, |
|
|
|
memory_masks: Tensor = None): |
|
|
|
src_words = state.src |
|
|
|
tgt_words = tgt |
|
|
|
src_batch, src_len = src_words.size() |
|
|
|
tgt_batch, tgt_len = tgt_words.size() |
|
|
|
@@ -487,33 +478,36 @@ class TransformerDecoder(nn.Module): # Decoder |
|
|
|
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ |
|
|
|
.expand(src_batch, tgt_len, src_len) |
|
|
|
|
|
|
|
if self.state.cache is None: |
|
|
|
if state.cache is None: |
|
|
|
saved_inputs = [] |
|
|
|
attns = [] |
|
|
|
for i in range(self.num_layers): |
|
|
|
prev_layer_input = None |
|
|
|
if self.state.cache is None: |
|
|
|
if self.state.previous_input is not None: |
|
|
|
prev_layer_input = self.state.previous_layer_inputs[i] |
|
|
|
if state.cache is None: |
|
|
|
if state.previous_input is not None: |
|
|
|
prev_layer_input = state.previous_layer_inputs[i] |
|
|
|
output, attn, all_input \ |
|
|
|
= self.transformer_layers[i](output, src_memory_bank, src_pad_mask, tgt_pad_mask, |
|
|
|
previous_input=prev_layer_input, |
|
|
|
layer_cache=self.state.cache['layer_{}'.format(i)] |
|
|
|
if self.state.cache is not None else None, step=step) |
|
|
|
if self.state.cache is None: |
|
|
|
= self.transformer_layers[i]( |
|
|
|
output, src_memory_bank, |
|
|
|
src_pad_mask, tgt_pad_mask, |
|
|
|
previous_input=prev_layer_input, |
|
|
|
layer_cache=state.cache['layer_{}'.format(i)] |
|
|
|
if state.cache is not None else None, |
|
|
|
step=step) |
|
|
|
if state.cache is None: |
|
|
|
saved_inputs.append(all_input) |
|
|
|
attns.append(attn) |
|
|
|
|
|
|
|
if self.state.cache is None: |
|
|
|
if state.cache is None: |
|
|
|
saved_inputs = torch.stack(saved_inputs) |
|
|
|
|
|
|
|
output = self.layer_norm(output) |
|
|
|
|
|
|
|
# Process the result and update the attentions. |
|
|
|
if self.state.cache is None: |
|
|
|
self.state.update_state(tgt, saved_inputs) |
|
|
|
if state.cache is None: |
|
|
|
state.update_state(tgt, saved_inputs) |
|
|
|
|
|
|
|
return output, attns |
|
|
|
return output, attns, state |
|
|
|
|
|
|
|
|
|
|
|
class PalmPointerGenerator(nn.Module): |
|
|
|
@@ -570,10 +564,11 @@ class AbsSummarizer(PalmPreTrainedModel): # Model |
|
|
|
if (config.max_pos > 512): |
|
|
|
my_pos_embeddings = nn.Embedding( |
|
|
|
config.max_pos, self.bert.model.config.hidden_size) |
|
|
|
my_pos_embeddings.weight.data[:512] = \ |
|
|
|
self.bert.embeddings.position_embeddings.weight.data |
|
|
|
my_pos_embeddings.weight.data[512:] = \ |
|
|
|
self.bert.embeddings.position_embeddings.weight.data[-1][None, :].repeat(config.max_pos - 512, 1) |
|
|
|
my_pos_embeddings.weight.data[: |
|
|
|
512] = self.bert.embeddings.position_embeddings.weight.data |
|
|
|
my_pos_embeddings.weight.data[ |
|
|
|
512:] = self.bert.embeddings.position_embeddings.weight.data[ |
|
|
|
-1][None, :].repeat(config.max_pos - 512, 1) |
|
|
|
self.bert.model.embeddings.position_embeddings = my_pos_embeddings |
|
|
|
self.vocab_size = self.bert.config.vocab_size |
|
|
|
tgt_embeddings = nn.Embedding( |
|
|
|
@@ -633,8 +628,8 @@ class AbsSummarizer(PalmPreTrainedModel): # Model |
|
|
|
|
|
|
|
def forward(self, src, tgt, mask_src): |
|
|
|
top_vec, _ = self.bert(src, mask_src, return_dict=False) |
|
|
|
self.decoder.init_state(src) |
|
|
|
decoder_outputs, attns = self.decoder(tgt[:, :-1], top_vec) |
|
|
|
state = TransformerDecoderState(src) |
|
|
|
decoder_outputs, attns, _ = self.decoder(state, tgt[:, :-1], top_vec) |
|
|
|
return decoder_outputs, attns[-1], top_vec |
|
|
|
|
|
|
|
|
|
|
|
@@ -776,9 +771,9 @@ class Translator(nn.Module): |
|
|
|
translation_batch['predictions'])) |
|
|
|
batch_size = batch.batch_size |
|
|
|
|
|
|
|
preds, pred_score, _, tgt_str, src, src_str = \ |
|
|
|
translation_batch['predictions'], translation_batch['scores'], translation_batch['gold_score'], \ |
|
|
|
batch.tgt_str, batch.src, batch.src_str |
|
|
|
preds, pred_score, tgt_str, src, src_str = translation_batch[ |
|
|
|
'predictions'], translation_batch[ |
|
|
|
'scores'], batch.tgt_str, batch.src, batch.src_str |
|
|
|
query_id = batch.query_id |
|
|
|
''' |
|
|
|
try: |
|
|
|
@@ -903,17 +898,13 @@ class Translator(nn.Module): |
|
|
|
'</s>', '').replace('<unk>', ' ').strip() |
|
|
|
if (self.args.recall_eval): |
|
|
|
_pred_str = '' |
|
|
|
# gap = 1e3 |
|
|
|
for sent in pred_str.split('<q>'): |
|
|
|
can_pred_str = _pred_str + '<q>' + sent.strip() |
|
|
|
# can_gap = math.fabs(len(_pred_str.split()) - len(gold_str.split())) |
|
|
|
# if(can_gap>=gap): |
|
|
|
if len(can_pred_str.split()) >= len( |
|
|
|
gold_str.split()) + 10: |
|
|
|
pred_str = _pred_str |
|
|
|
break |
|
|
|
else: |
|
|
|
# gap = can_gap |
|
|
|
_pred_str = can_pred_str |
|
|
|
|
|
|
|
if self.args.dataset == 'marco' or self.args.dataset == 'squad' or self.args.dataset == 'qg_ranking': |
|
|
|
@@ -967,10 +958,6 @@ class Translator(nn.Module): |
|
|
|
pred_str = [pred_str] |
|
|
|
pred_dict[query_id] = normalize([pred_str[0]]) |
|
|
|
ref_dict[query_id] = normalize([gold_str]) |
|
|
|
# pred_str_list = [src] + pred_str |
|
|
|
# self.can_out_file.write("\t".join(pred_str_list)+"\n") |
|
|
|
# self.can_out_file.write("\t".join(pred_str_list)+"\n") |
|
|
|
# self.gold_out_file.write("\t".join([src, pred_str[0], gold_str])+"\n") |
|
|
|
self.pred_json_score_out_file.write( |
|
|
|
'\t'.join([str(query_id), src, gold_str, pred_str[0]]) |
|
|
|
+ '\n') |
|
|
|
@@ -1027,8 +1014,6 @@ class Translator(nn.Module): |
|
|
|
preds[idx] = '。' |
|
|
|
return preds, labels |
|
|
|
|
|
|
|
# bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) |
|
|
|
# self.logger.info('Dev eval result: {}'.format(bleu_rouge)) |
|
|
|
pred_results, gold_results = postprocess_text( |
|
|
|
pred_results, gold_results) |
|
|
|
pred_dict = {str(i): tmp for i, tmp in enumerate(pred_results)} |
|
|
|
@@ -1037,8 +1022,6 @@ class Translator(nn.Module): |
|
|
|
print(bleu_rouge) |
|
|
|
# unreachable |
|
|
|
elif self.args.dataset == 'dureader' or self.args.dataset == 'paraphrase': |
|
|
|
# bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) |
|
|
|
# self.logger.info('Dev eval result: {}'.format(bleu_rouge)) |
|
|
|
pred_results, gold_results = postprocess_text( |
|
|
|
pred_results, gold_results) |
|
|
|
bleu_score = cal_bleu(pred_results, gold_results) |
|
|
|
@@ -1134,11 +1117,11 @@ class Translator(nn.Module): |
|
|
|
mask_src = batch.mask_src |
|
|
|
|
|
|
|
src_features, _ = self.model.bert(src, mask_src, return_dict=False) |
|
|
|
self.model.decoder.init_state(src, with_cache=True) |
|
|
|
state = TransformerDecoderState(src, self.model.decoder.num_layers) |
|
|
|
device = src_features.device |
|
|
|
|
|
|
|
# Tile states and memory beam_size times. |
|
|
|
self.model.decoder.state.map_batch_fn( |
|
|
|
state.map_batch_fn( |
|
|
|
lambda state, dim: self._tile(state, beam_size, dim=dim)) |
|
|
|
src_features = self._tile(src_features, beam_size, dim=0) |
|
|
|
batch_offset = torch.arange( |
|
|
|
@@ -1174,8 +1157,8 @@ class Translator(nn.Module): |
|
|
|
|
|
|
|
# Decoder forward. |
|
|
|
decoder_input = decoder_input.transpose(0, 1) |
|
|
|
dec_out, attns = self.model.decoder( |
|
|
|
decoder_input, src_features, step=step) |
|
|
|
dec_out, attns, state = self.model.decoder( |
|
|
|
state, decoder_input, src_features, step=step) |
|
|
|
|
|
|
|
# Generator forward. |
|
|
|
log_probs = self.generator.forward( |
|
|
|
@@ -1188,7 +1171,6 @@ class Translator(nn.Module): |
|
|
|
# Multiply probs by the beam probability. |
|
|
|
|
|
|
|
length_penalty = ((5.0 + (step + 1)) / 6.0)**self.alpha |
|
|
|
# ''' |
|
|
|
if self.args.sample_topk: |
|
|
|
temperature = self.args.temperature |
|
|
|
_scores = log_probs / temperature |
|
|
|
@@ -1211,13 +1193,11 @@ class Translator(nn.Module): |
|
|
|
_scores = _scores / length_penalty |
|
|
|
topk_scores = torch.gather( |
|
|
|
_scores, -1, topk_ids) # (batch_size * num_beams, 2) |
|
|
|
# log_probs += # (batch_size * num_beams, 2) |
|
|
|
# Match shape of greedy beam search |
|
|
|
topk_ids = topk_ids.view( |
|
|
|
-1, beam_size) # (batch_size, 2 * num_beams) |
|
|
|
topk_scores = topk_scores.view( |
|
|
|
-1, beam_size) # (batch_size, 2 * num_beams) |
|
|
|
# ''' |
|
|
|
else: |
|
|
|
log_probs += topk_log_probs.view(-1).unsqueeze(1) |
|
|
|
curr_scores = log_probs / length_penalty |
|
|
|
@@ -1231,7 +1211,6 @@ class Translator(nn.Module): |
|
|
|
fail = False |
|
|
|
words = [int(w) for w in alive_seq[i]] |
|
|
|
if self.args.encoder == 'roberta': |
|
|
|
# words = [self.vocab.convert_ids_to_tokens[w] for w in words] |
|
|
|
words = self.vocab.decode(words).strip().split() |
|
|
|
else: |
|
|
|
words = [ |
|
|
|
@@ -1252,7 +1231,6 @@ class Translator(nn.Module): |
|
|
|
topk_log_probs = topk_scores * length_penalty |
|
|
|
|
|
|
|
# Resolve beam origin and true word ids. |
|
|
|
# topk_beam_index = topk_ids.div(vocab_size) |
|
|
|
topk_beam_index = topk_ids // vocab_size |
|
|
|
topk_ids = topk_ids.fmod(vocab_size) |
|
|
|
|
|
|
|
@@ -1313,7 +1291,7 @@ class Translator(nn.Module): |
|
|
|
# Reorder states. |
|
|
|
select_indices = batch_index.view(-1) |
|
|
|
src_features = src_features.index_select(0, select_indices) |
|
|
|
self.model.decoder.state.map_batch_fn( |
|
|
|
state.map_batch_fn( |
|
|
|
lambda state, dim: state.index_select(dim, select_indices)) |
|
|
|
|
|
|
|
return results |
|
|
|
|