| @@ -73,7 +73,7 @@ class Pipelines(object): | |||
| asr_inference = 'asr-inference' | |||
| # multi-modal tasks | |||
| image_caption = 'image-captioning' | |||
| image_captioning = 'image-captioning' | |||
| multi_modal_embedding = 'multi-modal-embedding' | |||
| visual_question_answering = 'visual-question-answering' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| @@ -1,5 +1,5 @@ | |||
| from .clip.clip_model import CLIPForMultiModalEmbedding | |||
| from .image_captioning_model import OfaForImageCaptioning | |||
| from .imagen.imagen_model import ImagenForTextToImageSynthesis | |||
| from .mplug_for_visual_question_answering import \ | |||
| MPlugForVisualQuestionAnswering | |||
| from .ofa_for_image_captioning_model import OfaForImageCaptioning | |||
| @@ -0,0 +1,2 @@ | |||
| from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | |||
| from .tokenization_ofa import OFATokenizer | |||
| @@ -0,0 +1,194 @@ | |||
| # Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ OFA model configuration""" | |||
| import warnings | |||
| from transformers import PretrainedConfig | |||
| from transformers.utils import logging | |||
| logger = logging.get_logger(__name__) | |||
| OFA_PRETRAINED_CONFIG_ARCHIVE_MAP = { | |||
| 'ofa-medium': 'https://huggingface.co/ofa-base/resolve/main/config.json', | |||
| # OFA models are implemeted to be compatible with both huggingface | |||
| # and modelscope frameworks. For all OFA models available on huggingface, | |||
| # please refer to https://huggingface.co/models?filter=ofa | |||
| } | |||
| class OFAConfig(PretrainedConfig): | |||
| r""" | |||
| This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA | |||
| model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |||
| defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) | |||
| architecture. | |||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||
| documentation from [`PretrainedConfig`] for more information. | |||
| Args: | |||
| vocab_size (`int`, *optional*, defaults to 50265): | |||
| Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the | |||
| `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. | |||
| d_model (`int`, *optional*, defaults to 1024): | |||
| Dimension of the layers and the pooler layer. | |||
| encoder_layers (`int`, *optional*, defaults to 12): | |||
| Number of encoder layers. | |||
| decoder_layers (`int`, *optional*, defaults to 12): | |||
| Number of decoder layers. | |||
| encoder_attention_heads (`int`, *optional*, defaults to 16): | |||
| Number of attention heads for each attention layer in the Transformer encoder. | |||
| decoder_attention_heads (`int`, *optional*, defaults to 16): | |||
| Number of attention heads for each attention layer in the Transformer decoder. | |||
| decoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
| encoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
| activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): | |||
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |||
| `"relu"`, `"silu"` and `"gelu_new"` are supported. | |||
| dropout (`float`, *optional*, defaults to 0.1): | |||
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||
| attention_dropout (`float`, *optional*, defaults to 0.0): | |||
| The dropout ratio for the attention probabilities. | |||
| activation_dropout (`float`, *optional*, defaults to 0.0): | |||
| The dropout ratio for activations inside the fully connected layer. | |||
| classifier_dropout (`float`, *optional*, defaults to 0.0): | |||
| The dropout ratio for classifier. | |||
| max_position_embeddings (`int`, *optional*, defaults to 1024): | |||
| The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
| just in case (e.g., 512 or 1024 or 2048). | |||
| init_std (`float`, *optional*, defaults to 0.02): | |||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
| encoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
| The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
| for more details. | |||
| decoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
| The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
| for more details. | |||
| use_cache (`bool`, *optional*, defaults to `True`): | |||
| Whether or not the model should return the last key/values attentions (not used by all models). | |||
| """ | |||
| model_type = 'ofa' | |||
| keys_to_ignore_at_inference = ['past_key_values'] | |||
| attribute_map = { | |||
| 'num_attention_heads': 'encoder_attention_heads', | |||
| 'hidden_size': 'd_model' | |||
| } | |||
| def __init__(self, | |||
| vocab_size=59457, | |||
| max_position_embeddings=1024, | |||
| encoder_layers=4, | |||
| encoder_ffn_dim=512 * 4, | |||
| encoder_attention_heads=8, | |||
| decoder_layers=4, | |||
| decoder_ffn_dim=512 * 4, | |||
| decoder_attention_heads=8, | |||
| encoder_layerdrop=0.0, | |||
| decoder_layerdrop=0.0, | |||
| use_cache=True, | |||
| is_encoder_decoder=True, | |||
| activation_function='gelu', | |||
| d_model=512, | |||
| dropout=0.1, | |||
| attention_dropout=0.0, | |||
| activation_dropout=0.0, | |||
| init_std=0.02, | |||
| classifier_dropout=0.0, | |||
| scale_embedding=False, | |||
| pad_token_id=1, | |||
| bos_token_id=0, | |||
| decoder_start_token_id=0, | |||
| eos_token_id=2, | |||
| forced_eos_token_id=2, | |||
| encoder_normalize_before=True, | |||
| decoder_normalize_before=True, | |||
| normformer=True, | |||
| encoder_drop_path_rate=0.0, | |||
| decoder_drop_path_rate=0.0, | |||
| layernorm_embedding=True, | |||
| patch_layernorm_embedding=True, | |||
| resnet_type='resnet101', | |||
| resnet_model_path=None, | |||
| resnet_drop_path_rate=0.0, | |||
| token_bucket_size=256, | |||
| image_bucket_size=42, | |||
| add_type_embedding=True, | |||
| share_decoder_input_output_embed=True, | |||
| attn_scale_factor=2., | |||
| code_layernorm_embedding=True, | |||
| code_image_size=128, | |||
| entangle_position_embedding=False, | |||
| **kwargs): | |||
| self.vocab_size = vocab_size | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.d_model = d_model | |||
| self.encoder_ffn_dim = encoder_ffn_dim | |||
| self.encoder_layers = encoder_layers | |||
| self.encoder_attention_heads = encoder_attention_heads | |||
| self.decoder_ffn_dim = decoder_ffn_dim | |||
| self.decoder_layers = decoder_layers | |||
| self.decoder_attention_heads = decoder_attention_heads | |||
| self.dropout = dropout | |||
| self.attention_dropout = attention_dropout | |||
| self.activation_dropout = activation_dropout | |||
| self.activation_function = activation_function | |||
| self.init_std = init_std | |||
| self.encoder_layerdrop = encoder_layerdrop | |||
| self.decoder_layerdrop = decoder_layerdrop | |||
| self.classifier_dropout = classifier_dropout | |||
| self.use_cache = use_cache | |||
| self.num_hidden_layers = encoder_layers | |||
| self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True | |||
| self.encoder_normalize_before = encoder_normalize_before | |||
| self.decoder_normalize_before = decoder_normalize_before | |||
| self.normformer = normformer | |||
| self.encoder_drop_path_rate = encoder_drop_path_rate | |||
| self.decoder_drop_path_rate = decoder_drop_path_rate | |||
| self.layernorm_embedding = layernorm_embedding | |||
| self.patch_layernorm_embedding = patch_layernorm_embedding | |||
| self.resnet_type = resnet_type | |||
| self.resnet_model_path = resnet_model_path | |||
| self.resnet_drop_path_rate = resnet_drop_path_rate | |||
| self.token_bucket_size = token_bucket_size | |||
| self.image_bucket_size = image_bucket_size | |||
| self.add_type_embedding = add_type_embedding | |||
| self.share_decoder_input_output_embed = share_decoder_input_output_embed | |||
| self.attn_scale_factor = attn_scale_factor | |||
| self.code_layernorm_embedding = code_layernorm_embedding | |||
| self.code_image_size = code_image_size | |||
| self.entangle_position_embedding = entangle_position_embedding | |||
| super().__init__( | |||
| pad_token_id=pad_token_id, | |||
| bos_token_id=bos_token_id, | |||
| eos_token_id=eos_token_id, | |||
| is_encoder_decoder=is_encoder_decoder, | |||
| decoder_start_token_id=decoder_start_token_id, | |||
| forced_eos_token_id=forced_eos_token_id, | |||
| **kwargs, | |||
| ) | |||
| # ensure backward compatibility for BART CNN models | |||
| if self.forced_bos_token_id is None and kwargs.get( | |||
| 'force_bos_token_to_be_generated', False): | |||
| self.forced_bos_token_id = self.bos_token_id | |||
| warnings.warn( | |||
| f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' | |||
| 'The config can simply be saved and uploaded again to be fixed.' | |||
| ) | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||
| # | |||
| # This source code is licensed under the MIT license which can be found at | |||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
| import uuid | |||
| from typing import Dict, Optional | |||
| from torch import Tensor | |||
| class FairseqIncrementalState(object): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.init_incremental_state() | |||
| def init_incremental_state(self): | |||
| self._incremental_state_id = str(uuid.uuid4()) | |||
| def _get_full_incremental_state_key(self, key: str) -> str: | |||
| return '{}.{}'.format(self._incremental_state_id, key) | |||
| def get_incremental_state( | |||
| self, | |||
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |||
| key: str, | |||
| ) -> Optional[Dict[str, Optional[Tensor]]]: | |||
| """Helper for getting incremental state for an nn.Module.""" | |||
| full_key = self._get_full_incremental_state_key(key) | |||
| if incremental_state is None or full_key not in incremental_state: | |||
| return None | |||
| return incremental_state[full_key] | |||
| def set_incremental_state( | |||
| self, | |||
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |||
| key: str, | |||
| value: Dict[str, Optional[Tensor]], | |||
| ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: | |||
| """Helper for setting incremental state for an nn.Module.""" | |||
| if incremental_state is not None: | |||
| full_key = self._get_full_incremental_state_key(key) | |||
| incremental_state[full_key] = value | |||
| return incremental_state | |||
| def with_incremental_state(cls): | |||
| cls.__bases__ = (FairseqIncrementalState, ) + tuple( | |||
| b for b in cls.__bases__ if b != FairseqIncrementalState) | |||
| return cls | |||
| @@ -0,0 +1,510 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||
| # | |||
| # This source code is licensed under the MIT license which can be found at | |||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
| import math | |||
| from typing import Dict, Optional, Tuple | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from fairseq import utils | |||
| from fairseq.incremental_decoding_utils import with_incremental_state | |||
| from fairseq.modules.fairseq_dropout import FairseqDropout | |||
| from fairseq.modules.quant_noise import quant_noise | |||
| from torch import Tensor, nn | |||
| from torch.nn import Parameter | |||
| @with_incremental_state | |||
| class MultiheadAttention(nn.Module): | |||
| """Multi-headed attention. | |||
| See "Attention Is All You Need" for more details. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| embed_dim, | |||
| num_heads, | |||
| kdim=None, | |||
| vdim=None, | |||
| dropout=0.0, | |||
| bias=True, | |||
| add_bias_kv=False, | |||
| add_zero_attn=False, | |||
| self_attention=False, | |||
| encoder_decoder_attention=False, | |||
| q_noise=0.0, | |||
| qn_block_size=8, | |||
| ): | |||
| super().__init__() | |||
| self.embed_dim = embed_dim | |||
| self.kdim = kdim if kdim is not None else embed_dim | |||
| self.vdim = vdim if vdim is not None else embed_dim | |||
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim | |||
| self.num_heads = num_heads | |||
| self.dropout_module = FairseqDropout( | |||
| dropout, module_name=self.__class__.__name__) | |||
| self.head_dim = embed_dim // num_heads | |||
| assert (self.head_dim * num_heads == self.embed_dim | |||
| ), 'embed_dim must be divisible by num_heads' | |||
| self.scaling = self.head_dim**-0.5 | |||
| self.self_attention = self_attention | |||
| self.encoder_decoder_attention = encoder_decoder_attention | |||
| assert not self.self_attention or self.qkv_same_dim, ( | |||
| 'Self-attention requires query, key and ' | |||
| 'value to be of the same size') | |||
| self.k_proj = quant_noise( | |||
| nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
| self.v_proj = quant_noise( | |||
| nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
| self.q_proj = quant_noise( | |||
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
| self.out_proj = quant_noise( | |||
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) | |||
| if add_bias_kv: | |||
| self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) | |||
| self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) | |||
| else: | |||
| self.bias_k = self.bias_v = None | |||
| self.add_zero_attn = add_zero_attn | |||
| self.reset_parameters() | |||
| self.onnx_trace = False | |||
| def prepare_for_onnx_export_(self): | |||
| self.onnx_trace = True | |||
| def reset_parameters(self): | |||
| if self.qkv_same_dim: | |||
| # Empirically observed the convergence to be much better with | |||
| # the scaled initialization | |||
| nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) | |||
| nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) | |||
| nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) | |||
| else: | |||
| nn.init.xavier_uniform_(self.k_proj.weight) | |||
| nn.init.xavier_uniform_(self.v_proj.weight) | |||
| nn.init.xavier_uniform_(self.q_proj.weight) | |||
| nn.init.xavier_uniform_(self.out_proj.weight) | |||
| if self.out_proj.bias is not None: | |||
| nn.init.constant_(self.out_proj.bias, 0.0) | |||
| if self.bias_k is not None: | |||
| nn.init.xavier_normal_(self.bias_k) | |||
| if self.bias_v is not None: | |||
| nn.init.xavier_normal_(self.bias_v) | |||
| def forward( | |||
| self, | |||
| query, | |||
| key: Optional[Tensor], | |||
| value: Optional[Tensor], | |||
| key_padding_mask: Optional[Tensor] = None, | |||
| incremental_state: Optional[Dict[str, Dict[str, | |||
| Optional[Tensor]]]] = None, | |||
| need_weights: bool = True, | |||
| static_kv: bool = False, | |||
| attn_mask: Optional[Tensor] = None, | |||
| before_softmax: bool = False, | |||
| need_head_weights: bool = False, | |||
| ) -> Tuple[Tensor, Optional[Tensor]]: | |||
| """Input shape: Time x Batch x Channel | |||
| Args: | |||
| key_padding_mask (ByteTensor, optional): mask to exclude | |||
| keys that are pads, of shape `(batch, src_len)`, where | |||
| padding elements are indicated by 1s. | |||
| need_weights (bool, optional): return the attention weights, | |||
| averaged over heads (default: False). | |||
| attn_mask (ByteTensor, optional): typically used to | |||
| implement causal attention, where the mask prevents the | |||
| attention from looking forward in time (default: None). | |||
| before_softmax (bool, optional): return the raw attention | |||
| weights and values before the attention softmax. | |||
| need_head_weights (bool, optional): return the attention | |||
| weights for each head. Implies *need_weights*. Default: | |||
| return the average attention weights over all heads. | |||
| """ | |||
| if need_head_weights: | |||
| need_weights = True | |||
| is_tpu = query.device.type == 'xla' | |||
| tgt_len, bsz, embed_dim = query.size() | |||
| src_len = tgt_len | |||
| assert embed_dim == self.embed_dim, f'query dim {embed_dim} != {self.embed_dim}' | |||
| assert list(query.size()) == [tgt_len, bsz, embed_dim] | |||
| if key is not None: | |||
| src_len, key_bsz, _ = key.size() | |||
| if not torch.jit.is_scripting(): | |||
| assert key_bsz == bsz | |||
| assert value is not None | |||
| assert src_len, bsz == value.shape[:2] | |||
| if (not self.onnx_trace | |||
| and not is_tpu # don't use PyTorch version on TPUs | |||
| and incremental_state is None and not static_kv | |||
| # A workaround for quantization to work. Otherwise JIT compilation | |||
| # treats bias in linear module as method. | |||
| and not torch.jit.is_scripting()): | |||
| assert key is not None and value is not None | |||
| return F.multi_head_attention_forward( | |||
| query, | |||
| key, | |||
| value, | |||
| self.embed_dim, | |||
| self.num_heads, | |||
| torch.empty([0]), | |||
| torch.cat( | |||
| (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), | |||
| self.bias_k, | |||
| self.bias_v, | |||
| self.add_zero_attn, | |||
| self.dropout_module.p, | |||
| self.out_proj.weight, | |||
| self.out_proj.bias, | |||
| self.training or self.dropout_module.apply_during_inference, | |||
| key_padding_mask, | |||
| need_weights, | |||
| attn_mask, | |||
| use_separate_proj_weight=True, | |||
| q_proj_weight=self.q_proj.weight, | |||
| k_proj_weight=self.k_proj.weight, | |||
| v_proj_weight=self.v_proj.weight, | |||
| ) | |||
| if incremental_state is not None: | |||
| saved_state = self._get_input_buffer(incremental_state) | |||
| if saved_state is not None and 'prev_key' in saved_state: | |||
| # previous time steps are cached - no need to recompute | |||
| # key and value if they are static | |||
| if static_kv: | |||
| assert self.encoder_decoder_attention and not self.self_attention | |||
| key = value = None | |||
| else: | |||
| saved_state = None | |||
| if self.self_attention: | |||
| q = self.q_proj(query) | |||
| k = self.k_proj(query) | |||
| v = self.v_proj(query) | |||
| elif self.encoder_decoder_attention: | |||
| # encoder-decoder attention | |||
| q = self.q_proj(query) | |||
| if key is None: | |||
| assert value is None | |||
| k = v = None | |||
| else: | |||
| k = self.k_proj(key) | |||
| v = self.v_proj(key) | |||
| else: | |||
| assert key is not None and value is not None | |||
| q = self.q_proj(query) | |||
| k = self.k_proj(key) | |||
| v = self.v_proj(value) | |||
| q *= self.scaling | |||
| if self.bias_k is not None: | |||
| assert self.bias_v is not None | |||
| k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) | |||
| v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) | |||
| if attn_mask is not None: | |||
| attn_mask = torch.cat( | |||
| [attn_mask, | |||
| attn_mask.new_zeros(attn_mask.size(0), 1)], | |||
| dim=1) | |||
| if key_padding_mask is not None: | |||
| key_padding_mask = torch.cat( | |||
| [ | |||
| key_padding_mask, | |||
| key_padding_mask.new_zeros( | |||
| key_padding_mask.size(0), 1), | |||
| ], | |||
| dim=1, | |||
| ) | |||
| q = ( | |||
| q.contiguous().view(tgt_len, bsz * self.num_heads, | |||
| self.head_dim).transpose(0, 1)) | |||
| if k is not None: | |||
| k = ( | |||
| k.contiguous().view(-1, bsz * self.num_heads, | |||
| self.head_dim).transpose(0, 1)) | |||
| if v is not None: | |||
| v = ( | |||
| v.contiguous().view(-1, bsz * self.num_heads, | |||
| self.head_dim).transpose(0, 1)) | |||
| if saved_state is not None: | |||
| # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) | |||
| if 'prev_key' in saved_state: | |||
| _prev_key = saved_state['prev_key'] | |||
| assert _prev_key is not None | |||
| prev_key = _prev_key.view(bsz * self.num_heads, -1, | |||
| self.head_dim) | |||
| if static_kv: | |||
| k = prev_key | |||
| else: | |||
| assert k is not None | |||
| k = torch.cat([prev_key, k], dim=1) | |||
| src_len = k.size(1) | |||
| if 'prev_value' in saved_state: | |||
| _prev_value = saved_state['prev_value'] | |||
| assert _prev_value is not None | |||
| prev_value = _prev_value.view(bsz * self.num_heads, -1, | |||
| self.head_dim) | |||
| if static_kv: | |||
| v = prev_value | |||
| else: | |||
| assert v is not None | |||
| v = torch.cat([prev_value, v], dim=1) | |||
| prev_key_padding_mask: Optional[Tensor] = None | |||
| if 'prev_key_padding_mask' in saved_state: | |||
| prev_key_padding_mask = saved_state['prev_key_padding_mask'] | |||
| assert k is not None and v is not None | |||
| key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( | |||
| key_padding_mask=key_padding_mask, | |||
| prev_key_padding_mask=prev_key_padding_mask, | |||
| batch_size=bsz, | |||
| src_len=k.size(1), | |||
| static_kv=static_kv, | |||
| ) | |||
| saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, | |||
| self.head_dim) | |||
| saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, | |||
| self.head_dim) | |||
| saved_state['prev_key_padding_mask'] = key_padding_mask | |||
| # In this branch incremental_state is never None | |||
| assert incremental_state is not None | |||
| incremental_state = self._set_input_buffer(incremental_state, | |||
| saved_state) | |||
| assert k is not None | |||
| assert k.size(1) == src_len | |||
| # This is part of a workaround to get around fork/join parallelism | |||
| # not supporting Optional types. | |||
| if key_padding_mask is not None and key_padding_mask.dim() == 0: | |||
| key_padding_mask = None | |||
| if key_padding_mask is not None: | |||
| assert key_padding_mask.size(0) == bsz | |||
| assert key_padding_mask.size(1) == src_len | |||
| if self.add_zero_attn: | |||
| assert v is not None | |||
| src_len += 1 | |||
| k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], | |||
| dim=1) | |||
| v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], | |||
| dim=1) | |||
| if attn_mask is not None: | |||
| attn_mask = torch.cat( | |||
| [attn_mask, | |||
| attn_mask.new_zeros(attn_mask.size(0), 1)], | |||
| dim=1) | |||
| if key_padding_mask is not None: | |||
| key_padding_mask = torch.cat( | |||
| [ | |||
| key_padding_mask, | |||
| torch.zeros(key_padding_mask.size(0), | |||
| 1).type_as(key_padding_mask), | |||
| ], | |||
| dim=1, | |||
| ) | |||
| attn_weights = torch.bmm(q, k.transpose(1, 2)) | |||
| attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, | |||
| bsz) | |||
| assert list( | |||
| attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] | |||
| if attn_mask is not None: | |||
| attn_mask = attn_mask.unsqueeze(0) | |||
| if self.onnx_trace: | |||
| attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) | |||
| attn_weights += attn_mask | |||
| if key_padding_mask is not None: | |||
| # don't attend to padding symbols | |||
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, | |||
| src_len) | |||
| if not is_tpu: | |||
| attn_weights = attn_weights.masked_fill( | |||
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), | |||
| float('-inf'), | |||
| ) | |||
| else: | |||
| attn_weights = attn_weights.transpose(0, 2) | |||
| attn_weights = attn_weights.masked_fill( | |||
| key_padding_mask, float('-inf')) | |||
| attn_weights = attn_weights.transpose(0, 2) | |||
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, | |||
| src_len) | |||
| if before_softmax: | |||
| return attn_weights, v | |||
| attn_weights_float = utils.softmax( | |||
| attn_weights, dim=-1, onnx_trace=self.onnx_trace) | |||
| attn_weights = attn_weights_float.type_as(attn_weights) | |||
| attn_probs = self.dropout_module(attn_weights) | |||
| assert v is not None | |||
| attn = torch.bmm(attn_probs, v) | |||
| assert list( | |||
| attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] | |||
| if self.onnx_trace and attn.size(1) == 1: | |||
| # when ONNX tracing a single decoder step (sequence length == 1) | |||
| # the transpose is a no-op copy before view, thus unnecessary | |||
| attn = attn.contiguous().view(tgt_len, bsz, embed_dim) | |||
| else: | |||
| attn = attn.transpose(0, | |||
| 1).contiguous().view(tgt_len, bsz, embed_dim) | |||
| attn = self.out_proj(attn) | |||
| attn_weights: Optional[Tensor] = None | |||
| if need_weights: | |||
| attn_weights = attn_weights_float.view(bsz, self.num_heads, | |||
| tgt_len, | |||
| src_len).transpose(1, 0) | |||
| if not need_head_weights: | |||
| # average attention weights over heads | |||
| attn_weights = attn_weights.mean(dim=0) | |||
| return attn, attn_weights | |||
| @staticmethod | |||
| def _append_prev_key_padding_mask( | |||
| key_padding_mask: Optional[Tensor], | |||
| prev_key_padding_mask: Optional[Tensor], | |||
| batch_size: int, | |||
| src_len: int, | |||
| static_kv: bool, | |||
| ) -> Optional[Tensor]: | |||
| # saved key padding masks have shape (bsz, seq_len) | |||
| if prev_key_padding_mask is not None and static_kv: | |||
| new_key_padding_mask = prev_key_padding_mask | |||
| elif prev_key_padding_mask is not None and key_padding_mask is not None: | |||
| new_key_padding_mask = torch.cat( | |||
| [prev_key_padding_mask.float(), | |||
| key_padding_mask.float()], | |||
| dim=1) | |||
| # During incremental decoding, as the padding token enters and | |||
| # leaves the frame, there will be a time when prev or current | |||
| # is None | |||
| elif prev_key_padding_mask is not None: | |||
| if src_len > prev_key_padding_mask.size(1): | |||
| filler = torch.zeros( | |||
| (batch_size, src_len - prev_key_padding_mask.size(1)), | |||
| device=prev_key_padding_mask.device, | |||
| ) | |||
| new_key_padding_mask = torch.cat( | |||
| [prev_key_padding_mask.float(), | |||
| filler.float()], dim=1) | |||
| else: | |||
| new_key_padding_mask = prev_key_padding_mask.float() | |||
| elif key_padding_mask is not None: | |||
| if src_len > key_padding_mask.size(1): | |||
| filler = torch.zeros( | |||
| (batch_size, src_len - key_padding_mask.size(1)), | |||
| device=key_padding_mask.device, | |||
| ) | |||
| new_key_padding_mask = torch.cat( | |||
| [filler.float(), key_padding_mask.float()], dim=1) | |||
| else: | |||
| new_key_padding_mask = key_padding_mask.float() | |||
| else: | |||
| new_key_padding_mask = prev_key_padding_mask | |||
| return new_key_padding_mask | |||
| @torch.jit.export | |||
| def reorder_incremental_state( | |||
| self, | |||
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||
| new_order: Tensor, | |||
| ): | |||
| """Reorder buffered internal state (for incremental generation).""" | |||
| input_buffer = self._get_input_buffer(incremental_state) | |||
| if input_buffer is not None: | |||
| for k in input_buffer.keys(): | |||
| input_buffer_k = input_buffer[k] | |||
| if input_buffer_k is not None: | |||
| if self.encoder_decoder_attention and input_buffer_k.size( | |||
| 0) == new_order.size(0): | |||
| break | |||
| input_buffer[k] = input_buffer_k.index_select(0, new_order) | |||
| incremental_state = self._set_input_buffer(incremental_state, | |||
| input_buffer) | |||
| return incremental_state | |||
| def _get_input_buffer( | |||
| self, incremental_state: Optional[Dict[str, Dict[str, | |||
| Optional[Tensor]]]] | |||
| ) -> Dict[str, Optional[Tensor]]: | |||
| result = self.get_incremental_state(incremental_state, 'attn_state') | |||
| if result is not None: | |||
| return result | |||
| else: | |||
| empty_result: Dict[str, Optional[Tensor]] = {} | |||
| return empty_result | |||
| def _set_input_buffer( | |||
| self, | |||
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||
| buffer: Dict[str, Optional[Tensor]], | |||
| ): | |||
| return self.set_incremental_state(incremental_state, 'attn_state', | |||
| buffer) | |||
| def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, | |||
| bsz: int): | |||
| return attn_weights | |||
| def upgrade_state_dict_named(self, state_dict, name): | |||
| prefix = name + '.' if name != '' else '' | |||
| items_to_add = {} | |||
| keys_to_remove = [] | |||
| for k in state_dict.keys(): | |||
| if k.endswith(prefix + 'in_proj_weight'): | |||
| # in_proj_weight used to be q + k + v with same dimensions | |||
| dim = int(state_dict[k].shape[0] / 3) | |||
| items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim] | |||
| items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2 | |||
| * dim] | |||
| items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2 | |||
| * dim:] | |||
| keys_to_remove.append(k) | |||
| k_bias = prefix + 'in_proj_bias' | |||
| if k_bias in state_dict.keys(): | |||
| dim = int(state_dict[k].shape[0] / 3) | |||
| items_to_add[prefix | |||
| + 'q_proj.bias'] = state_dict[k_bias][:dim] | |||
| items_to_add[prefix | |||
| + 'k_proj.bias'] = state_dict[k_bias][dim:2 | |||
| * dim] | |||
| items_to_add[prefix | |||
| + 'v_proj.bias'] = state_dict[k_bias][2 | |||
| * dim:] | |||
| keys_to_remove.append(prefix + 'in_proj_bias') | |||
| for k in keys_to_remove: | |||
| del state_dict[k] | |||
| for key, value in items_to_add.items(): | |||
| state_dict[key] = value | |||
| @@ -0,0 +1,155 @@ | |||
| # Originally from Microsoft Corporation. | |||
| # Licensed under the MIT License. | |||
| """ Wrapper for ngram_repeat_block cuda extension """ | |||
| import math | |||
| import warnings | |||
| from typing import Dict, List | |||
| import torch | |||
| from torch import nn | |||
| try: | |||
| from fairseq import ngram_repeat_block_cuda | |||
| EXTENSION_BUILT = True | |||
| except ImportError: | |||
| EXTENSION_BUILT = False | |||
| def is_cuda_extension_usable() -> bool: | |||
| """Check whether ngram_repeat_block_cuda is built properly""" | |||
| if not EXTENSION_BUILT or not torch.cuda.is_available(): | |||
| return False | |||
| bsz = 2 | |||
| tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], | |||
| dtype=torch.long, | |||
| device='cuda') | |||
| lprobs = torch.rand((8, 12), device='cuda') | |||
| try: | |||
| outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) | |||
| outputs = outputs + 4 # This line breaks if the extension is built incorrectly. | |||
| return True | |||
| except RuntimeError: | |||
| warnings.warn( | |||
| 'NGramRepeatBlock extension must be rebuilt.' | |||
| 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' | |||
| ) | |||
| return False | |||
| class NGramRepeatBlock(nn.Module): | |||
| """ Wrapper class for calling ngram_repeat_block cuda extension """ | |||
| def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): | |||
| super().__init__() | |||
| self.use_extension = is_cuda_extension_usable( | |||
| ) if use_extension else False | |||
| self.no_repeat_ngram_size = no_repeat_ngram_size | |||
| def reset_parameters(self): | |||
| pass | |||
| @torch.jit.unused | |||
| def call_cuda_extension( | |||
| self, | |||
| tokens, | |||
| lprobs, | |||
| bsz: int, | |||
| beam_size: int, | |||
| step: int, | |||
| ): | |||
| return ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, step, | |||
| beam_size, | |||
| self.no_repeat_ngram_size) | |||
| def forward( | |||
| self, | |||
| tokens, | |||
| lprobs, | |||
| bsz: int, | |||
| beam_size: int, | |||
| step: int, | |||
| ): | |||
| """ | |||
| Args: | |||
| tokens(Tensor): Input tokens(Bsz*beam, seq_len) | |||
| lprobs(Tensor): likelihood probability, | |||
| Expected to be updated in place.(Bsz*beam, vocab_size) | |||
| bsz(int): batch size | |||
| step(int): current step | |||
| beam_size(int): beam size | |||
| no_repeat_ngram_size(int): Ngram size | |||
| """ | |||
| msg = f'expected {bsz * beam_size} got' | |||
| assert tokens.size(0) == bsz * beam_size, f'{msg} {tokens.size(0)}' | |||
| assert lprobs.size(0) == bsz * beam_size, f'{msg} {lprobs.size(0)}' | |||
| if self.use_extension: | |||
| return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, | |||
| step) | |||
| else: | |||
| return self._no_repeat_ngram( | |||
| tokens, | |||
| lprobs, | |||
| bsz, | |||
| beam_size, | |||
| step, | |||
| ) | |||
| def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, | |||
| step: int): | |||
| """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" | |||
| gen_ngrams: List[Dict[str, List[int]]] = [ | |||
| torch.jit.annotate(Dict[str, List[int]], {}) | |||
| for bbsz_idx in range(bsz * beam_size) | |||
| ] | |||
| cpu_tokens = tokens.cpu() | |||
| for bbsz_idx in range(bsz * beam_size): | |||
| gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() | |||
| for ngram in self.transpose_list([ | |||
| gen_tokens[i:] for i in range(self.no_repeat_ngram_size) | |||
| ]): # noqa | |||
| key = ','.join([str(x) for x in ngram[:-1]]) | |||
| gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( | |||
| key, torch.jit.annotate(List[int], [])) + [ngram[-1]] | |||
| if step + 2 - self.no_repeat_ngram_size >= 0: | |||
| # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |||
| banned_tokens = [ | |||
| self.calculate_banned_tokens(tokens, step, gen_ngrams, | |||
| self.no_repeat_ngram_size, | |||
| bbsz_idx) | |||
| for bbsz_idx in range(bsz * beam_size) | |||
| ] | |||
| else: | |||
| banned_tokens = [ | |||
| torch.jit.annotate(List[int], []) | |||
| for bbsz_idx in range(bsz * beam_size) | |||
| ] | |||
| for bbsz_idx in range(bsz * beam_size): | |||
| lprobs[bbsz_idx][torch.tensor( | |||
| banned_tokens[bbsz_idx], | |||
| dtype=torch.int64)] = torch.tensor(-math.inf).to(lprobs) | |||
| return lprobs | |||
| @staticmethod | |||
| def calculate_banned_tokens( | |||
| tokens, | |||
| step: int, | |||
| gen_ngrams: List[Dict[str, List[int]]], | |||
| no_repeat_ngram_size: int, | |||
| bbsz_idx: int, | |||
| ): | |||
| tokens_list: List[int] = tokens[bbsz_idx, | |||
| step + 2 - no_repeat_ngram_size:step | |||
| + 1].tolist() # noqa | |||
| # before decoding the next token, prevent decoding of ngrams that have already appeared | |||
| ngram_index = ','.join([str(x) for x in tokens_list]) | |||
| return gen_ngrams[bbsz_idx].get(ngram_index, | |||
| torch.jit.annotate(List[int], [])) | |||
| @staticmethod | |||
| def transpose_list(l: List[List[int]]): # noqa | |||
| # GeneratorExp aren't supported in TS so ignoring the lint | |||
| min_len = min([len(x) for x in l]) # noqa | |||
| l2 = [[row[i] for row in l] for i in range(min_len)] | |||
| return l2 | |||
| @@ -0,0 +1,848 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||
| # | |||
| # This source code is licensed under the MIT license which can be found at | |||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
| import math | |||
| from typing import List, Optional | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch import Tensor | |||
| from .token_generation_constraints import (ConstraintState, | |||
| OrderedConstraintState, | |||
| UnorderedConstraintState) | |||
| class Search(nn.Module): | |||
| def __init__(self, tokenizer): | |||
| super().__init__() | |||
| self.pad = tokenizer.pad_token_id | |||
| self.unk = tokenizer.unk_token_id | |||
| self.eos = tokenizer.eos_token_id | |||
| tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} | |||
| added = { | |||
| value: key | |||
| for key, value in tokenizer.get_added_vocab().items() | |||
| } | |||
| tgt_dict.update(added) | |||
| self.vocab_size = len(tgt_dict) | |||
| self.src_lengths = torch.tensor(-1) | |||
| self.supports_constraints = False | |||
| self.stop_on_max_len = False | |||
| def step(self, | |||
| step, | |||
| lprobs, | |||
| scores, | |||
| prev_output_tokens=None, | |||
| original_batch_idxs=None): | |||
| """Take a single search step. | |||
| Args: | |||
| step: the current search step, starting at 0 | |||
| lprobs: (bsz x input_beam_size x vocab_size) | |||
| the model's log-probabilities over the vocabulary at the current step | |||
| scores: (bsz x input_beam_size x step) | |||
| the historical model scores of each hypothesis up to this point | |||
| prev_output_tokens: (bsz x step) | |||
| the previously generated oputput tokens | |||
| original_batch_idxs: (bsz) | |||
| the tensor with the batch indices, in the range [0, bsz) | |||
| this is useful in case there has been applied a re-ordering | |||
| and we need to know the orignal indices | |||
| Return: A tuple of (scores, indices, beams) where: | |||
| scores: (bsz x output_beam_size) | |||
| the scores of the chosen elements; output_beam_size can be | |||
| larger than input_beam_size, e.g., we may return | |||
| 2*input_beam_size to account for EOS | |||
| indices: (bsz x output_beam_size) | |||
| the indices of the chosen elements | |||
| beams: (bsz x output_beam_size) | |||
| the hypothesis ids of the chosen elements, in the range [0, input_beam_size) | |||
| """ | |||
| raise NotImplementedError | |||
| @torch.jit.export | |||
| def set_src_lengths(self, src_lengths): | |||
| self.src_lengths = src_lengths | |||
| @torch.jit.export | |||
| def init_constraints(self, batch_constraints: Optional[Tensor], | |||
| beam_size: int): | |||
| """Initialize constraint states for constrained decoding (if supported). | |||
| Args: | |||
| batch_constraints: (torch.Tensor, optional) | |||
| the list of constraints, in packed form | |||
| beam_size: (int) | |||
| the beam size | |||
| Returns: | |||
| *encoder_out* rearranged according to *new_order* | |||
| """ | |||
| pass | |||
| def prune_sentences(self, batch_idxs: Tensor): | |||
| """ | |||
| Removes constraint states for completed sentences (if supported). | |||
| This is called from sequence_generator._generate() when sentences are | |||
| deleted from the batch. | |||
| Args: | |||
| batch_idxs: Indices of *sentences* whose constraint state should be *kept*. | |||
| """ | |||
| pass | |||
| def update_constraints(self, active_hypos: Tensor): | |||
| """ | |||
| Updates the constraint states by selecting the beam items that are retained. | |||
| This is called at each time step of sequence_generator._generate() when | |||
| the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. | |||
| Args: | |||
| active_hypos: (batch size, beam size) | |||
| list of integers denoting, for each sentence, which beam candidate items | |||
| should be kept. | |||
| """ | |||
| pass | |||
| class BeamSearch(Search): | |||
| def __init__(self, tgt_dict): | |||
| super().__init__(tgt_dict) | |||
| self.constraint_states = None | |||
| @torch.jit.export | |||
| def step( | |||
| self, | |||
| step: int, | |||
| lprobs, | |||
| scores: Optional[Tensor], | |||
| prev_output_tokens: Optional[Tensor] = None, | |||
| original_batch_idxs: Optional[Tensor] = None, | |||
| ): | |||
| bsz, beam_size, vocab_size = lprobs.size() | |||
| if step == 0: | |||
| # at the first step all hypotheses are equally likely, so use | |||
| # only the first beam | |||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
| else: | |||
| # make probs contain cumulative scores for each hypothesis | |||
| assert scores is not None | |||
| lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||
| top_prediction = torch.topk( | |||
| lprobs.view(bsz, -1), | |||
| k=min( | |||
| # Take the best 2 x beam_size predictions. We'll choose the first | |||
| # beam_size of these which don't predict eos to continue with. | |||
| beam_size * 2, | |||
| lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||
| ), | |||
| ) | |||
| scores_buf = top_prediction[0] | |||
| indices_buf = top_prediction[1] | |||
| # Project back into relative indices and beams | |||
| beams_buf = indices_buf // vocab_size | |||
| indices_buf = indices_buf.fmod(vocab_size) | |||
| # At this point, beams_buf and indices_buf are single-dim and contain relative indices | |||
| return scores_buf, indices_buf, beams_buf | |||
| class PrefixConstrainedBeamSearch(Search): | |||
| def __init__(self, tgt_dict, prefix_allowed_tokens_fn): | |||
| super().__init__(tgt_dict) | |||
| self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn | |||
| self.stop_on_max_len = True | |||
| @torch.jit.export | |||
| def apply_mask(self, x, prev_output_tokens, original_batch_idxs): | |||
| beam_size = x.shape[0] // original_batch_idxs.shape[0] | |||
| original_batch_idxs = ( | |||
| original_batch_idxs.unsqueeze(-1).repeat( | |||
| (1, beam_size)).flatten().tolist()) | |||
| mask = torch.full_like(x, -math.inf) | |||
| for sent_i, (sent, batch_i) in enumerate( | |||
| zip(prev_output_tokens, original_batch_idxs)): | |||
| mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 | |||
| return mask | |||
| @torch.jit.export | |||
| def step( | |||
| self, | |||
| step: int, | |||
| lprobs: Tensor, | |||
| scores: Tensor, | |||
| prev_output_tokens: Tensor, | |||
| original_batch_idxs: Tensor, | |||
| ): | |||
| bsz, beam_size, vocab_size = lprobs.size() | |||
| lprobs += self.apply_mask( | |||
| lprobs.view(bsz * beam_size, 1, vocab_size), | |||
| prev_output_tokens, | |||
| original_batch_idxs, | |||
| ).view(bsz, beam_size, vocab_size) | |||
| if step == 0: | |||
| # at the first step all hypotheses are equally likely, so use | |||
| # only the first beam | |||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
| else: | |||
| # make probs contain cumulative scores for each hypothesis | |||
| assert scores is not None | |||
| lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||
| top_prediction = torch.topk( | |||
| lprobs.view(bsz, -1), | |||
| k=min( | |||
| # Take the best beam_size predictions. We'll choose the first | |||
| # beam_size of these which don't predict eos to continue with. | |||
| beam_size, | |||
| lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||
| ), | |||
| ) | |||
| scores_buf = top_prediction[0] | |||
| indices_buf = top_prediction[1] | |||
| beams_buf = indices_buf // vocab_size | |||
| indices_buf = indices_buf.fmod(vocab_size) | |||
| return scores_buf, indices_buf, beams_buf | |||
| class LexicallyConstrainedBeamSearch(Search): | |||
| """Implements lexically constrained beam search as described in | |||
| Fast Lexically Constrained Decoding with Dynamic Beam | |||
| Allocation for Neural Machine Translation. Post & Vilar, | |||
| NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ | |||
| and | |||
| Improved Lexically Constrained Decoding for Translation and | |||
| Monolingual Rewriting. Hu et al, NAACL | |||
| 2019. https://www.aclweb.org/anthology/N19-1090/ | |||
| This is accomplished by maintaining, for each beam hypothesis, a | |||
| ConstraintState object (see constraints.py) that tracks which | |||
| constraints have been generated and using this information to | |||
| shape the beam for each input sentence. | |||
| """ | |||
| def __init__(self, tokenizer, representation): | |||
| super().__init__(tokenizer) | |||
| self.representation = representation | |||
| tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} | |||
| added = { | |||
| value: key | |||
| for key, value in tokenizer.get_added_vocab().items() | |||
| } | |||
| tgt_dict.update(added) | |||
| self.vocab_size = len(tgt_dict) | |||
| self.num_cands = 0 | |||
| self.supports_constraints = True | |||
| @torch.jit.export | |||
| def init_constraints(self, batch_constraints: Optional[Tensor], | |||
| beam_size: int): | |||
| self.constraint_states = [] | |||
| for constraint_tensor in batch_constraints: | |||
| if self.representation == 'ordered': | |||
| constraint_state = OrderedConstraintState.create( | |||
| constraint_tensor) | |||
| elif self.representation == 'unordered': | |||
| constraint_state = UnorderedConstraintState.create( | |||
| constraint_tensor) | |||
| self.constraint_states.append( | |||
| [constraint_state for i in range(beam_size)]) | |||
| @torch.jit.export | |||
| def prune_sentences(self, batch_idxs: Tensor): | |||
| self.constraint_states = [ | |||
| self.constraint_states[i] for i in batch_idxs.tolist() | |||
| ] | |||
| @torch.jit.export | |||
| def update_constraints(self, active_hypos: Tensor): | |||
| if self.constraint_states: | |||
| batch_size = active_hypos.size(0) | |||
| for sentid in range(batch_size): | |||
| self.constraint_states[sentid] = [ | |||
| self.constraint_states[sentid][i] | |||
| for i in active_hypos[sentid] | |||
| ] | |||
| @torch.jit.export | |||
| def step( | |||
| self, | |||
| step: int, | |||
| lprobs: Tensor, | |||
| scores: Optional[Tensor], | |||
| prev_output_tokens: Optional[Tensor] = None, | |||
| original_batch_idxs: Optional[Tensor] = None, | |||
| ): | |||
| """ | |||
| A constrained step builds a large candidates list from the following: | |||
| - the top 2 * {beam_size} items over the whole beam | |||
| - for each item in the beam | |||
| - the top {each_k} (default 1) | |||
| - all next constraints | |||
| We then compute the constrained state of each beam item, and assign | |||
| stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so | |||
| on. We then sort by (stripe, score), and truncate the list at | |||
| 2 * beam size. | |||
| Args: | |||
| step: the decoder step | |||
| lprobs: (batch size, beam size, target vocab) | |||
| the target-vocab distributions for each item in the beam. | |||
| Retrun: A tuple of (scores, indices, beams, constraints) where: | |||
| scores: (batch, output beam size) | |||
| the scores of the chosen elements | |||
| indices: (batch, output beam size) | |||
| the target vocab indices of the chosen elements | |||
| beams: (batch, output beam size) | |||
| the 0-indexed hypothesis ids of the chosen elements | |||
| constraints: (batch, output beam size) | |||
| the new constraint states | |||
| """ | |||
| each_k = 1 | |||
| device = lprobs.device | |||
| batch_size, beam_size, vocab_size = lprobs.size() | |||
| self.num_cands = min( | |||
| # Just take the k-best. We'll get another k from the 1-best from each | |||
| # row, plus more from the constraints | |||
| beam_size * 2, | |||
| lprobs.view(batch_size, -1).size(1) | |||
| - 1, # -1 so we never select pad | |||
| ) | |||
| # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items | |||
| constraint_states = self.constraint_states | |||
| if constraint_states and step > 0: | |||
| not_finished_indices = [] | |||
| for sentno, sent_constraints in enumerate(constraint_states): | |||
| for beamno, state in enumerate(sent_constraints): | |||
| index = sentno * beam_size + beamno | |||
| if not state.finished: | |||
| not_finished_indices.append(index) | |||
| not_finished_indices = torch.tensor(not_finished_indices) | |||
| if not_finished_indices.numel() > 0: | |||
| lprobs.view(batch_size * beam_size, -1)[not_finished_indices, | |||
| self.eos] = -math.inf | |||
| if step == 0: | |||
| # at the first step all hypotheses are equally likely, so use | |||
| # only the first beam entry for each batch item | |||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
| else: | |||
| # make probs contain cumulative scores for each hypothesis | |||
| assert scores is not None | |||
| lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||
| top_prediction = torch.topk( | |||
| lprobs.view(batch_size, -1), | |||
| self.num_cands, | |||
| ) | |||
| scores_buf, indices_buf = top_prediction | |||
| # Project back into relative indices and beams | |||
| beams_buf = indices_buf // vocab_size | |||
| indices_buf = indices_buf.fmod(vocab_size) | |||
| # Short circuit if there are no constraints in this batch | |||
| if not constraint_states: | |||
| return scores_buf, indices_buf, beams_buf | |||
| # STEP 1: get top-1 from each hypothesis across all sentences in the batch | |||
| if step > 0: | |||
| top_scores, top_indices = torch.topk( | |||
| lprobs.view(batch_size * beam_size, -1), | |||
| k=each_k, | |||
| dim=1, | |||
| ) | |||
| top_scores = top_scores.view(batch_size, -1) | |||
| top_indices = top_indices.view(batch_size, -1) | |||
| scores_buf = torch.cat((scores_buf, top_scores), dim=1) | |||
| indices_buf = torch.cat((indices_buf, top_indices), dim=1) | |||
| new_beams = torch.arange( | |||
| 0, beam_size, device=device).repeat(batch_size, 1) | |||
| beams_buf = torch.cat((beams_buf, new_beams), dim=1) | |||
| # Now, process sentences in the batch one by one. | |||
| new_scores_buf = torch.zeros((batch_size, 2 * beam_size), | |||
| device=device) | |||
| new_indices_buf = torch.zeros((batch_size, 2 * beam_size), | |||
| device=device).long() | |||
| new_beams_buf = torch.zeros((batch_size, 2 * beam_size), | |||
| device=device).long() | |||
| for sentno, states in enumerate(constraint_states): | |||
| scores, indices, beams, new_states = self.step_sentence( | |||
| step, | |||
| sentno, | |||
| lprobs[sentno], | |||
| constraint_states[sentno], | |||
| beams_buf[sentno].clone(), | |||
| indices_buf[sentno].clone(), | |||
| scores_buf[sentno].clone(), | |||
| ) | |||
| new_scores_buf[sentno] = scores | |||
| new_indices_buf[sentno] = indices | |||
| new_beams_buf[sentno] = beams | |||
| self.constraint_states[sentno] = new_states | |||
| return new_scores_buf, new_indices_buf, new_beams_buf | |||
| @torch.jit.export | |||
| def step_sentence( | |||
| self, | |||
| step: int, | |||
| sentno: int, | |||
| lprobs: Tensor, | |||
| constraint_states: List[List[ConstraintState]], | |||
| beams_buf: Tensor, | |||
| indices_buf: Tensor, | |||
| scores_buf: Tensor, | |||
| ): | |||
| """Does per-sentence processing. Adds all constraints for each | |||
| hypothesis to the list of candidates; then removes duplicates, | |||
| sorts, and dynamically stripes across the banks. All tensor inputs | |||
| are collapsed to those pertaining to a single input sentence. | |||
| """ | |||
| device = lprobs.device | |||
| # STEP 2: Add all constraints for each beam item | |||
| for beamno, state in enumerate(constraint_states): | |||
| next_tokens = torch.tensor( | |||
| list(state.next_tokens()), device=device).long() | |||
| if next_tokens.numel() != 0: | |||
| indices_buf = torch.cat((indices_buf, next_tokens)) | |||
| next_beams = ( | |||
| torch.tensor(beamno, device=device).repeat( | |||
| next_tokens.size(0)).long()) | |||
| beams_buf = torch.cat((beams_buf, next_beams)) | |||
| next_values = lprobs[beamno].take(next_tokens.view(-1)) | |||
| scores_buf = torch.cat((scores_buf, next_values)) | |||
| # At the 0th time step, there is just one beam item | |||
| if step == 0: | |||
| break | |||
| # STEP 3: Compute the "bank" for each candidate. This is the | |||
| # number of constraints it's generated. We need this so that | |||
| # we can do round-robin allocation of the beam across these | |||
| # banks. If C is the number of constraints, we select the best | |||
| # item in bank C, then the best in bank C-1, etc, followed by | |||
| # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so | |||
| # on, until the maximum beam size. We accomplish this by | |||
| # creating a sort key and striping across the banks. | |||
| # Compute the new states for all candidates | |||
| cands_size = indices_buf.size(0) | |||
| constraint_states = [ | |||
| constraint_states[beams_buf[i]].advance(indices_buf[i]) | |||
| for i in range(cands_size) | |||
| ] | |||
| banks = torch.tensor([state.bank for state in constraint_states], | |||
| device=device) | |||
| # STEP 4: Sort | |||
| num_constraint_tokens = len(state.tokens) | |||
| # Sort by keys (bank, score) (i.e., sort banks together, and scores | |||
| # within banks). AFAIK pytorch doesn't support either stable sort or | |||
| # multi-key sorting, so we have to hack this. | |||
| MAX_SCORE = -100 | |||
| sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf | |||
| sort_values, sort_indices = sort_key.sort(dim=0, descending=True) | |||
| scores_buf = scores_buf[sort_indices] | |||
| indices_buf = indices_buf[sort_indices] | |||
| beams_buf = beams_buf[sort_indices] | |||
| banks = banks[sort_indices] | |||
| # Sort the constraints to follow suit | |||
| constraint_states = [constraint_states[i] for i in sort_indices] | |||
| # STEP 5: Remove duplicates. The topk calls (overall and | |||
| # per-row) plus the per-row generation of constraints will | |||
| # produce duplicates. Here we remove them. | |||
| def roll(t): | |||
| """Rolls a 1d tensor left by 1. | |||
| [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] | |||
| """ | |||
| return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) | |||
| # We map candidates (beam, token_id) to a single dimension. | |||
| # This is then shifted by 1. We can then easily identify | |||
| # duplicates and create a mask that identifies unique | |||
| # extensions. | |||
| uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf | |||
| uniques_mask = roll(uniques_mask) != uniques_mask | |||
| # Use the mask to pare down the data structures | |||
| scores_buf = torch.masked_select(scores_buf, uniques_mask) | |||
| indices_buf = torch.masked_select(indices_buf, uniques_mask) | |||
| beams_buf = torch.masked_select(beams_buf, uniques_mask) | |||
| banks = torch.masked_select(banks, uniques_mask) | |||
| i = 1 | |||
| for mask in uniques_mask[1:]: | |||
| if not mask: | |||
| constraint_states.pop(i) | |||
| i += mask | |||
| # STEP 6: Assign IDs round-robin across banks, sort, and | |||
| # truncate. Now that the candidates are sorted by (bank, | |||
| # score) and uniqed, we dynamically allocate the {beam_size} | |||
| # beam by striping across the candidates. These stripes will | |||
| # be used as sort keys to do round-robin selection. This is | |||
| # accomplished in a single pass with offsets. Sorting by | |||
| # highest-banks (furthest-along hypotheses) first ensures | |||
| # progress through the constraints. | |||
| # | |||
| # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 | |||
| # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 | |||
| # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 | |||
| # = 0 5 10 1 6 11 13 2 7 12 3 8 | |||
| # | |||
| # Sorting by this then gives the following banks: | |||
| # | |||
| # 3 2 1 0 3 2 1 0 3 2 1 2 | |||
| # | |||
| # We'll take the top {beam_size} of these. | |||
| stripe_offsets = [ | |||
| offset * (len(banks) + 1) for offset in range(len(banks) + 1) | |||
| ] | |||
| stripes = torch.zeros_like(banks) | |||
| cur_bank_count = -1 | |||
| cur_bank = banks[0] | |||
| for i, bank in enumerate(banks): | |||
| if bank != cur_bank: | |||
| cur_bank_count = 0 | |||
| cur_bank = bank | |||
| else: | |||
| cur_bank_count += 1 | |||
| stripes[i] = num_constraint_tokens - bank + stripe_offsets[ | |||
| cur_bank_count] | |||
| # STEP 7: Sort by the stripes values | |||
| sort_values, sort_indices = stripes.sort(dim=0) | |||
| scores_buf = scores_buf[sort_indices] | |||
| indices_buf = indices_buf[sort_indices] | |||
| beams_buf = beams_buf[sort_indices] | |||
| constraint_states = [constraint_states[i] for i in sort_indices] | |||
| # STEP 8: Truncate to the candidates size! | |||
| scores_buf = scores_buf[:self.num_cands] | |||
| indices_buf = indices_buf[:self.num_cands] | |||
| beams_buf = beams_buf[:self.num_cands] | |||
| return scores_buf, indices_buf, beams_buf, constraint_states | |||
| class LengthConstrainedBeamSearch(Search): | |||
| def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): | |||
| super().__init__(tgt_dict) | |||
| self.min_len_a = min_len_a | |||
| self.min_len_b = min_len_b | |||
| self.max_len_a = max_len_a | |||
| self.max_len_b = max_len_b | |||
| self.beam = BeamSearch(tgt_dict) | |||
| self.needs_src_lengths = True | |||
| def step( | |||
| self, | |||
| step: int, | |||
| lprobs, | |||
| scores, | |||
| prev_output_tokens: Optional[Tensor] = None, | |||
| original_batch_idxs: Optional[Tensor] = None, | |||
| ): | |||
| min_lens = self.min_len_a * self.src_lengths + self.min_len_b | |||
| max_lens = self.max_len_a * self.src_lengths + self.max_len_b | |||
| lprobs[step < min_lens, :, self.eos] = -math.inf | |||
| lprobs[step >= max_lens, :, self.eos] = 0 | |||
| return self.beam.step(step, lprobs, scores) | |||
| class DiverseBeamSearch(Search): | |||
| """Diverse Beam Search. | |||
| See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence | |||
| Models" for details. | |||
| We only implement the Hamming Diversity penalty here, which performed best | |||
| in the original paper. | |||
| """ | |||
| def __init__(self, tgt_dict, num_groups, diversity_strength): | |||
| super().__init__(tgt_dict) | |||
| self.num_groups = num_groups | |||
| self.diversity_strength = -diversity_strength | |||
| self.beam = BeamSearch(tgt_dict) | |||
| @torch.jit.export | |||
| def step( | |||
| self, | |||
| step: int, | |||
| lprobs, | |||
| scores, | |||
| prev_output_tokens: Optional[Tensor] = None, | |||
| original_batch_idxs: Optional[Tensor] = None, | |||
| ): | |||
| bsz, beam_size, vocab_size = lprobs.size() | |||
| if beam_size % self.num_groups != 0: | |||
| raise ValueError( | |||
| 'DiverseBeamSearch requires --beam to be divisible by the number of groups' | |||
| ) | |||
| # initialize diversity penalty | |||
| diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) | |||
| scores_G, indices_G, beams_G = [], [], [] | |||
| for g in range(self.num_groups): | |||
| lprobs_g = lprobs[:, g::self.num_groups, :] | |||
| scores_g = scores[:, g::self.num_groups, :] if step > 0 else None | |||
| # apply diversity penalty | |||
| if g > 0: | |||
| lprobs_g = torch.add( | |||
| lprobs_g, | |||
| other=diversity_buf.unsqueeze(1), | |||
| alpha=self.diversity_strength, | |||
| ) | |||
| else: | |||
| lprobs_g = lprobs_g.contiguous() | |||
| scores_buf, indices_buf, beams_buf = self.beam.step( | |||
| step, lprobs_g, scores_g) | |||
| beams_buf.mul_(self.num_groups).add_(g) | |||
| scores_G.append(scores_buf.clone()) | |||
| indices_G.append(indices_buf.clone()) | |||
| beams_G.append(beams_buf.clone()) | |||
| # update diversity penalty | |||
| diversity_buf.scatter_add_( | |||
| 1, indices_buf, | |||
| torch.ones(indices_buf.size()).to(diversity_buf)) | |||
| # interleave results from different groups | |||
| scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) | |||
| indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1) | |||
| beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) | |||
| return scores_buf, indices_buf, beams_buf | |||
| class Sampling(Search): | |||
| sampling_topk: int | |||
| sampling_topp: float | |||
| def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): | |||
| super().__init__(tgt_dict) | |||
| self.sampling_topk = sampling_topk | |||
| self.sampling_topp = sampling_topp | |||
| def _sample_topp(self, lprobs): | |||
| """Sample among the smallest set of elements whose cumulative probability mass exceeds p. | |||
| See `"The Curious Case of Neural Text Degeneration" | |||
| (Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_. | |||
| Args: | |||
| lprobs: (bsz x input_beam_size x vocab_size) | |||
| the model's log-probabilities over the vocabulary at the current step | |||
| Return: A tuple of (trimed_probs, truncated_indices) where: | |||
| trimed_probs: (bsz x input_beam_size x ?) | |||
| the model's probabilities over the elements selected to sample from. The | |||
| width of the third dimension is determined by top-P. | |||
| truncated_indices: (bsz x input_beam_size x ?) | |||
| the indices of the chosen elements. | |||
| """ | |||
| probs = lprobs.exp_() | |||
| # sort the last dimension (vocab dimension) in descending order | |||
| sorted_probs, sorted_indices = probs.sort(descending=True) | |||
| # compute a mask to indicate the words to be included in the top-P set. | |||
| cumsum_probs = sorted_probs.cumsum(dim=2) | |||
| mask = cumsum_probs.lt(self.sampling_topp) | |||
| # note that mask was computed by 'lt'. One more word needs to be included | |||
| # so that the cumulative probability mass can exceed p. | |||
| cumsum_mask = mask.cumsum(dim=2) | |||
| last_included = cumsum_mask[:, :, -1:] | |||
| last_included.clamp_(0, mask.size()[2] - 1) | |||
| mask = mask.scatter_(2, last_included, 1) | |||
| # truncate unnecessary dims. | |||
| max_dim = last_included.max() | |||
| truncated_mask = mask[:, :, :max_dim + 1] | |||
| truncated_probs = sorted_probs[:, :, :max_dim + 1] | |||
| truncated_indices = sorted_indices[:, :, :max_dim + 1] | |||
| # trim the words that are not in top-P by setting their probabilities | |||
| # to 0, so that they would not be sampled later. | |||
| trim_mask = ~truncated_mask | |||
| trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) | |||
| return trimed_probs, truncated_indices | |||
| @torch.jit.export | |||
| def step( | |||
| self, | |||
| step: int, | |||
| lprobs, | |||
| scores, | |||
| prev_output_tokens: Optional[Tensor] = None, | |||
| original_batch_idxs: Optional[Tensor] = None, | |||
| ): | |||
| bsz, beam_size, vocab_size = lprobs.size() | |||
| if step == 0: | |||
| # at the first step all hypotheses are equally likely, so use | |||
| # only the first beam | |||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||
| if self.sampling_topp > 0: | |||
| # only sample from the smallest set of words whose cumulative probability mass exceeds p | |||
| probs, top_indices = self._sample_topp(lprobs) | |||
| elif self.sampling_topk > 0: | |||
| # only sample from top-k candidates | |||
| lprobs, top_indices = lprobs.topk(self.sampling_topk) | |||
| probs = lprobs.exp_() | |||
| else: | |||
| probs = lprobs.exp_() | |||
| # dummy data to be consistent with true branch for type check | |||
| top_indices = torch.empty(0).to(probs) | |||
| # sample | |||
| if step == 0: | |||
| indices_buf = torch.multinomial( | |||
| probs.view(bsz, -1), | |||
| beam_size, | |||
| replacement=True, | |||
| ).view(bsz, beam_size) | |||
| else: | |||
| indices_buf = torch.multinomial( | |||
| probs.view(bsz * beam_size, -1), | |||
| 1, | |||
| replacement=True, | |||
| ).view(bsz, beam_size) | |||
| if step == 0: | |||
| # expand to beam size | |||
| probs = probs.expand(bsz, beam_size, -1) | |||
| # gather scores | |||
| scores_buf = torch.gather( | |||
| probs, dim=2, index=indices_buf.unsqueeze(-1)) | |||
| scores_buf = scores_buf.log_().view(bsz, -1) | |||
| # remap indices if using top-k or top-P sampling | |||
| if self.sampling_topk > 0 or self.sampling_topp > 0: | |||
| indices_buf = torch.gather( | |||
| top_indices.expand(bsz, beam_size, -1), | |||
| dim=2, | |||
| index=indices_buf.unsqueeze(-1), | |||
| ).squeeze(2) | |||
| if step == 0: | |||
| beams_buf = indices_buf.new_zeros(bsz, beam_size) | |||
| else: | |||
| beams_buf = torch.arange(0, | |||
| beam_size).to(indices_buf).repeat(bsz, 1) | |||
| # make scores cumulative | |||
| scores_buf.add_( | |||
| torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf)) | |||
| return scores_buf, indices_buf, beams_buf | |||
| class DiverseSiblingsSearch(Search): | |||
| """ | |||
| Beam search with diverse siblings. | |||
| See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. | |||
| https://arxiv.org/abs/1611.08562 | |||
| 1/ Calculate hypotheses for each beam | |||
| 2/ Intra-sibling ordering | |||
| 3/ Rewrite scores | |||
| 4/ Choose top K hypotheses | |||
| if diversity_rate == 0 is equivalent to BeamSearch | |||
| """ | |||
| def __init__(self, tgt_dict, diversity_rate): | |||
| super().__init__(tgt_dict) | |||
| self.diversity_rate = diversity_rate | |||
| self.beam = BeamSearch(tgt_dict) | |||
| def step( | |||
| self, | |||
| step: int, | |||
| lprobs, | |||
| scores, | |||
| prev_output_tokens: Optional[Tensor] = None, | |||
| original_batch_idxs: Optional[Tensor] = None, | |||
| ): | |||
| bsz, beam_size, vocab_size = lprobs.size() | |||
| k = min( | |||
| # Take the best 2 x beam_size predictions. We'll choose the first | |||
| # beam_size of these which don't predict eos to continue with. | |||
| beam_size * 2, | |||
| lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||
| ) | |||
| s_list: List[Tensor] | |||
| i_list: List[Tensor] | |||
| s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] | |||
| i_list = [ | |||
| torch.LongTensor().to(device=lprobs.device) | |||
| for i in range(beam_size) | |||
| ] | |||
| sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate | |||
| if step == 0: | |||
| return self.beam.step(step, lprobs, scores) | |||
| lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) | |||
| # 1/ Calculate hypotheses for each beam | |||
| for i in range(beam_size): | |||
| torch.topk( | |||
| lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) | |||
| i_list[i].fmod_(vocab_size) | |||
| # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores | |||
| s_list[i].sub_(sibling_score) | |||
| # 4/ Choose top K hypotheses | |||
| indices = torch.stack(i_list, dim=1).view(bsz, -1) | |||
| final_scores = torch.empty(0).to(lprobs) | |||
| final_indices = torch.LongTensor().to(device=lprobs.device) | |||
| final_beams = torch.LongTensor().to(device=lprobs.device) | |||
| (final_scores, final_indices) = torch.topk( | |||
| torch.stack(s_list, dim=1).view(bsz, -1), | |||
| k, | |||
| ) | |||
| final_beams = final_indices // k | |||
| for i in range(bsz): | |||
| final_indices[i] = indices[i][final_indices[i]] | |||
| return final_scores, final_indices, final_beams | |||
| @@ -0,0 +1,996 @@ | |||
| # Copyright 2022 The OFA-Sys Team. | |||
| # All rights reserved. | |||
| # This source code is licensed under the Apache 2.0 license | |||
| # You may obtain a copy of the License at | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| import math | |||
| import sys | |||
| from typing import Dict, List, Optional, Tuple | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch import Tensor | |||
| from ..generate import search | |||
| from .ngram_repeat_block import NGramRepeatBlock | |||
| def _expand_mask(mask: torch.Tensor, | |||
| dtype: torch.dtype, | |||
| tgt_len: Optional[int] = None): | |||
| r""" | |||
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |||
| """ | |||
| bsz, src_len = mask.size() | |||
| tgt_len = tgt_len if tgt_len is not None else src_len | |||
| expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, | |||
| src_len).to(dtype) | |||
| return expanded_mask.masked_fill(expanded_mask.bool(), | |||
| torch.finfo(dtype).min) | |||
| class SequenceGenerator(nn.Module): | |||
| def __init__(self, | |||
| tokenizer, | |||
| beam_size=1, | |||
| max_len_a=0, | |||
| max_len_b=200, | |||
| max_len=0, | |||
| min_len=1, | |||
| normalize_scores=True, | |||
| len_penalty=1.0, | |||
| unk_penalty=0.0, | |||
| temperature=1.0, | |||
| match_source_len=False, | |||
| no_repeat_ngram_size=0, | |||
| search_strategy=None, | |||
| eos=None, | |||
| symbols_to_strip_from_output=None, | |||
| lm_model=None, | |||
| lm_weight=1.0, | |||
| constraint_trie=None, | |||
| constraint_range=None, | |||
| gen_code=False, | |||
| gen_box=False, | |||
| ignore_eos=False, | |||
| zero_shot=False): | |||
| """Generates translations of a given source sentence. | |||
| Args: | |||
| models (List[~fairseq.models.FairseqModel]): ensemble of models, | |||
| currently support fairseq.models.TransformerModel for scripting | |||
| beam_size (int, optional): beam width (default: 1) | |||
| max_len_a/b (int, optional): generate sequences of maximum length | |||
| ax + b, where x is the source length | |||
| max_len (int, optional): the maximum length of the generated output | |||
| (not including end-of-sentence) | |||
| min_len (int, optional): the minimum length of the generated output | |||
| (not including end-of-sentence) | |||
| normalize_scores (bool, optional): normalize scores by the length | |||
| of the output (default: True) | |||
| len_penalty (float, optional): length penalty, where <1.0 favors | |||
| shorter, >1.0 favors longer sentences (default: 1.0) | |||
| unk_penalty (float, optional): unknown word penalty, where <0 | |||
| produces more unks, >0 produces fewer (default: 0.0) | |||
| temperature (float, optional): temperature, where values | |||
| >1.0 produce more uniform samples and values <1.0 produce | |||
| sharper samples (default: 1.0) | |||
| match_source_len (bool, optional): outputs should match the source | |||
| length (default: False) | |||
| """ | |||
| super().__init__() | |||
| self.gen_code = gen_code | |||
| self.gen_box = gen_box | |||
| self.ignore_eos = ignore_eos | |||
| self.tokenizer = tokenizer | |||
| self.tgt_dict = { | |||
| value: key | |||
| for key, value in tokenizer.get_vocab().items() | |||
| } | |||
| added = { | |||
| value: key | |||
| for key, value in tokenizer.get_added_vocab().items() | |||
| } | |||
| self.tgt_dict.update(added) | |||
| self.pad = tokenizer.pad_token_id | |||
| self.unk = tokenizer.unk_token_id | |||
| self.bos = tokenizer.bos_token_id | |||
| self.eos = tokenizer.eos_token_id | |||
| self.symbols_to_strip_from_output = ( | |||
| symbols_to_strip_from_output.union({self.eos}) if | |||
| symbols_to_strip_from_output is not None else {self.bos, self.eos}) | |||
| self.vocab_size = len(self.tgt_dict) | |||
| self.beam_size = beam_size | |||
| # the max beam size is the dictionary size - 1, since we never select pad | |||
| self.beam_size = min(beam_size, self.vocab_size - 1) | |||
| self.max_len_a = max_len_a | |||
| self.max_len_b = max_len_b | |||
| self.min_len = min_len | |||
| self.max_len = max_len | |||
| self.normalize_scores = normalize_scores | |||
| self.len_penalty = len_penalty | |||
| self.unk_penalty = unk_penalty | |||
| self.temperature = temperature | |||
| self.match_source_len = match_source_len | |||
| self.zero_shot = zero_shot | |||
| if no_repeat_ngram_size > 0: | |||
| self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) | |||
| else: | |||
| self.repeat_ngram_blocker = None | |||
| assert temperature > 0, '--temperature must be greater than 0' | |||
| self.search = ( | |||
| search.BeamSearch(self.tokenizer) | |||
| if search_strategy is None else search_strategy) | |||
| # We only need to set src_lengths in LengthConstrainedBeamSearch. | |||
| # As a module attribute, setting it would break in multithread | |||
| # settings when the model is shared. | |||
| self.should_set_src_lengths = ( | |||
| hasattr(self.search, 'needs_src_lengths') | |||
| and self.search.needs_src_lengths) | |||
| self.lm_model = lm_model | |||
| self.lm_weight = lm_weight | |||
| if self.lm_model is not None: | |||
| self.lm_model.eval() | |||
| self.constraint_trie = constraint_trie | |||
| self.constraint_start = None | |||
| self.constraint_end = None | |||
| if constraint_range is not None: | |||
| constraint_start, constraint_end = constraint_range.split(',') | |||
| self.constraint_start = int(constraint_start) | |||
| self.constraint_end = int(constraint_end) | |||
| @torch.no_grad() | |||
| def forward( | |||
| self, | |||
| sample: Dict[str, Dict[str, Tensor]], | |||
| prefix_tokens: Optional[Tensor] = None, | |||
| bos_token: Optional[int] = None, | |||
| ): | |||
| """Generate a batch of translations. | |||
| Args: | |||
| sample (dict): batch | |||
| prefix_tokens (torch.LongTensor, optional): force decoder to begin | |||
| with these tokens | |||
| bos_token (int, optional): beginning of sentence token | |||
| (default: self.eos) | |||
| """ | |||
| return self._generate(sample, prefix_tokens, bos_token=bos_token) | |||
| @torch.no_grad() | |||
| def generate(self, models, sample: Dict[str, Dict[str, Tensor]], | |||
| **kwargs) -> List[List[Dict[str, Tensor]]]: | |||
| """Generate translations. Match the api of other fairseq generators. | |||
| Args: | |||
| models (List[~fairseq.models.FairseqModel]): ensemble of models | |||
| sample (dict): batch | |||
| prefix_tokens (torch.LongTensor, optional): force decoder to begin | |||
| with these tokens | |||
| constraints (torch.LongTensor, optional): force decoder to include | |||
| the list of constraints | |||
| bos_token (int, optional): beginning of sentence token | |||
| (default: self.eos) | |||
| """ | |||
| return self._generate(models, sample, **kwargs) | |||
| def _generate( | |||
| self, | |||
| models, | |||
| sample: Dict[str, Dict[str, Tensor]], | |||
| prefix_tokens: Optional[Tensor] = None, | |||
| constraints: Optional[Tensor] = None, | |||
| bos_token: Optional[int] = None, | |||
| ): | |||
| model = EnsembleModel(models) | |||
| # incremental_states = torch.jit.annotate( | |||
| # List[Dict[str, Dict[str, Optional[Tensor]]]], | |||
| # [ | |||
| # torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) | |||
| # for i in range(model.models_size) | |||
| # ], | |||
| # ) | |||
| incremental_states = torch.jit.annotate( | |||
| List[Tuple[Tuple[torch.Tensor]]], | |||
| [ | |||
| torch.jit.annotate(Tuple[Tuple[torch.Tensor]], {}) | |||
| for i in range(model.models_size) | |||
| ], | |||
| ) | |||
| # print("incremental_states",incremental_states) | |||
| # print("incremental_states[0]",incremental_states[0]) | |||
| net_input = sample['net_input'] | |||
| if 'src_tokens' in net_input: | |||
| src_tokens = net_input['src_tokens'] | |||
| # length of the source text being the character length except EndOfSentence and pad | |||
| src_lengths = ((src_tokens.ne(self.eos) | |||
| & src_tokens.ne(self.pad)).long().sum(dim=1)) | |||
| elif 'input_ids' in net_input: | |||
| src_tokens = net_input['input_ids'] | |||
| # length of the source text being the character length except EndOfSentence and pad | |||
| src_lengths = ((src_tokens.ne(self.eos) | |||
| & src_tokens.ne(self.pad)).long().sum(dim=1)) | |||
| elif 'source' in net_input: | |||
| src_tokens = net_input['source'] | |||
| src_lengths = ( | |||
| net_input['padding_mask'].size(-1) | |||
| - net_input['padding_mask'].sum(-1) | |||
| if net_input['padding_mask'] is not None else torch.tensor( | |||
| src_tokens.size(-1)).to(src_tokens)) | |||
| elif 'features' in net_input: | |||
| src_tokens = net_input['features'] | |||
| src_lengths = ( | |||
| net_input['padding_mask'].size(-1) | |||
| - net_input['padding_mask'].sum(-1) | |||
| if net_input['padding_mask'] is not None else torch.tensor( | |||
| src_tokens.size(-1)).to(src_tokens)) | |||
| else: | |||
| raise Exception( | |||
| 'expected src_tokens or source in net input. input keys: ' | |||
| + str(net_input.keys())) | |||
| # bsz: total number of sentences in beam | |||
| # Note that src_tokens may have more than 2 dimensions (i.e. audio features) | |||
| bsz, src_len = src_tokens.size()[:2] | |||
| beam_size = self.beam_size | |||
| if constraints is not None and not self.search.supports_constraints: | |||
| raise NotImplementedError( | |||
| "Target-side constraints were provided, but search method doesn't support them" | |||
| ) | |||
| # Initialize constraints, when active | |||
| self.search.init_constraints(constraints, beam_size) | |||
| max_len: int = -1 | |||
| if self.match_source_len: | |||
| max_len = src_lengths.max().item() | |||
| else: | |||
| max_len = int(self.max_len_a * src_len + self.max_len_b) | |||
| assert ( | |||
| self.min_len <= max_len | |||
| ), 'min_len cannot be larger than max_len, please adjust these!' | |||
| # compute the encoder output for each beam | |||
| with torch.autograd.profiler.record_function( | |||
| 'EnsembleModel: forward_encoder'): | |||
| encoder_outs = model.forward_encoder(net_input) | |||
| # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores | |||
| new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) | |||
| new_order = new_order.to(src_tokens.device).long() | |||
| encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) | |||
| # ensure encoder_outs is a List. | |||
| assert encoder_outs is not None | |||
| # initialize buffers | |||
| scores = (torch.zeros(bsz * beam_size, | |||
| max_len + 1).to(src_tokens).float() | |||
| ) # +1 for eos; pad is never chosen for scoring | |||
| tokens = (torch.zeros(bsz * beam_size, | |||
| max_len + 2).to(src_tokens).long().fill_( | |||
| self.pad)) # +2 for eos and pad | |||
| # tokens[:, 0] = self.eos if bos_token is None else bos_token | |||
| tokens[:, 0] = self.bos | |||
| attn: Optional[Tensor] = None | |||
| # A list that indicates candidates that should be ignored. | |||
| # For example, suppose we're sampling and have already finalized 2/5 | |||
| # samples. Then cands_to_ignore would mark 2 positions as being ignored, | |||
| # so that we only finalize the remaining 3 samples. | |||
| cands_to_ignore = (torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) | |||
| ) # forward and backward-compatible False mask | |||
| # list of completed sentences | |||
| finalized = torch.jit.annotate( | |||
| List[List[Dict[str, Tensor]]], | |||
| [ | |||
| torch.jit.annotate(List[Dict[str, Tensor]], []) | |||
| for i in range(bsz) | |||
| ], | |||
| ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step | |||
| # a boolean array indicating if the sentence at the index is finished or not | |||
| finished = [False for i in range(bsz)] | |||
| num_remaining_sent = bsz # number of sentences remaining | |||
| # number of candidate hypos per step | |||
| cand_size = 2 * beam_size # 2 x beam size in case half are EOS | |||
| # offset arrays for converting between different indexing schemes | |||
| bbsz_offsets = ((torch.arange(0, bsz) | |||
| * beam_size).unsqueeze(1).type_as(tokens).to( | |||
| src_tokens.device)) | |||
| cand_offsets = torch.arange(0, cand_size).type_as(tokens).to( | |||
| src_tokens.device) | |||
| reorder_state: Optional[Tensor] = None | |||
| batch_idxs: Optional[Tensor] = None | |||
| original_batch_idxs: Optional[Tensor] = None | |||
| if 'id' in sample and isinstance(sample['id'], Tensor): | |||
| original_batch_idxs = sample['id'] | |||
| else: | |||
| original_batch_idxs = torch.arange(0, bsz).type_as(tokens) | |||
| for step in range(max_len + 1): # one extra step for EOS marker | |||
| # reorder decoder internal states based on the prev choice of beams | |||
| if reorder_state is not None: | |||
| if batch_idxs is not None: | |||
| # update beam indices to take into account removed sentences | |||
| corr = batch_idxs - torch.arange( | |||
| batch_idxs.numel()).type_as(batch_idxs) | |||
| reorder_state.view(-1, beam_size).add_( | |||
| corr.unsqueeze(-1) * beam_size) | |||
| original_batch_idxs = original_batch_idxs[batch_idxs] | |||
| model.reorder_incremental_state(incremental_states, | |||
| reorder_state) # todo | |||
| encoder_outs = model.reorder_encoder_out( | |||
| encoder_outs, reorder_state) | |||
| with torch.autograd.profiler.record_function( | |||
| 'EnsembleModel: forward_decoder'): | |||
| lprobs, avg_attn_scores = model.forward_decoder( | |||
| tokens[:, :step + 1], | |||
| encoder_outs, | |||
| incremental_states, | |||
| self.temperature, | |||
| constraint_trie=self.constraint_trie, | |||
| constraint_start=self.constraint_start, | |||
| constraint_end=self.constraint_end, | |||
| gen_code=self.gen_code, | |||
| zero_shot=self.zero_shot, | |||
| prefix_tokens=prefix_tokens) | |||
| if self.lm_model is not None: | |||
| lm_out = self.lm_model(tokens[:, :step + 1]) | |||
| probs = self.lm_model.get_normalized_probs( | |||
| lm_out, log_probs=True, sample=None) | |||
| probs = probs[:, -1, :] * self.lm_weight | |||
| lprobs += probs | |||
| # handle prefix tokens (possibly with different lengths) | |||
| if (prefix_tokens is not None and step < prefix_tokens.size(1) | |||
| and step < max_len): | |||
| lprobs, tokens, scores = self._prefix_tokens( | |||
| step, lprobs, scores, tokens, prefix_tokens, beam_size) | |||
| elif step < self.min_len: | |||
| # minimum length constraint (does not apply if using prefix_tokens) | |||
| lprobs[:, self.eos] = -math.inf | |||
| lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) | |||
| lprobs[:, self.pad] = -math.inf # never select pad | |||
| lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty | |||
| if (self.gen_code or self.gen_box) and step < max_len: | |||
| lprobs[:, :4] = -math.inf | |||
| if self.gen_box: | |||
| lprobs[:, -1] = -math.inf | |||
| if (step + 1) % 5 == 0: | |||
| lprobs[:, self.constraint_start:59457] = -math.inf | |||
| else: | |||
| lprobs[:, 59457:] = -math.inf | |||
| # handle max length constraint | |||
| if step >= max_len: | |||
| lprobs[:, :self.eos] = -math.inf | |||
| lprobs[:, self.eos + 1:] = -math.inf | |||
| if self.ignore_eos: | |||
| lprobs[:, self.eos] = 1 | |||
| # Record attention scores, only support avg_attn_scores is a Tensor | |||
| if avg_attn_scores is not None: | |||
| if attn is None: | |||
| attn = torch.empty(bsz * beam_size, | |||
| avg_attn_scores.size(1), | |||
| max_len + 2).to(scores) | |||
| # print("+++++++ debug attention shape +++++++") | |||
| # print("attn", attn.shape) | |||
| # print("avg_attn_scores", avg_attn_scores.shape) | |||
| attn[:, :, step + 1].copy_(avg_attn_scores) | |||
| # print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) | |||
| # print("attn", attn.shape) | |||
| scores = scores.type_as(lprobs) | |||
| eos_bbsz_idx = torch.empty(0).to( | |||
| tokens | |||
| ) # indices of hypothesis ending with eos (finished sentences) | |||
| eos_scores = torch.empty(0).to( | |||
| scores | |||
| ) # scores of hypothesis ending with eos (finished sentences) | |||
| if self.should_set_src_lengths: | |||
| self.search.set_src_lengths(src_lengths) | |||
| if self.repeat_ngram_blocker is not None: | |||
| lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, | |||
| beam_size, step) | |||
| # Shape: (batch, cand_size) | |||
| cand_scores, cand_indices, cand_beams = self.search.step( | |||
| step, | |||
| lprobs.view(bsz, -1, self.vocab_size), | |||
| scores.view(bsz, beam_size, -1)[:, :, :step], | |||
| tokens[:, :step + 1], | |||
| original_batch_idxs, | |||
| ) | |||
| # cand_bbsz_idx contains beam indices for the top candidate | |||
| # hypotheses, with a range of values: [0, bsz*beam_size), | |||
| # and dimensions: [bsz, cand_size] | |||
| cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |||
| # finalize hypotheses that end in eos | |||
| # Shape of eos_mask: (batch size, beam size) | |||
| eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) | |||
| eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to( | |||
| eos_mask) | |||
| # only consider eos when it's among the top beam_size indices | |||
| # Now we know what beam item(s) to finish | |||
| # Shape: 1d list of absolute-numbered | |||
| eos_bbsz_idx = torch.masked_select( | |||
| cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]) | |||
| finalized_sents: List[int] = [] | |||
| if eos_bbsz_idx.numel() > 0: | |||
| eos_scores = torch.masked_select( | |||
| cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]) | |||
| finalized_sents = self.finalize_hypos( | |||
| step, | |||
| eos_bbsz_idx, | |||
| eos_scores, | |||
| tokens, | |||
| scores, | |||
| finalized, | |||
| finished, | |||
| beam_size, | |||
| attn, | |||
| src_lengths, | |||
| max_len, | |||
| ) | |||
| num_remaining_sent -= len(finalized_sents) | |||
| assert num_remaining_sent >= 0 | |||
| if num_remaining_sent == 0: | |||
| break | |||
| if self.search.stop_on_max_len and step >= max_len: | |||
| break | |||
| assert step < max_len, f'{step} < {max_len}' | |||
| # Remove finalized sentences (ones for which {beam_size} | |||
| # finished hypotheses have been generated) from the batch. | |||
| if len(finalized_sents) > 0: | |||
| new_bsz = bsz - len(finalized_sents) | |||
| # construct batch_idxs which holds indices of batches to keep for the next pass | |||
| batch_mask = torch.ones( | |||
| bsz, dtype=torch.bool, device=cand_indices.device) | |||
| batch_mask[finalized_sents] = False | |||
| # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it | |||
| batch_idxs = torch.arange( | |||
| bsz, device=cand_indices.device).masked_select(batch_mask) | |||
| # Choose the subset of the hypothesized constraints that will continue | |||
| self.search.prune_sentences(batch_idxs) | |||
| eos_mask = eos_mask[batch_idxs] | |||
| cand_beams = cand_beams[batch_idxs] | |||
| bbsz_offsets.resize_(new_bsz, 1) | |||
| cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |||
| cand_scores = cand_scores[batch_idxs] | |||
| cand_indices = cand_indices[batch_idxs] | |||
| if prefix_tokens is not None: | |||
| prefix_tokens = prefix_tokens[batch_idxs] | |||
| src_lengths = src_lengths[batch_idxs] | |||
| cands_to_ignore = cands_to_ignore[batch_idxs] | |||
| scores = scores.view(bsz, -1)[batch_idxs].view( | |||
| new_bsz * beam_size, -1) | |||
| tokens = tokens.view(bsz, -1)[batch_idxs].view( | |||
| new_bsz * beam_size, -1) | |||
| if attn is not None: | |||
| attn = attn.view(bsz, -1)[batch_idxs].view( | |||
| new_bsz * beam_size, attn.size(1), -1) | |||
| bsz = new_bsz | |||
| else: | |||
| batch_idxs = None | |||
| # Set active_mask so that values > cand_size indicate eos hypos | |||
| # and values < cand_size indicate candidate active hypos. | |||
| # After, the min values per row are the top candidate active hypos | |||
| # Rewrite the operator since the element wise or is not supported in torchscript. | |||
| eos_mask[:, :beam_size] = ~( # noqa | |||
| (~cands_to_ignore) & (~eos_mask[:, :beam_size])) # noqa | |||
| active_mask = torch.add( | |||
| eos_mask.type_as(cand_offsets) * cand_size, | |||
| cand_offsets[:eos_mask.size(1)], | |||
| ) | |||
| # get the top beam_size active hypotheses, which are just | |||
| # the hypos with the smallest values in active_mask. | |||
| # {active_hypos} indicates which {beam_size} hypotheses | |||
| # from the list of {2 * beam_size} candidates were | |||
| # selected. Shapes: (batch size, beam size) | |||
| new_cands_to_ignore, active_hypos = torch.topk( | |||
| active_mask, k=beam_size, dim=1, largest=False) | |||
| # update cands_to_ignore to ignore any finalized hypos. | |||
| cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] | |||
| # Make sure there is at least one active item for each sentence in the batch. | |||
| assert (~cands_to_ignore).any(dim=1).all() | |||
| # update cands_to_ignore to ignore any finalized hypos | |||
| # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam | |||
| # can be selected more than once). | |||
| active_bbsz_idx = torch.gather( | |||
| cand_bbsz_idx, dim=1, index=active_hypos) | |||
| active_scores = torch.gather( | |||
| cand_scores, dim=1, index=active_hypos) | |||
| active_bbsz_idx = active_bbsz_idx.view(-1) | |||
| active_scores = active_scores.view(-1) | |||
| # copy tokens and scores for active hypotheses | |||
| # Set the tokens for each beam (can select the same row more than once) | |||
| tokens[:, :step + 1] = torch.index_select( | |||
| tokens[:, :step + 1], dim=0, index=active_bbsz_idx) | |||
| # Select the next token for each of them | |||
| tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( | |||
| cand_indices, dim=1, index=active_hypos) | |||
| if step > 0: | |||
| scores[:, :step] = torch.index_select( | |||
| scores[:, :step], dim=0, index=active_bbsz_idx) | |||
| scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( | |||
| cand_scores, dim=1, index=active_hypos) | |||
| # Update constraints based on which candidates were selected for the next beam | |||
| self.search.update_constraints(active_hypos) | |||
| # copy attention for active hypotheses | |||
| if attn is not None: | |||
| attn[:, :, :step + 2] = torch.index_select( | |||
| attn[:, :, :step + 2], dim=0, index=active_bbsz_idx) | |||
| # reorder incremental state in decoder | |||
| reorder_state = active_bbsz_idx | |||
| # sort by score descending | |||
| for sent in range(len(finalized)): | |||
| scores = torch.tensor( | |||
| [float(elem['score'].item()) for elem in finalized[sent]]) | |||
| _, sorted_scores_indices = torch.sort(scores, descending=True) | |||
| finalized[sent] = [ | |||
| finalized[sent][ssi] for ssi in sorted_scores_indices | |||
| ] | |||
| finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], | |||
| finalized[sent]) | |||
| return finalized | |||
| def _prefix_tokens(self, step: int, lprobs, scores, tokens, prefix_tokens, | |||
| beam_size: int): | |||
| """Handle prefix tokens""" | |||
| prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat( | |||
| 1, beam_size).view(-1) | |||
| prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) | |||
| prefix_mask = prefix_toks.ne(self.pad) | |||
| if self.constraint_trie is None: | |||
| lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 | |||
| else: | |||
| lprobs[prefix_mask] = -math.inf | |||
| lprobs[prefix_mask] = lprobs[prefix_mask].scatter( | |||
| -1, prefix_toks[prefix_mask].unsqueeze(-1), | |||
| prefix_lprobs[prefix_mask]) | |||
| # if prefix includes eos, then we should make sure tokens and | |||
| # scores are the same across all beams | |||
| eos_mask = prefix_toks.eq(self.eos) | |||
| if eos_mask.any(): | |||
| # validate that the first beam matches the prefix | |||
| first_beam = tokens[eos_mask].view(-1, beam_size, | |||
| tokens.size(-1))[:, 0, | |||
| 1:step + 1] | |||
| eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] | |||
| target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] | |||
| assert (first_beam == target_prefix).all() | |||
| # copy tokens, scores and lprobs from the first beam to all beams | |||
| tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, | |||
| beam_size) | |||
| scores = self.replicate_first_beam(scores, eos_mask_batch_dim, | |||
| beam_size) | |||
| lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, | |||
| beam_size) | |||
| return lprobs, tokens, scores | |||
| def replicate_first_beam(self, tensor, mask, beam_size: int): | |||
| tensor = tensor.view(-1, beam_size, tensor.size(-1)) | |||
| tensor[mask] = tensor[mask][:, :1, :] | |||
| return tensor.view(-1, tensor.size(-1)) | |||
| def finalize_hypos( | |||
| self, | |||
| step: int, | |||
| bbsz_idx, | |||
| eos_scores, | |||
| tokens, | |||
| scores, | |||
| finalized: List[List[Dict[str, Tensor]]], | |||
| finished: List[bool], | |||
| beam_size: int, | |||
| attn: Optional[Tensor], | |||
| src_lengths, | |||
| max_len: int, | |||
| ): | |||
| """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. | |||
| A sentence is finalized when {beam_size} finished items have been collected for it. | |||
| Returns number of sentences (not beam items) being finalized. | |||
| These will be removed from the batch and not processed further. | |||
| Args: | |||
| bbsz_idx (Tensor): | |||
| """ | |||
| assert bbsz_idx.numel() == eos_scores.numel() | |||
| # clone relevant token and attention tensors. | |||
| # tokens is (batch * beam, max_len). So the index_select | |||
| # gets the newly EOS rows, then selects cols 1..{step + 2} | |||
| tokens_clone = tokens.index_select( | |||
| 0, bbsz_idx)[:, 1:step + 2] # skip the first index, which is EOS | |||
| tokens_clone[:, step] = self.eos | |||
| attn_clone = ( | |||
| attn.index_select(0, bbsz_idx)[:, :, 1:step | |||
| + 2] if attn is not None else None) | |||
| # compute scores per token position | |||
| pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] | |||
| pos_scores[:, step] = eos_scores | |||
| # convert from cumulative to per-position scores | |||
| pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] | |||
| # normalize sentence-level scores | |||
| if self.normalize_scores: | |||
| eos_scores /= (step + 1)**self.len_penalty | |||
| # cum_unfin records which sentences in the batch are finished. | |||
| # It helps match indexing between (a) the original sentences | |||
| # in the batch and (b) the current, possibly-reduced set of | |||
| # sentences. | |||
| cum_unfin: List[int] = [] | |||
| prev = 0 | |||
| for f in finished: | |||
| if f: | |||
| prev += 1 | |||
| else: | |||
| cum_unfin.append(prev) | |||
| cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) | |||
| unfin_idx = bbsz_idx // beam_size | |||
| sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) | |||
| # Create a set of "{sent}{unfin_idx}", where | |||
| # "unfin_idx" is the index in the current (possibly reduced) | |||
| # list of sentences, and "sent" is the index in the original, | |||
| # unreduced batch | |||
| # For every finished beam item | |||
| # sentence index in the current (possibly reduced) batch | |||
| seen = (sent << 32) + unfin_idx | |||
| unique_seen: List[int] = torch.unique(seen).tolist() | |||
| if self.match_source_len: | |||
| condition = step > torch.index_select(src_lengths, 0, unfin_idx) | |||
| eos_scores = torch.where(condition, torch.tensor(-math.inf), | |||
| eos_scores) | |||
| sent_list: List[int] = sent.tolist() | |||
| for i in range(bbsz_idx.size()[0]): | |||
| # An input sentence (among those in a batch) is finished when | |||
| # beam_size hypotheses have been collected for it | |||
| if len(finalized[sent_list[i]]) < beam_size: | |||
| if attn_clone is not None: | |||
| # remove padding tokens from attn scores | |||
| hypo_attn = attn_clone[i] | |||
| else: | |||
| hypo_attn = torch.empty(0) | |||
| finalized[sent_list[i]].append({ | |||
| 'tokens': | |||
| tokens_clone[i], | |||
| 'score': | |||
| eos_scores[i], | |||
| 'attention': | |||
| hypo_attn, # src_len x tgt_len | |||
| 'alignment': | |||
| torch.empty(0), | |||
| 'positional_scores': | |||
| pos_scores[i], | |||
| }) | |||
| newly_finished: List[int] = [] | |||
| for unique_s in unique_seen: | |||
| # check termination conditions for this sentence | |||
| unique_sent: int = unique_s >> 32 | |||
| unique_unfin_idx: int = unique_s - (unique_sent << 32) | |||
| if not finished[unique_sent] and self.is_finished( | |||
| step, unique_unfin_idx, max_len, len( | |||
| finalized[unique_sent]), beam_size): | |||
| finished[unique_sent] = True | |||
| newly_finished.append(unique_unfin_idx) | |||
| return newly_finished | |||
| def is_finished( | |||
| self, | |||
| step: int, | |||
| unfin_idx: int, | |||
| max_len: int, | |||
| finalized_sent_len: int, | |||
| beam_size: int, | |||
| ): | |||
| """ | |||
| Check whether decoding for a sentence is finished, which | |||
| occurs when the list of finalized sentences has reached the | |||
| beam size, or when we reach the maximum length. | |||
| """ | |||
| assert finalized_sent_len <= beam_size | |||
| if finalized_sent_len == beam_size or step == max_len: | |||
| return True | |||
| return False | |||
| class EnsembleModel(nn.Module): | |||
| """A wrapper around an ensemble of models.""" | |||
| def __init__(self, models): | |||
| super().__init__() | |||
| self.models_size = len(models) | |||
| # method '__len__' is not supported in ModuleList for torch script | |||
| self.single_model = models[0] | |||
| self.models = nn.ModuleList(models) | |||
| # self.has_incremental: bool = False | |||
| # if all( | |||
| # hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) | |||
| # for m in models | |||
| # ): | |||
| # self.has_incremental = True | |||
| self.has_incremental = True | |||
| def forward(self): | |||
| pass | |||
| def has_encoder(self): | |||
| return hasattr(self.single_model, 'encoder') | |||
| def has_incremental_states(self): | |||
| return self.has_incremental | |||
| def max_decoder_positions(self): | |||
| return min([ | |||
| m.max_decoder_positions() | |||
| for m in self.models if hasattr(m, 'max_decoder_positions') | |||
| ] + [sys.maxsize]) # | |||
| @torch.jit.export | |||
| def forward_encoder(self, net_input: Dict[str, Tensor]): | |||
| if not self.has_encoder(): | |||
| return None | |||
| encoder_input = { | |||
| k: v | |||
| for k, v in net_input.items() if k != 'decoder_input_ids' | |||
| } | |||
| encoder_input['output_hidden_states'] = True | |||
| return [ | |||
| model.encoder.forward(**encoder_input) for model in self.models | |||
| ] | |||
| @torch.jit.export | |||
| def forward_decoder(self, | |||
| tokens, | |||
| encoder_outs: List[Dict[str, List[Tensor]]], | |||
| incremental_states: List[Optional[torch.Tensor]], | |||
| temperature: float = 1.0, | |||
| constraint_trie=None, | |||
| constraint_start=None, | |||
| constraint_end=None, | |||
| gen_code=False, | |||
| zero_shot=False, | |||
| prefix_tokens=None): | |||
| log_probs = [] | |||
| avg_attn: Optional[Tensor] = None | |||
| encoder_out: Optional[Dict[str, List[Tensor]]] = None | |||
| code_mask = (tokens.new_ones(tokens.size(0)) * gen_code).bool() | |||
| for i, model in enumerate(self.models): | |||
| if self.has_encoder(): | |||
| encoder_out = encoder_outs[i] | |||
| encoder_hidden_states = encoder_out.last_hidden_state | |||
| encoder_attention_mask = _expand_mask( | |||
| encoder_out.padding_mask, encoder_hidden_states.dtype, | |||
| tokens.shape[-1]) | |||
| src_pos_embed = encoder_out.position_embedding | |||
| # if tokens.eq(self.single_model.config.pad_token_id).any(): | |||
| attention_mask = tokens.eq(self.single_model.padding_idx) | |||
| # decode each model | |||
| if self.has_incremental_states(): | |||
| decoder_out = model.decoder.forward( # todo 模型输入不同 | |||
| input_ids=tokens, | |||
| attention_mask=attention_mask, | |||
| encoder_hidden_states=encoder_hidden_states, | |||
| encoder_attention_mask=encoder_attention_mask, | |||
| code_masks=code_mask, | |||
| src_pos_embed=src_pos_embed, | |||
| past_key_values=incremental_states[i], | |||
| use_cache=True, | |||
| output_attentions=True) | |||
| else: | |||
| if hasattr(model, 'decoder'): | |||
| # decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out) | |||
| decoder_out = model.decoder.forward( # todo 模型输入不同 | |||
| input_ids=tokens, | |||
| attention_mask=attention_mask, | |||
| encoder_hidden_states=encoder_hidden_states, | |||
| encoder_attention_mask=encoder_attention_mask, | |||
| code_masks=code_mask, | |||
| src_pos_embed=src_pos_embed) | |||
| else: | |||
| decoder_out = model.forward(tokens) | |||
| # print('#### decoder_out ####', decoder_out) | |||
| # print('#### decoder_out ####', decoder_out.keys()) | |||
| # for k,v in decoder_out.items(): | |||
| # print(k) | |||
| # if isinstance(v, Tensor): | |||
| # print(v.shape) | |||
| # elif k == "past_key_values": | |||
| # print(len(v)) | |||
| # print([v[0][i].shape for i in range(len(v[0]))]) | |||
| # else: | |||
| # print(len(v)) | |||
| # print([v[i].shape for i in range(len(v))]) | |||
| attn: Optional[Tensor] = None | |||
| decoder_len = len(decoder_out) | |||
| # if decoder_len > 1 and decoder_out[1] is not None: | |||
| # if isinstance(decoder_out[1], Tensor): | |||
| # attn = decoder_out[1] | |||
| # else: | |||
| # attn_holder = decoder_out[1]["attn"] | |||
| # if isinstance(attn_holder, Tensor): | |||
| # attn = attn_holder | |||
| # elif attn_holder is not None: | |||
| # attn = attn_holder[0] | |||
| # if attn is not None: | |||
| # attn = attn[:, -1, :] | |||
| if 'cross_attentions' in decoder_out: | |||
| attn = decoder_out['cross_attentions'][-1].transpose(1, 0) | |||
| attn = attn.mean(dim=0) # (B, tgt_len, src_len) | |||
| if attn is not None: | |||
| attn = attn[:, -1, :] | |||
| # decoder_out_tuple = ( | |||
| # decoder_out[0][:, -1:, :].div_(temperature), | |||
| # None if decoder_len <= 1 else decoder_out[1], | |||
| # ) | |||
| decoder_out_tuple = ( | |||
| decoder_out[0][:, -1:, :].div_(temperature), | |||
| None if decoder_len <= 1 else attn, | |||
| ) | |||
| beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size( | |||
| 0) if prefix_tokens is not None else 0 | |||
| if constraint_trie is not None and not zero_shot: | |||
| assert constraint_start is None and constraint_end is None | |||
| constraint_masks = decoder_out_tuple[0].new_zeros( | |||
| decoder_out_tuple[0].size()).bool() | |||
| constraint_prefix_tokens = tokens.tolist() | |||
| for token_index, constraint_prefix_token in enumerate( | |||
| constraint_prefix_tokens): | |||
| prefix_len = prefix_tokens[token_index // beam_size].ne( | |||
| 1).sum().item() if prefix_tokens is not None else 0 | |||
| if len(constraint_prefix_token) > prefix_len: | |||
| constraint_prefix_token = [ | |||
| 0 | |||
| ] + constraint_prefix_token[prefix_len + 1:] | |||
| constraint_nodes = constraint_trie.get_next_layer( | |||
| constraint_prefix_token) | |||
| constraint_masks[token_index][:, | |||
| constraint_nodes] = True | |||
| else: | |||
| constraint_masks[token_index] = True | |||
| decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf) | |||
| if constraint_start is not None and constraint_end is not None and not zero_shot: | |||
| assert constraint_trie is None | |||
| decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf | |||
| decoder_out_tuple[0][:, :, constraint_end:] = -math.inf | |||
| probs = model.get_normalized_probs( | |||
| decoder_out_tuple, log_probs=True, sample=None) | |||
| if constraint_trie is not None and zero_shot: | |||
| assert constraint_start is None and constraint_end is None | |||
| constraint_masks = decoder_out_tuple[0].new_zeros( | |||
| decoder_out_tuple[0].size()).bool() | |||
| constraint_prefix_tokens = tokens.tolist() | |||
| for token_index, constraint_prefix_token in enumerate( | |||
| constraint_prefix_tokens): | |||
| constraint_nodes = constraint_trie.get_next_layer( | |||
| constraint_prefix_token) | |||
| constraint_masks[token_index][:, constraint_nodes] = True | |||
| probs.masked_fill_(~constraint_masks, -math.inf) | |||
| if constraint_start is not None and constraint_end is not None and zero_shot: | |||
| assert constraint_trie is None | |||
| probs[:, :, 4:constraint_start] = -math.inf | |||
| probs[:, :, constraint_end:] = -math.inf | |||
| probs = probs[:, -1, :] | |||
| if self.models_size == 1: | |||
| return probs, attn | |||
| log_probs.append(probs) | |||
| if attn is not None: | |||
| if avg_attn is None: | |||
| avg_attn = attn | |||
| else: | |||
| avg_attn.add_(attn) | |||
| avg_probs = torch.logsumexp( | |||
| torch.stack(log_probs, dim=0), dim=0) - math.log(self.models_size) | |||
| if avg_attn is not None: | |||
| avg_attn.div_(self.models_size) | |||
| return avg_probs, avg_attn | |||
| @torch.jit.export | |||
| def reorder_encoder_out(self, | |||
| encoder_outs: Optional[List[Dict[str, | |||
| List[Tensor]]]], | |||
| new_order): | |||
| """ | |||
| Reorder encoder output according to *new_order*. | |||
| Args: | |||
| encoder_out: output from the ``forward()`` method | |||
| new_order (LongTensor): desired order | |||
| Returns: | |||
| *encoder_out* rearranged according to *new_order* | |||
| """ | |||
| new_outs: List[Dict[str, List[Tensor]]] = [] | |||
| if not self.has_encoder(): | |||
| return new_outs | |||
| for i, model in enumerate(self.models): | |||
| assert encoder_outs is not None | |||
| new_outs.append( | |||
| model.encoder.reorder_encoder_out(encoder_outs[i], new_order)) | |||
| return new_outs | |||
| @torch.jit.export | |||
| def reorder_incremental_state( | |||
| self, | |||
| incremental_states: List[Optional[torch.Tensor]], | |||
| new_order, | |||
| ): | |||
| if not self.has_incremental_states(): | |||
| return | |||
| for i, model in enumerate(self.models): | |||
| model.decoder.reorder_incremental_state_scripting( # todo | |||
| incremental_states[i], new_order) | |||
| @@ -0,0 +1,512 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||
| # | |||
| # This source code is licensed under the MIT license which can be found at | |||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
| """Implements tracking of constraints for a beam item. | |||
| A list of constraints is given as a list of one or more token | |||
| sequences, each of length at least one token. For example, for an input sentence | |||
| > Die maschinelle Übersetzung ist schwer zu kontrollieren. | |||
| We could have the constraints: | |||
| * to influence | |||
| * hard | |||
| There are two implementations: | |||
| * OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. | |||
| * UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. | |||
| The difference is that in the first, the constraints are assumed to be | |||
| in order; the algorithm will permit zero or more tokens between them. | |||
| In the second, the constraints are not ordered, so many orderings will | |||
| be explored. | |||
| The same sequence can be present any number of times, and will appear | |||
| that many times in the output. | |||
| """ | |||
| from collections import Counter | |||
| from typing import List, Set | |||
| import torch | |||
| class ConstraintState: | |||
| def __init__(self): | |||
| pass | |||
| def pack_constraints( | |||
| batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: | |||
| """Takes a list of list of constraints in tensor form (a list of | |||
| tensor constraints for each sentence) and transforms it into a | |||
| packed Tensor. For example, here is a batch of size 3 with 3, 0, | |||
| and 1 constraints: | |||
| [ [ [3 1 2], [3], [4 5 6 7], ] | |||
| [], | |||
| [ [1 8 9 10 1 4 11 12], ] | |||
| ] | |||
| Its corresponding packed structure is: | |||
| [ [ 3 3 1 2 0 3 0 4 5 6 7 0], | |||
| [ 0 0 0 0 0 0 0 0 0 0 0 0], | |||
| [ 1 1 8 9 10 1 4 11 12 0 0 0] ] | |||
| The packed tensor has shape (batch size, maxlen), where | |||
| maxlen is defined below. Each row contains concatenated | |||
| constraint tokens for that sentence, with 0 appended after | |||
| each constraint. The first item in each row is the number | |||
| of constraints for that sentence. So maxlen is the maximum | |||
| of | |||
| (number of constraints) + (sum length of constraints) + 1. | |||
| across all sentences in the batch. | |||
| """ | |||
| # The maximum word length of concatenated constraints for any sentence | |||
| max_constraints_len = 1 | |||
| for sentence_constraints in batch_constraints: | |||
| if len(sentence_constraints): | |||
| # number of constraints, plus sum of constrain lens, plus a zero after each | |||
| constraints_len = (1 | |||
| + sum([c.size(0) for c in sentence_constraints]) | |||
| + len(sentence_constraints)) | |||
| max_constraints_len = max(max_constraints_len, constraints_len) | |||
| batch_size = len(batch_constraints) | |||
| constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() | |||
| for i, sentence_constraints in enumerate(batch_constraints): | |||
| constraints_tensor[i, 0] = len(sentence_constraints) | |||
| offset = 1 | |||
| for j, constraint in enumerate(sentence_constraints): | |||
| this_len = constraint.size(0) | |||
| constraints_tensor[i, offset:offset + this_len] = constraint | |||
| offset += this_len + 1 | |||
| return constraints_tensor.long() | |||
| def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: | |||
| """ | |||
| Transforms *one row* of a packed constraint tensor (e.g., for one | |||
| sentence in the batch) into a list of constraint tensors. | |||
| """ | |||
| constraint_list = [] | |||
| num_constraints = constraint_tensor[0] | |||
| constraints = constraint_tensor.tolist() | |||
| offset = 1 | |||
| for i in range(num_constraints): | |||
| where = constraints.index(0, offset) | |||
| constraint_list.append(constraint_tensor[offset:where]) | |||
| offset = where + 1 | |||
| return constraint_list | |||
| class ConstraintNode: | |||
| """ | |||
| Represents a node in a trie managing unordered constraints. | |||
| """ | |||
| def __init__(self, token: int = None, parent=None): | |||
| # The token associate with this node (None for the root) | |||
| self.token = int(token) if token is not None else None | |||
| # The parent (None at the root) | |||
| self.parent = parent | |||
| # Whether this node is a completed constraint | |||
| self.terminal = 0 | |||
| # List of child nodes | |||
| self.children = {} | |||
| # The cumulative number of constraints from this point in the | |||
| # trie forward | |||
| self.num_constraints = 0 | |||
| @property | |||
| def id(self): | |||
| return self.token | |||
| def __str__(self): | |||
| term = self.terminal != 0 | |||
| return f'[{self.token}].{term}#{self.num_constraints}' | |||
| def __getitem__(self, key: int): | |||
| return self.children.get(key, None) | |||
| def next_tokens(self) -> Set[int]: | |||
| """The set of child labels.""" | |||
| return set(self.children.keys()) | |||
| @staticmethod | |||
| def create(constraints: List[List[int]]): | |||
| root = ConstraintNode() | |||
| for sequence in constraints: | |||
| root.add_sequence(sequence) | |||
| return root | |||
| @staticmethod | |||
| def print_graph(node: 'ConstraintNode'): | |||
| if len(node.children) == 0: | |||
| return str(node) | |||
| else: | |||
| s = f'({node}' | |||
| for child in node.children.values(): | |||
| s += ' ' + ConstraintNode.print_graph(child) | |||
| s += ')' | |||
| return s | |||
| def token_counts(self) -> Counter: | |||
| """Returns a counter of the number of times each token is used | |||
| in a constraint. | |||
| """ | |||
| token_counts = Counter() | |||
| kids = list(self.children.values()) | |||
| while len(kids) > 0: | |||
| kid = kids.pop() | |||
| token_counts[kid.id] += kid.num_constraints | |||
| kids += list(kid.children.values()) | |||
| return token_counts | |||
| def tokens(self) -> Set[int]: | |||
| """Returns the set of tokens in constraints.""" | |||
| return set(self.token_counts().keys()) | |||
| def add_sequence(self, sequence: List[int]): | |||
| """Adds a constraint, represented as a list of integers, to | |||
| the trie.""" | |||
| assert len(sequence) > 0 | |||
| token = int(sequence[0]) | |||
| if token not in self.children: | |||
| self.children[token] = ConstraintNode(token, parent=self) | |||
| node = self.children[token] | |||
| if len(sequence) == 1: | |||
| node.terminal += 1 | |||
| node.num_constraints += 1 | |||
| parent = node.parent | |||
| while parent is not None: | |||
| parent.num_constraints += 1 | |||
| parent = parent.parent | |||
| else: | |||
| node.add_sequence(sequence[1:]) | |||
| class UnorderedConstraintState(ConstraintState): | |||
| """ | |||
| Records progress through the set of constraints for each item in the beam | |||
| using a trie. | |||
| """ | |||
| def __init__(self, | |||
| node: ConstraintNode, | |||
| copy_from: 'ConstraintState' = None): | |||
| self.node = node | |||
| if copy_from is None: | |||
| # The root node | |||
| self.root = node | |||
| # The set of states in the graph that have been completed | |||
| self.completed = Counter() | |||
| # The... | |||
| self.generated = Counter() | |||
| # The list of tokens we need to generate | |||
| self.needed_tokens = self.root.tokens() | |||
| else: | |||
| self.completed = Counter(copy_from.completed) | |||
| self.generated = Counter(copy_from.generated) | |||
| self.root = copy_from.root | |||
| # Mark the node as generated | |||
| if self.node != self.root: | |||
| self.generated[node] += 1 | |||
| @staticmethod | |||
| def create(constraint_tensor: torch.Tensor): | |||
| constraint_list = unpack_constraints(constraint_tensor) | |||
| constraint_trie_root = ConstraintNode.create(constraint_list) | |||
| return UnorderedConstraintState(constraint_trie_root) | |||
| def __str__(self): | |||
| gen_str = ','.join([str(node) for node in self.generated]) | |||
| return f'{self.name}/{self.bank}({gen_str})x{self.num_completed}' | |||
| def __copy__(self): | |||
| copied_state = UnorderedConstraintState(self.node, copy_from=self) | |||
| return copied_state | |||
| def copy(self): | |||
| return self.__copy__() | |||
| @property | |||
| def name(self): | |||
| if self.node.id is None: | |||
| return 'ROOT' | |||
| else: | |||
| return str(self.node.id) | |||
| @property | |||
| def is_root(self): | |||
| return self.node == self.root | |||
| @property | |||
| def bank(self): | |||
| return sum(self.generated.values()) | |||
| @property | |||
| def num_completed(self): | |||
| """The number of constraints (not constraint tokens) that are completed. | |||
| In addition to the already-completed states, we need to account for the | |||
| current state, which might get marked as completed when another token | |||
| is generated. | |||
| """ | |||
| in_final = self.node.terminal and self.completed[ | |||
| self.node] < self.node.terminal | |||
| return sum(self.completed.values()) + in_final | |||
| @property | |||
| def finished(self): | |||
| return self.root.num_constraints - self.num_completed == 0 | |||
| @property | |||
| def token_counts(self): | |||
| return self.root.token_counts() | |||
| @property | |||
| def tokens(self): | |||
| return self.root.tokens() | |||
| @property | |||
| def num_constraint_tokens(self): | |||
| return sum(self.token_counts.values()) | |||
| def next_tokens(self) -> Set[int]: | |||
| """Returns the list of tokens that could come next. | |||
| These are (a) all tokens extending the root state and, for | |||
| non-root states, additionally all tokens extending the current | |||
| state.""" | |||
| if self.node != self.root: | |||
| return self.root.next_tokens().union(self.node.next_tokens()) | |||
| else: | |||
| return self.root.next_tokens() | |||
| def advance(self, token: int): | |||
| """Reads in a token and advances the state. Here's how it works. | |||
| We can advance to the next state if: | |||
| - there is a matching child | |||
| - its path isn't blocked | |||
| A path is blocked when all constraints that are descendants of | |||
| that node have already been generated, in the current state. | |||
| If we are not able to advance from the current state, we "fall | |||
| off the graph" and return to the root state. There, we again | |||
| try to advance, checking the same criteria. | |||
| In any case, when falling off the graph, we need to do some | |||
| bookkeeping. We: | |||
| - check whether any constraints were met (all prefixes of | |||
| current state) | |||
| - if one is found, mark it as completed | |||
| - adjust visited nodes accordingly | |||
| """ | |||
| token = int(token) | |||
| next_state = None | |||
| child = self.node[token] | |||
| if child is not None and self.generated[child] < child.num_constraints: | |||
| next_state = UnorderedConstraintState(child, copy_from=self) | |||
| def rewind(): | |||
| """If we're mid-trie and an "illegal" token is chosen next, we need | |||
| to reset our state to the root state. However, along the way, we need | |||
| to check whether a prefix of the current trie state represents a state | |||
| we could mark as completed. | |||
| """ | |||
| node = self.node | |||
| while node != self.root: | |||
| if node.terminal and self.completed[node] < node.terminal: | |||
| next_state.completed[node] += 1 | |||
| return | |||
| next_state.generated[node] -= 1 | |||
| node = node.parent | |||
| # Fall off the graph, check the root | |||
| if next_state is None and token in self.root.next_tokens(): | |||
| child = self.root[token] | |||
| # We can only traverse this edge if it's not saturated | |||
| if self.generated[child] < child.num_constraints: | |||
| next_state = UnorderedConstraintState(child, copy_from=self) | |||
| else: | |||
| next_state = UnorderedConstraintState( | |||
| self.root, copy_from=self) | |||
| # Rewind | |||
| rewind() | |||
| elif next_state is None: | |||
| next_state = UnorderedConstraintState(self.root, copy_from=self) | |||
| # Rewind | |||
| rewind() | |||
| return next_state | |||
| class ConstraintSequence: | |||
| def __init__(self, sequences: List[List[int]]): | |||
| """Represents a set of possibly multitoken constraints by | |||
| concatenating them and internally recording the end points. | |||
| """ | |||
| self.sequences = [] | |||
| self.endpoints = [] | |||
| self.num_tokens = 0 | |||
| self.tokens = set() | |||
| for sequence in sequences: | |||
| for token in sequence: | |||
| self.tokens.add(token) | |||
| self.num_tokens += len(sequence) | |||
| self.endpoints += [False | |||
| for x in range(len(sequence) - 1)] + [True] | |||
| self.sequences += sequence | |||
| def __getitem__(self, key: int): | |||
| return self.sequences[key] | |||
| def __len__(self): | |||
| return len(self.sequences) | |||
| def __str__(self): | |||
| return str(self.sequences) | |||
| class OrderedConstraintState(ConstraintState): | |||
| """ | |||
| Records progress through the set of linear nonbranching constraints with gaps. | |||
| """ | |||
| def __init__(self, sequence: ConstraintSequence, state: int = -1): | |||
| self.sequence = sequence | |||
| self.state = state | |||
| @staticmethod | |||
| def create(constraint_tensor: torch.Tensor): | |||
| constraint_list = unpack_constraints(constraint_tensor) | |||
| return OrderedConstraintState(ConstraintSequence(constraint_list), -1) | |||
| def __str__(self): | |||
| return f'{self.state}/{self.bank}x{self.num_completed}' | |||
| def __copy__(self): | |||
| return OrderedConstraintState(self.sequence, self.state) | |||
| def copy(self): | |||
| return self.__copy__() | |||
| @property | |||
| def num_completed(self): | |||
| if self.state == -1: | |||
| return 0 | |||
| count = len( | |||
| list( | |||
| filter(lambda x: x, | |||
| self.sequence.endpoints[0:self.state + 1]))) | |||
| return count | |||
| @property | |||
| def is_root(self): | |||
| return self.state == -1 | |||
| @property | |||
| def name(self): | |||
| if self.state == -1: | |||
| return 'ROOT' | |||
| else: | |||
| return str(self.sequence[self.state]) | |||
| @property | |||
| def bank(self) -> int: | |||
| return self.state + 1 | |||
| @property | |||
| def finished(self): | |||
| return self.state + 1 == len(self.sequence) | |||
| @property | |||
| def token_counts(self): | |||
| return self.sequence.token_counts() | |||
| @property | |||
| def tokens(self): | |||
| return self.sequence.tokens | |||
| @property | |||
| def num_constraint_tokens(self): | |||
| return sum(self.token_counts.values()) | |||
| def next_tokens(self) -> Set[int]: | |||
| """Returns the list of tokens that could come next. | |||
| These are (a) all tokens extending the root state and, for | |||
| non-root states, additionally all tokens extending the current | |||
| state.""" | |||
| tokens = set() | |||
| if self.state > 0: | |||
| tokens.add(self.sequence[0]) | |||
| if not self.finished: | |||
| tokens.add(self.sequence[self.state + 1]) | |||
| return tokens | |||
| def advance(self, token: int): | |||
| """Reads in a token and advances the state. Here's how it works. | |||
| We can advance to the next state if: | |||
| - there is a matching child | |||
| - its path isn't blocked | |||
| A path is blocked when all constraints that are descendants of | |||
| that node have already been generated, in the current state. | |||
| If we are not able to advance from the current state, we "fall | |||
| off the graph" and return to the root state. There, we again | |||
| try to advance, checking the same criteria. | |||
| In any case, when falling off the graph, we need to do some | |||
| bookkeeping. We: | |||
| - check whether any constraints were met (all prefixes of | |||
| current state) | |||
| - if one is found, mark it as completed | |||
| - adjust visited nodes accordingly | |||
| """ | |||
| token = int(token) | |||
| # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") | |||
| if self.finished: | |||
| # Accept anything | |||
| next_state = self.copy() | |||
| elif self.sequence[self.state + 1] == token: | |||
| # Advance to the next token | |||
| next_state = OrderedConstraintState(self.sequence, self.state + 1) | |||
| elif self.sequence.endpoints[self.state]: | |||
| # Accept anything between constraints (*) | |||
| next_state = self.copy() | |||
| elif token == self.sequence[0]: | |||
| # Start over having generated the first token | |||
| next_state = OrderedConstraintState(self.sequence, 0) | |||
| else: | |||
| # Start over from the root | |||
| next_state = OrderedConstraintState(self.sequence, -1) | |||
| return next_state | |||
| @@ -0,0 +1,124 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||
| # | |||
| # This source code is licensed under the MIT license which can be found at | |||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||
| import collections | |||
| from collections import abc | |||
| from itertools import accumulate | |||
| import torch | |||
| import torch.nn.functional as F | |||
| try: | |||
| from amp_C import multi_tensor_l2norm | |||
| multi_tensor_l2norm_available = True | |||
| except ImportError: | |||
| multi_tensor_l2norm_available = False | |||
| try: | |||
| import torch_xla.core.xla_model as xm | |||
| except ImportError: | |||
| xm = None | |||
| MANIFOLD_PATH_SEP = '|' | |||
| def apply_to_sample(f, sample): | |||
| if hasattr(sample, '__len__') and len(sample) == 0: | |||
| return {} | |||
| def _apply(x): | |||
| if torch.is_tensor(x): | |||
| return f(x) | |||
| elif isinstance(x, collections.OrderedDict): | |||
| # OrderedDict has attributes that needs to be preserved | |||
| od = collections.OrderedDict( | |||
| (key, _apply(value)) for key, value in x.items()) | |||
| od.__dict__ = x.__dict__ | |||
| return od | |||
| elif isinstance(x, dict): | |||
| return {key: _apply(value) for key, value in x.items()} | |||
| elif isinstance(x, list): | |||
| return [_apply(x) for x in x] | |||
| elif isinstance(x, tuple): | |||
| return tuple(_apply(x) for x in x) | |||
| elif isinstance(x, set): | |||
| return {_apply(x) for x in x} | |||
| else: | |||
| return x | |||
| return _apply(sample) | |||
| def move_to_device(batch, device): | |||
| r"""Puts each data field to the device""" | |||
| if isinstance(batch, torch.Tensor): | |||
| return batch.to(device) | |||
| elif isinstance(batch, (list, tuple)): | |||
| return tuple(move_to_device(item, device) for item in batch) | |||
| elif isinstance(batch, abc.Mapping): | |||
| return { | |||
| key: move_to_device(value, device) | |||
| for key, value in batch.items() | |||
| } | |||
| else: | |||
| return batch | |||
| def strip_pad(tensor, pad): | |||
| return tensor[tensor.ne(pad)] | |||
| def get_token_to_word_mapping(tokens, exclude_list): | |||
| n = len(tokens) | |||
| word_start = [int(token not in exclude_list) for token in tokens] | |||
| word_idx = list(accumulate(word_start)) | |||
| token_to_word = {i: word_idx[i] for i in range(n)} | |||
| return token_to_word | |||
| def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): | |||
| tgt_valid = (((tgt_sent != pad) & # noqa | |||
| (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)) | |||
| src_invalid = (((src_sent == pad) | # noqa | |||
| (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)) | |||
| src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) | |||
| tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) | |||
| alignment = [] | |||
| if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): | |||
| attn_valid = attn[tgt_valid] | |||
| attn_valid[:, src_invalid] = float('-inf') | |||
| _, src_indices = attn_valid.max(dim=1) | |||
| for tgt_idx, src_idx in zip(tgt_valid, src_indices): | |||
| alignment.append(( | |||
| src_token_to_word[src_idx.item()] - 1, | |||
| tgt_token_to_word[tgt_idx.item()] - 1, | |||
| )) | |||
| return alignment | |||
| def softmax(x, dim: int, onnx_trace: bool = False): | |||
| if onnx_trace: | |||
| return F.softmax(x.float(), dim=dim) | |||
| else: | |||
| return F.softmax(x, dim=dim, dtype=torch.float32) | |||
| def log_softmax(x, dim: int, onnx_trace: bool = False): | |||
| if onnx_trace: | |||
| return F.log_softmax(x.float(), dim=dim) | |||
| else: | |||
| return F.log_softmax(x, dim=dim, dtype=torch.float32) | |||
| def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): | |||
| tgt_valid = (tgt_sent != pad).nonzero(as_tuple=False) | |||
| src_valid = (src_sent != pad).nonzero(as_tuple=False).squeeze(dim=-1) | |||
| alignment = [] | |||
| if len(tgt_valid) != 0 and len(src_valid) != 0: | |||
| attn_valid = attn[tgt_valid, src_valid] | |||
| alignment = [['{:.6f}'.format(p) for p in src_probs.tolist()] | |||
| for src_probs in attn_valid] | |||
| return alignment | |||
| @@ -0,0 +1,283 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| def drop_path(x, drop_prob: float = 0., training: bool = False): | |||
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |||
| the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper... | |||
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |||
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use | |||
| 'survival rate' as the argument. | |||
| """ | |||
| if drop_prob == 0. or not training: | |||
| return x | |||
| keep_prob = 1 - drop_prob | |||
| shape = (x.shape[0], ) + (1, ) * ( | |||
| x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |||
| random_tensor = keep_prob + torch.rand( | |||
| shape, dtype=x.dtype, device=x.device) | |||
| random_tensor.floor_() # binarize | |||
| output = x.div(keep_prob) * random_tensor | |||
| return output | |||
| class DropPath(nn.Module): | |||
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| """ | |||
| def __init__(self, drop_prob=None): | |||
| super(DropPath, self).__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, x): | |||
| return drop_path(x, self.drop_prob, self.training) | |||
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
| """3x3 convolution with padding""" | |||
| return nn.Conv2d( | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=dilation, | |||
| groups=groups, | |||
| bias=False, | |||
| dilation=dilation) | |||
| def conv1x1(in_planes, out_planes, stride=1): | |||
| """1x1 convolution""" | |||
| return nn.Conv2d( | |||
| in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
| class BasicBlock(nn.Module): | |||
| expansion = 1 | |||
| def __init__(self, | |||
| inplanes, | |||
| planes, | |||
| stride=1, | |||
| downsample=None, | |||
| groups=1, | |||
| base_width=64, | |||
| dilation=1, | |||
| norm_layer=None): | |||
| super(BasicBlock, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| if groups != 1 or base_width != 64: | |||
| raise ValueError( | |||
| 'BasicBlock only supports groups=1 and base_width=64') | |||
| if dilation > 1: | |||
| raise NotImplementedError( | |||
| 'Dilation > 1 not supported in BasicBlock') | |||
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||
| self.bn1 = norm_layer(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.conv2 = conv3x3(planes, planes) | |||
| self.bn2 = norm_layer(planes) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| assert False | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class Bottleneck(nn.Module): | |||
| # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||
| # while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||
| # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||
| # This variant is also known as ResNet V1.5 and improves accuracy according to | |||
| # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||
| expansion = 4 | |||
| def __init__(self, | |||
| inplanes, | |||
| planes, | |||
| stride=1, | |||
| downsample=None, | |||
| groups=1, | |||
| base_width=64, | |||
| dilation=1, | |||
| norm_layer=None, | |||
| drop_path_rate=0.0): | |||
| super(Bottleneck, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| width = int(planes * (base_width / 64.)) * groups | |||
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||
| self.conv1 = conv1x1(inplanes, width) | |||
| self.bn1 = norm_layer(width) | |||
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||
| self.bn2 = norm_layer(width) | |||
| self.conv3 = conv1x1(width, planes * self.expansion) | |||
| self.bn3 = norm_layer(planes * self.expansion) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| self.drop_path = DropPath( | |||
| drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out = identity + self.drop_path(out) | |||
| out = self.relu(out) | |||
| return out | |||
| class ResNet(nn.Module): | |||
| def __init__(self, | |||
| layers, | |||
| zero_init_residual=False, | |||
| groups=1, | |||
| width_per_group=64, | |||
| replace_stride_with_dilation=None, | |||
| norm_layer=None, | |||
| drop_path_rate=0.0): | |||
| super(ResNet, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| self._norm_layer = norm_layer | |||
| self.inplanes = 64 | |||
| self.dilation = 1 | |||
| if replace_stride_with_dilation is None: | |||
| # each element in the tuple indicates if we should replace | |||
| # the 2x2 stride with a dilated convolution instead | |||
| replace_stride_with_dilation = [False, False, False] | |||
| if len(replace_stride_with_dilation) != 3: | |||
| raise ValueError('replace_stride_with_dilation should be None ' | |||
| 'or a 3-element tuple, got {}'.format( | |||
| replace_stride_with_dilation)) | |||
| self.groups = groups | |||
| self.base_width = width_per_group | |||
| self.conv1 = nn.Conv2d( | |||
| 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||
| self.bn1 = norm_layer(self.inplanes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| self.layer1 = self._make_layer( | |||
| Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate) | |||
| self.layer2 = self._make_layer( | |||
| Bottleneck, | |||
| 128, | |||
| layers[1], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[0], | |||
| drop_path_rate=drop_path_rate) | |||
| self.layer3 = self._make_layer( | |||
| Bottleneck, | |||
| 256, | |||
| layers[2], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[1], | |||
| drop_path_rate=drop_path_rate) | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| elif isinstance(m, | |||
| (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| # Zero-initialize the last BN in each residual branch, | |||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||
| if zero_init_residual: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck): | |||
| nn.init.constant_(m.bn3.weight, 0) | |||
| elif isinstance(m, BasicBlock): | |||
| nn.init.constant_(m.bn2.weight, 0) | |||
| def _make_layer(self, | |||
| block, | |||
| planes, | |||
| blocks, | |||
| stride=1, | |||
| dilate=False, | |||
| drop_path_rate=0.0): | |||
| norm_layer = self._norm_layer | |||
| downsample = None | |||
| previous_dilation = self.dilation | |||
| if dilate: | |||
| self.dilation *= stride | |||
| stride = 1 | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.Sequential( | |||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||
| norm_layer(planes * block.expansion), | |||
| ) | |||
| layers = [] | |||
| layers.append( | |||
| block(self.inplanes, planes, stride, downsample, self.groups, | |||
| self.base_width, previous_dilation, norm_layer)) | |||
| self.inplanes = planes * block.expansion | |||
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)] | |||
| for i in range(1, blocks): | |||
| layers.append( | |||
| block( | |||
| self.inplanes, | |||
| planes, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=self.dilation, | |||
| norm_layer=norm_layer, | |||
| drop_path_rate=dpr[i])) | |||
| return nn.Sequential(*layers) | |||
| def _forward_impl(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| return x | |||
| def forward(self, x): | |||
| return self._forward_impl(x) | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright 2022 OFA-Sys Team. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for OFA.""" | |||
| from transformers.models.bart.tokenization_bart import BartTokenizer | |||
| from transformers.utils import logging | |||
| logger = logging.get_logger(__name__) | |||
| VOCAB_FILES_NAMES = {'vocab_file': 'vocab.json', 'merges_file': 'merges.txt'} | |||
| PRETRAINED_VOCAB_FILES_MAP = { | |||
| 'vocab_file': { | |||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', | |||
| }, | |||
| 'merges_file': { | |||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||
| }, | |||
| } | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
| 'ofa-base': 1024, | |||
| } | |||
| class OFATokenizer(BartTokenizer): | |||
| """ | |||
| Construct a OFA tokenizer. | |||
| [`~OFATokenizer`] is identical to [`BartTokenizer`] and runs end-to-end tokenization: punctuation splitting and | |||
| wordpiece. | |||
| Refer to superclass [`BartTokenizer`] for usage examples and documentation concerning parameters. | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
| @@ -0,0 +1,59 @@ | |||
| # Copyright 2022 OFA-Sys Team. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for OFA.""" | |||
| from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast | |||
| from transformers.utils import logging | |||
| from .tokenization_ofa import OFATokenizer | |||
| logger = logging.get_logger(__name__) | |||
| VOCAB_FILES_NAMES = { | |||
| 'vocab_file': 'vocab.json', | |||
| 'merges_file': 'merges.txt', | |||
| 'tokenizer_file': 'tokenizer.json' | |||
| } | |||
| PRETRAINED_VOCAB_FILES_MAP = { | |||
| 'vocab_file': { | |||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', | |||
| }, | |||
| 'merges_file': { | |||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||
| }, | |||
| 'tokenizer_file': { | |||
| 'ofa-base': | |||
| 'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', | |||
| }, | |||
| } | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
| 'ofa-base': 1024, | |||
| } | |||
| class OFATokenizerFast(BartTokenizerFast): | |||
| r""" | |||
| Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). | |||
| [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting | |||
| and wordpiece. | |||
| Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
| slow_tokenizer_class = OFATokenizer | |||
| @@ -0,0 +1,53 @@ | |||
| from typing import Any, Dict | |||
| import torch.cuda | |||
| from modelscope.metainfo import Models | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| from .ofa import OFAModel, OFATokenizer | |||
| from .ofa.generate import sequence_generator as sg | |||
| from .ofa.generate.utils import move_to_device | |||
| __all__ = ['OfaForImageCaptioning'] | |||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | |||
| class OfaForImageCaptioning(Model): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||
| model = OFAModel.from_pretrained(model_dir) | |||
| self.model = model.module if hasattr(model, 'module') else model | |||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
| self._device = torch.device('cuda') if torch.cuda.is_available() \ | |||
| else torch.device('cpu') | |||
| self.model.to(self._device) | |||
| # Initialize generator | |||
| sg_args = { | |||
| 'tokenizer': self.tokenizer, | |||
| 'beam_size': 5, | |||
| 'max_len_b': 16, | |||
| 'min_len': 1, | |||
| 'no_repeat_ngram_size': 3, | |||
| 'constraint_range': None | |||
| } | |||
| if hasattr(kwargs, 'beam_search'): | |||
| sg_args.update(kwargs['beam_search']) | |||
| self.generator = sg.SequenceGenerator(**sg_args) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| input = move_to_device(input, self._device) | |||
| gen_output = self.generator.generate([self.model], input) | |||
| gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | |||
| result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) | |||
| return {'image_id': '42', OutputKeys.CAPTION: result[0]} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| # What should we do here ? | |||
| return inputs | |||
| @@ -44,7 +44,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/nlp_space_dialog-modeling'), | |||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||
| 'damo/nlp_space_dialog-state-tracking'), | |||
| Tasks.image_captioning: (Pipelines.image_caption, | |||
| Tasks.image_captioning: (Pipelines.image_captioning, | |||
| 'damo/ofa_image-caption_coco_large_en'), | |||
| Tasks.image_generation: | |||
| (Pipelines.person_image_cartoon, | |||
| @@ -11,7 +11,7 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_captioning, module_name=Pipelines.image_caption) | |||
| Tasks.image_captioning, module_name=Pipelines.image_captioning) | |||
| class ImageCaptionPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -2,13 +2,14 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields, ModelFile | |||
| from modelscope.models.multi_modal.ofa import OFATokenizer | |||
| from modelscope.utils.constant import Fields | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| @@ -31,84 +32,39 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
| model_dir) | |||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
| if osp.exists(model_dir): | |||
| local_model_dir = model_dir | |||
| else: | |||
| local_model_dir = snapshot_download(model_dir) | |||
| local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| bpe_dir = local_model_dir | |||
| from fairseq import checkpoint_utils, tasks, utils | |||
| from ofa.tasks.mm_tasks import CaptionTask | |||
| tasks.register_task('caption', CaptionTask) | |||
| overrides = { | |||
| 'bpe_dir': bpe_dir, | |||
| 'eval_cider': False, | |||
| 'beam': 5, | |||
| 'max_len_b': 16, | |||
| 'no_repeat_ngram_size': 3, | |||
| 'seed': 7 | |||
| } | |||
| model, cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |||
| utils.split_paths(local_model), arg_overrides=overrides) | |||
| del model | |||
| # Initialize transform | |||
| from torchvision import transforms | |||
| mean = [0.5, 0.5, 0.5] | |||
| std = [0.5, 0.5, 0.5] | |||
| patch_image_size = 480 | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize( | |||
| (cfg.task.patch_image_size, cfg.task.patch_image_size), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.Resize((patch_image_size, patch_image_size), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=mean, std=std), | |||
| ]) | |||
| self.task = task | |||
| self.bos_item = torch.LongTensor([task.src_dict.bos()]) | |||
| self.eos_item = torch.LongTensor([task.src_dict.eos()]) | |||
| self.pad_idx = task.src_dict.pad() | |||
| @type_assert(object, (str, tuple, Image.Image)) | |||
| def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | |||
| def encode_text(text, length=None, append_bos=False, append_eos=False): | |||
| s = self.task.tgt_dict.encode_line( | |||
| line=self.task.bpe.encode(text), | |||
| add_if_not_exist=False, | |||
| append_eos=False).long() | |||
| if length is not None: | |||
| s = s[:length] | |||
| if append_bos: | |||
| s = torch.cat([self.bos_item, s]) | |||
| if append_eos: | |||
| s = torch.cat([s, self.eos_item]) | |||
| return s | |||
| if isinstance(data, Image.Image): | |||
| patch_image = self.patch_resize_transform(data).unsqueeze(0) | |||
| else: | |||
| patch_image = self.patch_resize_transform( | |||
| load_image(data)).unsqueeze(0) | |||
| patch_mask = torch.tensor([True]) | |||
| text = 'what does the image describe?' | |||
| src_text = encode_text( | |||
| text, append_bos=True, append_eos=True).unsqueeze(0) | |||
| src_length = torch.LongTensor( | |||
| [s.ne(self.pad_idx).long().sum() for s in src_text]) | |||
| sample = { | |||
| 'id': np.array(['42']), | |||
| 'net_input': { | |||
| 'src_tokens': src_text, | |||
| 'src_lengths': src_length, | |||
| 'patch_images': patch_image, | |||
| 'patch_masks': patch_mask, | |||
| } | |||
| text = ' what does the image describe?' | |||
| inputs = self.tokenizer([text], max_length=1024, | |||
| return_tensors='pt')['input_ids'] | |||
| sample = dict() | |||
| sample['net_input'] = { | |||
| 'input_ids': inputs, | |||
| 'patch_images': patch_image, | |||
| 'patch_masks': torch.tensor([True]) | |||
| } | |||
| return sample | |||
| @@ -14,7 +14,7 @@ class ImageCaptionTest(unittest.TestCase): | |||
| def test_run(self): | |||
| img_captioning = pipeline( | |||
| Tasks.image_captioning, | |||
| model='damo/ofa_image-caption_coco_large_en') | |||
| model='damo/ofa_image-caption_coco_distilled_en') | |||
| result = img_captioning('data/test/images/image_captioning.png') | |||
| print(result[OutputKeys.CAPTION]) | |||