| @@ -73,7 +73,7 @@ class Pipelines(object): | |||||
| asr_inference = 'asr-inference' | asr_inference = 'asr-inference' | ||||
| # multi-modal tasks | # multi-modal tasks | ||||
| image_caption = 'image-captioning' | |||||
| image_captioning = 'image-captioning' | |||||
| multi_modal_embedding = 'multi-modal-embedding' | multi_modal_embedding = 'multi-modal-embedding' | ||||
| visual_question_answering = 'visual-question-answering' | visual_question_answering = 'visual-question-answering' | ||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| @@ -1,5 +1,5 @@ | |||||
| from .clip.clip_model import CLIPForMultiModalEmbedding | from .clip.clip_model import CLIPForMultiModalEmbedding | ||||
| from .image_captioning_model import OfaForImageCaptioning | |||||
| from .imagen.imagen_model import ImagenForTextToImageSynthesis | from .imagen.imagen_model import ImagenForTextToImageSynthesis | ||||
| from .mplug_for_visual_question_answering import \ | from .mplug_for_visual_question_answering import \ | ||||
| MPlugForVisualQuestionAnswering | MPlugForVisualQuestionAnswering | ||||
| from .ofa_for_image_captioning_model import OfaForImageCaptioning | |||||
| @@ -0,0 +1,2 @@ | |||||
| from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | |||||
| from .tokenization_ofa import OFATokenizer | |||||
| @@ -0,0 +1,194 @@ | |||||
| # Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ OFA model configuration""" | |||||
| import warnings | |||||
| from transformers import PretrainedConfig | |||||
| from transformers.utils import logging | |||||
| logger = logging.get_logger(__name__) | |||||
| OFA_PRETRAINED_CONFIG_ARCHIVE_MAP = { | |||||
| 'ofa-medium': 'https://huggingface.co/ofa-base/resolve/main/config.json', | |||||
| # OFA models are implemeted to be compatible with both huggingface | |||||
| # and modelscope frameworks. For all OFA models available on huggingface, | |||||
| # please refer to https://huggingface.co/models?filter=ofa | |||||
| } | |||||
| class OFAConfig(PretrainedConfig): | |||||
| r""" | |||||
| This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA | |||||
| model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |||||
| defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) | |||||
| architecture. | |||||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||||
| documentation from [`PretrainedConfig`] for more information. | |||||
| Args: | |||||
| vocab_size (`int`, *optional*, defaults to 50265): | |||||
| Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the | |||||
| `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. | |||||
| d_model (`int`, *optional*, defaults to 1024): | |||||
| Dimension of the layers and the pooler layer. | |||||
| encoder_layers (`int`, *optional*, defaults to 12): | |||||
| Number of encoder layers. | |||||
| decoder_layers (`int`, *optional*, defaults to 12): | |||||
| Number of decoder layers. | |||||
| encoder_attention_heads (`int`, *optional*, defaults to 16): | |||||
| Number of attention heads for each attention layer in the Transformer encoder. | |||||
| decoder_attention_heads (`int`, *optional*, defaults to 16): | |||||
| Number of attention heads for each attention layer in the Transformer decoder. | |||||
| decoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||||
| encoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||||
| activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): | |||||
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |||||
| `"relu"`, `"silu"` and `"gelu_new"` are supported. | |||||
| dropout (`float`, *optional*, defaults to 0.1): | |||||
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||||
| attention_dropout (`float`, *optional*, defaults to 0.0): | |||||
| The dropout ratio for the attention probabilities. | |||||
| activation_dropout (`float`, *optional*, defaults to 0.0): | |||||
| The dropout ratio for activations inside the fully connected layer. | |||||
| classifier_dropout (`float`, *optional*, defaults to 0.0): | |||||
| The dropout ratio for classifier. | |||||
| max_position_embeddings (`int`, *optional*, defaults to 1024): | |||||
| The maximum sequence length that this model might ever be used with. Typically set this to something large | |||||
| just in case (e.g., 512 or 1024 or 2048). | |||||
| init_std (`float`, *optional*, defaults to 0.02): | |||||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||||
| encoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||||
| The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||||
| for more details. | |||||
| decoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||||
| The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||||
| for more details. | |||||
| use_cache (`bool`, *optional*, defaults to `True`): | |||||
| Whether or not the model should return the last key/values attentions (not used by all models). | |||||
| """ | |||||
| model_type = 'ofa' | |||||
| keys_to_ignore_at_inference = ['past_key_values'] | |||||
| attribute_map = { | |||||
| 'num_attention_heads': 'encoder_attention_heads', | |||||
| 'hidden_size': 'd_model' | |||||
| } | |||||
| def __init__(self, | |||||
| vocab_size=59457, | |||||
| max_position_embeddings=1024, | |||||
| encoder_layers=4, | |||||
| encoder_ffn_dim=512 * 4, | |||||
| encoder_attention_heads=8, | |||||
| decoder_layers=4, | |||||
| decoder_ffn_dim=512 * 4, | |||||
| decoder_attention_heads=8, | |||||
| encoder_layerdrop=0.0, | |||||
| decoder_layerdrop=0.0, | |||||
| use_cache=True, | |||||
| is_encoder_decoder=True, | |||||
| activation_function='gelu', | |||||
| d_model=512, | |||||
| dropout=0.1, | |||||
| attention_dropout=0.0, | |||||
| activation_dropout=0.0, | |||||
| init_std=0.02, | |||||
| classifier_dropout=0.0, | |||||
| scale_embedding=False, | |||||
| pad_token_id=1, | |||||
| bos_token_id=0, | |||||
| decoder_start_token_id=0, | |||||
| eos_token_id=2, | |||||
| forced_eos_token_id=2, | |||||
| encoder_normalize_before=True, | |||||
| decoder_normalize_before=True, | |||||
| normformer=True, | |||||
| encoder_drop_path_rate=0.0, | |||||
| decoder_drop_path_rate=0.0, | |||||
| layernorm_embedding=True, | |||||
| patch_layernorm_embedding=True, | |||||
| resnet_type='resnet101', | |||||
| resnet_model_path=None, | |||||
| resnet_drop_path_rate=0.0, | |||||
| token_bucket_size=256, | |||||
| image_bucket_size=42, | |||||
| add_type_embedding=True, | |||||
| share_decoder_input_output_embed=True, | |||||
| attn_scale_factor=2., | |||||
| code_layernorm_embedding=True, | |||||
| code_image_size=128, | |||||
| entangle_position_embedding=False, | |||||
| **kwargs): | |||||
| self.vocab_size = vocab_size | |||||
| self.max_position_embeddings = max_position_embeddings | |||||
| self.d_model = d_model | |||||
| self.encoder_ffn_dim = encoder_ffn_dim | |||||
| self.encoder_layers = encoder_layers | |||||
| self.encoder_attention_heads = encoder_attention_heads | |||||
| self.decoder_ffn_dim = decoder_ffn_dim | |||||
| self.decoder_layers = decoder_layers | |||||
| self.decoder_attention_heads = decoder_attention_heads | |||||
| self.dropout = dropout | |||||
| self.attention_dropout = attention_dropout | |||||
| self.activation_dropout = activation_dropout | |||||
| self.activation_function = activation_function | |||||
| self.init_std = init_std | |||||
| self.encoder_layerdrop = encoder_layerdrop | |||||
| self.decoder_layerdrop = decoder_layerdrop | |||||
| self.classifier_dropout = classifier_dropout | |||||
| self.use_cache = use_cache | |||||
| self.num_hidden_layers = encoder_layers | |||||
| self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True | |||||
| self.encoder_normalize_before = encoder_normalize_before | |||||
| self.decoder_normalize_before = decoder_normalize_before | |||||
| self.normformer = normformer | |||||
| self.encoder_drop_path_rate = encoder_drop_path_rate | |||||
| self.decoder_drop_path_rate = decoder_drop_path_rate | |||||
| self.layernorm_embedding = layernorm_embedding | |||||
| self.patch_layernorm_embedding = patch_layernorm_embedding | |||||
| self.resnet_type = resnet_type | |||||
| self.resnet_model_path = resnet_model_path | |||||
| self.resnet_drop_path_rate = resnet_drop_path_rate | |||||
| self.token_bucket_size = token_bucket_size | |||||
| self.image_bucket_size = image_bucket_size | |||||
| self.add_type_embedding = add_type_embedding | |||||
| self.share_decoder_input_output_embed = share_decoder_input_output_embed | |||||
| self.attn_scale_factor = attn_scale_factor | |||||
| self.code_layernorm_embedding = code_layernorm_embedding | |||||
| self.code_image_size = code_image_size | |||||
| self.entangle_position_embedding = entangle_position_embedding | |||||
| super().__init__( | |||||
| pad_token_id=pad_token_id, | |||||
| bos_token_id=bos_token_id, | |||||
| eos_token_id=eos_token_id, | |||||
| is_encoder_decoder=is_encoder_decoder, | |||||
| decoder_start_token_id=decoder_start_token_id, | |||||
| forced_eos_token_id=forced_eos_token_id, | |||||
| **kwargs, | |||||
| ) | |||||
| # ensure backward compatibility for BART CNN models | |||||
| if self.forced_bos_token_id is None and kwargs.get( | |||||
| 'force_bos_token_to_be_generated', False): | |||||
| self.forced_bos_token_id = self.bos_token_id | |||||
| warnings.warn( | |||||
| f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' | |||||
| 'The config can simply be saved and uploaded again to be fixed.' | |||||
| ) | |||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||||
| # | |||||
| # This source code is licensed under the MIT license which can be found at | |||||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||||
| import uuid | |||||
| from typing import Dict, Optional | |||||
| from torch import Tensor | |||||
| class FairseqIncrementalState(object): | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.init_incremental_state() | |||||
| def init_incremental_state(self): | |||||
| self._incremental_state_id = str(uuid.uuid4()) | |||||
| def _get_full_incremental_state_key(self, key: str) -> str: | |||||
| return '{}.{}'.format(self._incremental_state_id, key) | |||||
| def get_incremental_state( | |||||
| self, | |||||
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |||||
| key: str, | |||||
| ) -> Optional[Dict[str, Optional[Tensor]]]: | |||||
| """Helper for getting incremental state for an nn.Module.""" | |||||
| full_key = self._get_full_incremental_state_key(key) | |||||
| if incremental_state is None or full_key not in incremental_state: | |||||
| return None | |||||
| return incremental_state[full_key] | |||||
| def set_incremental_state( | |||||
| self, | |||||
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |||||
| key: str, | |||||
| value: Dict[str, Optional[Tensor]], | |||||
| ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: | |||||
| """Helper for setting incremental state for an nn.Module.""" | |||||
| if incremental_state is not None: | |||||
| full_key = self._get_full_incremental_state_key(key) | |||||
| incremental_state[full_key] = value | |||||
| return incremental_state | |||||
| def with_incremental_state(cls): | |||||
| cls.__bases__ = (FairseqIncrementalState, ) + tuple( | |||||
| b for b in cls.__bases__ if b != FairseqIncrementalState) | |||||
| return cls | |||||
| @@ -0,0 +1,510 @@ | |||||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||||
| # | |||||
| # This source code is licensed under the MIT license which can be found at | |||||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||||
| import math | |||||
| from typing import Dict, Optional, Tuple | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from fairseq import utils | |||||
| from fairseq.incremental_decoding_utils import with_incremental_state | |||||
| from fairseq.modules.fairseq_dropout import FairseqDropout | |||||
| from fairseq.modules.quant_noise import quant_noise | |||||
| from torch import Tensor, nn | |||||
| from torch.nn import Parameter | |||||
| @with_incremental_state | |||||
| class MultiheadAttention(nn.Module): | |||||
| """Multi-headed attention. | |||||
| See "Attention Is All You Need" for more details. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| embed_dim, | |||||
| num_heads, | |||||
| kdim=None, | |||||
| vdim=None, | |||||
| dropout=0.0, | |||||
| bias=True, | |||||
| add_bias_kv=False, | |||||
| add_zero_attn=False, | |||||
| self_attention=False, | |||||
| encoder_decoder_attention=False, | |||||
| q_noise=0.0, | |||||
| qn_block_size=8, | |||||
| ): | |||||
| super().__init__() | |||||
| self.embed_dim = embed_dim | |||||
| self.kdim = kdim if kdim is not None else embed_dim | |||||
| self.vdim = vdim if vdim is not None else embed_dim | |||||
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim | |||||
| self.num_heads = num_heads | |||||
| self.dropout_module = FairseqDropout( | |||||
| dropout, module_name=self.__class__.__name__) | |||||
| self.head_dim = embed_dim // num_heads | |||||
| assert (self.head_dim * num_heads == self.embed_dim | |||||
| ), 'embed_dim must be divisible by num_heads' | |||||
| self.scaling = self.head_dim**-0.5 | |||||
| self.self_attention = self_attention | |||||
| self.encoder_decoder_attention = encoder_decoder_attention | |||||
| assert not self.self_attention or self.qkv_same_dim, ( | |||||
| 'Self-attention requires query, key and ' | |||||
| 'value to be of the same size') | |||||
| self.k_proj = quant_noise( | |||||
| nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) | |||||
| self.v_proj = quant_noise( | |||||
| nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) | |||||
| self.q_proj = quant_noise( | |||||
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) | |||||
| self.out_proj = quant_noise( | |||||
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) | |||||
| if add_bias_kv: | |||||
| self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) | |||||
| self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) | |||||
| else: | |||||
| self.bias_k = self.bias_v = None | |||||
| self.add_zero_attn = add_zero_attn | |||||
| self.reset_parameters() | |||||
| self.onnx_trace = False | |||||
| def prepare_for_onnx_export_(self): | |||||
| self.onnx_trace = True | |||||
| def reset_parameters(self): | |||||
| if self.qkv_same_dim: | |||||
| # Empirically observed the convergence to be much better with | |||||
| # the scaled initialization | |||||
| nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) | |||||
| nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) | |||||
| nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) | |||||
| else: | |||||
| nn.init.xavier_uniform_(self.k_proj.weight) | |||||
| nn.init.xavier_uniform_(self.v_proj.weight) | |||||
| nn.init.xavier_uniform_(self.q_proj.weight) | |||||
| nn.init.xavier_uniform_(self.out_proj.weight) | |||||
| if self.out_proj.bias is not None: | |||||
| nn.init.constant_(self.out_proj.bias, 0.0) | |||||
| if self.bias_k is not None: | |||||
| nn.init.xavier_normal_(self.bias_k) | |||||
| if self.bias_v is not None: | |||||
| nn.init.xavier_normal_(self.bias_v) | |||||
| def forward( | |||||
| self, | |||||
| query, | |||||
| key: Optional[Tensor], | |||||
| value: Optional[Tensor], | |||||
| key_padding_mask: Optional[Tensor] = None, | |||||
| incremental_state: Optional[Dict[str, Dict[str, | |||||
| Optional[Tensor]]]] = None, | |||||
| need_weights: bool = True, | |||||
| static_kv: bool = False, | |||||
| attn_mask: Optional[Tensor] = None, | |||||
| before_softmax: bool = False, | |||||
| need_head_weights: bool = False, | |||||
| ) -> Tuple[Tensor, Optional[Tensor]]: | |||||
| """Input shape: Time x Batch x Channel | |||||
| Args: | |||||
| key_padding_mask (ByteTensor, optional): mask to exclude | |||||
| keys that are pads, of shape `(batch, src_len)`, where | |||||
| padding elements are indicated by 1s. | |||||
| need_weights (bool, optional): return the attention weights, | |||||
| averaged over heads (default: False). | |||||
| attn_mask (ByteTensor, optional): typically used to | |||||
| implement causal attention, where the mask prevents the | |||||
| attention from looking forward in time (default: None). | |||||
| before_softmax (bool, optional): return the raw attention | |||||
| weights and values before the attention softmax. | |||||
| need_head_weights (bool, optional): return the attention | |||||
| weights for each head. Implies *need_weights*. Default: | |||||
| return the average attention weights over all heads. | |||||
| """ | |||||
| if need_head_weights: | |||||
| need_weights = True | |||||
| is_tpu = query.device.type == 'xla' | |||||
| tgt_len, bsz, embed_dim = query.size() | |||||
| src_len = tgt_len | |||||
| assert embed_dim == self.embed_dim, f'query dim {embed_dim} != {self.embed_dim}' | |||||
| assert list(query.size()) == [tgt_len, bsz, embed_dim] | |||||
| if key is not None: | |||||
| src_len, key_bsz, _ = key.size() | |||||
| if not torch.jit.is_scripting(): | |||||
| assert key_bsz == bsz | |||||
| assert value is not None | |||||
| assert src_len, bsz == value.shape[:2] | |||||
| if (not self.onnx_trace | |||||
| and not is_tpu # don't use PyTorch version on TPUs | |||||
| and incremental_state is None and not static_kv | |||||
| # A workaround for quantization to work. Otherwise JIT compilation | |||||
| # treats bias in linear module as method. | |||||
| and not torch.jit.is_scripting()): | |||||
| assert key is not None and value is not None | |||||
| return F.multi_head_attention_forward( | |||||
| query, | |||||
| key, | |||||
| value, | |||||
| self.embed_dim, | |||||
| self.num_heads, | |||||
| torch.empty([0]), | |||||
| torch.cat( | |||||
| (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), | |||||
| self.bias_k, | |||||
| self.bias_v, | |||||
| self.add_zero_attn, | |||||
| self.dropout_module.p, | |||||
| self.out_proj.weight, | |||||
| self.out_proj.bias, | |||||
| self.training or self.dropout_module.apply_during_inference, | |||||
| key_padding_mask, | |||||
| need_weights, | |||||
| attn_mask, | |||||
| use_separate_proj_weight=True, | |||||
| q_proj_weight=self.q_proj.weight, | |||||
| k_proj_weight=self.k_proj.weight, | |||||
| v_proj_weight=self.v_proj.weight, | |||||
| ) | |||||
| if incremental_state is not None: | |||||
| saved_state = self._get_input_buffer(incremental_state) | |||||
| if saved_state is not None and 'prev_key' in saved_state: | |||||
| # previous time steps are cached - no need to recompute | |||||
| # key and value if they are static | |||||
| if static_kv: | |||||
| assert self.encoder_decoder_attention and not self.self_attention | |||||
| key = value = None | |||||
| else: | |||||
| saved_state = None | |||||
| if self.self_attention: | |||||
| q = self.q_proj(query) | |||||
| k = self.k_proj(query) | |||||
| v = self.v_proj(query) | |||||
| elif self.encoder_decoder_attention: | |||||
| # encoder-decoder attention | |||||
| q = self.q_proj(query) | |||||
| if key is None: | |||||
| assert value is None | |||||
| k = v = None | |||||
| else: | |||||
| k = self.k_proj(key) | |||||
| v = self.v_proj(key) | |||||
| else: | |||||
| assert key is not None and value is not None | |||||
| q = self.q_proj(query) | |||||
| k = self.k_proj(key) | |||||
| v = self.v_proj(value) | |||||
| q *= self.scaling | |||||
| if self.bias_k is not None: | |||||
| assert self.bias_v is not None | |||||
| k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) | |||||
| v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) | |||||
| if attn_mask is not None: | |||||
| attn_mask = torch.cat( | |||||
| [attn_mask, | |||||
| attn_mask.new_zeros(attn_mask.size(0), 1)], | |||||
| dim=1) | |||||
| if key_padding_mask is not None: | |||||
| key_padding_mask = torch.cat( | |||||
| [ | |||||
| key_padding_mask, | |||||
| key_padding_mask.new_zeros( | |||||
| key_padding_mask.size(0), 1), | |||||
| ], | |||||
| dim=1, | |||||
| ) | |||||
| q = ( | |||||
| q.contiguous().view(tgt_len, bsz * self.num_heads, | |||||
| self.head_dim).transpose(0, 1)) | |||||
| if k is not None: | |||||
| k = ( | |||||
| k.contiguous().view(-1, bsz * self.num_heads, | |||||
| self.head_dim).transpose(0, 1)) | |||||
| if v is not None: | |||||
| v = ( | |||||
| v.contiguous().view(-1, bsz * self.num_heads, | |||||
| self.head_dim).transpose(0, 1)) | |||||
| if saved_state is not None: | |||||
| # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) | |||||
| if 'prev_key' in saved_state: | |||||
| _prev_key = saved_state['prev_key'] | |||||
| assert _prev_key is not None | |||||
| prev_key = _prev_key.view(bsz * self.num_heads, -1, | |||||
| self.head_dim) | |||||
| if static_kv: | |||||
| k = prev_key | |||||
| else: | |||||
| assert k is not None | |||||
| k = torch.cat([prev_key, k], dim=1) | |||||
| src_len = k.size(1) | |||||
| if 'prev_value' in saved_state: | |||||
| _prev_value = saved_state['prev_value'] | |||||
| assert _prev_value is not None | |||||
| prev_value = _prev_value.view(bsz * self.num_heads, -1, | |||||
| self.head_dim) | |||||
| if static_kv: | |||||
| v = prev_value | |||||
| else: | |||||
| assert v is not None | |||||
| v = torch.cat([prev_value, v], dim=1) | |||||
| prev_key_padding_mask: Optional[Tensor] = None | |||||
| if 'prev_key_padding_mask' in saved_state: | |||||
| prev_key_padding_mask = saved_state['prev_key_padding_mask'] | |||||
| assert k is not None and v is not None | |||||
| key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( | |||||
| key_padding_mask=key_padding_mask, | |||||
| prev_key_padding_mask=prev_key_padding_mask, | |||||
| batch_size=bsz, | |||||
| src_len=k.size(1), | |||||
| static_kv=static_kv, | |||||
| ) | |||||
| saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, | |||||
| self.head_dim) | |||||
| saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, | |||||
| self.head_dim) | |||||
| saved_state['prev_key_padding_mask'] = key_padding_mask | |||||
| # In this branch incremental_state is never None | |||||
| assert incremental_state is not None | |||||
| incremental_state = self._set_input_buffer(incremental_state, | |||||
| saved_state) | |||||
| assert k is not None | |||||
| assert k.size(1) == src_len | |||||
| # This is part of a workaround to get around fork/join parallelism | |||||
| # not supporting Optional types. | |||||
| if key_padding_mask is not None and key_padding_mask.dim() == 0: | |||||
| key_padding_mask = None | |||||
| if key_padding_mask is not None: | |||||
| assert key_padding_mask.size(0) == bsz | |||||
| assert key_padding_mask.size(1) == src_len | |||||
| if self.add_zero_attn: | |||||
| assert v is not None | |||||
| src_len += 1 | |||||
| k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], | |||||
| dim=1) | |||||
| v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], | |||||
| dim=1) | |||||
| if attn_mask is not None: | |||||
| attn_mask = torch.cat( | |||||
| [attn_mask, | |||||
| attn_mask.new_zeros(attn_mask.size(0), 1)], | |||||
| dim=1) | |||||
| if key_padding_mask is not None: | |||||
| key_padding_mask = torch.cat( | |||||
| [ | |||||
| key_padding_mask, | |||||
| torch.zeros(key_padding_mask.size(0), | |||||
| 1).type_as(key_padding_mask), | |||||
| ], | |||||
| dim=1, | |||||
| ) | |||||
| attn_weights = torch.bmm(q, k.transpose(1, 2)) | |||||
| attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, | |||||
| bsz) | |||||
| assert list( | |||||
| attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] | |||||
| if attn_mask is not None: | |||||
| attn_mask = attn_mask.unsqueeze(0) | |||||
| if self.onnx_trace: | |||||
| attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) | |||||
| attn_weights += attn_mask | |||||
| if key_padding_mask is not None: | |||||
| # don't attend to padding symbols | |||||
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, | |||||
| src_len) | |||||
| if not is_tpu: | |||||
| attn_weights = attn_weights.masked_fill( | |||||
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), | |||||
| float('-inf'), | |||||
| ) | |||||
| else: | |||||
| attn_weights = attn_weights.transpose(0, 2) | |||||
| attn_weights = attn_weights.masked_fill( | |||||
| key_padding_mask, float('-inf')) | |||||
| attn_weights = attn_weights.transpose(0, 2) | |||||
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, | |||||
| src_len) | |||||
| if before_softmax: | |||||
| return attn_weights, v | |||||
| attn_weights_float = utils.softmax( | |||||
| attn_weights, dim=-1, onnx_trace=self.onnx_trace) | |||||
| attn_weights = attn_weights_float.type_as(attn_weights) | |||||
| attn_probs = self.dropout_module(attn_weights) | |||||
| assert v is not None | |||||
| attn = torch.bmm(attn_probs, v) | |||||
| assert list( | |||||
| attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] | |||||
| if self.onnx_trace and attn.size(1) == 1: | |||||
| # when ONNX tracing a single decoder step (sequence length == 1) | |||||
| # the transpose is a no-op copy before view, thus unnecessary | |||||
| attn = attn.contiguous().view(tgt_len, bsz, embed_dim) | |||||
| else: | |||||
| attn = attn.transpose(0, | |||||
| 1).contiguous().view(tgt_len, bsz, embed_dim) | |||||
| attn = self.out_proj(attn) | |||||
| attn_weights: Optional[Tensor] = None | |||||
| if need_weights: | |||||
| attn_weights = attn_weights_float.view(bsz, self.num_heads, | |||||
| tgt_len, | |||||
| src_len).transpose(1, 0) | |||||
| if not need_head_weights: | |||||
| # average attention weights over heads | |||||
| attn_weights = attn_weights.mean(dim=0) | |||||
| return attn, attn_weights | |||||
| @staticmethod | |||||
| def _append_prev_key_padding_mask( | |||||
| key_padding_mask: Optional[Tensor], | |||||
| prev_key_padding_mask: Optional[Tensor], | |||||
| batch_size: int, | |||||
| src_len: int, | |||||
| static_kv: bool, | |||||
| ) -> Optional[Tensor]: | |||||
| # saved key padding masks have shape (bsz, seq_len) | |||||
| if prev_key_padding_mask is not None and static_kv: | |||||
| new_key_padding_mask = prev_key_padding_mask | |||||
| elif prev_key_padding_mask is not None and key_padding_mask is not None: | |||||
| new_key_padding_mask = torch.cat( | |||||
| [prev_key_padding_mask.float(), | |||||
| key_padding_mask.float()], | |||||
| dim=1) | |||||
| # During incremental decoding, as the padding token enters and | |||||
| # leaves the frame, there will be a time when prev or current | |||||
| # is None | |||||
| elif prev_key_padding_mask is not None: | |||||
| if src_len > prev_key_padding_mask.size(1): | |||||
| filler = torch.zeros( | |||||
| (batch_size, src_len - prev_key_padding_mask.size(1)), | |||||
| device=prev_key_padding_mask.device, | |||||
| ) | |||||
| new_key_padding_mask = torch.cat( | |||||
| [prev_key_padding_mask.float(), | |||||
| filler.float()], dim=1) | |||||
| else: | |||||
| new_key_padding_mask = prev_key_padding_mask.float() | |||||
| elif key_padding_mask is not None: | |||||
| if src_len > key_padding_mask.size(1): | |||||
| filler = torch.zeros( | |||||
| (batch_size, src_len - key_padding_mask.size(1)), | |||||
| device=key_padding_mask.device, | |||||
| ) | |||||
| new_key_padding_mask = torch.cat( | |||||
| [filler.float(), key_padding_mask.float()], dim=1) | |||||
| else: | |||||
| new_key_padding_mask = key_padding_mask.float() | |||||
| else: | |||||
| new_key_padding_mask = prev_key_padding_mask | |||||
| return new_key_padding_mask | |||||
| @torch.jit.export | |||||
| def reorder_incremental_state( | |||||
| self, | |||||
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||||
| new_order: Tensor, | |||||
| ): | |||||
| """Reorder buffered internal state (for incremental generation).""" | |||||
| input_buffer = self._get_input_buffer(incremental_state) | |||||
| if input_buffer is not None: | |||||
| for k in input_buffer.keys(): | |||||
| input_buffer_k = input_buffer[k] | |||||
| if input_buffer_k is not None: | |||||
| if self.encoder_decoder_attention and input_buffer_k.size( | |||||
| 0) == new_order.size(0): | |||||
| break | |||||
| input_buffer[k] = input_buffer_k.index_select(0, new_order) | |||||
| incremental_state = self._set_input_buffer(incremental_state, | |||||
| input_buffer) | |||||
| return incremental_state | |||||
| def _get_input_buffer( | |||||
| self, incremental_state: Optional[Dict[str, Dict[str, | |||||
| Optional[Tensor]]]] | |||||
| ) -> Dict[str, Optional[Tensor]]: | |||||
| result = self.get_incremental_state(incremental_state, 'attn_state') | |||||
| if result is not None: | |||||
| return result | |||||
| else: | |||||
| empty_result: Dict[str, Optional[Tensor]] = {} | |||||
| return empty_result | |||||
| def _set_input_buffer( | |||||
| self, | |||||
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||||
| buffer: Dict[str, Optional[Tensor]], | |||||
| ): | |||||
| return self.set_incremental_state(incremental_state, 'attn_state', | |||||
| buffer) | |||||
| def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, | |||||
| bsz: int): | |||||
| return attn_weights | |||||
| def upgrade_state_dict_named(self, state_dict, name): | |||||
| prefix = name + '.' if name != '' else '' | |||||
| items_to_add = {} | |||||
| keys_to_remove = [] | |||||
| for k in state_dict.keys(): | |||||
| if k.endswith(prefix + 'in_proj_weight'): | |||||
| # in_proj_weight used to be q + k + v with same dimensions | |||||
| dim = int(state_dict[k].shape[0] / 3) | |||||
| items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim] | |||||
| items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2 | |||||
| * dim] | |||||
| items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2 | |||||
| * dim:] | |||||
| keys_to_remove.append(k) | |||||
| k_bias = prefix + 'in_proj_bias' | |||||
| if k_bias in state_dict.keys(): | |||||
| dim = int(state_dict[k].shape[0] / 3) | |||||
| items_to_add[prefix | |||||
| + 'q_proj.bias'] = state_dict[k_bias][:dim] | |||||
| items_to_add[prefix | |||||
| + 'k_proj.bias'] = state_dict[k_bias][dim:2 | |||||
| * dim] | |||||
| items_to_add[prefix | |||||
| + 'v_proj.bias'] = state_dict[k_bias][2 | |||||
| * dim:] | |||||
| keys_to_remove.append(prefix + 'in_proj_bias') | |||||
| for k in keys_to_remove: | |||||
| del state_dict[k] | |||||
| for key, value in items_to_add.items(): | |||||
| state_dict[key] = value | |||||
| @@ -0,0 +1,155 @@ | |||||
| # Originally from Microsoft Corporation. | |||||
| # Licensed under the MIT License. | |||||
| """ Wrapper for ngram_repeat_block cuda extension """ | |||||
| import math | |||||
| import warnings | |||||
| from typing import Dict, List | |||||
| import torch | |||||
| from torch import nn | |||||
| try: | |||||
| from fairseq import ngram_repeat_block_cuda | |||||
| EXTENSION_BUILT = True | |||||
| except ImportError: | |||||
| EXTENSION_BUILT = False | |||||
| def is_cuda_extension_usable() -> bool: | |||||
| """Check whether ngram_repeat_block_cuda is built properly""" | |||||
| if not EXTENSION_BUILT or not torch.cuda.is_available(): | |||||
| return False | |||||
| bsz = 2 | |||||
| tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], | |||||
| dtype=torch.long, | |||||
| device='cuda') | |||||
| lprobs = torch.rand((8, 12), device='cuda') | |||||
| try: | |||||
| outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) | |||||
| outputs = outputs + 4 # This line breaks if the extension is built incorrectly. | |||||
| return True | |||||
| except RuntimeError: | |||||
| warnings.warn( | |||||
| 'NGramRepeatBlock extension must be rebuilt.' | |||||
| 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' | |||||
| ) | |||||
| return False | |||||
| class NGramRepeatBlock(nn.Module): | |||||
| """ Wrapper class for calling ngram_repeat_block cuda extension """ | |||||
| def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): | |||||
| super().__init__() | |||||
| self.use_extension = is_cuda_extension_usable( | |||||
| ) if use_extension else False | |||||
| self.no_repeat_ngram_size = no_repeat_ngram_size | |||||
| def reset_parameters(self): | |||||
| pass | |||||
| @torch.jit.unused | |||||
| def call_cuda_extension( | |||||
| self, | |||||
| tokens, | |||||
| lprobs, | |||||
| bsz: int, | |||||
| beam_size: int, | |||||
| step: int, | |||||
| ): | |||||
| return ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, step, | |||||
| beam_size, | |||||
| self.no_repeat_ngram_size) | |||||
| def forward( | |||||
| self, | |||||
| tokens, | |||||
| lprobs, | |||||
| bsz: int, | |||||
| beam_size: int, | |||||
| step: int, | |||||
| ): | |||||
| """ | |||||
| Args: | |||||
| tokens(Tensor): Input tokens(Bsz*beam, seq_len) | |||||
| lprobs(Tensor): likelihood probability, | |||||
| Expected to be updated in place.(Bsz*beam, vocab_size) | |||||
| bsz(int): batch size | |||||
| step(int): current step | |||||
| beam_size(int): beam size | |||||
| no_repeat_ngram_size(int): Ngram size | |||||
| """ | |||||
| msg = f'expected {bsz * beam_size} got' | |||||
| assert tokens.size(0) == bsz * beam_size, f'{msg} {tokens.size(0)}' | |||||
| assert lprobs.size(0) == bsz * beam_size, f'{msg} {lprobs.size(0)}' | |||||
| if self.use_extension: | |||||
| return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, | |||||
| step) | |||||
| else: | |||||
| return self._no_repeat_ngram( | |||||
| tokens, | |||||
| lprobs, | |||||
| bsz, | |||||
| beam_size, | |||||
| step, | |||||
| ) | |||||
| def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, | |||||
| step: int): | |||||
| """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" | |||||
| gen_ngrams: List[Dict[str, List[int]]] = [ | |||||
| torch.jit.annotate(Dict[str, List[int]], {}) | |||||
| for bbsz_idx in range(bsz * beam_size) | |||||
| ] | |||||
| cpu_tokens = tokens.cpu() | |||||
| for bbsz_idx in range(bsz * beam_size): | |||||
| gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() | |||||
| for ngram in self.transpose_list([ | |||||
| gen_tokens[i:] for i in range(self.no_repeat_ngram_size) | |||||
| ]): # noqa | |||||
| key = ','.join([str(x) for x in ngram[:-1]]) | |||||
| gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( | |||||
| key, torch.jit.annotate(List[int], [])) + [ngram[-1]] | |||||
| if step + 2 - self.no_repeat_ngram_size >= 0: | |||||
| # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |||||
| banned_tokens = [ | |||||
| self.calculate_banned_tokens(tokens, step, gen_ngrams, | |||||
| self.no_repeat_ngram_size, | |||||
| bbsz_idx) | |||||
| for bbsz_idx in range(bsz * beam_size) | |||||
| ] | |||||
| else: | |||||
| banned_tokens = [ | |||||
| torch.jit.annotate(List[int], []) | |||||
| for bbsz_idx in range(bsz * beam_size) | |||||
| ] | |||||
| for bbsz_idx in range(bsz * beam_size): | |||||
| lprobs[bbsz_idx][torch.tensor( | |||||
| banned_tokens[bbsz_idx], | |||||
| dtype=torch.int64)] = torch.tensor(-math.inf).to(lprobs) | |||||
| return lprobs | |||||
| @staticmethod | |||||
| def calculate_banned_tokens( | |||||
| tokens, | |||||
| step: int, | |||||
| gen_ngrams: List[Dict[str, List[int]]], | |||||
| no_repeat_ngram_size: int, | |||||
| bbsz_idx: int, | |||||
| ): | |||||
| tokens_list: List[int] = tokens[bbsz_idx, | |||||
| step + 2 - no_repeat_ngram_size:step | |||||
| + 1].tolist() # noqa | |||||
| # before decoding the next token, prevent decoding of ngrams that have already appeared | |||||
| ngram_index = ','.join([str(x) for x in tokens_list]) | |||||
| return gen_ngrams[bbsz_idx].get(ngram_index, | |||||
| torch.jit.annotate(List[int], [])) | |||||
| @staticmethod | |||||
| def transpose_list(l: List[List[int]]): # noqa | |||||
| # GeneratorExp aren't supported in TS so ignoring the lint | |||||
| min_len = min([len(x) for x in l]) # noqa | |||||
| l2 = [[row[i] for row in l] for i in range(min_len)] | |||||
| return l2 | |||||
| @@ -0,0 +1,848 @@ | |||||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||||
| # | |||||
| # This source code is licensed under the MIT license which can be found at | |||||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||||
| import math | |||||
| from typing import List, Optional | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch import Tensor | |||||
| from .token_generation_constraints import (ConstraintState, | |||||
| OrderedConstraintState, | |||||
| UnorderedConstraintState) | |||||
| class Search(nn.Module): | |||||
| def __init__(self, tokenizer): | |||||
| super().__init__() | |||||
| self.pad = tokenizer.pad_token_id | |||||
| self.unk = tokenizer.unk_token_id | |||||
| self.eos = tokenizer.eos_token_id | |||||
| tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} | |||||
| added = { | |||||
| value: key | |||||
| for key, value in tokenizer.get_added_vocab().items() | |||||
| } | |||||
| tgt_dict.update(added) | |||||
| self.vocab_size = len(tgt_dict) | |||||
| self.src_lengths = torch.tensor(-1) | |||||
| self.supports_constraints = False | |||||
| self.stop_on_max_len = False | |||||
| def step(self, | |||||
| step, | |||||
| lprobs, | |||||
| scores, | |||||
| prev_output_tokens=None, | |||||
| original_batch_idxs=None): | |||||
| """Take a single search step. | |||||
| Args: | |||||
| step: the current search step, starting at 0 | |||||
| lprobs: (bsz x input_beam_size x vocab_size) | |||||
| the model's log-probabilities over the vocabulary at the current step | |||||
| scores: (bsz x input_beam_size x step) | |||||
| the historical model scores of each hypothesis up to this point | |||||
| prev_output_tokens: (bsz x step) | |||||
| the previously generated oputput tokens | |||||
| original_batch_idxs: (bsz) | |||||
| the tensor with the batch indices, in the range [0, bsz) | |||||
| this is useful in case there has been applied a re-ordering | |||||
| and we need to know the orignal indices | |||||
| Return: A tuple of (scores, indices, beams) where: | |||||
| scores: (bsz x output_beam_size) | |||||
| the scores of the chosen elements; output_beam_size can be | |||||
| larger than input_beam_size, e.g., we may return | |||||
| 2*input_beam_size to account for EOS | |||||
| indices: (bsz x output_beam_size) | |||||
| the indices of the chosen elements | |||||
| beams: (bsz x output_beam_size) | |||||
| the hypothesis ids of the chosen elements, in the range [0, input_beam_size) | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @torch.jit.export | |||||
| def set_src_lengths(self, src_lengths): | |||||
| self.src_lengths = src_lengths | |||||
| @torch.jit.export | |||||
| def init_constraints(self, batch_constraints: Optional[Tensor], | |||||
| beam_size: int): | |||||
| """Initialize constraint states for constrained decoding (if supported). | |||||
| Args: | |||||
| batch_constraints: (torch.Tensor, optional) | |||||
| the list of constraints, in packed form | |||||
| beam_size: (int) | |||||
| the beam size | |||||
| Returns: | |||||
| *encoder_out* rearranged according to *new_order* | |||||
| """ | |||||
| pass | |||||
| def prune_sentences(self, batch_idxs: Tensor): | |||||
| """ | |||||
| Removes constraint states for completed sentences (if supported). | |||||
| This is called from sequence_generator._generate() when sentences are | |||||
| deleted from the batch. | |||||
| Args: | |||||
| batch_idxs: Indices of *sentences* whose constraint state should be *kept*. | |||||
| """ | |||||
| pass | |||||
| def update_constraints(self, active_hypos: Tensor): | |||||
| """ | |||||
| Updates the constraint states by selecting the beam items that are retained. | |||||
| This is called at each time step of sequence_generator._generate() when | |||||
| the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. | |||||
| Args: | |||||
| active_hypos: (batch size, beam size) | |||||
| list of integers denoting, for each sentence, which beam candidate items | |||||
| should be kept. | |||||
| """ | |||||
| pass | |||||
| class BeamSearch(Search): | |||||
| def __init__(self, tgt_dict): | |||||
| super().__init__(tgt_dict) | |||||
| self.constraint_states = None | |||||
| @torch.jit.export | |||||
| def step( | |||||
| self, | |||||
| step: int, | |||||
| lprobs, | |||||
| scores: Optional[Tensor], | |||||
| prev_output_tokens: Optional[Tensor] = None, | |||||
| original_batch_idxs: Optional[Tensor] = None, | |||||
| ): | |||||
| bsz, beam_size, vocab_size = lprobs.size() | |||||
| if step == 0: | |||||
| # at the first step all hypotheses are equally likely, so use | |||||
| # only the first beam | |||||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||||
| else: | |||||
| # make probs contain cumulative scores for each hypothesis | |||||
| assert scores is not None | |||||
| lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||||
| top_prediction = torch.topk( | |||||
| lprobs.view(bsz, -1), | |||||
| k=min( | |||||
| # Take the best 2 x beam_size predictions. We'll choose the first | |||||
| # beam_size of these which don't predict eos to continue with. | |||||
| beam_size * 2, | |||||
| lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||||
| ), | |||||
| ) | |||||
| scores_buf = top_prediction[0] | |||||
| indices_buf = top_prediction[1] | |||||
| # Project back into relative indices and beams | |||||
| beams_buf = indices_buf // vocab_size | |||||
| indices_buf = indices_buf.fmod(vocab_size) | |||||
| # At this point, beams_buf and indices_buf are single-dim and contain relative indices | |||||
| return scores_buf, indices_buf, beams_buf | |||||
| class PrefixConstrainedBeamSearch(Search): | |||||
| def __init__(self, tgt_dict, prefix_allowed_tokens_fn): | |||||
| super().__init__(tgt_dict) | |||||
| self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn | |||||
| self.stop_on_max_len = True | |||||
| @torch.jit.export | |||||
| def apply_mask(self, x, prev_output_tokens, original_batch_idxs): | |||||
| beam_size = x.shape[0] // original_batch_idxs.shape[0] | |||||
| original_batch_idxs = ( | |||||
| original_batch_idxs.unsqueeze(-1).repeat( | |||||
| (1, beam_size)).flatten().tolist()) | |||||
| mask = torch.full_like(x, -math.inf) | |||||
| for sent_i, (sent, batch_i) in enumerate( | |||||
| zip(prev_output_tokens, original_batch_idxs)): | |||||
| mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 | |||||
| return mask | |||||
| @torch.jit.export | |||||
| def step( | |||||
| self, | |||||
| step: int, | |||||
| lprobs: Tensor, | |||||
| scores: Tensor, | |||||
| prev_output_tokens: Tensor, | |||||
| original_batch_idxs: Tensor, | |||||
| ): | |||||
| bsz, beam_size, vocab_size = lprobs.size() | |||||
| lprobs += self.apply_mask( | |||||
| lprobs.view(bsz * beam_size, 1, vocab_size), | |||||
| prev_output_tokens, | |||||
| original_batch_idxs, | |||||
| ).view(bsz, beam_size, vocab_size) | |||||
| if step == 0: | |||||
| # at the first step all hypotheses are equally likely, so use | |||||
| # only the first beam | |||||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||||
| else: | |||||
| # make probs contain cumulative scores for each hypothesis | |||||
| assert scores is not None | |||||
| lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||||
| top_prediction = torch.topk( | |||||
| lprobs.view(bsz, -1), | |||||
| k=min( | |||||
| # Take the best beam_size predictions. We'll choose the first | |||||
| # beam_size of these which don't predict eos to continue with. | |||||
| beam_size, | |||||
| lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||||
| ), | |||||
| ) | |||||
| scores_buf = top_prediction[0] | |||||
| indices_buf = top_prediction[1] | |||||
| beams_buf = indices_buf // vocab_size | |||||
| indices_buf = indices_buf.fmod(vocab_size) | |||||
| return scores_buf, indices_buf, beams_buf | |||||
| class LexicallyConstrainedBeamSearch(Search): | |||||
| """Implements lexically constrained beam search as described in | |||||
| Fast Lexically Constrained Decoding with Dynamic Beam | |||||
| Allocation for Neural Machine Translation. Post & Vilar, | |||||
| NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ | |||||
| and | |||||
| Improved Lexically Constrained Decoding for Translation and | |||||
| Monolingual Rewriting. Hu et al, NAACL | |||||
| 2019. https://www.aclweb.org/anthology/N19-1090/ | |||||
| This is accomplished by maintaining, for each beam hypothesis, a | |||||
| ConstraintState object (see constraints.py) that tracks which | |||||
| constraints have been generated and using this information to | |||||
| shape the beam for each input sentence. | |||||
| """ | |||||
| def __init__(self, tokenizer, representation): | |||||
| super().__init__(tokenizer) | |||||
| self.representation = representation | |||||
| tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} | |||||
| added = { | |||||
| value: key | |||||
| for key, value in tokenizer.get_added_vocab().items() | |||||
| } | |||||
| tgt_dict.update(added) | |||||
| self.vocab_size = len(tgt_dict) | |||||
| self.num_cands = 0 | |||||
| self.supports_constraints = True | |||||
| @torch.jit.export | |||||
| def init_constraints(self, batch_constraints: Optional[Tensor], | |||||
| beam_size: int): | |||||
| self.constraint_states = [] | |||||
| for constraint_tensor in batch_constraints: | |||||
| if self.representation == 'ordered': | |||||
| constraint_state = OrderedConstraintState.create( | |||||
| constraint_tensor) | |||||
| elif self.representation == 'unordered': | |||||
| constraint_state = UnorderedConstraintState.create( | |||||
| constraint_tensor) | |||||
| self.constraint_states.append( | |||||
| [constraint_state for i in range(beam_size)]) | |||||
| @torch.jit.export | |||||
| def prune_sentences(self, batch_idxs: Tensor): | |||||
| self.constraint_states = [ | |||||
| self.constraint_states[i] for i in batch_idxs.tolist() | |||||
| ] | |||||
| @torch.jit.export | |||||
| def update_constraints(self, active_hypos: Tensor): | |||||
| if self.constraint_states: | |||||
| batch_size = active_hypos.size(0) | |||||
| for sentid in range(batch_size): | |||||
| self.constraint_states[sentid] = [ | |||||
| self.constraint_states[sentid][i] | |||||
| for i in active_hypos[sentid] | |||||
| ] | |||||
| @torch.jit.export | |||||
| def step( | |||||
| self, | |||||
| step: int, | |||||
| lprobs: Tensor, | |||||
| scores: Optional[Tensor], | |||||
| prev_output_tokens: Optional[Tensor] = None, | |||||
| original_batch_idxs: Optional[Tensor] = None, | |||||
| ): | |||||
| """ | |||||
| A constrained step builds a large candidates list from the following: | |||||
| - the top 2 * {beam_size} items over the whole beam | |||||
| - for each item in the beam | |||||
| - the top {each_k} (default 1) | |||||
| - all next constraints | |||||
| We then compute the constrained state of each beam item, and assign | |||||
| stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so | |||||
| on. We then sort by (stripe, score), and truncate the list at | |||||
| 2 * beam size. | |||||
| Args: | |||||
| step: the decoder step | |||||
| lprobs: (batch size, beam size, target vocab) | |||||
| the target-vocab distributions for each item in the beam. | |||||
| Retrun: A tuple of (scores, indices, beams, constraints) where: | |||||
| scores: (batch, output beam size) | |||||
| the scores of the chosen elements | |||||
| indices: (batch, output beam size) | |||||
| the target vocab indices of the chosen elements | |||||
| beams: (batch, output beam size) | |||||
| the 0-indexed hypothesis ids of the chosen elements | |||||
| constraints: (batch, output beam size) | |||||
| the new constraint states | |||||
| """ | |||||
| each_k = 1 | |||||
| device = lprobs.device | |||||
| batch_size, beam_size, vocab_size = lprobs.size() | |||||
| self.num_cands = min( | |||||
| # Just take the k-best. We'll get another k from the 1-best from each | |||||
| # row, plus more from the constraints | |||||
| beam_size * 2, | |||||
| lprobs.view(batch_size, -1).size(1) | |||||
| - 1, # -1 so we never select pad | |||||
| ) | |||||
| # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items | |||||
| constraint_states = self.constraint_states | |||||
| if constraint_states and step > 0: | |||||
| not_finished_indices = [] | |||||
| for sentno, sent_constraints in enumerate(constraint_states): | |||||
| for beamno, state in enumerate(sent_constraints): | |||||
| index = sentno * beam_size + beamno | |||||
| if not state.finished: | |||||
| not_finished_indices.append(index) | |||||
| not_finished_indices = torch.tensor(not_finished_indices) | |||||
| if not_finished_indices.numel() > 0: | |||||
| lprobs.view(batch_size * beam_size, -1)[not_finished_indices, | |||||
| self.eos] = -math.inf | |||||
| if step == 0: | |||||
| # at the first step all hypotheses are equally likely, so use | |||||
| # only the first beam entry for each batch item | |||||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||||
| else: | |||||
| # make probs contain cumulative scores for each hypothesis | |||||
| assert scores is not None | |||||
| lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) | |||||
| top_prediction = torch.topk( | |||||
| lprobs.view(batch_size, -1), | |||||
| self.num_cands, | |||||
| ) | |||||
| scores_buf, indices_buf = top_prediction | |||||
| # Project back into relative indices and beams | |||||
| beams_buf = indices_buf // vocab_size | |||||
| indices_buf = indices_buf.fmod(vocab_size) | |||||
| # Short circuit if there are no constraints in this batch | |||||
| if not constraint_states: | |||||
| return scores_buf, indices_buf, beams_buf | |||||
| # STEP 1: get top-1 from each hypothesis across all sentences in the batch | |||||
| if step > 0: | |||||
| top_scores, top_indices = torch.topk( | |||||
| lprobs.view(batch_size * beam_size, -1), | |||||
| k=each_k, | |||||
| dim=1, | |||||
| ) | |||||
| top_scores = top_scores.view(batch_size, -1) | |||||
| top_indices = top_indices.view(batch_size, -1) | |||||
| scores_buf = torch.cat((scores_buf, top_scores), dim=1) | |||||
| indices_buf = torch.cat((indices_buf, top_indices), dim=1) | |||||
| new_beams = torch.arange( | |||||
| 0, beam_size, device=device).repeat(batch_size, 1) | |||||
| beams_buf = torch.cat((beams_buf, new_beams), dim=1) | |||||
| # Now, process sentences in the batch one by one. | |||||
| new_scores_buf = torch.zeros((batch_size, 2 * beam_size), | |||||
| device=device) | |||||
| new_indices_buf = torch.zeros((batch_size, 2 * beam_size), | |||||
| device=device).long() | |||||
| new_beams_buf = torch.zeros((batch_size, 2 * beam_size), | |||||
| device=device).long() | |||||
| for sentno, states in enumerate(constraint_states): | |||||
| scores, indices, beams, new_states = self.step_sentence( | |||||
| step, | |||||
| sentno, | |||||
| lprobs[sentno], | |||||
| constraint_states[sentno], | |||||
| beams_buf[sentno].clone(), | |||||
| indices_buf[sentno].clone(), | |||||
| scores_buf[sentno].clone(), | |||||
| ) | |||||
| new_scores_buf[sentno] = scores | |||||
| new_indices_buf[sentno] = indices | |||||
| new_beams_buf[sentno] = beams | |||||
| self.constraint_states[sentno] = new_states | |||||
| return new_scores_buf, new_indices_buf, new_beams_buf | |||||
| @torch.jit.export | |||||
| def step_sentence( | |||||
| self, | |||||
| step: int, | |||||
| sentno: int, | |||||
| lprobs: Tensor, | |||||
| constraint_states: List[List[ConstraintState]], | |||||
| beams_buf: Tensor, | |||||
| indices_buf: Tensor, | |||||
| scores_buf: Tensor, | |||||
| ): | |||||
| """Does per-sentence processing. Adds all constraints for each | |||||
| hypothesis to the list of candidates; then removes duplicates, | |||||
| sorts, and dynamically stripes across the banks. All tensor inputs | |||||
| are collapsed to those pertaining to a single input sentence. | |||||
| """ | |||||
| device = lprobs.device | |||||
| # STEP 2: Add all constraints for each beam item | |||||
| for beamno, state in enumerate(constraint_states): | |||||
| next_tokens = torch.tensor( | |||||
| list(state.next_tokens()), device=device).long() | |||||
| if next_tokens.numel() != 0: | |||||
| indices_buf = torch.cat((indices_buf, next_tokens)) | |||||
| next_beams = ( | |||||
| torch.tensor(beamno, device=device).repeat( | |||||
| next_tokens.size(0)).long()) | |||||
| beams_buf = torch.cat((beams_buf, next_beams)) | |||||
| next_values = lprobs[beamno].take(next_tokens.view(-1)) | |||||
| scores_buf = torch.cat((scores_buf, next_values)) | |||||
| # At the 0th time step, there is just one beam item | |||||
| if step == 0: | |||||
| break | |||||
| # STEP 3: Compute the "bank" for each candidate. This is the | |||||
| # number of constraints it's generated. We need this so that | |||||
| # we can do round-robin allocation of the beam across these | |||||
| # banks. If C is the number of constraints, we select the best | |||||
| # item in bank C, then the best in bank C-1, etc, followed by | |||||
| # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so | |||||
| # on, until the maximum beam size. We accomplish this by | |||||
| # creating a sort key and striping across the banks. | |||||
| # Compute the new states for all candidates | |||||
| cands_size = indices_buf.size(0) | |||||
| constraint_states = [ | |||||
| constraint_states[beams_buf[i]].advance(indices_buf[i]) | |||||
| for i in range(cands_size) | |||||
| ] | |||||
| banks = torch.tensor([state.bank for state in constraint_states], | |||||
| device=device) | |||||
| # STEP 4: Sort | |||||
| num_constraint_tokens = len(state.tokens) | |||||
| # Sort by keys (bank, score) (i.e., sort banks together, and scores | |||||
| # within banks). AFAIK pytorch doesn't support either stable sort or | |||||
| # multi-key sorting, so we have to hack this. | |||||
| MAX_SCORE = -100 | |||||
| sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf | |||||
| sort_values, sort_indices = sort_key.sort(dim=0, descending=True) | |||||
| scores_buf = scores_buf[sort_indices] | |||||
| indices_buf = indices_buf[sort_indices] | |||||
| beams_buf = beams_buf[sort_indices] | |||||
| banks = banks[sort_indices] | |||||
| # Sort the constraints to follow suit | |||||
| constraint_states = [constraint_states[i] for i in sort_indices] | |||||
| # STEP 5: Remove duplicates. The topk calls (overall and | |||||
| # per-row) plus the per-row generation of constraints will | |||||
| # produce duplicates. Here we remove them. | |||||
| def roll(t): | |||||
| """Rolls a 1d tensor left by 1. | |||||
| [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] | |||||
| """ | |||||
| return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) | |||||
| # We map candidates (beam, token_id) to a single dimension. | |||||
| # This is then shifted by 1. We can then easily identify | |||||
| # duplicates and create a mask that identifies unique | |||||
| # extensions. | |||||
| uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf | |||||
| uniques_mask = roll(uniques_mask) != uniques_mask | |||||
| # Use the mask to pare down the data structures | |||||
| scores_buf = torch.masked_select(scores_buf, uniques_mask) | |||||
| indices_buf = torch.masked_select(indices_buf, uniques_mask) | |||||
| beams_buf = torch.masked_select(beams_buf, uniques_mask) | |||||
| banks = torch.masked_select(banks, uniques_mask) | |||||
| i = 1 | |||||
| for mask in uniques_mask[1:]: | |||||
| if not mask: | |||||
| constraint_states.pop(i) | |||||
| i += mask | |||||
| # STEP 6: Assign IDs round-robin across banks, sort, and | |||||
| # truncate. Now that the candidates are sorted by (bank, | |||||
| # score) and uniqed, we dynamically allocate the {beam_size} | |||||
| # beam by striping across the candidates. These stripes will | |||||
| # be used as sort keys to do round-robin selection. This is | |||||
| # accomplished in a single pass with offsets. Sorting by | |||||
| # highest-banks (furthest-along hypotheses) first ensures | |||||
| # progress through the constraints. | |||||
| # | |||||
| # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 | |||||
| # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 | |||||
| # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 | |||||
| # = 0 5 10 1 6 11 13 2 7 12 3 8 | |||||
| # | |||||
| # Sorting by this then gives the following banks: | |||||
| # | |||||
| # 3 2 1 0 3 2 1 0 3 2 1 2 | |||||
| # | |||||
| # We'll take the top {beam_size} of these. | |||||
| stripe_offsets = [ | |||||
| offset * (len(banks) + 1) for offset in range(len(banks) + 1) | |||||
| ] | |||||
| stripes = torch.zeros_like(banks) | |||||
| cur_bank_count = -1 | |||||
| cur_bank = banks[0] | |||||
| for i, bank in enumerate(banks): | |||||
| if bank != cur_bank: | |||||
| cur_bank_count = 0 | |||||
| cur_bank = bank | |||||
| else: | |||||
| cur_bank_count += 1 | |||||
| stripes[i] = num_constraint_tokens - bank + stripe_offsets[ | |||||
| cur_bank_count] | |||||
| # STEP 7: Sort by the stripes values | |||||
| sort_values, sort_indices = stripes.sort(dim=0) | |||||
| scores_buf = scores_buf[sort_indices] | |||||
| indices_buf = indices_buf[sort_indices] | |||||
| beams_buf = beams_buf[sort_indices] | |||||
| constraint_states = [constraint_states[i] for i in sort_indices] | |||||
| # STEP 8: Truncate to the candidates size! | |||||
| scores_buf = scores_buf[:self.num_cands] | |||||
| indices_buf = indices_buf[:self.num_cands] | |||||
| beams_buf = beams_buf[:self.num_cands] | |||||
| return scores_buf, indices_buf, beams_buf, constraint_states | |||||
| class LengthConstrainedBeamSearch(Search): | |||||
| def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): | |||||
| super().__init__(tgt_dict) | |||||
| self.min_len_a = min_len_a | |||||
| self.min_len_b = min_len_b | |||||
| self.max_len_a = max_len_a | |||||
| self.max_len_b = max_len_b | |||||
| self.beam = BeamSearch(tgt_dict) | |||||
| self.needs_src_lengths = True | |||||
| def step( | |||||
| self, | |||||
| step: int, | |||||
| lprobs, | |||||
| scores, | |||||
| prev_output_tokens: Optional[Tensor] = None, | |||||
| original_batch_idxs: Optional[Tensor] = None, | |||||
| ): | |||||
| min_lens = self.min_len_a * self.src_lengths + self.min_len_b | |||||
| max_lens = self.max_len_a * self.src_lengths + self.max_len_b | |||||
| lprobs[step < min_lens, :, self.eos] = -math.inf | |||||
| lprobs[step >= max_lens, :, self.eos] = 0 | |||||
| return self.beam.step(step, lprobs, scores) | |||||
| class DiverseBeamSearch(Search): | |||||
| """Diverse Beam Search. | |||||
| See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence | |||||
| Models" for details. | |||||
| We only implement the Hamming Diversity penalty here, which performed best | |||||
| in the original paper. | |||||
| """ | |||||
| def __init__(self, tgt_dict, num_groups, diversity_strength): | |||||
| super().__init__(tgt_dict) | |||||
| self.num_groups = num_groups | |||||
| self.diversity_strength = -diversity_strength | |||||
| self.beam = BeamSearch(tgt_dict) | |||||
| @torch.jit.export | |||||
| def step( | |||||
| self, | |||||
| step: int, | |||||
| lprobs, | |||||
| scores, | |||||
| prev_output_tokens: Optional[Tensor] = None, | |||||
| original_batch_idxs: Optional[Tensor] = None, | |||||
| ): | |||||
| bsz, beam_size, vocab_size = lprobs.size() | |||||
| if beam_size % self.num_groups != 0: | |||||
| raise ValueError( | |||||
| 'DiverseBeamSearch requires --beam to be divisible by the number of groups' | |||||
| ) | |||||
| # initialize diversity penalty | |||||
| diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) | |||||
| scores_G, indices_G, beams_G = [], [], [] | |||||
| for g in range(self.num_groups): | |||||
| lprobs_g = lprobs[:, g::self.num_groups, :] | |||||
| scores_g = scores[:, g::self.num_groups, :] if step > 0 else None | |||||
| # apply diversity penalty | |||||
| if g > 0: | |||||
| lprobs_g = torch.add( | |||||
| lprobs_g, | |||||
| other=diversity_buf.unsqueeze(1), | |||||
| alpha=self.diversity_strength, | |||||
| ) | |||||
| else: | |||||
| lprobs_g = lprobs_g.contiguous() | |||||
| scores_buf, indices_buf, beams_buf = self.beam.step( | |||||
| step, lprobs_g, scores_g) | |||||
| beams_buf.mul_(self.num_groups).add_(g) | |||||
| scores_G.append(scores_buf.clone()) | |||||
| indices_G.append(indices_buf.clone()) | |||||
| beams_G.append(beams_buf.clone()) | |||||
| # update diversity penalty | |||||
| diversity_buf.scatter_add_( | |||||
| 1, indices_buf, | |||||
| torch.ones(indices_buf.size()).to(diversity_buf)) | |||||
| # interleave results from different groups | |||||
| scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) | |||||
| indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1) | |||||
| beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) | |||||
| return scores_buf, indices_buf, beams_buf | |||||
| class Sampling(Search): | |||||
| sampling_topk: int | |||||
| sampling_topp: float | |||||
| def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): | |||||
| super().__init__(tgt_dict) | |||||
| self.sampling_topk = sampling_topk | |||||
| self.sampling_topp = sampling_topp | |||||
| def _sample_topp(self, lprobs): | |||||
| """Sample among the smallest set of elements whose cumulative probability mass exceeds p. | |||||
| See `"The Curious Case of Neural Text Degeneration" | |||||
| (Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_. | |||||
| Args: | |||||
| lprobs: (bsz x input_beam_size x vocab_size) | |||||
| the model's log-probabilities over the vocabulary at the current step | |||||
| Return: A tuple of (trimed_probs, truncated_indices) where: | |||||
| trimed_probs: (bsz x input_beam_size x ?) | |||||
| the model's probabilities over the elements selected to sample from. The | |||||
| width of the third dimension is determined by top-P. | |||||
| truncated_indices: (bsz x input_beam_size x ?) | |||||
| the indices of the chosen elements. | |||||
| """ | |||||
| probs = lprobs.exp_() | |||||
| # sort the last dimension (vocab dimension) in descending order | |||||
| sorted_probs, sorted_indices = probs.sort(descending=True) | |||||
| # compute a mask to indicate the words to be included in the top-P set. | |||||
| cumsum_probs = sorted_probs.cumsum(dim=2) | |||||
| mask = cumsum_probs.lt(self.sampling_topp) | |||||
| # note that mask was computed by 'lt'. One more word needs to be included | |||||
| # so that the cumulative probability mass can exceed p. | |||||
| cumsum_mask = mask.cumsum(dim=2) | |||||
| last_included = cumsum_mask[:, :, -1:] | |||||
| last_included.clamp_(0, mask.size()[2] - 1) | |||||
| mask = mask.scatter_(2, last_included, 1) | |||||
| # truncate unnecessary dims. | |||||
| max_dim = last_included.max() | |||||
| truncated_mask = mask[:, :, :max_dim + 1] | |||||
| truncated_probs = sorted_probs[:, :, :max_dim + 1] | |||||
| truncated_indices = sorted_indices[:, :, :max_dim + 1] | |||||
| # trim the words that are not in top-P by setting their probabilities | |||||
| # to 0, so that they would not be sampled later. | |||||
| trim_mask = ~truncated_mask | |||||
| trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) | |||||
| return trimed_probs, truncated_indices | |||||
| @torch.jit.export | |||||
| def step( | |||||
| self, | |||||
| step: int, | |||||
| lprobs, | |||||
| scores, | |||||
| prev_output_tokens: Optional[Tensor] = None, | |||||
| original_batch_idxs: Optional[Tensor] = None, | |||||
| ): | |||||
| bsz, beam_size, vocab_size = lprobs.size() | |||||
| if step == 0: | |||||
| # at the first step all hypotheses are equally likely, so use | |||||
| # only the first beam | |||||
| lprobs = lprobs[:, ::beam_size, :].contiguous() | |||||
| if self.sampling_topp > 0: | |||||
| # only sample from the smallest set of words whose cumulative probability mass exceeds p | |||||
| probs, top_indices = self._sample_topp(lprobs) | |||||
| elif self.sampling_topk > 0: | |||||
| # only sample from top-k candidates | |||||
| lprobs, top_indices = lprobs.topk(self.sampling_topk) | |||||
| probs = lprobs.exp_() | |||||
| else: | |||||
| probs = lprobs.exp_() | |||||
| # dummy data to be consistent with true branch for type check | |||||
| top_indices = torch.empty(0).to(probs) | |||||
| # sample | |||||
| if step == 0: | |||||
| indices_buf = torch.multinomial( | |||||
| probs.view(bsz, -1), | |||||
| beam_size, | |||||
| replacement=True, | |||||
| ).view(bsz, beam_size) | |||||
| else: | |||||
| indices_buf = torch.multinomial( | |||||
| probs.view(bsz * beam_size, -1), | |||||
| 1, | |||||
| replacement=True, | |||||
| ).view(bsz, beam_size) | |||||
| if step == 0: | |||||
| # expand to beam size | |||||
| probs = probs.expand(bsz, beam_size, -1) | |||||
| # gather scores | |||||
| scores_buf = torch.gather( | |||||
| probs, dim=2, index=indices_buf.unsqueeze(-1)) | |||||
| scores_buf = scores_buf.log_().view(bsz, -1) | |||||
| # remap indices if using top-k or top-P sampling | |||||
| if self.sampling_topk > 0 or self.sampling_topp > 0: | |||||
| indices_buf = torch.gather( | |||||
| top_indices.expand(bsz, beam_size, -1), | |||||
| dim=2, | |||||
| index=indices_buf.unsqueeze(-1), | |||||
| ).squeeze(2) | |||||
| if step == 0: | |||||
| beams_buf = indices_buf.new_zeros(bsz, beam_size) | |||||
| else: | |||||
| beams_buf = torch.arange(0, | |||||
| beam_size).to(indices_buf).repeat(bsz, 1) | |||||
| # make scores cumulative | |||||
| scores_buf.add_( | |||||
| torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf)) | |||||
| return scores_buf, indices_buf, beams_buf | |||||
| class DiverseSiblingsSearch(Search): | |||||
| """ | |||||
| Beam search with diverse siblings. | |||||
| See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. | |||||
| https://arxiv.org/abs/1611.08562 | |||||
| 1/ Calculate hypotheses for each beam | |||||
| 2/ Intra-sibling ordering | |||||
| 3/ Rewrite scores | |||||
| 4/ Choose top K hypotheses | |||||
| if diversity_rate == 0 is equivalent to BeamSearch | |||||
| """ | |||||
| def __init__(self, tgt_dict, diversity_rate): | |||||
| super().__init__(tgt_dict) | |||||
| self.diversity_rate = diversity_rate | |||||
| self.beam = BeamSearch(tgt_dict) | |||||
| def step( | |||||
| self, | |||||
| step: int, | |||||
| lprobs, | |||||
| scores, | |||||
| prev_output_tokens: Optional[Tensor] = None, | |||||
| original_batch_idxs: Optional[Tensor] = None, | |||||
| ): | |||||
| bsz, beam_size, vocab_size = lprobs.size() | |||||
| k = min( | |||||
| # Take the best 2 x beam_size predictions. We'll choose the first | |||||
| # beam_size of these which don't predict eos to continue with. | |||||
| beam_size * 2, | |||||
| lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |||||
| ) | |||||
| s_list: List[Tensor] | |||||
| i_list: List[Tensor] | |||||
| s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] | |||||
| i_list = [ | |||||
| torch.LongTensor().to(device=lprobs.device) | |||||
| for i in range(beam_size) | |||||
| ] | |||||
| sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate | |||||
| if step == 0: | |||||
| return self.beam.step(step, lprobs, scores) | |||||
| lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) | |||||
| # 1/ Calculate hypotheses for each beam | |||||
| for i in range(beam_size): | |||||
| torch.topk( | |||||
| lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) | |||||
| i_list[i].fmod_(vocab_size) | |||||
| # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores | |||||
| s_list[i].sub_(sibling_score) | |||||
| # 4/ Choose top K hypotheses | |||||
| indices = torch.stack(i_list, dim=1).view(bsz, -1) | |||||
| final_scores = torch.empty(0).to(lprobs) | |||||
| final_indices = torch.LongTensor().to(device=lprobs.device) | |||||
| final_beams = torch.LongTensor().to(device=lprobs.device) | |||||
| (final_scores, final_indices) = torch.topk( | |||||
| torch.stack(s_list, dim=1).view(bsz, -1), | |||||
| k, | |||||
| ) | |||||
| final_beams = final_indices // k | |||||
| for i in range(bsz): | |||||
| final_indices[i] = indices[i][final_indices[i]] | |||||
| return final_scores, final_indices, final_beams | |||||
| @@ -0,0 +1,996 @@ | |||||
| # Copyright 2022 The OFA-Sys Team. | |||||
| # All rights reserved. | |||||
| # This source code is licensed under the Apache 2.0 license | |||||
| # You may obtain a copy of the License at | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| import math | |||||
| import sys | |||||
| from typing import Dict, List, Optional, Tuple | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch import Tensor | |||||
| from ..generate import search | |||||
| from .ngram_repeat_block import NGramRepeatBlock | |||||
| def _expand_mask(mask: torch.Tensor, | |||||
| dtype: torch.dtype, | |||||
| tgt_len: Optional[int] = None): | |||||
| r""" | |||||
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |||||
| """ | |||||
| bsz, src_len = mask.size() | |||||
| tgt_len = tgt_len if tgt_len is not None else src_len | |||||
| expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, | |||||
| src_len).to(dtype) | |||||
| return expanded_mask.masked_fill(expanded_mask.bool(), | |||||
| torch.finfo(dtype).min) | |||||
| class SequenceGenerator(nn.Module): | |||||
| def __init__(self, | |||||
| tokenizer, | |||||
| beam_size=1, | |||||
| max_len_a=0, | |||||
| max_len_b=200, | |||||
| max_len=0, | |||||
| min_len=1, | |||||
| normalize_scores=True, | |||||
| len_penalty=1.0, | |||||
| unk_penalty=0.0, | |||||
| temperature=1.0, | |||||
| match_source_len=False, | |||||
| no_repeat_ngram_size=0, | |||||
| search_strategy=None, | |||||
| eos=None, | |||||
| symbols_to_strip_from_output=None, | |||||
| lm_model=None, | |||||
| lm_weight=1.0, | |||||
| constraint_trie=None, | |||||
| constraint_range=None, | |||||
| gen_code=False, | |||||
| gen_box=False, | |||||
| ignore_eos=False, | |||||
| zero_shot=False): | |||||
| """Generates translations of a given source sentence. | |||||
| Args: | |||||
| models (List[~fairseq.models.FairseqModel]): ensemble of models, | |||||
| currently support fairseq.models.TransformerModel for scripting | |||||
| beam_size (int, optional): beam width (default: 1) | |||||
| max_len_a/b (int, optional): generate sequences of maximum length | |||||
| ax + b, where x is the source length | |||||
| max_len (int, optional): the maximum length of the generated output | |||||
| (not including end-of-sentence) | |||||
| min_len (int, optional): the minimum length of the generated output | |||||
| (not including end-of-sentence) | |||||
| normalize_scores (bool, optional): normalize scores by the length | |||||
| of the output (default: True) | |||||
| len_penalty (float, optional): length penalty, where <1.0 favors | |||||
| shorter, >1.0 favors longer sentences (default: 1.0) | |||||
| unk_penalty (float, optional): unknown word penalty, where <0 | |||||
| produces more unks, >0 produces fewer (default: 0.0) | |||||
| temperature (float, optional): temperature, where values | |||||
| >1.0 produce more uniform samples and values <1.0 produce | |||||
| sharper samples (default: 1.0) | |||||
| match_source_len (bool, optional): outputs should match the source | |||||
| length (default: False) | |||||
| """ | |||||
| super().__init__() | |||||
| self.gen_code = gen_code | |||||
| self.gen_box = gen_box | |||||
| self.ignore_eos = ignore_eos | |||||
| self.tokenizer = tokenizer | |||||
| self.tgt_dict = { | |||||
| value: key | |||||
| for key, value in tokenizer.get_vocab().items() | |||||
| } | |||||
| added = { | |||||
| value: key | |||||
| for key, value in tokenizer.get_added_vocab().items() | |||||
| } | |||||
| self.tgt_dict.update(added) | |||||
| self.pad = tokenizer.pad_token_id | |||||
| self.unk = tokenizer.unk_token_id | |||||
| self.bos = tokenizer.bos_token_id | |||||
| self.eos = tokenizer.eos_token_id | |||||
| self.symbols_to_strip_from_output = ( | |||||
| symbols_to_strip_from_output.union({self.eos}) if | |||||
| symbols_to_strip_from_output is not None else {self.bos, self.eos}) | |||||
| self.vocab_size = len(self.tgt_dict) | |||||
| self.beam_size = beam_size | |||||
| # the max beam size is the dictionary size - 1, since we never select pad | |||||
| self.beam_size = min(beam_size, self.vocab_size - 1) | |||||
| self.max_len_a = max_len_a | |||||
| self.max_len_b = max_len_b | |||||
| self.min_len = min_len | |||||
| self.max_len = max_len | |||||
| self.normalize_scores = normalize_scores | |||||
| self.len_penalty = len_penalty | |||||
| self.unk_penalty = unk_penalty | |||||
| self.temperature = temperature | |||||
| self.match_source_len = match_source_len | |||||
| self.zero_shot = zero_shot | |||||
| if no_repeat_ngram_size > 0: | |||||
| self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) | |||||
| else: | |||||
| self.repeat_ngram_blocker = None | |||||
| assert temperature > 0, '--temperature must be greater than 0' | |||||
| self.search = ( | |||||
| search.BeamSearch(self.tokenizer) | |||||
| if search_strategy is None else search_strategy) | |||||
| # We only need to set src_lengths in LengthConstrainedBeamSearch. | |||||
| # As a module attribute, setting it would break in multithread | |||||
| # settings when the model is shared. | |||||
| self.should_set_src_lengths = ( | |||||
| hasattr(self.search, 'needs_src_lengths') | |||||
| and self.search.needs_src_lengths) | |||||
| self.lm_model = lm_model | |||||
| self.lm_weight = lm_weight | |||||
| if self.lm_model is not None: | |||||
| self.lm_model.eval() | |||||
| self.constraint_trie = constraint_trie | |||||
| self.constraint_start = None | |||||
| self.constraint_end = None | |||||
| if constraint_range is not None: | |||||
| constraint_start, constraint_end = constraint_range.split(',') | |||||
| self.constraint_start = int(constraint_start) | |||||
| self.constraint_end = int(constraint_end) | |||||
| @torch.no_grad() | |||||
| def forward( | |||||
| self, | |||||
| sample: Dict[str, Dict[str, Tensor]], | |||||
| prefix_tokens: Optional[Tensor] = None, | |||||
| bos_token: Optional[int] = None, | |||||
| ): | |||||
| """Generate a batch of translations. | |||||
| Args: | |||||
| sample (dict): batch | |||||
| prefix_tokens (torch.LongTensor, optional): force decoder to begin | |||||
| with these tokens | |||||
| bos_token (int, optional): beginning of sentence token | |||||
| (default: self.eos) | |||||
| """ | |||||
| return self._generate(sample, prefix_tokens, bos_token=bos_token) | |||||
| @torch.no_grad() | |||||
| def generate(self, models, sample: Dict[str, Dict[str, Tensor]], | |||||
| **kwargs) -> List[List[Dict[str, Tensor]]]: | |||||
| """Generate translations. Match the api of other fairseq generators. | |||||
| Args: | |||||
| models (List[~fairseq.models.FairseqModel]): ensemble of models | |||||
| sample (dict): batch | |||||
| prefix_tokens (torch.LongTensor, optional): force decoder to begin | |||||
| with these tokens | |||||
| constraints (torch.LongTensor, optional): force decoder to include | |||||
| the list of constraints | |||||
| bos_token (int, optional): beginning of sentence token | |||||
| (default: self.eos) | |||||
| """ | |||||
| return self._generate(models, sample, **kwargs) | |||||
| def _generate( | |||||
| self, | |||||
| models, | |||||
| sample: Dict[str, Dict[str, Tensor]], | |||||
| prefix_tokens: Optional[Tensor] = None, | |||||
| constraints: Optional[Tensor] = None, | |||||
| bos_token: Optional[int] = None, | |||||
| ): | |||||
| model = EnsembleModel(models) | |||||
| # incremental_states = torch.jit.annotate( | |||||
| # List[Dict[str, Dict[str, Optional[Tensor]]]], | |||||
| # [ | |||||
| # torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) | |||||
| # for i in range(model.models_size) | |||||
| # ], | |||||
| # ) | |||||
| incremental_states = torch.jit.annotate( | |||||
| List[Tuple[Tuple[torch.Tensor]]], | |||||
| [ | |||||
| torch.jit.annotate(Tuple[Tuple[torch.Tensor]], {}) | |||||
| for i in range(model.models_size) | |||||
| ], | |||||
| ) | |||||
| # print("incremental_states",incremental_states) | |||||
| # print("incremental_states[0]",incremental_states[0]) | |||||
| net_input = sample['net_input'] | |||||
| if 'src_tokens' in net_input: | |||||
| src_tokens = net_input['src_tokens'] | |||||
| # length of the source text being the character length except EndOfSentence and pad | |||||
| src_lengths = ((src_tokens.ne(self.eos) | |||||
| & src_tokens.ne(self.pad)).long().sum(dim=1)) | |||||
| elif 'input_ids' in net_input: | |||||
| src_tokens = net_input['input_ids'] | |||||
| # length of the source text being the character length except EndOfSentence and pad | |||||
| src_lengths = ((src_tokens.ne(self.eos) | |||||
| & src_tokens.ne(self.pad)).long().sum(dim=1)) | |||||
| elif 'source' in net_input: | |||||
| src_tokens = net_input['source'] | |||||
| src_lengths = ( | |||||
| net_input['padding_mask'].size(-1) | |||||
| - net_input['padding_mask'].sum(-1) | |||||
| if net_input['padding_mask'] is not None else torch.tensor( | |||||
| src_tokens.size(-1)).to(src_tokens)) | |||||
| elif 'features' in net_input: | |||||
| src_tokens = net_input['features'] | |||||
| src_lengths = ( | |||||
| net_input['padding_mask'].size(-1) | |||||
| - net_input['padding_mask'].sum(-1) | |||||
| if net_input['padding_mask'] is not None else torch.tensor( | |||||
| src_tokens.size(-1)).to(src_tokens)) | |||||
| else: | |||||
| raise Exception( | |||||
| 'expected src_tokens or source in net input. input keys: ' | |||||
| + str(net_input.keys())) | |||||
| # bsz: total number of sentences in beam | |||||
| # Note that src_tokens may have more than 2 dimensions (i.e. audio features) | |||||
| bsz, src_len = src_tokens.size()[:2] | |||||
| beam_size = self.beam_size | |||||
| if constraints is not None and not self.search.supports_constraints: | |||||
| raise NotImplementedError( | |||||
| "Target-side constraints were provided, but search method doesn't support them" | |||||
| ) | |||||
| # Initialize constraints, when active | |||||
| self.search.init_constraints(constraints, beam_size) | |||||
| max_len: int = -1 | |||||
| if self.match_source_len: | |||||
| max_len = src_lengths.max().item() | |||||
| else: | |||||
| max_len = int(self.max_len_a * src_len + self.max_len_b) | |||||
| assert ( | |||||
| self.min_len <= max_len | |||||
| ), 'min_len cannot be larger than max_len, please adjust these!' | |||||
| # compute the encoder output for each beam | |||||
| with torch.autograd.profiler.record_function( | |||||
| 'EnsembleModel: forward_encoder'): | |||||
| encoder_outs = model.forward_encoder(net_input) | |||||
| # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores | |||||
| new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) | |||||
| new_order = new_order.to(src_tokens.device).long() | |||||
| encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) | |||||
| # ensure encoder_outs is a List. | |||||
| assert encoder_outs is not None | |||||
| # initialize buffers | |||||
| scores = (torch.zeros(bsz * beam_size, | |||||
| max_len + 1).to(src_tokens).float() | |||||
| ) # +1 for eos; pad is never chosen for scoring | |||||
| tokens = (torch.zeros(bsz * beam_size, | |||||
| max_len + 2).to(src_tokens).long().fill_( | |||||
| self.pad)) # +2 for eos and pad | |||||
| # tokens[:, 0] = self.eos if bos_token is None else bos_token | |||||
| tokens[:, 0] = self.bos | |||||
| attn: Optional[Tensor] = None | |||||
| # A list that indicates candidates that should be ignored. | |||||
| # For example, suppose we're sampling and have already finalized 2/5 | |||||
| # samples. Then cands_to_ignore would mark 2 positions as being ignored, | |||||
| # so that we only finalize the remaining 3 samples. | |||||
| cands_to_ignore = (torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) | |||||
| ) # forward and backward-compatible False mask | |||||
| # list of completed sentences | |||||
| finalized = torch.jit.annotate( | |||||
| List[List[Dict[str, Tensor]]], | |||||
| [ | |||||
| torch.jit.annotate(List[Dict[str, Tensor]], []) | |||||
| for i in range(bsz) | |||||
| ], | |||||
| ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step | |||||
| # a boolean array indicating if the sentence at the index is finished or not | |||||
| finished = [False for i in range(bsz)] | |||||
| num_remaining_sent = bsz # number of sentences remaining | |||||
| # number of candidate hypos per step | |||||
| cand_size = 2 * beam_size # 2 x beam size in case half are EOS | |||||
| # offset arrays for converting between different indexing schemes | |||||
| bbsz_offsets = ((torch.arange(0, bsz) | |||||
| * beam_size).unsqueeze(1).type_as(tokens).to( | |||||
| src_tokens.device)) | |||||
| cand_offsets = torch.arange(0, cand_size).type_as(tokens).to( | |||||
| src_tokens.device) | |||||
| reorder_state: Optional[Tensor] = None | |||||
| batch_idxs: Optional[Tensor] = None | |||||
| original_batch_idxs: Optional[Tensor] = None | |||||
| if 'id' in sample and isinstance(sample['id'], Tensor): | |||||
| original_batch_idxs = sample['id'] | |||||
| else: | |||||
| original_batch_idxs = torch.arange(0, bsz).type_as(tokens) | |||||
| for step in range(max_len + 1): # one extra step for EOS marker | |||||
| # reorder decoder internal states based on the prev choice of beams | |||||
| if reorder_state is not None: | |||||
| if batch_idxs is not None: | |||||
| # update beam indices to take into account removed sentences | |||||
| corr = batch_idxs - torch.arange( | |||||
| batch_idxs.numel()).type_as(batch_idxs) | |||||
| reorder_state.view(-1, beam_size).add_( | |||||
| corr.unsqueeze(-1) * beam_size) | |||||
| original_batch_idxs = original_batch_idxs[batch_idxs] | |||||
| model.reorder_incremental_state(incremental_states, | |||||
| reorder_state) # todo | |||||
| encoder_outs = model.reorder_encoder_out( | |||||
| encoder_outs, reorder_state) | |||||
| with torch.autograd.profiler.record_function( | |||||
| 'EnsembleModel: forward_decoder'): | |||||
| lprobs, avg_attn_scores = model.forward_decoder( | |||||
| tokens[:, :step + 1], | |||||
| encoder_outs, | |||||
| incremental_states, | |||||
| self.temperature, | |||||
| constraint_trie=self.constraint_trie, | |||||
| constraint_start=self.constraint_start, | |||||
| constraint_end=self.constraint_end, | |||||
| gen_code=self.gen_code, | |||||
| zero_shot=self.zero_shot, | |||||
| prefix_tokens=prefix_tokens) | |||||
| if self.lm_model is not None: | |||||
| lm_out = self.lm_model(tokens[:, :step + 1]) | |||||
| probs = self.lm_model.get_normalized_probs( | |||||
| lm_out, log_probs=True, sample=None) | |||||
| probs = probs[:, -1, :] * self.lm_weight | |||||
| lprobs += probs | |||||
| # handle prefix tokens (possibly with different lengths) | |||||
| if (prefix_tokens is not None and step < prefix_tokens.size(1) | |||||
| and step < max_len): | |||||
| lprobs, tokens, scores = self._prefix_tokens( | |||||
| step, lprobs, scores, tokens, prefix_tokens, beam_size) | |||||
| elif step < self.min_len: | |||||
| # minimum length constraint (does not apply if using prefix_tokens) | |||||
| lprobs[:, self.eos] = -math.inf | |||||
| lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) | |||||
| lprobs[:, self.pad] = -math.inf # never select pad | |||||
| lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty | |||||
| if (self.gen_code or self.gen_box) and step < max_len: | |||||
| lprobs[:, :4] = -math.inf | |||||
| if self.gen_box: | |||||
| lprobs[:, -1] = -math.inf | |||||
| if (step + 1) % 5 == 0: | |||||
| lprobs[:, self.constraint_start:59457] = -math.inf | |||||
| else: | |||||
| lprobs[:, 59457:] = -math.inf | |||||
| # handle max length constraint | |||||
| if step >= max_len: | |||||
| lprobs[:, :self.eos] = -math.inf | |||||
| lprobs[:, self.eos + 1:] = -math.inf | |||||
| if self.ignore_eos: | |||||
| lprobs[:, self.eos] = 1 | |||||
| # Record attention scores, only support avg_attn_scores is a Tensor | |||||
| if avg_attn_scores is not None: | |||||
| if attn is None: | |||||
| attn = torch.empty(bsz * beam_size, | |||||
| avg_attn_scores.size(1), | |||||
| max_len + 2).to(scores) | |||||
| # print("+++++++ debug attention shape +++++++") | |||||
| # print("attn", attn.shape) | |||||
| # print("avg_attn_scores", avg_attn_scores.shape) | |||||
| attn[:, :, step + 1].copy_(avg_attn_scores) | |||||
| # print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) | |||||
| # print("attn", attn.shape) | |||||
| scores = scores.type_as(lprobs) | |||||
| eos_bbsz_idx = torch.empty(0).to( | |||||
| tokens | |||||
| ) # indices of hypothesis ending with eos (finished sentences) | |||||
| eos_scores = torch.empty(0).to( | |||||
| scores | |||||
| ) # scores of hypothesis ending with eos (finished sentences) | |||||
| if self.should_set_src_lengths: | |||||
| self.search.set_src_lengths(src_lengths) | |||||
| if self.repeat_ngram_blocker is not None: | |||||
| lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, | |||||
| beam_size, step) | |||||
| # Shape: (batch, cand_size) | |||||
| cand_scores, cand_indices, cand_beams = self.search.step( | |||||
| step, | |||||
| lprobs.view(bsz, -1, self.vocab_size), | |||||
| scores.view(bsz, beam_size, -1)[:, :, :step], | |||||
| tokens[:, :step + 1], | |||||
| original_batch_idxs, | |||||
| ) | |||||
| # cand_bbsz_idx contains beam indices for the top candidate | |||||
| # hypotheses, with a range of values: [0, bsz*beam_size), | |||||
| # and dimensions: [bsz, cand_size] | |||||
| cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |||||
| # finalize hypotheses that end in eos | |||||
| # Shape of eos_mask: (batch size, beam size) | |||||
| eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) | |||||
| eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to( | |||||
| eos_mask) | |||||
| # only consider eos when it's among the top beam_size indices | |||||
| # Now we know what beam item(s) to finish | |||||
| # Shape: 1d list of absolute-numbered | |||||
| eos_bbsz_idx = torch.masked_select( | |||||
| cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]) | |||||
| finalized_sents: List[int] = [] | |||||
| if eos_bbsz_idx.numel() > 0: | |||||
| eos_scores = torch.masked_select( | |||||
| cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]) | |||||
| finalized_sents = self.finalize_hypos( | |||||
| step, | |||||
| eos_bbsz_idx, | |||||
| eos_scores, | |||||
| tokens, | |||||
| scores, | |||||
| finalized, | |||||
| finished, | |||||
| beam_size, | |||||
| attn, | |||||
| src_lengths, | |||||
| max_len, | |||||
| ) | |||||
| num_remaining_sent -= len(finalized_sents) | |||||
| assert num_remaining_sent >= 0 | |||||
| if num_remaining_sent == 0: | |||||
| break | |||||
| if self.search.stop_on_max_len and step >= max_len: | |||||
| break | |||||
| assert step < max_len, f'{step} < {max_len}' | |||||
| # Remove finalized sentences (ones for which {beam_size} | |||||
| # finished hypotheses have been generated) from the batch. | |||||
| if len(finalized_sents) > 0: | |||||
| new_bsz = bsz - len(finalized_sents) | |||||
| # construct batch_idxs which holds indices of batches to keep for the next pass | |||||
| batch_mask = torch.ones( | |||||
| bsz, dtype=torch.bool, device=cand_indices.device) | |||||
| batch_mask[finalized_sents] = False | |||||
| # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it | |||||
| batch_idxs = torch.arange( | |||||
| bsz, device=cand_indices.device).masked_select(batch_mask) | |||||
| # Choose the subset of the hypothesized constraints that will continue | |||||
| self.search.prune_sentences(batch_idxs) | |||||
| eos_mask = eos_mask[batch_idxs] | |||||
| cand_beams = cand_beams[batch_idxs] | |||||
| bbsz_offsets.resize_(new_bsz, 1) | |||||
| cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |||||
| cand_scores = cand_scores[batch_idxs] | |||||
| cand_indices = cand_indices[batch_idxs] | |||||
| if prefix_tokens is not None: | |||||
| prefix_tokens = prefix_tokens[batch_idxs] | |||||
| src_lengths = src_lengths[batch_idxs] | |||||
| cands_to_ignore = cands_to_ignore[batch_idxs] | |||||
| scores = scores.view(bsz, -1)[batch_idxs].view( | |||||
| new_bsz * beam_size, -1) | |||||
| tokens = tokens.view(bsz, -1)[batch_idxs].view( | |||||
| new_bsz * beam_size, -1) | |||||
| if attn is not None: | |||||
| attn = attn.view(bsz, -1)[batch_idxs].view( | |||||
| new_bsz * beam_size, attn.size(1), -1) | |||||
| bsz = new_bsz | |||||
| else: | |||||
| batch_idxs = None | |||||
| # Set active_mask so that values > cand_size indicate eos hypos | |||||
| # and values < cand_size indicate candidate active hypos. | |||||
| # After, the min values per row are the top candidate active hypos | |||||
| # Rewrite the operator since the element wise or is not supported in torchscript. | |||||
| eos_mask[:, :beam_size] = ~( # noqa | |||||
| (~cands_to_ignore) & (~eos_mask[:, :beam_size])) # noqa | |||||
| active_mask = torch.add( | |||||
| eos_mask.type_as(cand_offsets) * cand_size, | |||||
| cand_offsets[:eos_mask.size(1)], | |||||
| ) | |||||
| # get the top beam_size active hypotheses, which are just | |||||
| # the hypos with the smallest values in active_mask. | |||||
| # {active_hypos} indicates which {beam_size} hypotheses | |||||
| # from the list of {2 * beam_size} candidates were | |||||
| # selected. Shapes: (batch size, beam size) | |||||
| new_cands_to_ignore, active_hypos = torch.topk( | |||||
| active_mask, k=beam_size, dim=1, largest=False) | |||||
| # update cands_to_ignore to ignore any finalized hypos. | |||||
| cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] | |||||
| # Make sure there is at least one active item for each sentence in the batch. | |||||
| assert (~cands_to_ignore).any(dim=1).all() | |||||
| # update cands_to_ignore to ignore any finalized hypos | |||||
| # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam | |||||
| # can be selected more than once). | |||||
| active_bbsz_idx = torch.gather( | |||||
| cand_bbsz_idx, dim=1, index=active_hypos) | |||||
| active_scores = torch.gather( | |||||
| cand_scores, dim=1, index=active_hypos) | |||||
| active_bbsz_idx = active_bbsz_idx.view(-1) | |||||
| active_scores = active_scores.view(-1) | |||||
| # copy tokens and scores for active hypotheses | |||||
| # Set the tokens for each beam (can select the same row more than once) | |||||
| tokens[:, :step + 1] = torch.index_select( | |||||
| tokens[:, :step + 1], dim=0, index=active_bbsz_idx) | |||||
| # Select the next token for each of them | |||||
| tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( | |||||
| cand_indices, dim=1, index=active_hypos) | |||||
| if step > 0: | |||||
| scores[:, :step] = torch.index_select( | |||||
| scores[:, :step], dim=0, index=active_bbsz_idx) | |||||
| scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( | |||||
| cand_scores, dim=1, index=active_hypos) | |||||
| # Update constraints based on which candidates were selected for the next beam | |||||
| self.search.update_constraints(active_hypos) | |||||
| # copy attention for active hypotheses | |||||
| if attn is not None: | |||||
| attn[:, :, :step + 2] = torch.index_select( | |||||
| attn[:, :, :step + 2], dim=0, index=active_bbsz_idx) | |||||
| # reorder incremental state in decoder | |||||
| reorder_state = active_bbsz_idx | |||||
| # sort by score descending | |||||
| for sent in range(len(finalized)): | |||||
| scores = torch.tensor( | |||||
| [float(elem['score'].item()) for elem in finalized[sent]]) | |||||
| _, sorted_scores_indices = torch.sort(scores, descending=True) | |||||
| finalized[sent] = [ | |||||
| finalized[sent][ssi] for ssi in sorted_scores_indices | |||||
| ] | |||||
| finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], | |||||
| finalized[sent]) | |||||
| return finalized | |||||
| def _prefix_tokens(self, step: int, lprobs, scores, tokens, prefix_tokens, | |||||
| beam_size: int): | |||||
| """Handle prefix tokens""" | |||||
| prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat( | |||||
| 1, beam_size).view(-1) | |||||
| prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) | |||||
| prefix_mask = prefix_toks.ne(self.pad) | |||||
| if self.constraint_trie is None: | |||||
| lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 | |||||
| else: | |||||
| lprobs[prefix_mask] = -math.inf | |||||
| lprobs[prefix_mask] = lprobs[prefix_mask].scatter( | |||||
| -1, prefix_toks[prefix_mask].unsqueeze(-1), | |||||
| prefix_lprobs[prefix_mask]) | |||||
| # if prefix includes eos, then we should make sure tokens and | |||||
| # scores are the same across all beams | |||||
| eos_mask = prefix_toks.eq(self.eos) | |||||
| if eos_mask.any(): | |||||
| # validate that the first beam matches the prefix | |||||
| first_beam = tokens[eos_mask].view(-1, beam_size, | |||||
| tokens.size(-1))[:, 0, | |||||
| 1:step + 1] | |||||
| eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] | |||||
| target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] | |||||
| assert (first_beam == target_prefix).all() | |||||
| # copy tokens, scores and lprobs from the first beam to all beams | |||||
| tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, | |||||
| beam_size) | |||||
| scores = self.replicate_first_beam(scores, eos_mask_batch_dim, | |||||
| beam_size) | |||||
| lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, | |||||
| beam_size) | |||||
| return lprobs, tokens, scores | |||||
| def replicate_first_beam(self, tensor, mask, beam_size: int): | |||||
| tensor = tensor.view(-1, beam_size, tensor.size(-1)) | |||||
| tensor[mask] = tensor[mask][:, :1, :] | |||||
| return tensor.view(-1, tensor.size(-1)) | |||||
| def finalize_hypos( | |||||
| self, | |||||
| step: int, | |||||
| bbsz_idx, | |||||
| eos_scores, | |||||
| tokens, | |||||
| scores, | |||||
| finalized: List[List[Dict[str, Tensor]]], | |||||
| finished: List[bool], | |||||
| beam_size: int, | |||||
| attn: Optional[Tensor], | |||||
| src_lengths, | |||||
| max_len: int, | |||||
| ): | |||||
| """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. | |||||
| A sentence is finalized when {beam_size} finished items have been collected for it. | |||||
| Returns number of sentences (not beam items) being finalized. | |||||
| These will be removed from the batch and not processed further. | |||||
| Args: | |||||
| bbsz_idx (Tensor): | |||||
| """ | |||||
| assert bbsz_idx.numel() == eos_scores.numel() | |||||
| # clone relevant token and attention tensors. | |||||
| # tokens is (batch * beam, max_len). So the index_select | |||||
| # gets the newly EOS rows, then selects cols 1..{step + 2} | |||||
| tokens_clone = tokens.index_select( | |||||
| 0, bbsz_idx)[:, 1:step + 2] # skip the first index, which is EOS | |||||
| tokens_clone[:, step] = self.eos | |||||
| attn_clone = ( | |||||
| attn.index_select(0, bbsz_idx)[:, :, 1:step | |||||
| + 2] if attn is not None else None) | |||||
| # compute scores per token position | |||||
| pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] | |||||
| pos_scores[:, step] = eos_scores | |||||
| # convert from cumulative to per-position scores | |||||
| pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] | |||||
| # normalize sentence-level scores | |||||
| if self.normalize_scores: | |||||
| eos_scores /= (step + 1)**self.len_penalty | |||||
| # cum_unfin records which sentences in the batch are finished. | |||||
| # It helps match indexing between (a) the original sentences | |||||
| # in the batch and (b) the current, possibly-reduced set of | |||||
| # sentences. | |||||
| cum_unfin: List[int] = [] | |||||
| prev = 0 | |||||
| for f in finished: | |||||
| if f: | |||||
| prev += 1 | |||||
| else: | |||||
| cum_unfin.append(prev) | |||||
| cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) | |||||
| unfin_idx = bbsz_idx // beam_size | |||||
| sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) | |||||
| # Create a set of "{sent}{unfin_idx}", where | |||||
| # "unfin_idx" is the index in the current (possibly reduced) | |||||
| # list of sentences, and "sent" is the index in the original, | |||||
| # unreduced batch | |||||
| # For every finished beam item | |||||
| # sentence index in the current (possibly reduced) batch | |||||
| seen = (sent << 32) + unfin_idx | |||||
| unique_seen: List[int] = torch.unique(seen).tolist() | |||||
| if self.match_source_len: | |||||
| condition = step > torch.index_select(src_lengths, 0, unfin_idx) | |||||
| eos_scores = torch.where(condition, torch.tensor(-math.inf), | |||||
| eos_scores) | |||||
| sent_list: List[int] = sent.tolist() | |||||
| for i in range(bbsz_idx.size()[0]): | |||||
| # An input sentence (among those in a batch) is finished when | |||||
| # beam_size hypotheses have been collected for it | |||||
| if len(finalized[sent_list[i]]) < beam_size: | |||||
| if attn_clone is not None: | |||||
| # remove padding tokens from attn scores | |||||
| hypo_attn = attn_clone[i] | |||||
| else: | |||||
| hypo_attn = torch.empty(0) | |||||
| finalized[sent_list[i]].append({ | |||||
| 'tokens': | |||||
| tokens_clone[i], | |||||
| 'score': | |||||
| eos_scores[i], | |||||
| 'attention': | |||||
| hypo_attn, # src_len x tgt_len | |||||
| 'alignment': | |||||
| torch.empty(0), | |||||
| 'positional_scores': | |||||
| pos_scores[i], | |||||
| }) | |||||
| newly_finished: List[int] = [] | |||||
| for unique_s in unique_seen: | |||||
| # check termination conditions for this sentence | |||||
| unique_sent: int = unique_s >> 32 | |||||
| unique_unfin_idx: int = unique_s - (unique_sent << 32) | |||||
| if not finished[unique_sent] and self.is_finished( | |||||
| step, unique_unfin_idx, max_len, len( | |||||
| finalized[unique_sent]), beam_size): | |||||
| finished[unique_sent] = True | |||||
| newly_finished.append(unique_unfin_idx) | |||||
| return newly_finished | |||||
| def is_finished( | |||||
| self, | |||||
| step: int, | |||||
| unfin_idx: int, | |||||
| max_len: int, | |||||
| finalized_sent_len: int, | |||||
| beam_size: int, | |||||
| ): | |||||
| """ | |||||
| Check whether decoding for a sentence is finished, which | |||||
| occurs when the list of finalized sentences has reached the | |||||
| beam size, or when we reach the maximum length. | |||||
| """ | |||||
| assert finalized_sent_len <= beam_size | |||||
| if finalized_sent_len == beam_size or step == max_len: | |||||
| return True | |||||
| return False | |||||
| class EnsembleModel(nn.Module): | |||||
| """A wrapper around an ensemble of models.""" | |||||
| def __init__(self, models): | |||||
| super().__init__() | |||||
| self.models_size = len(models) | |||||
| # method '__len__' is not supported in ModuleList for torch script | |||||
| self.single_model = models[0] | |||||
| self.models = nn.ModuleList(models) | |||||
| # self.has_incremental: bool = False | |||||
| # if all( | |||||
| # hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) | |||||
| # for m in models | |||||
| # ): | |||||
| # self.has_incremental = True | |||||
| self.has_incremental = True | |||||
| def forward(self): | |||||
| pass | |||||
| def has_encoder(self): | |||||
| return hasattr(self.single_model, 'encoder') | |||||
| def has_incremental_states(self): | |||||
| return self.has_incremental | |||||
| def max_decoder_positions(self): | |||||
| return min([ | |||||
| m.max_decoder_positions() | |||||
| for m in self.models if hasattr(m, 'max_decoder_positions') | |||||
| ] + [sys.maxsize]) # | |||||
| @torch.jit.export | |||||
| def forward_encoder(self, net_input: Dict[str, Tensor]): | |||||
| if not self.has_encoder(): | |||||
| return None | |||||
| encoder_input = { | |||||
| k: v | |||||
| for k, v in net_input.items() if k != 'decoder_input_ids' | |||||
| } | |||||
| encoder_input['output_hidden_states'] = True | |||||
| return [ | |||||
| model.encoder.forward(**encoder_input) for model in self.models | |||||
| ] | |||||
| @torch.jit.export | |||||
| def forward_decoder(self, | |||||
| tokens, | |||||
| encoder_outs: List[Dict[str, List[Tensor]]], | |||||
| incremental_states: List[Optional[torch.Tensor]], | |||||
| temperature: float = 1.0, | |||||
| constraint_trie=None, | |||||
| constraint_start=None, | |||||
| constraint_end=None, | |||||
| gen_code=False, | |||||
| zero_shot=False, | |||||
| prefix_tokens=None): | |||||
| log_probs = [] | |||||
| avg_attn: Optional[Tensor] = None | |||||
| encoder_out: Optional[Dict[str, List[Tensor]]] = None | |||||
| code_mask = (tokens.new_ones(tokens.size(0)) * gen_code).bool() | |||||
| for i, model in enumerate(self.models): | |||||
| if self.has_encoder(): | |||||
| encoder_out = encoder_outs[i] | |||||
| encoder_hidden_states = encoder_out.last_hidden_state | |||||
| encoder_attention_mask = _expand_mask( | |||||
| encoder_out.padding_mask, encoder_hidden_states.dtype, | |||||
| tokens.shape[-1]) | |||||
| src_pos_embed = encoder_out.position_embedding | |||||
| # if tokens.eq(self.single_model.config.pad_token_id).any(): | |||||
| attention_mask = tokens.eq(self.single_model.padding_idx) | |||||
| # decode each model | |||||
| if self.has_incremental_states(): | |||||
| decoder_out = model.decoder.forward( # todo 模型输入不同 | |||||
| input_ids=tokens, | |||||
| attention_mask=attention_mask, | |||||
| encoder_hidden_states=encoder_hidden_states, | |||||
| encoder_attention_mask=encoder_attention_mask, | |||||
| code_masks=code_mask, | |||||
| src_pos_embed=src_pos_embed, | |||||
| past_key_values=incremental_states[i], | |||||
| use_cache=True, | |||||
| output_attentions=True) | |||||
| else: | |||||
| if hasattr(model, 'decoder'): | |||||
| # decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out) | |||||
| decoder_out = model.decoder.forward( # todo 模型输入不同 | |||||
| input_ids=tokens, | |||||
| attention_mask=attention_mask, | |||||
| encoder_hidden_states=encoder_hidden_states, | |||||
| encoder_attention_mask=encoder_attention_mask, | |||||
| code_masks=code_mask, | |||||
| src_pos_embed=src_pos_embed) | |||||
| else: | |||||
| decoder_out = model.forward(tokens) | |||||
| # print('#### decoder_out ####', decoder_out) | |||||
| # print('#### decoder_out ####', decoder_out.keys()) | |||||
| # for k,v in decoder_out.items(): | |||||
| # print(k) | |||||
| # if isinstance(v, Tensor): | |||||
| # print(v.shape) | |||||
| # elif k == "past_key_values": | |||||
| # print(len(v)) | |||||
| # print([v[0][i].shape for i in range(len(v[0]))]) | |||||
| # else: | |||||
| # print(len(v)) | |||||
| # print([v[i].shape for i in range(len(v))]) | |||||
| attn: Optional[Tensor] = None | |||||
| decoder_len = len(decoder_out) | |||||
| # if decoder_len > 1 and decoder_out[1] is not None: | |||||
| # if isinstance(decoder_out[1], Tensor): | |||||
| # attn = decoder_out[1] | |||||
| # else: | |||||
| # attn_holder = decoder_out[1]["attn"] | |||||
| # if isinstance(attn_holder, Tensor): | |||||
| # attn = attn_holder | |||||
| # elif attn_holder is not None: | |||||
| # attn = attn_holder[0] | |||||
| # if attn is not None: | |||||
| # attn = attn[:, -1, :] | |||||
| if 'cross_attentions' in decoder_out: | |||||
| attn = decoder_out['cross_attentions'][-1].transpose(1, 0) | |||||
| attn = attn.mean(dim=0) # (B, tgt_len, src_len) | |||||
| if attn is not None: | |||||
| attn = attn[:, -1, :] | |||||
| # decoder_out_tuple = ( | |||||
| # decoder_out[0][:, -1:, :].div_(temperature), | |||||
| # None if decoder_len <= 1 else decoder_out[1], | |||||
| # ) | |||||
| decoder_out_tuple = ( | |||||
| decoder_out[0][:, -1:, :].div_(temperature), | |||||
| None if decoder_len <= 1 else attn, | |||||
| ) | |||||
| beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size( | |||||
| 0) if prefix_tokens is not None else 0 | |||||
| if constraint_trie is not None and not zero_shot: | |||||
| assert constraint_start is None and constraint_end is None | |||||
| constraint_masks = decoder_out_tuple[0].new_zeros( | |||||
| decoder_out_tuple[0].size()).bool() | |||||
| constraint_prefix_tokens = tokens.tolist() | |||||
| for token_index, constraint_prefix_token in enumerate( | |||||
| constraint_prefix_tokens): | |||||
| prefix_len = prefix_tokens[token_index // beam_size].ne( | |||||
| 1).sum().item() if prefix_tokens is not None else 0 | |||||
| if len(constraint_prefix_token) > prefix_len: | |||||
| constraint_prefix_token = [ | |||||
| 0 | |||||
| ] + constraint_prefix_token[prefix_len + 1:] | |||||
| constraint_nodes = constraint_trie.get_next_layer( | |||||
| constraint_prefix_token) | |||||
| constraint_masks[token_index][:, | |||||
| constraint_nodes] = True | |||||
| else: | |||||
| constraint_masks[token_index] = True | |||||
| decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf) | |||||
| if constraint_start is not None and constraint_end is not None and not zero_shot: | |||||
| assert constraint_trie is None | |||||
| decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf | |||||
| decoder_out_tuple[0][:, :, constraint_end:] = -math.inf | |||||
| probs = model.get_normalized_probs( | |||||
| decoder_out_tuple, log_probs=True, sample=None) | |||||
| if constraint_trie is not None and zero_shot: | |||||
| assert constraint_start is None and constraint_end is None | |||||
| constraint_masks = decoder_out_tuple[0].new_zeros( | |||||
| decoder_out_tuple[0].size()).bool() | |||||
| constraint_prefix_tokens = tokens.tolist() | |||||
| for token_index, constraint_prefix_token in enumerate( | |||||
| constraint_prefix_tokens): | |||||
| constraint_nodes = constraint_trie.get_next_layer( | |||||
| constraint_prefix_token) | |||||
| constraint_masks[token_index][:, constraint_nodes] = True | |||||
| probs.masked_fill_(~constraint_masks, -math.inf) | |||||
| if constraint_start is not None and constraint_end is not None and zero_shot: | |||||
| assert constraint_trie is None | |||||
| probs[:, :, 4:constraint_start] = -math.inf | |||||
| probs[:, :, constraint_end:] = -math.inf | |||||
| probs = probs[:, -1, :] | |||||
| if self.models_size == 1: | |||||
| return probs, attn | |||||
| log_probs.append(probs) | |||||
| if attn is not None: | |||||
| if avg_attn is None: | |||||
| avg_attn = attn | |||||
| else: | |||||
| avg_attn.add_(attn) | |||||
| avg_probs = torch.logsumexp( | |||||
| torch.stack(log_probs, dim=0), dim=0) - math.log(self.models_size) | |||||
| if avg_attn is not None: | |||||
| avg_attn.div_(self.models_size) | |||||
| return avg_probs, avg_attn | |||||
| @torch.jit.export | |||||
| def reorder_encoder_out(self, | |||||
| encoder_outs: Optional[List[Dict[str, | |||||
| List[Tensor]]]], | |||||
| new_order): | |||||
| """ | |||||
| Reorder encoder output according to *new_order*. | |||||
| Args: | |||||
| encoder_out: output from the ``forward()`` method | |||||
| new_order (LongTensor): desired order | |||||
| Returns: | |||||
| *encoder_out* rearranged according to *new_order* | |||||
| """ | |||||
| new_outs: List[Dict[str, List[Tensor]]] = [] | |||||
| if not self.has_encoder(): | |||||
| return new_outs | |||||
| for i, model in enumerate(self.models): | |||||
| assert encoder_outs is not None | |||||
| new_outs.append( | |||||
| model.encoder.reorder_encoder_out(encoder_outs[i], new_order)) | |||||
| return new_outs | |||||
| @torch.jit.export | |||||
| def reorder_incremental_state( | |||||
| self, | |||||
| incremental_states: List[Optional[torch.Tensor]], | |||||
| new_order, | |||||
| ): | |||||
| if not self.has_incremental_states(): | |||||
| return | |||||
| for i, model in enumerate(self.models): | |||||
| model.decoder.reorder_incremental_state_scripting( # todo | |||||
| incremental_states[i], new_order) | |||||
| @@ -0,0 +1,512 @@ | |||||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||||
| # | |||||
| # This source code is licensed under the MIT license which can be found at | |||||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||||
| """Implements tracking of constraints for a beam item. | |||||
| A list of constraints is given as a list of one or more token | |||||
| sequences, each of length at least one token. For example, for an input sentence | |||||
| > Die maschinelle Übersetzung ist schwer zu kontrollieren. | |||||
| We could have the constraints: | |||||
| * to influence | |||||
| * hard | |||||
| There are two implementations: | |||||
| * OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. | |||||
| * UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. | |||||
| The difference is that in the first, the constraints are assumed to be | |||||
| in order; the algorithm will permit zero or more tokens between them. | |||||
| In the second, the constraints are not ordered, so many orderings will | |||||
| be explored. | |||||
| The same sequence can be present any number of times, and will appear | |||||
| that many times in the output. | |||||
| """ | |||||
| from collections import Counter | |||||
| from typing import List, Set | |||||
| import torch | |||||
| class ConstraintState: | |||||
| def __init__(self): | |||||
| pass | |||||
| def pack_constraints( | |||||
| batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: | |||||
| """Takes a list of list of constraints in tensor form (a list of | |||||
| tensor constraints for each sentence) and transforms it into a | |||||
| packed Tensor. For example, here is a batch of size 3 with 3, 0, | |||||
| and 1 constraints: | |||||
| [ [ [3 1 2], [3], [4 5 6 7], ] | |||||
| [], | |||||
| [ [1 8 9 10 1 4 11 12], ] | |||||
| ] | |||||
| Its corresponding packed structure is: | |||||
| [ [ 3 3 1 2 0 3 0 4 5 6 7 0], | |||||
| [ 0 0 0 0 0 0 0 0 0 0 0 0], | |||||
| [ 1 1 8 9 10 1 4 11 12 0 0 0] ] | |||||
| The packed tensor has shape (batch size, maxlen), where | |||||
| maxlen is defined below. Each row contains concatenated | |||||
| constraint tokens for that sentence, with 0 appended after | |||||
| each constraint. The first item in each row is the number | |||||
| of constraints for that sentence. So maxlen is the maximum | |||||
| of | |||||
| (number of constraints) + (sum length of constraints) + 1. | |||||
| across all sentences in the batch. | |||||
| """ | |||||
| # The maximum word length of concatenated constraints for any sentence | |||||
| max_constraints_len = 1 | |||||
| for sentence_constraints in batch_constraints: | |||||
| if len(sentence_constraints): | |||||
| # number of constraints, plus sum of constrain lens, plus a zero after each | |||||
| constraints_len = (1 | |||||
| + sum([c.size(0) for c in sentence_constraints]) | |||||
| + len(sentence_constraints)) | |||||
| max_constraints_len = max(max_constraints_len, constraints_len) | |||||
| batch_size = len(batch_constraints) | |||||
| constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() | |||||
| for i, sentence_constraints in enumerate(batch_constraints): | |||||
| constraints_tensor[i, 0] = len(sentence_constraints) | |||||
| offset = 1 | |||||
| for j, constraint in enumerate(sentence_constraints): | |||||
| this_len = constraint.size(0) | |||||
| constraints_tensor[i, offset:offset + this_len] = constraint | |||||
| offset += this_len + 1 | |||||
| return constraints_tensor.long() | |||||
| def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: | |||||
| """ | |||||
| Transforms *one row* of a packed constraint tensor (e.g., for one | |||||
| sentence in the batch) into a list of constraint tensors. | |||||
| """ | |||||
| constraint_list = [] | |||||
| num_constraints = constraint_tensor[0] | |||||
| constraints = constraint_tensor.tolist() | |||||
| offset = 1 | |||||
| for i in range(num_constraints): | |||||
| where = constraints.index(0, offset) | |||||
| constraint_list.append(constraint_tensor[offset:where]) | |||||
| offset = where + 1 | |||||
| return constraint_list | |||||
| class ConstraintNode: | |||||
| """ | |||||
| Represents a node in a trie managing unordered constraints. | |||||
| """ | |||||
| def __init__(self, token: int = None, parent=None): | |||||
| # The token associate with this node (None for the root) | |||||
| self.token = int(token) if token is not None else None | |||||
| # The parent (None at the root) | |||||
| self.parent = parent | |||||
| # Whether this node is a completed constraint | |||||
| self.terminal = 0 | |||||
| # List of child nodes | |||||
| self.children = {} | |||||
| # The cumulative number of constraints from this point in the | |||||
| # trie forward | |||||
| self.num_constraints = 0 | |||||
| @property | |||||
| def id(self): | |||||
| return self.token | |||||
| def __str__(self): | |||||
| term = self.terminal != 0 | |||||
| return f'[{self.token}].{term}#{self.num_constraints}' | |||||
| def __getitem__(self, key: int): | |||||
| return self.children.get(key, None) | |||||
| def next_tokens(self) -> Set[int]: | |||||
| """The set of child labels.""" | |||||
| return set(self.children.keys()) | |||||
| @staticmethod | |||||
| def create(constraints: List[List[int]]): | |||||
| root = ConstraintNode() | |||||
| for sequence in constraints: | |||||
| root.add_sequence(sequence) | |||||
| return root | |||||
| @staticmethod | |||||
| def print_graph(node: 'ConstraintNode'): | |||||
| if len(node.children) == 0: | |||||
| return str(node) | |||||
| else: | |||||
| s = f'({node}' | |||||
| for child in node.children.values(): | |||||
| s += ' ' + ConstraintNode.print_graph(child) | |||||
| s += ')' | |||||
| return s | |||||
| def token_counts(self) -> Counter: | |||||
| """Returns a counter of the number of times each token is used | |||||
| in a constraint. | |||||
| """ | |||||
| token_counts = Counter() | |||||
| kids = list(self.children.values()) | |||||
| while len(kids) > 0: | |||||
| kid = kids.pop() | |||||
| token_counts[kid.id] += kid.num_constraints | |||||
| kids += list(kid.children.values()) | |||||
| return token_counts | |||||
| def tokens(self) -> Set[int]: | |||||
| """Returns the set of tokens in constraints.""" | |||||
| return set(self.token_counts().keys()) | |||||
| def add_sequence(self, sequence: List[int]): | |||||
| """Adds a constraint, represented as a list of integers, to | |||||
| the trie.""" | |||||
| assert len(sequence) > 0 | |||||
| token = int(sequence[0]) | |||||
| if token not in self.children: | |||||
| self.children[token] = ConstraintNode(token, parent=self) | |||||
| node = self.children[token] | |||||
| if len(sequence) == 1: | |||||
| node.terminal += 1 | |||||
| node.num_constraints += 1 | |||||
| parent = node.parent | |||||
| while parent is not None: | |||||
| parent.num_constraints += 1 | |||||
| parent = parent.parent | |||||
| else: | |||||
| node.add_sequence(sequence[1:]) | |||||
| class UnorderedConstraintState(ConstraintState): | |||||
| """ | |||||
| Records progress through the set of constraints for each item in the beam | |||||
| using a trie. | |||||
| """ | |||||
| def __init__(self, | |||||
| node: ConstraintNode, | |||||
| copy_from: 'ConstraintState' = None): | |||||
| self.node = node | |||||
| if copy_from is None: | |||||
| # The root node | |||||
| self.root = node | |||||
| # The set of states in the graph that have been completed | |||||
| self.completed = Counter() | |||||
| # The... | |||||
| self.generated = Counter() | |||||
| # The list of tokens we need to generate | |||||
| self.needed_tokens = self.root.tokens() | |||||
| else: | |||||
| self.completed = Counter(copy_from.completed) | |||||
| self.generated = Counter(copy_from.generated) | |||||
| self.root = copy_from.root | |||||
| # Mark the node as generated | |||||
| if self.node != self.root: | |||||
| self.generated[node] += 1 | |||||
| @staticmethod | |||||
| def create(constraint_tensor: torch.Tensor): | |||||
| constraint_list = unpack_constraints(constraint_tensor) | |||||
| constraint_trie_root = ConstraintNode.create(constraint_list) | |||||
| return UnorderedConstraintState(constraint_trie_root) | |||||
| def __str__(self): | |||||
| gen_str = ','.join([str(node) for node in self.generated]) | |||||
| return f'{self.name}/{self.bank}({gen_str})x{self.num_completed}' | |||||
| def __copy__(self): | |||||
| copied_state = UnorderedConstraintState(self.node, copy_from=self) | |||||
| return copied_state | |||||
| def copy(self): | |||||
| return self.__copy__() | |||||
| @property | |||||
| def name(self): | |||||
| if self.node.id is None: | |||||
| return 'ROOT' | |||||
| else: | |||||
| return str(self.node.id) | |||||
| @property | |||||
| def is_root(self): | |||||
| return self.node == self.root | |||||
| @property | |||||
| def bank(self): | |||||
| return sum(self.generated.values()) | |||||
| @property | |||||
| def num_completed(self): | |||||
| """The number of constraints (not constraint tokens) that are completed. | |||||
| In addition to the already-completed states, we need to account for the | |||||
| current state, which might get marked as completed when another token | |||||
| is generated. | |||||
| """ | |||||
| in_final = self.node.terminal and self.completed[ | |||||
| self.node] < self.node.terminal | |||||
| return sum(self.completed.values()) + in_final | |||||
| @property | |||||
| def finished(self): | |||||
| return self.root.num_constraints - self.num_completed == 0 | |||||
| @property | |||||
| def token_counts(self): | |||||
| return self.root.token_counts() | |||||
| @property | |||||
| def tokens(self): | |||||
| return self.root.tokens() | |||||
| @property | |||||
| def num_constraint_tokens(self): | |||||
| return sum(self.token_counts.values()) | |||||
| def next_tokens(self) -> Set[int]: | |||||
| """Returns the list of tokens that could come next. | |||||
| These are (a) all tokens extending the root state and, for | |||||
| non-root states, additionally all tokens extending the current | |||||
| state.""" | |||||
| if self.node != self.root: | |||||
| return self.root.next_tokens().union(self.node.next_tokens()) | |||||
| else: | |||||
| return self.root.next_tokens() | |||||
| def advance(self, token: int): | |||||
| """Reads in a token and advances the state. Here's how it works. | |||||
| We can advance to the next state if: | |||||
| - there is a matching child | |||||
| - its path isn't blocked | |||||
| A path is blocked when all constraints that are descendants of | |||||
| that node have already been generated, in the current state. | |||||
| If we are not able to advance from the current state, we "fall | |||||
| off the graph" and return to the root state. There, we again | |||||
| try to advance, checking the same criteria. | |||||
| In any case, when falling off the graph, we need to do some | |||||
| bookkeeping. We: | |||||
| - check whether any constraints were met (all prefixes of | |||||
| current state) | |||||
| - if one is found, mark it as completed | |||||
| - adjust visited nodes accordingly | |||||
| """ | |||||
| token = int(token) | |||||
| next_state = None | |||||
| child = self.node[token] | |||||
| if child is not None and self.generated[child] < child.num_constraints: | |||||
| next_state = UnorderedConstraintState(child, copy_from=self) | |||||
| def rewind(): | |||||
| """If we're mid-trie and an "illegal" token is chosen next, we need | |||||
| to reset our state to the root state. However, along the way, we need | |||||
| to check whether a prefix of the current trie state represents a state | |||||
| we could mark as completed. | |||||
| """ | |||||
| node = self.node | |||||
| while node != self.root: | |||||
| if node.terminal and self.completed[node] < node.terminal: | |||||
| next_state.completed[node] += 1 | |||||
| return | |||||
| next_state.generated[node] -= 1 | |||||
| node = node.parent | |||||
| # Fall off the graph, check the root | |||||
| if next_state is None and token in self.root.next_tokens(): | |||||
| child = self.root[token] | |||||
| # We can only traverse this edge if it's not saturated | |||||
| if self.generated[child] < child.num_constraints: | |||||
| next_state = UnorderedConstraintState(child, copy_from=self) | |||||
| else: | |||||
| next_state = UnorderedConstraintState( | |||||
| self.root, copy_from=self) | |||||
| # Rewind | |||||
| rewind() | |||||
| elif next_state is None: | |||||
| next_state = UnorderedConstraintState(self.root, copy_from=self) | |||||
| # Rewind | |||||
| rewind() | |||||
| return next_state | |||||
| class ConstraintSequence: | |||||
| def __init__(self, sequences: List[List[int]]): | |||||
| """Represents a set of possibly multitoken constraints by | |||||
| concatenating them and internally recording the end points. | |||||
| """ | |||||
| self.sequences = [] | |||||
| self.endpoints = [] | |||||
| self.num_tokens = 0 | |||||
| self.tokens = set() | |||||
| for sequence in sequences: | |||||
| for token in sequence: | |||||
| self.tokens.add(token) | |||||
| self.num_tokens += len(sequence) | |||||
| self.endpoints += [False | |||||
| for x in range(len(sequence) - 1)] + [True] | |||||
| self.sequences += sequence | |||||
| def __getitem__(self, key: int): | |||||
| return self.sequences[key] | |||||
| def __len__(self): | |||||
| return len(self.sequences) | |||||
| def __str__(self): | |||||
| return str(self.sequences) | |||||
| class OrderedConstraintState(ConstraintState): | |||||
| """ | |||||
| Records progress through the set of linear nonbranching constraints with gaps. | |||||
| """ | |||||
| def __init__(self, sequence: ConstraintSequence, state: int = -1): | |||||
| self.sequence = sequence | |||||
| self.state = state | |||||
| @staticmethod | |||||
| def create(constraint_tensor: torch.Tensor): | |||||
| constraint_list = unpack_constraints(constraint_tensor) | |||||
| return OrderedConstraintState(ConstraintSequence(constraint_list), -1) | |||||
| def __str__(self): | |||||
| return f'{self.state}/{self.bank}x{self.num_completed}' | |||||
| def __copy__(self): | |||||
| return OrderedConstraintState(self.sequence, self.state) | |||||
| def copy(self): | |||||
| return self.__copy__() | |||||
| @property | |||||
| def num_completed(self): | |||||
| if self.state == -1: | |||||
| return 0 | |||||
| count = len( | |||||
| list( | |||||
| filter(lambda x: x, | |||||
| self.sequence.endpoints[0:self.state + 1]))) | |||||
| return count | |||||
| @property | |||||
| def is_root(self): | |||||
| return self.state == -1 | |||||
| @property | |||||
| def name(self): | |||||
| if self.state == -1: | |||||
| return 'ROOT' | |||||
| else: | |||||
| return str(self.sequence[self.state]) | |||||
| @property | |||||
| def bank(self) -> int: | |||||
| return self.state + 1 | |||||
| @property | |||||
| def finished(self): | |||||
| return self.state + 1 == len(self.sequence) | |||||
| @property | |||||
| def token_counts(self): | |||||
| return self.sequence.token_counts() | |||||
| @property | |||||
| def tokens(self): | |||||
| return self.sequence.tokens | |||||
| @property | |||||
| def num_constraint_tokens(self): | |||||
| return sum(self.token_counts.values()) | |||||
| def next_tokens(self) -> Set[int]: | |||||
| """Returns the list of tokens that could come next. | |||||
| These are (a) all tokens extending the root state and, for | |||||
| non-root states, additionally all tokens extending the current | |||||
| state.""" | |||||
| tokens = set() | |||||
| if self.state > 0: | |||||
| tokens.add(self.sequence[0]) | |||||
| if not self.finished: | |||||
| tokens.add(self.sequence[self.state + 1]) | |||||
| return tokens | |||||
| def advance(self, token: int): | |||||
| """Reads in a token and advances the state. Here's how it works. | |||||
| We can advance to the next state if: | |||||
| - there is a matching child | |||||
| - its path isn't blocked | |||||
| A path is blocked when all constraints that are descendants of | |||||
| that node have already been generated, in the current state. | |||||
| If we are not able to advance from the current state, we "fall | |||||
| off the graph" and return to the root state. There, we again | |||||
| try to advance, checking the same criteria. | |||||
| In any case, when falling off the graph, we need to do some | |||||
| bookkeeping. We: | |||||
| - check whether any constraints were met (all prefixes of | |||||
| current state) | |||||
| - if one is found, mark it as completed | |||||
| - adjust visited nodes accordingly | |||||
| """ | |||||
| token = int(token) | |||||
| # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") | |||||
| if self.finished: | |||||
| # Accept anything | |||||
| next_state = self.copy() | |||||
| elif self.sequence[self.state + 1] == token: | |||||
| # Advance to the next token | |||||
| next_state = OrderedConstraintState(self.sequence, self.state + 1) | |||||
| elif self.sequence.endpoints[self.state]: | |||||
| # Accept anything between constraints (*) | |||||
| next_state = self.copy() | |||||
| elif token == self.sequence[0]: | |||||
| # Start over having generated the first token | |||||
| next_state = OrderedConstraintState(self.sequence, 0) | |||||
| else: | |||||
| # Start over from the root | |||||
| next_state = OrderedConstraintState(self.sequence, -1) | |||||
| return next_state | |||||
| @@ -0,0 +1,124 @@ | |||||
| # Copyright (c) Facebook, Inc. and its affiliates. | |||||
| # | |||||
| # This source code is licensed under the MIT license which can be found at | |||||
| # https://github.com/facebookresearch/fairseq/blob/main/LICENSE | |||||
| import collections | |||||
| from collections import abc | |||||
| from itertools import accumulate | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| try: | |||||
| from amp_C import multi_tensor_l2norm | |||||
| multi_tensor_l2norm_available = True | |||||
| except ImportError: | |||||
| multi_tensor_l2norm_available = False | |||||
| try: | |||||
| import torch_xla.core.xla_model as xm | |||||
| except ImportError: | |||||
| xm = None | |||||
| MANIFOLD_PATH_SEP = '|' | |||||
| def apply_to_sample(f, sample): | |||||
| if hasattr(sample, '__len__') and len(sample) == 0: | |||||
| return {} | |||||
| def _apply(x): | |||||
| if torch.is_tensor(x): | |||||
| return f(x) | |||||
| elif isinstance(x, collections.OrderedDict): | |||||
| # OrderedDict has attributes that needs to be preserved | |||||
| od = collections.OrderedDict( | |||||
| (key, _apply(value)) for key, value in x.items()) | |||||
| od.__dict__ = x.__dict__ | |||||
| return od | |||||
| elif isinstance(x, dict): | |||||
| return {key: _apply(value) for key, value in x.items()} | |||||
| elif isinstance(x, list): | |||||
| return [_apply(x) for x in x] | |||||
| elif isinstance(x, tuple): | |||||
| return tuple(_apply(x) for x in x) | |||||
| elif isinstance(x, set): | |||||
| return {_apply(x) for x in x} | |||||
| else: | |||||
| return x | |||||
| return _apply(sample) | |||||
| def move_to_device(batch, device): | |||||
| r"""Puts each data field to the device""" | |||||
| if isinstance(batch, torch.Tensor): | |||||
| return batch.to(device) | |||||
| elif isinstance(batch, (list, tuple)): | |||||
| return tuple(move_to_device(item, device) for item in batch) | |||||
| elif isinstance(batch, abc.Mapping): | |||||
| return { | |||||
| key: move_to_device(value, device) | |||||
| for key, value in batch.items() | |||||
| } | |||||
| else: | |||||
| return batch | |||||
| def strip_pad(tensor, pad): | |||||
| return tensor[tensor.ne(pad)] | |||||
| def get_token_to_word_mapping(tokens, exclude_list): | |||||
| n = len(tokens) | |||||
| word_start = [int(token not in exclude_list) for token in tokens] | |||||
| word_idx = list(accumulate(word_start)) | |||||
| token_to_word = {i: word_idx[i] for i in range(n)} | |||||
| return token_to_word | |||||
| def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): | |||||
| tgt_valid = (((tgt_sent != pad) & # noqa | |||||
| (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)) | |||||
| src_invalid = (((src_sent == pad) | # noqa | |||||
| (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)) | |||||
| src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) | |||||
| tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) | |||||
| alignment = [] | |||||
| if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): | |||||
| attn_valid = attn[tgt_valid] | |||||
| attn_valid[:, src_invalid] = float('-inf') | |||||
| _, src_indices = attn_valid.max(dim=1) | |||||
| for tgt_idx, src_idx in zip(tgt_valid, src_indices): | |||||
| alignment.append(( | |||||
| src_token_to_word[src_idx.item()] - 1, | |||||
| tgt_token_to_word[tgt_idx.item()] - 1, | |||||
| )) | |||||
| return alignment | |||||
| def softmax(x, dim: int, onnx_trace: bool = False): | |||||
| if onnx_trace: | |||||
| return F.softmax(x.float(), dim=dim) | |||||
| else: | |||||
| return F.softmax(x, dim=dim, dtype=torch.float32) | |||||
| def log_softmax(x, dim: int, onnx_trace: bool = False): | |||||
| if onnx_trace: | |||||
| return F.log_softmax(x.float(), dim=dim) | |||||
| else: | |||||
| return F.log_softmax(x, dim=dim, dtype=torch.float32) | |||||
| def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): | |||||
| tgt_valid = (tgt_sent != pad).nonzero(as_tuple=False) | |||||
| src_valid = (src_sent != pad).nonzero(as_tuple=False).squeeze(dim=-1) | |||||
| alignment = [] | |||||
| if len(tgt_valid) != 0 and len(src_valid) != 0: | |||||
| attn_valid = attn[tgt_valid, src_valid] | |||||
| alignment = [['{:.6f}'.format(p) for p in src_probs.tolist()] | |||||
| for src_probs in attn_valid] | |||||
| return alignment | |||||
| @@ -0,0 +1,283 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| def drop_path(x, drop_prob: float = 0., training: bool = False): | |||||
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||||
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |||||
| the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper... | |||||
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |||||
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use | |||||
| 'survival rate' as the argument. | |||||
| """ | |||||
| if drop_prob == 0. or not training: | |||||
| return x | |||||
| keep_prob = 1 - drop_prob | |||||
| shape = (x.shape[0], ) + (1, ) * ( | |||||
| x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |||||
| random_tensor = keep_prob + torch.rand( | |||||
| shape, dtype=x.dtype, device=x.device) | |||||
| random_tensor.floor_() # binarize | |||||
| output = x.div(keep_prob) * random_tensor | |||||
| return output | |||||
| class DropPath(nn.Module): | |||||
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||||
| """ | |||||
| def __init__(self, drop_prob=None): | |||||
| super(DropPath, self).__init__() | |||||
| self.drop_prob = drop_prob | |||||
| def forward(self, x): | |||||
| return drop_path(x, self.drop_prob, self.training) | |||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||||
| """3x3 convolution with padding""" | |||||
| return nn.Conv2d( | |||||
| in_planes, | |||||
| out_planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=dilation, | |||||
| groups=groups, | |||||
| bias=False, | |||||
| dilation=dilation) | |||||
| def conv1x1(in_planes, out_planes, stride=1): | |||||
| """1x1 convolution""" | |||||
| return nn.Conv2d( | |||||
| in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||||
| class BasicBlock(nn.Module): | |||||
| expansion = 1 | |||||
| def __init__(self, | |||||
| inplanes, | |||||
| planes, | |||||
| stride=1, | |||||
| downsample=None, | |||||
| groups=1, | |||||
| base_width=64, | |||||
| dilation=1, | |||||
| norm_layer=None): | |||||
| super(BasicBlock, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| if groups != 1 or base_width != 64: | |||||
| raise ValueError( | |||||
| 'BasicBlock only supports groups=1 and base_width=64') | |||||
| if dilation > 1: | |||||
| raise NotImplementedError( | |||||
| 'Dilation > 1 not supported in BasicBlock') | |||||
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||||
| self.bn1 = norm_layer(planes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.conv2 = conv3x3(planes, planes) | |||||
| self.bn2 = norm_layer(planes) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| def forward(self, x): | |||||
| assert False | |||||
| identity = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out += identity | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Bottleneck(nn.Module): | |||||
| # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||||
| # while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||||
| # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||||
| # This variant is also known as ResNet V1.5 and improves accuracy according to | |||||
| # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||||
| expansion = 4 | |||||
| def __init__(self, | |||||
| inplanes, | |||||
| planes, | |||||
| stride=1, | |||||
| downsample=None, | |||||
| groups=1, | |||||
| base_width=64, | |||||
| dilation=1, | |||||
| norm_layer=None, | |||||
| drop_path_rate=0.0): | |||||
| super(Bottleneck, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| width = int(planes * (base_width / 64.)) * groups | |||||
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||||
| self.conv1 = conv1x1(inplanes, width) | |||||
| self.bn1 = norm_layer(width) | |||||
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||||
| self.bn2 = norm_layer(width) | |||||
| self.conv3 = conv1x1(width, planes * self.expansion) | |||||
| self.bn3 = norm_layer(planes * self.expansion) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| self.drop_path = DropPath( | |||||
| drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |||||
| def forward(self, x): | |||||
| identity = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out = identity + self.drop_path(out) | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class ResNet(nn.Module): | |||||
| def __init__(self, | |||||
| layers, | |||||
| zero_init_residual=False, | |||||
| groups=1, | |||||
| width_per_group=64, | |||||
| replace_stride_with_dilation=None, | |||||
| norm_layer=None, | |||||
| drop_path_rate=0.0): | |||||
| super(ResNet, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| self._norm_layer = norm_layer | |||||
| self.inplanes = 64 | |||||
| self.dilation = 1 | |||||
| if replace_stride_with_dilation is None: | |||||
| # each element in the tuple indicates if we should replace | |||||
| # the 2x2 stride with a dilated convolution instead | |||||
| replace_stride_with_dilation = [False, False, False] | |||||
| if len(replace_stride_with_dilation) != 3: | |||||
| raise ValueError('replace_stride_with_dilation should be None ' | |||||
| 'or a 3-element tuple, got {}'.format( | |||||
| replace_stride_with_dilation)) | |||||
| self.groups = groups | |||||
| self.base_width = width_per_group | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||||
| self.bn1 = norm_layer(self.inplanes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
| self.layer1 = self._make_layer( | |||||
| Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate) | |||||
| self.layer2 = self._make_layer( | |||||
| Bottleneck, | |||||
| 128, | |||||
| layers[1], | |||||
| stride=2, | |||||
| dilate=replace_stride_with_dilation[0], | |||||
| drop_path_rate=drop_path_rate) | |||||
| self.layer3 = self._make_layer( | |||||
| Bottleneck, | |||||
| 256, | |||||
| layers[2], | |||||
| stride=2, | |||||
| dilate=replace_stride_with_dilation[1], | |||||
| drop_path_rate=drop_path_rate) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| nn.init.kaiming_normal_( | |||||
| m.weight, mode='fan_out', nonlinearity='relu') | |||||
| elif isinstance(m, | |||||
| (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)): | |||||
| nn.init.constant_(m.weight, 1) | |||||
| nn.init.constant_(m.bias, 0) | |||||
| # Zero-initialize the last BN in each residual branch, | |||||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||||
| if zero_init_residual: | |||||
| for m in self.modules(): | |||||
| if isinstance(m, Bottleneck): | |||||
| nn.init.constant_(m.bn3.weight, 0) | |||||
| elif isinstance(m, BasicBlock): | |||||
| nn.init.constant_(m.bn2.weight, 0) | |||||
| def _make_layer(self, | |||||
| block, | |||||
| planes, | |||||
| blocks, | |||||
| stride=1, | |||||
| dilate=False, | |||||
| drop_path_rate=0.0): | |||||
| norm_layer = self._norm_layer | |||||
| downsample = None | |||||
| previous_dilation = self.dilation | |||||
| if dilate: | |||||
| self.dilation *= stride | |||||
| stride = 1 | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||||
| downsample = nn.Sequential( | |||||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||||
| norm_layer(planes * block.expansion), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append( | |||||
| block(self.inplanes, planes, stride, downsample, self.groups, | |||||
| self.base_width, previous_dilation, norm_layer)) | |||||
| self.inplanes = planes * block.expansion | |||||
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)] | |||||
| for i in range(1, blocks): | |||||
| layers.append( | |||||
| block( | |||||
| self.inplanes, | |||||
| planes, | |||||
| groups=self.groups, | |||||
| base_width=self.base_width, | |||||
| dilation=self.dilation, | |||||
| norm_layer=norm_layer, | |||||
| drop_path_rate=dpr[i])) | |||||
| return nn.Sequential(*layers) | |||||
| def _forward_impl(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = self.relu(x) | |||||
| x = self.maxpool(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| return x | |||||
| def forward(self, x): | |||||
| return self._forward_impl(x) | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2022 OFA-Sys Team. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Tokenization classes for OFA.""" | |||||
| from transformers.models.bart.tokenization_bart import BartTokenizer | |||||
| from transformers.utils import logging | |||||
| logger = logging.get_logger(__name__) | |||||
| VOCAB_FILES_NAMES = {'vocab_file': 'vocab.json', 'merges_file': 'merges.txt'} | |||||
| PRETRAINED_VOCAB_FILES_MAP = { | |||||
| 'vocab_file': { | |||||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', | |||||
| }, | |||||
| 'merges_file': { | |||||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||||
| }, | |||||
| } | |||||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||||
| 'ofa-base': 1024, | |||||
| } | |||||
| class OFATokenizer(BartTokenizer): | |||||
| """ | |||||
| Construct a OFA tokenizer. | |||||
| [`~OFATokenizer`] is identical to [`BartTokenizer`] and runs end-to-end tokenization: punctuation splitting and | |||||
| wordpiece. | |||||
| Refer to superclass [`BartTokenizer`] for usage examples and documentation concerning parameters. | |||||
| """ | |||||
| vocab_files_names = VOCAB_FILES_NAMES | |||||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||||
| @@ -0,0 +1,59 @@ | |||||
| # Copyright 2022 OFA-Sys Team. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Tokenization classes for OFA.""" | |||||
| from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast | |||||
| from transformers.utils import logging | |||||
| from .tokenization_ofa import OFATokenizer | |||||
| logger = logging.get_logger(__name__) | |||||
| VOCAB_FILES_NAMES = { | |||||
| 'vocab_file': 'vocab.json', | |||||
| 'merges_file': 'merges.txt', | |||||
| 'tokenizer_file': 'tokenizer.json' | |||||
| } | |||||
| PRETRAINED_VOCAB_FILES_MAP = { | |||||
| 'vocab_file': { | |||||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', | |||||
| }, | |||||
| 'merges_file': { | |||||
| 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||||
| }, | |||||
| 'tokenizer_file': { | |||||
| 'ofa-base': | |||||
| 'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', | |||||
| }, | |||||
| } | |||||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||||
| 'ofa-base': 1024, | |||||
| } | |||||
| class OFATokenizerFast(BartTokenizerFast): | |||||
| r""" | |||||
| Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). | |||||
| [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting | |||||
| and wordpiece. | |||||
| Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. | |||||
| """ | |||||
| vocab_files_names = VOCAB_FILES_NAMES | |||||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||||
| slow_tokenizer_class = OFATokenizer | |||||
| @@ -0,0 +1,53 @@ | |||||
| from typing import Any, Dict | |||||
| import torch.cuda | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ..base import Model | |||||
| from ..builder import MODELS | |||||
| from .ofa import OFAModel, OFATokenizer | |||||
| from .ofa.generate import sequence_generator as sg | |||||
| from .ofa.generate.utils import move_to_device | |||||
| __all__ = ['OfaForImageCaptioning'] | |||||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | |||||
| class OfaForImageCaptioning(Model): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||||
| model = OFAModel.from_pretrained(model_dir) | |||||
| self.model = model.module if hasattr(model, 'module') else model | |||||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||||
| self._device = torch.device('cuda') if torch.cuda.is_available() \ | |||||
| else torch.device('cpu') | |||||
| self.model.to(self._device) | |||||
| # Initialize generator | |||||
| sg_args = { | |||||
| 'tokenizer': self.tokenizer, | |||||
| 'beam_size': 5, | |||||
| 'max_len_b': 16, | |||||
| 'min_len': 1, | |||||
| 'no_repeat_ngram_size': 3, | |||||
| 'constraint_range': None | |||||
| } | |||||
| if hasattr(kwargs, 'beam_search'): | |||||
| sg_args.update(kwargs['beam_search']) | |||||
| self.generator = sg.SequenceGenerator(**sg_args) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| input = move_to_device(input, self._device) | |||||
| gen_output = self.generator.generate([self.model], input) | |||||
| gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | |||||
| result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) | |||||
| return {'image_id': '42', OutputKeys.CAPTION: result[0]} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| # What should we do here ? | |||||
| return inputs | |||||
| @@ -44,7 +44,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_space_dialog-modeling'), | 'damo/nlp_space_dialog-modeling'), | ||||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | ||||
| 'damo/nlp_space_dialog-state-tracking'), | 'damo/nlp_space_dialog-state-tracking'), | ||||
| Tasks.image_captioning: (Pipelines.image_caption, | |||||
| Tasks.image_captioning: (Pipelines.image_captioning, | |||||
| 'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
| Tasks.image_generation: | Tasks.image_generation: | ||||
| (Pipelines.person_image_cartoon, | (Pipelines.person_image_cartoon, | ||||
| @@ -11,7 +11,7 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_captioning, module_name=Pipelines.image_caption) | |||||
| Tasks.image_captioning, module_name=Pipelines.image_captioning) | |||||
| class ImageCaptionPipeline(Pipeline): | class ImageCaptionPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -2,13 +2,14 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from PIL import Image | from PIL import Image | ||||
| from torchvision import transforms | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| from modelscope.utils.constant import Fields, ModelFile | |||||
| from modelscope.models.multi_modal.ofa import OFATokenizer | |||||
| from modelscope.utils.constant import Fields | |||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| @@ -31,84 +32,39 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||||
| model_dir (str): model path | model_dir (str): model path | ||||
| """ | """ | ||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||||
| model_dir) | |||||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||||
| self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||||
| if osp.exists(model_dir): | |||||
| local_model_dir = model_dir | |||||
| else: | |||||
| local_model_dir = snapshot_download(model_dir) | |||||
| local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| bpe_dir = local_model_dir | |||||
| from fairseq import checkpoint_utils, tasks, utils | |||||
| from ofa.tasks.mm_tasks import CaptionTask | |||||
| tasks.register_task('caption', CaptionTask) | |||||
| overrides = { | |||||
| 'bpe_dir': bpe_dir, | |||||
| 'eval_cider': False, | |||||
| 'beam': 5, | |||||
| 'max_len_b': 16, | |||||
| 'no_repeat_ngram_size': 3, | |||||
| 'seed': 7 | |||||
| } | |||||
| model, cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |||||
| utils.split_paths(local_model), arg_overrides=overrides) | |||||
| del model | |||||
| # Initialize transform | # Initialize transform | ||||
| from torchvision import transforms | |||||
| mean = [0.5, 0.5, 0.5] | mean = [0.5, 0.5, 0.5] | ||||
| std = [0.5, 0.5, 0.5] | std = [0.5, 0.5, 0.5] | ||||
| patch_image_size = 480 | |||||
| self.patch_resize_transform = transforms.Compose([ | self.patch_resize_transform = transforms.Compose([ | ||||
| lambda image: image.convert('RGB'), | lambda image: image.convert('RGB'), | ||||
| transforms.Resize( | |||||
| (cfg.task.patch_image_size, cfg.task.patch_image_size), | |||||
| interpolation=Image.BICUBIC), | |||||
| transforms.Resize((patch_image_size, patch_image_size), | |||||
| interpolation=Image.BICUBIC), | |||||
| transforms.ToTensor(), | transforms.ToTensor(), | ||||
| transforms.Normalize(mean=mean, std=std), | transforms.Normalize(mean=mean, std=std), | ||||
| ]) | ]) | ||||
| self.task = task | |||||
| self.bos_item = torch.LongTensor([task.src_dict.bos()]) | |||||
| self.eos_item = torch.LongTensor([task.src_dict.eos()]) | |||||
| self.pad_idx = task.src_dict.pad() | |||||
| @type_assert(object, (str, tuple, Image.Image)) | @type_assert(object, (str, tuple, Image.Image)) | ||||
| def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | ||||
| def encode_text(text, length=None, append_bos=False, append_eos=False): | |||||
| s = self.task.tgt_dict.encode_line( | |||||
| line=self.task.bpe.encode(text), | |||||
| add_if_not_exist=False, | |||||
| append_eos=False).long() | |||||
| if length is not None: | |||||
| s = s[:length] | |||||
| if append_bos: | |||||
| s = torch.cat([self.bos_item, s]) | |||||
| if append_eos: | |||||
| s = torch.cat([s, self.eos_item]) | |||||
| return s | |||||
| if isinstance(data, Image.Image): | if isinstance(data, Image.Image): | ||||
| patch_image = self.patch_resize_transform(data).unsqueeze(0) | patch_image = self.patch_resize_transform(data).unsqueeze(0) | ||||
| else: | else: | ||||
| patch_image = self.patch_resize_transform( | patch_image = self.patch_resize_transform( | ||||
| load_image(data)).unsqueeze(0) | load_image(data)).unsqueeze(0) | ||||
| patch_mask = torch.tensor([True]) | |||||
| text = 'what does the image describe?' | |||||
| src_text = encode_text( | |||||
| text, append_bos=True, append_eos=True).unsqueeze(0) | |||||
| src_length = torch.LongTensor( | |||||
| [s.ne(self.pad_idx).long().sum() for s in src_text]) | |||||
| sample = { | |||||
| 'id': np.array(['42']), | |||||
| 'net_input': { | |||||
| 'src_tokens': src_text, | |||||
| 'src_lengths': src_length, | |||||
| 'patch_images': patch_image, | |||||
| 'patch_masks': patch_mask, | |||||
| } | |||||
| text = ' what does the image describe?' | |||||
| inputs = self.tokenizer([text], max_length=1024, | |||||
| return_tensors='pt')['input_ids'] | |||||
| sample = dict() | |||||
| sample['net_input'] = { | |||||
| 'input_ids': inputs, | |||||
| 'patch_images': patch_image, | |||||
| 'patch_masks': torch.tensor([True]) | |||||
| } | } | ||||
| return sample | return sample | ||||
| @@ -14,7 +14,7 @@ class ImageCaptionTest(unittest.TestCase): | |||||
| def test_run(self): | def test_run(self): | ||||
| img_captioning = pipeline( | img_captioning = pipeline( | ||||
| Tasks.image_captioning, | Tasks.image_captioning, | ||||
| model='damo/ofa_image-caption_coco_large_en') | |||||
| model='damo/ofa_image-caption_coco_distilled_en') | |||||
| result = img_captioning('data/test/images/image_captioning.png') | result = img_captioning('data/test/images/image_captioning.png') | ||||
| print(result[OutputKeys.CAPTION]) | print(result[OutputKeys.CAPTION]) | ||||