|
|
|
@@ -1725,7 +1725,116 @@ class BertLMHeadModel(BertPreTrainedModel): |
|
|
|
return reordered_past |
|
|
|
|
|
|
|
|
|
|
|
class MPlugForVisualQuestionAnswering(PreTrainedModel): |
|
|
|
class BertPrefixModel(BertPreTrainedModel): |
|
|
|
|
|
|
|
_keys_to_ignore_on_load_unexpected = [r'pooler'] |
|
|
|
_keys_to_ignore_on_load_missing = [ |
|
|
|
r'position_ids', r'predictions.decoder.bias' |
|
|
|
] |
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
|
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
|
|
def get_output_embeddings(self): |
|
|
|
return self.cls.predictions.decoder |
|
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
|
self.cls.predictions.decoder = new_embeddings |
|
|
|
|
|
|
|
@add_start_docstrings_to_model_forward( |
|
|
|
BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) |
|
|
|
@add_code_sample_docstrings( |
|
|
|
processor_class=_TOKENIZER_FOR_DOC, |
|
|
|
checkpoint='bert-base-uncased', |
|
|
|
output_type=CausalLMOutputWithCrossAttentions, |
|
|
|
config_class=_CONFIG_FOR_DOC, |
|
|
|
) |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
input_ids=None, |
|
|
|
attention_mask=None, |
|
|
|
token_type_ids=None, |
|
|
|
position_ids=None, |
|
|
|
head_mask=None, |
|
|
|
inputs_embeds=None, |
|
|
|
encoder_hidden_states=None, |
|
|
|
encoder_attention_mask=None, |
|
|
|
labels=None, |
|
|
|
past_key_values=None, |
|
|
|
use_cache=None, |
|
|
|
output_attentions=None, |
|
|
|
output_hidden_states=None, |
|
|
|
return_dict=None, |
|
|
|
is_decoder=True, |
|
|
|
reduction='mean', |
|
|
|
soft_labels=None, |
|
|
|
alpha=0, |
|
|
|
return_logits=False, |
|
|
|
): |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if labels is not None: |
|
|
|
use_cache = False |
|
|
|
|
|
|
|
outputs = self.bert( |
|
|
|
input_ids, |
|
|
|
attention_mask=attention_mask, |
|
|
|
token_type_ids=token_type_ids, |
|
|
|
position_ids=position_ids, |
|
|
|
head_mask=head_mask, |
|
|
|
inputs_embeds=inputs_embeds, |
|
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
|
past_key_values=past_key_values, |
|
|
|
use_cache=use_cache, |
|
|
|
output_attentions=output_attentions, |
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
return_dict=return_dict, |
|
|
|
is_decoder=is_decoder, |
|
|
|
) |
|
|
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
prediction_scores = self.cls(sequence_output) |
|
|
|
|
|
|
|
if return_logits: |
|
|
|
return prediction_scores[:, :-1, :].contiguous() |
|
|
|
|
|
|
|
lm_loss = None |
|
|
|
if labels is not None: |
|
|
|
# we are doing next-token prediction; shift prediction scores and input ids by one |
|
|
|
shifted_prediction_scores = prediction_scores[:, : |
|
|
|
-1, :].contiguous() |
|
|
|
labels = labels[:, 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
lm_loss = loss_fct( |
|
|
|
shifted_prediction_scores.view(-1, self.config.vocab_size), |
|
|
|
labels.view(-1)) |
|
|
|
if soft_labels is not None: |
|
|
|
loss_distill = -torch.sum( |
|
|
|
F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, |
|
|
|
dim=-1) |
|
|
|
loss_distill = loss_distill[labels != -100].mean() |
|
|
|
lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill |
|
|
|
|
|
|
|
if not return_dict: |
|
|
|
output = (prediction_scores, ) + outputs[2:] |
|
|
|
return ((lm_loss, ) + output) if lm_loss is not None else output |
|
|
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
|
loss=lm_loss, |
|
|
|
logits=prediction_scores, |
|
|
|
past_key_values=outputs.past_key_values, |
|
|
|
hidden_states=outputs.hidden_states, |
|
|
|
attentions=outputs.attentions, |
|
|
|
cross_attentions=outputs.cross_attentions, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class MPlug(PreTrainedModel): |
|
|
|
config_class = MPlugConfig |
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
|
@@ -1739,16 +1848,19 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): |
|
|
|
self.config_encoder, add_pooling_layer=False) |
|
|
|
self.fusion_encoder = FusionModel( |
|
|
|
self.config_fusion, add_pooling_layer=False) |
|
|
|
self.text_decoder = BertLMHeadModel(self.config_decoder) |
|
|
|
self.init_distill(config) |
|
|
|
self.beam_generator = TextGenerator(config, self.text_decoder) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_pretrained(cls, model_dir, load_checkpoint=True): |
|
|
|
config = MPlugConfig.from_yaml_file( |
|
|
|
from modelscope.utils.constant import Tasks |
|
|
|
|
|
|
|
task_mapping = { |
|
|
|
Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, |
|
|
|
Tasks.image_captioning: MPLUGForImageCaption |
|
|
|
} |
|
|
|
config = cls.config_class.from_yaml_file( |
|
|
|
os.path.join(model_dir, CONFIG_NAME)) |
|
|
|
config.model_dir = model_dir |
|
|
|
model = cls(config) |
|
|
|
model = task_mapping[config.task](config) |
|
|
|
if load_checkpoint: |
|
|
|
checkpoint_path = os.path.join(model_dir, |
|
|
|
ModelFile.TORCH_MODEL_BIN_FILE) |
|
|
|
@@ -1803,6 +1915,161 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): |
|
|
|
clip_model.visual.positional_embedding = pos_embed |
|
|
|
return clip_model |
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def module_setting(self, config): |
|
|
|
bert_config_path = os.path.join(config.model_dir, config.bert_config) |
|
|
|
self.config_encoder = BertConfig.from_json_file(bert_config_path) |
|
|
|
self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers |
|
|
|
self.config_fusion = BertConfig.from_json_file(bert_config_path) |
|
|
|
self.config_decoder = BertConfig.from_json_file(bert_config_path) |
|
|
|
self.config_decoder.add_cross_attention = True |
|
|
|
self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers |
|
|
|
self.large = False |
|
|
|
if self.config_encoder.hidden_size != config.vision_width: |
|
|
|
self.visn_fc = nn.Linear(config.vision_width, |
|
|
|
self.config_encoder.hidden_size) |
|
|
|
self.visn_layer_norm = nn.LayerNorm( |
|
|
|
self.config_encoder.hidden_size, eps=1e-12) |
|
|
|
self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) |
|
|
|
self.large = True |
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
def copy_params(self): |
|
|
|
for model_pair in self.model_pairs: |
|
|
|
for param, param_m in zip(model_pair[0].parameters(), |
|
|
|
model_pair[1].parameters()): |
|
|
|
param_m.data.copy_(param.data) # initialize |
|
|
|
param_m.requires_grad = False # not update by gradient |
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
def _momentum_update(self): |
|
|
|
for model_pair in self.model_pairs: |
|
|
|
for param, param_m in zip(model_pair[0].parameters(), |
|
|
|
model_pair[1].parameters()): |
|
|
|
param_m.data = param_m.data * self.momentum + param.data * ( |
|
|
|
1. - self.momentum) |
|
|
|
|
|
|
|
def generation(self, question_states, question_atts, out_size=1): |
|
|
|
encoder_inputs = [question_states, question_atts] |
|
|
|
topk_ids, topk_scores = self.beam_generator.translate_batch( |
|
|
|
encoder_inputs, out_size=out_size) |
|
|
|
return topk_ids, topk_scores |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _tile(x, dim, n_tile): |
|
|
|
import numpy as np |
|
|
|
init_dim = x.size(dim) |
|
|
|
repeat_idx = [1] * x.dim() |
|
|
|
repeat_idx[dim] = n_tile |
|
|
|
x = x.repeat(*(repeat_idx)) |
|
|
|
order_index = torch.LongTensor( |
|
|
|
np.concatenate( |
|
|
|
[init_dim * np.arange(n_tile) + i for i in range(init_dim)])) |
|
|
|
return torch.index_select(x, dim, order_index.to(x.device)) |
|
|
|
|
|
|
|
def rank_answer(self, question_states, question_atts, answer_ids, |
|
|
|
answer_atts, k): |
|
|
|
|
|
|
|
num_ques = question_states.size(0) |
|
|
|
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token |
|
|
|
|
|
|
|
start_output = self.text_decoder( |
|
|
|
start_ids, |
|
|
|
encoder_hidden_states=question_states, |
|
|
|
encoder_attention_mask=question_atts, |
|
|
|
return_dict=True, |
|
|
|
reduction='none') |
|
|
|
logits = start_output.logits[:, 0, :] # first token's logit |
|
|
|
|
|
|
|
# topk_probs: top-k probability |
|
|
|
# topk_ids: [num_question, k] |
|
|
|
answer_first_token = answer_ids[:, 1] |
|
|
|
prob_first_token = F.softmax( |
|
|
|
logits, dim=1).index_select( |
|
|
|
dim=1, index=answer_first_token) |
|
|
|
topk_probs, topk_ids = prob_first_token.topk(k, dim=1) |
|
|
|
|
|
|
|
# answer input: [num_question*k, answer_len] |
|
|
|
input_ids = [] |
|
|
|
input_atts = [] |
|
|
|
for b, topk_id in enumerate(topk_ids): |
|
|
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) |
|
|
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) |
|
|
|
input_ids = torch.cat(input_ids, dim=0) |
|
|
|
input_atts = torch.cat(input_atts, dim=0) |
|
|
|
|
|
|
|
targets_ids = input_ids.masked_fill( |
|
|
|
input_ids == self.tokenizer.pad_token_id, -100) |
|
|
|
|
|
|
|
# repeat encoder's output for top-k answers |
|
|
|
question_states = self._tile(question_states, 0, k) |
|
|
|
question_atts = self._tile(question_atts, 0, k) |
|
|
|
|
|
|
|
output = self.text_decoder( |
|
|
|
input_ids, |
|
|
|
attention_mask=input_atts, |
|
|
|
encoder_hidden_states=question_states, |
|
|
|
encoder_attention_mask=question_atts, |
|
|
|
labels=targets_ids, |
|
|
|
return_dict=True, |
|
|
|
reduction='none') |
|
|
|
|
|
|
|
answer_loss = output.loss |
|
|
|
answer_loss = answer_loss.view(input_ids.size(0), -1) |
|
|
|
|
|
|
|
# topk_prob: first token probability |
|
|
|
topk_probs = topk_probs.view(-1, 1) |
|
|
|
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) |
|
|
|
|
|
|
|
# re-calculate log probabilities for the answer sequences using chain rule |
|
|
|
log_probs_sum = log_probs.sum(1) |
|
|
|
log_probs_sum = log_probs_sum.view(num_ques, k) |
|
|
|
|
|
|
|
topk_probs = F.softmax(log_probs_sum, dim=-1) |
|
|
|
# get top-k after re-ranking |
|
|
|
topk_probs, rerank_id = topk_probs.topk(k, dim=1) |
|
|
|
topk_ids = torch.gather(topk_ids, 1, rerank_id) |
|
|
|
|
|
|
|
return topk_ids, topk_probs |
|
|
|
|
|
|
|
|
|
|
|
class MPlugForVisualQuestionAnswering(MPlug): |
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
|
|
self.text_decoder = BertLMHeadModel(self.config_decoder) |
|
|
|
self.beam_generator = TextGenerator(config, self.text_decoder) |
|
|
|
self.init_distill(config) |
|
|
|
|
|
|
|
def init_distill(self, config): |
|
|
|
self.distill = config.distill |
|
|
|
if self.distill: |
|
|
|
self.visual_encoder_m = self._initialize_clip(config) |
|
|
|
self.text_encoder_m = BertModel( |
|
|
|
self.config_encoder, add_pooling_layer=False) |
|
|
|
self.fusion_encoder_m = FusionModel( |
|
|
|
self.config_fusion, add_pooling_layer=False) |
|
|
|
self.text_decoder_m = BertLMHeadModel(self.config_decoder) |
|
|
|
self.model_pairs = [ |
|
|
|
[self.visual_encoder, self.visual_encoder_m], |
|
|
|
[self.text_encoder, self.text_encoder_m], |
|
|
|
[self.text_decoder, self.text_decoder_m], |
|
|
|
] |
|
|
|
if self.config_encoder.hidden_size != config.vision_width: |
|
|
|
self.visn_fc_m = nn.Linear(config.vision_width, |
|
|
|
self.config_encoder.hidden_size) |
|
|
|
self.visn_layer_norm_m = nn.LayerNorm( |
|
|
|
self.config_encoder.hidden_size, eps=1e-12) |
|
|
|
self.dropout_m = nn.Dropout( |
|
|
|
self.config_encoder.hidden_dropout_prob) |
|
|
|
self.model_pairs.extend( |
|
|
|
[[self.visn_fc, self.visn_fc_m], |
|
|
|
[self.visn_layer_norm, self.visn_layer_norm_m]]) |
|
|
|
self.copy_params() |
|
|
|
self.momentum = 0.995 |
|
|
|
|
|
|
|
def forward(self, |
|
|
|
image, |
|
|
|
question, |
|
|
|
@@ -1935,145 +2202,110 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): |
|
|
|
merge_text_attention) |
|
|
|
return topk_ids, topk_probs |
|
|
|
|
|
|
|
def module_setting(self, config): |
|
|
|
bert_config_path = os.path.join(config.model_dir, config.bert_config) |
|
|
|
self.config_encoder = BertConfig.from_json_file(bert_config_path) |
|
|
|
self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers |
|
|
|
self.config_fusion = BertConfig.from_json_file(bert_config_path) |
|
|
|
self.config_decoder = BertConfig.from_json_file(bert_config_path) |
|
|
|
self.config_decoder.add_cross_attention = True |
|
|
|
self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers |
|
|
|
self.large = False |
|
|
|
if self.config_encoder.hidden_size != config.vision_width: |
|
|
|
self.visn_fc = nn.Linear(config.vision_width, |
|
|
|
self.config_encoder.hidden_size) |
|
|
|
self.visn_layer_norm = nn.LayerNorm( |
|
|
|
self.config_encoder.hidden_size, eps=1e-12) |
|
|
|
self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) |
|
|
|
self.large = True |
|
|
|
|
|
|
|
def init_distill(self, config): |
|
|
|
self.distill = config.distill |
|
|
|
if self.distill: |
|
|
|
self.visual_encoder_m = self._initialize_clip(config) |
|
|
|
self.text_encoder_m = BertModel( |
|
|
|
self.config_encoder, add_pooling_layer=False) |
|
|
|
self.fusion_encoder_m = FusionModel( |
|
|
|
self.config_fusion, add_pooling_layer=False) |
|
|
|
self.text_decoder_m = BertLMHeadModel(self.config_decoder) |
|
|
|
self.model_pairs = [ |
|
|
|
[self.visual_encoder, self.visual_encoder_m], |
|
|
|
[self.text_encoder, self.text_encoder_m], |
|
|
|
[self.text_decoder, self.text_decoder_m], |
|
|
|
] |
|
|
|
if self.config_encoder.hidden_size != config.vision_width: |
|
|
|
self.visn_fc_m = nn.Linear(config.vision_width, |
|
|
|
self.config_encoder.hidden_size) |
|
|
|
self.visn_layer_norm_m = nn.LayerNorm( |
|
|
|
self.config_encoder.hidden_size, eps=1e-12) |
|
|
|
self.dropout_m = nn.Dropout( |
|
|
|
self.config_encoder.hidden_dropout_prob) |
|
|
|
self.model_pairs.extend( |
|
|
|
[[self.visn_fc, self.visn_fc_m], |
|
|
|
[self.visn_layer_norm, self.visn_layer_norm_m]]) |
|
|
|
self.copy_params() |
|
|
|
self.momentum = 0.995 |
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
def copy_params(self): |
|
|
|
for model_pair in self.model_pairs: |
|
|
|
for param, param_m in zip(model_pair[0].parameters(), |
|
|
|
model_pair[1].parameters()): |
|
|
|
param_m.data.copy_(param.data) # initialize |
|
|
|
param_m.requires_grad = False # not update by gradient |
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
def _momentum_update(self): |
|
|
|
for model_pair in self.model_pairs: |
|
|
|
for param, param_m in zip(model_pair[0].parameters(), |
|
|
|
model_pair[1].parameters()): |
|
|
|
param_m.data = param_m.data * self.momentum + param.data * ( |
|
|
|
1. - self.momentum) |
|
|
|
|
|
|
|
def generation(self, question_states, question_atts): |
|
|
|
encoder_inputs = [question_states, question_atts] |
|
|
|
topk_ids, topk_scores = self.beam_generator.translate_batch( |
|
|
|
encoder_inputs) |
|
|
|
return topk_ids, topk_scores |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _tile(x, dim, n_tile): |
|
|
|
import numpy as np |
|
|
|
init_dim = x.size(dim) |
|
|
|
repeat_idx = [1] * x.dim() |
|
|
|
repeat_idx[dim] = n_tile |
|
|
|
x = x.repeat(*(repeat_idx)) |
|
|
|
order_index = torch.LongTensor( |
|
|
|
np.concatenate( |
|
|
|
[init_dim * np.arange(n_tile) + i for i in range(init_dim)])) |
|
|
|
return torch.index_select(x, dim, order_index.to(x.device)) |
|
|
|
|
|
|
|
def rank_answer(self, question_states, question_atts, answer_ids, |
|
|
|
answer_atts, k): |
|
|
|
|
|
|
|
num_ques = question_states.size(0) |
|
|
|
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token |
|
|
|
|
|
|
|
start_output = self.text_decoder( |
|
|
|
start_ids, |
|
|
|
encoder_hidden_states=question_states, |
|
|
|
encoder_attention_mask=question_atts, |
|
|
|
return_dict=True, |
|
|
|
reduction='none') |
|
|
|
logits = start_output.logits[:, 0, :] # first token's logit |
|
|
|
class MPLUGForImageCaption(MPlug): |
|
|
|
|
|
|
|
# topk_probs: top-k probability |
|
|
|
# topk_ids: [num_question, k] |
|
|
|
answer_first_token = answer_ids[:, 1] |
|
|
|
prob_first_token = F.softmax( |
|
|
|
logits, dim=1).index_select( |
|
|
|
dim=1, index=answer_first_token) |
|
|
|
topk_probs, topk_ids = prob_first_token.topk(k, dim=1) |
|
|
|
|
|
|
|
# answer input: [num_question*k, answer_len] |
|
|
|
input_ids = [] |
|
|
|
input_atts = [] |
|
|
|
for b, topk_id in enumerate(topk_ids): |
|
|
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) |
|
|
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) |
|
|
|
input_ids = torch.cat(input_ids, dim=0) |
|
|
|
input_atts = torch.cat(input_atts, dim=0) |
|
|
|
|
|
|
|
targets_ids = input_ids.masked_fill( |
|
|
|
input_ids == self.tokenizer.pad_token_id, -100) |
|
|
|
|
|
|
|
# repeat encoder's output for top-k answers |
|
|
|
question_states = self._tile(question_states, 0, k) |
|
|
|
question_atts = self._tile(question_atts, 0, k) |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
|
|
self.text_decoder = BertPrefixModel(self.config_decoder) |
|
|
|
self.beam_generator = TextGenerator(config, self.text_decoder) |
|
|
|
|
|
|
|
output = self.text_decoder( |
|
|
|
input_ids, |
|
|
|
attention_mask=input_atts, |
|
|
|
encoder_hidden_states=question_states, |
|
|
|
encoder_attention_mask=question_atts, |
|
|
|
labels=targets_ids, |
|
|
|
return_dict=True, |
|
|
|
reduction='none') |
|
|
|
def beam_search(self, |
|
|
|
image, |
|
|
|
question, |
|
|
|
answer=None, |
|
|
|
train=True, |
|
|
|
out_size=5): |
|
|
|
image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) |
|
|
|
if self.large: |
|
|
|
image_embeds = self.dropout( |
|
|
|
self.visn_layer_norm(self.visn_fc(image_embeds))) |
|
|
|
image_atts = torch.ones( |
|
|
|
image_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
text_output = self.text_encoder( |
|
|
|
question.input_ids, |
|
|
|
attention_mask=question.attention_mask, |
|
|
|
return_dict=True) |
|
|
|
text_embeds = text_output.last_hidden_state |
|
|
|
fusion_output = self.fusion_encoder( |
|
|
|
encoder_embeds=text_embeds, |
|
|
|
attention_mask=question.attention_mask, |
|
|
|
encoder_hidden_states=image_embeds, |
|
|
|
encoder_attention_mask=image_atts, |
|
|
|
return_dict=False) |
|
|
|
image_output, question_output = fusion_output |
|
|
|
question_output = torch.cat([image_output, question_output], 1) |
|
|
|
merge_text_attention = torch.cat([image_atts, question.attention_mask], |
|
|
|
1) |
|
|
|
topk_ids, topk_probs = self.generation( |
|
|
|
question_output, merge_text_attention, out_size=out_size) |
|
|
|
return topk_ids, topk_probs |
|
|
|
|
|
|
|
answer_loss = output.loss |
|
|
|
answer_loss = answer_loss.view(input_ids.size(0), -1) |
|
|
|
def forward(self, |
|
|
|
image, |
|
|
|
question, |
|
|
|
answer=None, |
|
|
|
train=True, |
|
|
|
out_size=5, |
|
|
|
scst=False): |
|
|
|
if (scst): |
|
|
|
return self.beam_search( |
|
|
|
image, question, answer, train=True, out_size=out_size) |
|
|
|
image = image.to(dtype=next(self.parameters()).dtype) |
|
|
|
image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) |
|
|
|
if self.large: |
|
|
|
image_embeds = self.dropout( |
|
|
|
self.visn_layer_norm(self.visn_fc(image_embeds))) |
|
|
|
image_atts = torch.ones( |
|
|
|
image_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
|
|
|
|
# topk_prob: first token probability |
|
|
|
topk_probs = topk_probs.view(-1, 1) |
|
|
|
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) |
|
|
|
if train: |
|
|
|
answer_targets = answer.input_ids.masked_fill( |
|
|
|
answer.input_ids == self.tokenizer.pad_token_id, -100) |
|
|
|
text_output = self.text_encoder( |
|
|
|
question.input_ids, |
|
|
|
attention_mask=question.attention_mask, |
|
|
|
return_dict=True) |
|
|
|
text_embeds = text_output.last_hidden_state |
|
|
|
fusion_output = self.fusion_encoder( |
|
|
|
encoder_embeds=text_embeds, |
|
|
|
attention_mask=question.attention_mask, |
|
|
|
encoder_hidden_states=image_embeds, |
|
|
|
encoder_attention_mask=image_atts, |
|
|
|
return_dict=False) |
|
|
|
|
|
|
|
# re-calculate log probabilities for the answer sequences using chain rule |
|
|
|
log_probs_sum = log_probs.sum(1) |
|
|
|
log_probs_sum = log_probs_sum.view(num_ques, k) |
|
|
|
image_output, question_output = fusion_output |
|
|
|
|
|
|
|
topk_probs = F.softmax(log_probs_sum, dim=-1) |
|
|
|
# get top-k after re-ranking |
|
|
|
topk_probs, rerank_id = topk_probs.topk(k, dim=1) |
|
|
|
topk_ids = torch.gather(topk_ids, 1, rerank_id) |
|
|
|
question_output = torch.cat([image_output, question_output], 1) |
|
|
|
merge_text_attention = torch.cat( |
|
|
|
[image_atts, question.attention_mask], 1) |
|
|
|
|
|
|
|
return topk_ids, topk_probs |
|
|
|
answer_output = self.text_decoder( |
|
|
|
answer.input_ids, |
|
|
|
attention_mask=answer.attention_mask, |
|
|
|
encoder_hidden_states=question_output, |
|
|
|
encoder_attention_mask=merge_text_attention, |
|
|
|
labels=answer_targets, |
|
|
|
return_dict=True, |
|
|
|
reduction='none') |
|
|
|
loss = answer_output.loss |
|
|
|
return loss |
|
|
|
else: |
|
|
|
text_output = self.text_encoder( |
|
|
|
question.input_ids, |
|
|
|
attention_mask=question.attention_mask, |
|
|
|
return_dict=True) |
|
|
|
text_embeds = text_output.last_hidden_state |
|
|
|
fusion_output = self.fusion_encoder( |
|
|
|
encoder_embeds=text_embeds, |
|
|
|
attention_mask=question.attention_mask, |
|
|
|
encoder_hidden_states=image_embeds, |
|
|
|
encoder_attention_mask=image_atts, |
|
|
|
return_dict=False) |
|
|
|
image_output, question_output = fusion_output |
|
|
|
question_output = torch.cat([image_output, question_output], 1) |
|
|
|
merge_text_attention = torch.cat( |
|
|
|
[image_atts, question.attention_mask], 1) |
|
|
|
topk_ids, topk_probs = self.generation(question_output, |
|
|
|
merge_text_attention) |
|
|
|
return topk_ids, topk_probs |