|
|
|
@@ -35,6 +35,8 @@ from transformers.utils import logging |
|
|
|
from .configuration_ofa import OFAConfig |
|
|
|
from .generate import utils |
|
|
|
from .resnet import ResNet |
|
|
|
from .utils.utils import DropPath |
|
|
|
from .vit import vit_base, vit_huge, vit_large, vit_large_336 |
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
@@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList): |
|
|
|
yield m |
|
|
|
|
|
|
|
|
|
|
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False): |
|
|
|
r""" |
|
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
|
|
|
|
Args: |
|
|
|
x (`nn.Modules`): input nn layers. |
|
|
|
drop_prob (`float`): drop path ratio. |
|
|
|
training (`bool`): whether is training or inference. |
|
|
|
""" |
|
|
|
if drop_prob == 0.0 or not training: |
|
|
|
return x |
|
|
|
keep_prob = 1 - drop_prob |
|
|
|
shape = (1, x.shape[1], 1) |
|
|
|
random_tensor = keep_prob + torch.rand( |
|
|
|
shape, dtype=x.dtype, device=x.device) |
|
|
|
random_tensor.floor_() # binarize |
|
|
|
output = x.div(keep_prob) * random_tensor |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class DropPath(nn.Module): |
|
|
|
r""" |
|
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
|
|
|
|
Args: |
|
|
|
drop_prob: drop path ratio. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, drop_prob=None): |
|
|
|
super().__init__() |
|
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
|
return 'p={}'.format(self.drop_prob) |
|
|
|
|
|
|
|
|
|
|
|
class OFAAttention(nn.Module): |
|
|
|
r""" |
|
|
|
Multi-headed attention, with additional implementation for NormFormer. |
|
|
|
@@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel): |
|
|
|
self.padding_idx) |
|
|
|
|
|
|
|
if config.add_type_embedding: |
|
|
|
self.type_embedding = Embedding(2, embed_dim, padding_idx=None) |
|
|
|
if config.use_image_feature: |
|
|
|
self.type_embedding = Embedding(2, embed_dim, padding_idx=None) |
|
|
|
else: |
|
|
|
self.type_embedding = Embedding(1, embed_dim, padding_idx=None) |
|
|
|
else: |
|
|
|
self.type_embedding = None |
|
|
|
|
|
|
|
if config.resnet_type == 'resnet18': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet34': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet50': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet101': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 4, 23], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet152': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 8, 36], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
if config.use_image_feature: |
|
|
|
if config.use_ofasys: |
|
|
|
vit_backbone = { |
|
|
|
'vit_base': vit_base, |
|
|
|
'vit_large': vit_large, |
|
|
|
'vit_large_336': vit_large_336, |
|
|
|
'vit_huge': vit_huge, |
|
|
|
}[config.vit_type] |
|
|
|
self.embed_images = vit_backbone(config.vit_drop_path_rate) |
|
|
|
|
|
|
|
self.image_proj = Linear(1024, embed_dim) |
|
|
|
self.image_proj = Linear(self.embed_images.width, embed_dim) |
|
|
|
|
|
|
|
if config.resnet_model_path: |
|
|
|
else: |
|
|
|
if config.resnet_type == 'resnet18': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet34': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet50': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet101': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 4, 23], |
|
|
|
drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
elif config.resnet_type == 'resnet152': |
|
|
|
self.embed_images = ResNet( |
|
|
|
[3, 8, 36], |
|
|
|
drop_path_rate=config.resnet_drop_path_rate) |
|
|
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
self.image_proj = Linear(1024, embed_dim) |
|
|
|
|
|
|
|
if not config.use_ofasys and config.resnet_model_path: |
|
|
|
print('load resnet {}'.format(config.resnet_model_path)) |
|
|
|
resnet_state_dict = torch.load(config.resnet_model_path) |
|
|
|
self.embed_images.load_state_dict(resnet_state_dict) |
|
|
|
@@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel): |
|
|
|
|
|
|
|
self.embed_positions = Embedding(self.max_source_positions + 2, |
|
|
|
embed_dim) |
|
|
|
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, |
|
|
|
embed_dim) |
|
|
|
self.pos_ln = LayerNorm(embed_dim) |
|
|
|
self.image_pos_ln = LayerNorm(embed_dim) |
|
|
|
|
|
|
|
if config.use_image_feature: |
|
|
|
self.embed_image_positions = Embedding( |
|
|
|
config.image_bucket_size**2 + 1, embed_dim) |
|
|
|
if not config.use_ofasys: |
|
|
|
self.pos_ln = LayerNorm(embed_dim) |
|
|
|
|
|
|
|
if config.use_image_feature: |
|
|
|
self.image_pos_ln = LayerNorm(embed_dim) |
|
|
|
self.pos_scaling = float(embed_dim / self.num_attention_heads |
|
|
|
* config.attn_scale_factor)**-0.5 |
|
|
|
self.pos_q_linear = nn.Linear(embed_dim, embed_dim) |
|
|
|
self.pos_k_linear = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
|
|
if not (config.use_ofasys and config.entangle_position_embedding): |
|
|
|
self.pos_q_linear = nn.Linear(embed_dim, embed_dim) |
|
|
|
self.pos_k_linear = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
|
|
if self.encoder_layerdrop > 0.0: |
|
|
|
self.layers = LayerDropModuleList(p=self.encoder_layerdrop) |
|
|
|
@@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel): |
|
|
|
self.token_bucket_size = config.token_bucket_size |
|
|
|
token_num_rel_dis = 2 * config.token_bucket_size - 1 |
|
|
|
token_rp_bucket = make_token_bucket_position(config.token_bucket_size) |
|
|
|
self.share_attn_bias = config.share_attn_bias |
|
|
|
num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers |
|
|
|
self.token_rel_pos_table_list = nn.ModuleList([ |
|
|
|
Embedding( |
|
|
|
token_num_rel_dis, self.num_attention_heads, zero_init=True) |
|
|
|
for _ in range(config.encoder_layers) |
|
|
|
for _ in range(num_rel_pos_tables) |
|
|
|
]) |
|
|
|
|
|
|
|
self.image_bucket_size = config.image_bucket_size |
|
|
|
image_num_rel_dis = (2 * config.image_bucket_size |
|
|
|
- 1) * (2 * config.image_bucket_size - 1) + 3 |
|
|
|
image_rp_bucket = make_image_bucket_position(config.image_bucket_size, |
|
|
|
image_num_rel_dis) |
|
|
|
self.image_rel_pos_table_list = nn.ModuleList([ |
|
|
|
Embedding( |
|
|
|
image_num_rel_dis, self.num_attention_heads, zero_init=True) |
|
|
|
for _ in range(config.encoder_layers) |
|
|
|
]) |
|
|
|
if config.use_image_feature: |
|
|
|
self.image_bucket_size = config.image_bucket_size |
|
|
|
image_num_rel_dis = (2 * config.image_bucket_size |
|
|
|
- 1) * (2 * config.image_bucket_size - 1) + 3 |
|
|
|
image_rp_bucket = make_image_bucket_position( |
|
|
|
config.image_bucket_size, image_num_rel_dis) |
|
|
|
self.image_rel_pos_table_list = nn.ModuleList([ |
|
|
|
Embedding( |
|
|
|
image_num_rel_dis, |
|
|
|
self.num_attention_heads, |
|
|
|
zero_init=True) for _ in range(num_rel_pos_tables) |
|
|
|
]) |
|
|
|
|
|
|
|
self.register_buffer('image_rp_bucket', image_rp_bucket) |
|
|
|
|
|
|
|
if config.layernorm_embedding: |
|
|
|
self.layernorm_embedding = LayerNorm(embed_dim) |
|
|
|
@@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel): |
|
|
|
self.layernorm_embedding = None |
|
|
|
|
|
|
|
self.register_buffer('token_rp_bucket', token_rp_bucket) |
|
|
|
self.register_buffer('image_rp_bucket', image_rp_bucket) |
|
|
|
self.entangle_position_embedding = config.entangle_position_embedding |
|
|
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
# Initialize weights and apply final processing |
|
|
|
self.post_init() |
|
|
|
self.use_ofasys = config.use_ofasys |
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
|
|
r""" |
|
|
|
@@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel): |
|
|
|
if has_pads: |
|
|
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) |
|
|
|
|
|
|
|
pos_embed = self.pos_ln(pos_embed) |
|
|
|
if patch_images is not None: |
|
|
|
image_pos_embed = self.image_pos_ln(image_pos_embed) |
|
|
|
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) |
|
|
|
if patch_images_2 is not None: |
|
|
|
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) |
|
|
|
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) |
|
|
|
if self.use_ofasys: |
|
|
|
if patch_images is not None: |
|
|
|
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) |
|
|
|
if patch_images_2 is not None: |
|
|
|
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) |
|
|
|
else: |
|
|
|
pos_embed = self.pos_ln(pos_embed) |
|
|
|
if patch_images is not None: |
|
|
|
image_pos_embed = self.image_pos_ln(image_pos_embed) |
|
|
|
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) |
|
|
|
if patch_images_2 is not None: |
|
|
|
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) |
|
|
|
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) |
|
|
|
|
|
|
|
def build_abs_pos_bias(pos_embed): |
|
|
|
batch_size, seq_length = pos_embed.size(0), pos_embed.size(1) |
|
|
|
if not (self.use_ofasys and self.entangle_position_embedding): |
|
|
|
pos_q = self.pos_q_linear(pos_embed).view( |
|
|
|
batch_size, seq_length, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) * self.pos_scaling |
|
|
|
pos_k = self.pos_k_linear(pos_embed).view( |
|
|
|
batch_size, seq_length, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) |
|
|
|
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) |
|
|
|
else: |
|
|
|
abs_pos_bias = torch.zeros( |
|
|
|
batch_size, |
|
|
|
self.num_attention_heads, |
|
|
|
seq_length, |
|
|
|
seq_length, |
|
|
|
dtype=pos_embed.dtype, |
|
|
|
device=pos_embed.device) |
|
|
|
return abs_pos_bias |
|
|
|
|
|
|
|
pos_q = self.pos_q_linear(pos_embed).view( |
|
|
|
x.size(0), x.size(1), self.num_attention_heads, -1).transpose( |
|
|
|
1, 2) * self.pos_scaling |
|
|
|
pos_k = self.pos_k_linear(pos_embed).view( |
|
|
|
x.size(0), x.size(1), self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) |
|
|
|
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) |
|
|
|
abs_pos_bias = build_abs_pos_bias(pos_embed) |
|
|
|
|
|
|
|
# expand attention_mask |
|
|
|
if has_pads: |
|
|
|
@@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel): |
|
|
|
if output_hidden_states: |
|
|
|
encoder_states += (x, ) |
|
|
|
self_attn_bias = abs_pos_bias.clone() |
|
|
|
|
|
|
|
real_idx = 0 if self.share_attn_bias else idx |
|
|
|
|
|
|
|
self_attn_bias[:, :, -input_ids.size(1):, |
|
|
|
-input_ids.size(1):] += self.get_rel_pos_bias( |
|
|
|
input_ids, idx) |
|
|
|
input_ids, real_idx) |
|
|
|
if patch_images_2 is not None: |
|
|
|
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ |
|
|
|
self.get_image_rel_pos_bias(image_position_ids_2, idx) |
|
|
|
self.get_image_rel_pos_bias(image_position_ids_2, real_idx) |
|
|
|
self_attn_bias[:, :, |
|
|
|
image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa |
|
|
|
image_num_patches_2:image_num_patches_2 + image_num_patches] += \ |
|
|
|
self.get_image_rel_pos_bias(image_position_ids, idx) # noqa |
|
|
|
self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa |
|
|
|
elif patch_images is not None: |
|
|
|
self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ |
|
|
|
self.get_image_rel_pos_bias(image_position_ids, idx) |
|
|
|
self.get_image_rel_pos_bias(image_position_ids, real_idx) |
|
|
|
self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) |
|
|
|
|
|
|
|
hidden_outputs = layer( |
|
|
|
@@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel): |
|
|
|
self._future_mask = torch.empty(0) |
|
|
|
self.share_input_output_embed = config.share_decoder_input_output_embed |
|
|
|
self.num_attention_heads = config.decoder_attention_heads |
|
|
|
self.use_ofasys = config.use_ofasys |
|
|
|
self.disable_entangle = config.disable_entangle |
|
|
|
|
|
|
|
if embed_tokens is not None: |
|
|
|
self.embed_tokens = embed_tokens |
|
|
|
@@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel): |
|
|
|
else: |
|
|
|
self.layernorm_embedding = None |
|
|
|
|
|
|
|
if config.use_ofasys: |
|
|
|
if config.add_type_embedding: |
|
|
|
self.type_embedding = Embedding( |
|
|
|
1, self.embed_dim, padding_idx=None) |
|
|
|
else: |
|
|
|
self.type_embedding = None |
|
|
|
|
|
|
|
self.window_size = config.code_image_size // 8 |
|
|
|
|
|
|
|
self.embed_positions = Embedding(self.max_target_positions + 2, |
|
|
|
self.embed_dim) |
|
|
|
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, |
|
|
|
self.embed_dim) |
|
|
|
self.pos_ln = LayerNorm(self.embed_dim) |
|
|
|
self.image_pos_ln = LayerNorm(self.embed_dim) |
|
|
|
|
|
|
|
if not config.use_ofasys: |
|
|
|
self.embed_image_positions = Embedding( |
|
|
|
config.image_bucket_size**2 + 1, self.embed_dim) |
|
|
|
if not config.use_ofasys: |
|
|
|
self.pos_ln = LayerNorm(self.embed_dim) |
|
|
|
self.image_pos_ln = LayerNorm(self.embed_dim) |
|
|
|
self.pos_scaling = float(self.embed_dim / self.num_attention_heads |
|
|
|
* config.attn_scale_factor)**-0.5 |
|
|
|
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
|
|
if not (config.use_ofasys and config.entangle_position_embedding): |
|
|
|
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
|
|
self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
|
|
@@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel): |
|
|
|
self.token_bucket_size = config.token_bucket_size |
|
|
|
token_num_rel_dis = 2 * config.token_bucket_size - 1 |
|
|
|
token_rp_bucket = make_token_bucket_position(config.token_bucket_size) |
|
|
|
|
|
|
|
self.share_attn_bias = config.share_attn_bias |
|
|
|
num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers |
|
|
|
self.token_rel_pos_table_list = nn.ModuleList([ |
|
|
|
Embedding( |
|
|
|
token_num_rel_dis, self.num_attention_heads, zero_init=True) |
|
|
|
for _ in range(config.decoder_layers) |
|
|
|
for _ in range(num_rel_pos_tables) |
|
|
|
]) |
|
|
|
|
|
|
|
self.image_bucket_size = config.image_bucket_size |
|
|
|
image_num_rel_dis = (2 * config.image_bucket_size |
|
|
|
- 1) * (2 * config.image_bucket_size - 1) + 3 |
|
|
|
image_rp_bucket = make_image_bucket_position(config.image_bucket_size, |
|
|
|
image_num_rel_dis) |
|
|
|
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ |
|
|
|
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa |
|
|
|
image_position_idx = torch.cat( |
|
|
|
[torch.tensor([0]), image_position_idx.view(-1)]) |
|
|
|
image_position_idx = torch.cat( |
|
|
|
[image_position_idx, |
|
|
|
torch.tensor([1024] * 768)]) |
|
|
|
self.image_rel_pos_table_list = nn.ModuleList([ |
|
|
|
Embedding( |
|
|
|
image_num_rel_dis, self.num_attention_heads, zero_init=True) |
|
|
|
for _ in range(config.decoder_layers) |
|
|
|
]) |
|
|
|
if config.use_image_feature: |
|
|
|
if not config.use_ofasys: |
|
|
|
self.image_bucket_size = config.image_bucket_size |
|
|
|
image_num_rel_dis = (2 * config.image_bucket_size - 1) * ( |
|
|
|
2 * config.image_bucket_size - 1) + 3 |
|
|
|
image_rp_bucket = make_image_bucket_position( |
|
|
|
config.image_bucket_size, image_num_rel_dis) |
|
|
|
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ |
|
|
|
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa |
|
|
|
image_position_idx = torch.cat( |
|
|
|
[torch.tensor([0]), |
|
|
|
image_position_idx.view(-1)]) |
|
|
|
image_position_idx = torch.cat( |
|
|
|
[image_position_idx, |
|
|
|
torch.tensor([1024] * 768)]) |
|
|
|
self.register_buffer('image_position_idx', image_position_idx) |
|
|
|
|
|
|
|
self.image_rel_pos_table_list = nn.ModuleList([ |
|
|
|
Embedding( |
|
|
|
image_num_rel_dis, |
|
|
|
self.num_attention_heads, |
|
|
|
zero_init=True) for _ in range(num_rel_pos_tables) |
|
|
|
]) |
|
|
|
self.register_buffer('image_rp_bucket', image_rp_bucket) |
|
|
|
|
|
|
|
self.register_buffer('token_rp_bucket', token_rp_bucket) |
|
|
|
self.register_buffer('image_rp_bucket', image_rp_bucket) |
|
|
|
self.register_buffer('image_position_idx', image_position_idx) |
|
|
|
self.entangle_position_embedding = config.entangle_position_embedding |
|
|
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
@@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel): |
|
|
|
|
|
|
|
batch_size = tgt_pos_embed.size(0) |
|
|
|
tgt_len = tgt_pos_embed.size(1) |
|
|
|
tgt_pos_embed = self.image_pos_ln( |
|
|
|
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) |
|
|
|
if not self.use_ofasys: |
|
|
|
tgt_pos_embed = self.image_pos_ln( |
|
|
|
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) |
|
|
|
|
|
|
|
if src_pos_embed is not None: |
|
|
|
src_len = src_pos_embed.size(1) |
|
|
|
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( |
|
|
|
batch_size, tgt_len, self.num_attention_heads, -1).transpose( |
|
|
|
1, 2) * self.pos_scaling |
|
|
|
pos_k = self.cross_pos_k_linear(src_pos_embed).view( |
|
|
|
batch_size, src_len, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) |
|
|
|
if not (self.entangle_position_embedding and self.use_ofasys): |
|
|
|
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( |
|
|
|
batch_size, tgt_len, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) * self.pos_scaling |
|
|
|
pos_k = self.cross_pos_k_linear(src_pos_embed).view( |
|
|
|
batch_size, src_len, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) |
|
|
|
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) |
|
|
|
else: |
|
|
|
abs_pos_bias = torch.zeros( |
|
|
|
batch_size, |
|
|
|
self.num_attention_heads, |
|
|
|
tgt_len, |
|
|
|
src_len, |
|
|
|
dtype=tgt_pos_embed.dtype, |
|
|
|
device=tgt_pos_embed.device) |
|
|
|
else: |
|
|
|
src_len = tgt_pos_embed.size(1) |
|
|
|
pos_q = self.self_pos_q_linear(tgt_pos_embed).view( |
|
|
|
batch_size, tgt_len, self.num_attention_heads, -1).transpose( |
|
|
|
1, 2) * self.pos_scaling |
|
|
|
pos_k = self.self_pos_k_linear(tgt_pos_embed).view( |
|
|
|
batch_size, src_len, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) |
|
|
|
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) |
|
|
|
# batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1) |
|
|
|
if not (self.entangle_position_embedding and self.use_ofasys): |
|
|
|
pos_q = self.self_pos_q_linear(tgt_pos_embed).view( |
|
|
|
batch_size, tgt_len, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) * self.pos_scaling |
|
|
|
pos_k = self.self_pos_k_linear(tgt_pos_embed).view( |
|
|
|
batch_size, tgt_len, self.num_attention_heads, |
|
|
|
-1).transpose(1, 2) |
|
|
|
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) |
|
|
|
else: |
|
|
|
abs_pos_bias = torch.zeros( |
|
|
|
batch_size, |
|
|
|
self.num_attention_heads, |
|
|
|
tgt_len, |
|
|
|
tgt_len, |
|
|
|
dtype=tgt_pos_embed.dtype, |
|
|
|
device=tgt_pos_embed.device) |
|
|
|
|
|
|
|
return abs_pos_bias |
|
|
|
|
|
|
|
@@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel): |
|
|
|
past_key_values) > 0 else None |
|
|
|
|
|
|
|
self_attn_bias = self_abs_pos_bias.clone() |
|
|
|
real_idx = 0 if self.share_attn_bias else idx |
|
|
|
if code_masks is None or not code_masks.any(): |
|
|
|
self_attn_bias += self.get_rel_pos_bias( |
|
|
|
all_prev_output_tokens, idx).unsqueeze(0) |
|
|
|
all_prev_output_tokens, real_idx).unsqueeze(0) |
|
|
|
elif code_masks is not None and code_masks.all(): |
|
|
|
self_attn_bias += self.get_image_rel_pos_bias( |
|
|
|
all_prev_output_tokens, idx).unsqueeze(0) |
|
|
|
all_prev_output_tokens, real_idx).unsqueeze(0) |
|
|
|
else: |
|
|
|
self_attn_bias[~code_masks] += self.get_rel_pos_bias( |
|
|
|
all_prev_output_tokens, idx).unsqueeze(0) |
|
|
|
all_prev_output_tokens, real_idx).unsqueeze(0) |
|
|
|
self_attn_bias[code_masks] += self.get_image_rel_pos_bias( |
|
|
|
all_prev_output_tokens, idx).unsqueeze(0) |
|
|
|
all_prev_output_tokens, real_idx).unsqueeze(0) |
|
|
|
self_attn_bias = self_attn_bias.reshape( |
|
|
|
-1, |
|
|
|
*self_attn_bias.size()[-2:]) |
|
|
|
@@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel): |
|
|
|
|
|
|
|
self.encoder = OFAEncoder(config, shared) |
|
|
|
self.decoder = OFADecoder(config, shared) |
|
|
|
self.use_ofasys = config.use_ofasys |
|
|
|
|
|
|
|
# Initialize weights and apply final processing |
|
|
|
self.post_init() |
|
|
|
|