Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10909489master^2
| @@ -90,6 +90,7 @@ class Models(object): | |||||
| mglm = 'mglm' | mglm = 'mglm' | ||||
| codegeex = 'codegeex' | codegeex = 'codegeex' | ||||
| bloom = 'bloom' | bloom = 'bloom' | ||||
| unite = 'unite' | |||||
| # audio models | # audio models | ||||
| sambert_hifigan = 'sambert-hifigan' | sambert_hifigan = 'sambert-hifigan' | ||||
| @@ -275,6 +276,7 @@ class Pipelines(object): | |||||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | ||||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | ||||
| token_classification = 'token-classification' | token_classification = 'token-classification' | ||||
| translation_evaluation = 'translation-evaluation' | |||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | sambert_hifigan_tts = 'sambert-hifigan-tts' | ||||
| @@ -404,6 +406,7 @@ class Preprocessors(object): | |||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| mglm_summarization = 'mglm-summarization' | mglm_summarization = 'mglm-summarization' | ||||
| sentence_piece = 'sentence-piece' | sentence_piece = 'sentence-piece' | ||||
| translation_evaluation = 'translation-evaluation-preprocessor' | |||||
| # audio preprocessor | # audio preprocessor | ||||
| linear_aec_fbank = 'linear-aec-fbank' | linear_aec_fbank = 'linear-aec-fbank' | ||||
| @@ -51,6 +51,7 @@ if TYPE_CHECKING: | |||||
| VecoForSequenceClassification, | VecoForSequenceClassification, | ||||
| VecoForTokenClassification, VecoModel) | VecoForTokenClassification, VecoModel) | ||||
| from .bloom import BloomModel | from .bloom import BloomModel | ||||
| from .unite import UniTEModel | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'backbones': ['SbertModel'], | 'backbones': ['SbertModel'], | ||||
| @@ -108,6 +109,7 @@ else: | |||||
| ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | ||||
| 'gpt_neo': ['GPTNeoModel'], | 'gpt_neo': ['GPTNeoModel'], | ||||
| 'bloom': ['BloomModel'], | 'bloom': ['BloomModel'], | ||||
| 'unite': ['UniTEModel'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .configuration_unite import UniTEConfig | |||||
| from .modeling_unite import UniTEForTranslationEvaluation | |||||
| else: | |||||
| _import_structure = { | |||||
| 'configuration_unite': ['UniTEConfig'], | |||||
| 'modeling_unite': ['UniTEForTranslationEvaluation'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,21 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| """UniTE model configuration""" | |||||
| from enum import Enum | |||||
| from modelscope.utils import logger as logging | |||||
| from modelscope.utils.config import Config | |||||
| logger = logging.get_logger(__name__) | |||||
| class EvaluationMode(Enum): | |||||
| SRC = 'src' | |||||
| REF = 'ref' | |||||
| SRC_REF = 'src-ref' | |||||
| class UniTEConfig(Config): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| @@ -0,0 +1,400 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| """PyTorch UniTE model.""" | |||||
| import math | |||||
| import warnings | |||||
| from dataclasses import dataclass | |||||
| from typing import Dict, List, Optional, Tuple, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.utils.checkpoint | |||||
| from packaging import version | |||||
| from torch.nn import (Dropout, Linear, Module, Parameter, ParameterList, | |||||
| Sequential) | |||||
| from torch.nn.functional import softmax | |||||
| from torch.nn.utils.rnn import pad_sequence | |||||
| from transformers import XLMRobertaConfig, XLMRobertaModel | |||||
| from transformers.activations import ACT2FN | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger(__name__) | |||||
| __all__ = ['UniTEForTranslationEvaluation'] | |||||
| def _layer_norm_all(tensor, mask_float): | |||||
| broadcast_mask = mask_float.unsqueeze(dim=-1) | |||||
| num_elements_not_masked = broadcast_mask.sum() * tensor.size(-1) | |||||
| tensor_masked = tensor * broadcast_mask | |||||
| mean = tensor_masked.sum([-1, -2, -3], | |||||
| keepdim=True) / num_elements_not_masked | |||||
| variance = (((tensor_masked - mean) * broadcast_mask)**2).sum( | |||||
| [-1, -2, -3], keepdim=True) / num_elements_not_masked | |||||
| return (tensor - mean) / torch.sqrt(variance + 1e-12) | |||||
| class LayerwiseAttention(Module): | |||||
| def __init__( | |||||
| self, | |||||
| num_layers: int, | |||||
| model_dim: int, | |||||
| dropout: float = None, | |||||
| ) -> None: | |||||
| super(LayerwiseAttention, self).__init__() | |||||
| self.num_layers = num_layers | |||||
| self.model_dim = model_dim | |||||
| self.dropout = dropout | |||||
| self.scalar_parameters = Parameter( | |||||
| torch.zeros((num_layers, ), requires_grad=True)) | |||||
| self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=True) | |||||
| if self.dropout: | |||||
| dropout_mask = torch.zeros(len(self.scalar_parameters)) | |||||
| dropout_fill = torch.empty(len( | |||||
| self.scalar_parameters)).fill_(-1e20) | |||||
| self.register_buffer('dropout_mask', dropout_mask) | |||||
| self.register_buffer('dropout_fill', dropout_fill) | |||||
| def forward( | |||||
| self, | |||||
| tensors: List[torch.Tensor], # pylint: disable=arguments-differ | |||||
| mask: torch.Tensor = None, | |||||
| ) -> torch.Tensor: | |||||
| tensors = torch.cat(list(x.unsqueeze(dim=0) for x in tensors), dim=0) | |||||
| normed_weights = softmax( | |||||
| self.scalar_parameters, dim=0).view(-1, 1, 1, 1) | |||||
| mask_float = mask.float() | |||||
| weighted_sum = (normed_weights | |||||
| * _layer_norm_all(tensors, mask_float)).sum(dim=0) | |||||
| weighted_sum = weighted_sum[:, 0, :] | |||||
| return self.gamma * weighted_sum | |||||
| class FeedForward(Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_dim: int, | |||||
| out_dim: int = 1, | |||||
| hidden_sizes: List[int] = [3072, 768], | |||||
| activations: str = 'Sigmoid', | |||||
| final_activation: Optional[str] = None, | |||||
| dropout: float = 0.1, | |||||
| ) -> None: | |||||
| """ | |||||
| Feed Forward Neural Network. | |||||
| Args: | |||||
| in_dim (:obj:`int`): | |||||
| Number of input features. | |||||
| out_dim (:obj:`int`, defaults to 1): | |||||
| Number of output features. Default is 1 -- a single scalar. | |||||
| hidden_sizes (:obj:`List[int]`, defaults to `[3072, 768]`): | |||||
| List with hidden layer sizes. | |||||
| activations (:obj:`str`, defaults to `Sigmoid`): | |||||
| Name of the activation function to be used in the hidden layers. | |||||
| final_activation (:obj:`str`, Optional, defaults to `None`): | |||||
| Name of the final activation function if any. | |||||
| dropout (:obj:`float`, defaults to 0.1): | |||||
| Dropout ratio to be used in the hidden layers. | |||||
| """ | |||||
| super().__init__() | |||||
| modules = [] | |||||
| modules.append(Linear(in_dim, hidden_sizes[0])) | |||||
| modules.append(self.build_activation(activations)) | |||||
| modules.append(Dropout(dropout)) | |||||
| for i in range(1, len(hidden_sizes)): | |||||
| modules.append(Linear(hidden_sizes[i - 1], hidden_sizes[i])) | |||||
| modules.append(self.build_activation(activations)) | |||||
| modules.append(Dropout(dropout)) | |||||
| modules.append(Linear(hidden_sizes[-1], int(out_dim))) | |||||
| if final_activation is not None: | |||||
| modules.append(self.build_activation(final_activation)) | |||||
| self.ff = Sequential(*modules) | |||||
| def build_activation(self, activation: str) -> Module: | |||||
| return ACT2FN[activation] | |||||
| def forward(self, in_features: torch.Tensor) -> torch.Tensor: | |||||
| return self.ff(in_features) | |||||
| @MODELS.register_module(Tasks.translation_evaluation, module_name=Models.unite) | |||||
| class UniTEForTranslationEvaluation(TorchModel): | |||||
| def __init__(self, | |||||
| attention_probs_dropout_prob: float = 0.1, | |||||
| bos_token_id: int = 0, | |||||
| eos_token_id: int = 2, | |||||
| pad_token_id: int = 1, | |||||
| hidden_act: str = 'gelu', | |||||
| hidden_dropout_prob: float = 0.1, | |||||
| hidden_size: int = 1024, | |||||
| initializer_range: float = 0.02, | |||||
| intermediate_size: int = 4096, | |||||
| layer_norm_eps: float = 1e-05, | |||||
| max_position_embeddings: int = 512, | |||||
| num_attention_heads: int = 16, | |||||
| num_hidden_layers: int = 24, | |||||
| type_vocab_size: int = 1, | |||||
| use_cache: bool = True, | |||||
| vocab_size: int = 250002, | |||||
| mlp_hidden_sizes: List[int] = [3072, 1024], | |||||
| mlp_act: str = 'tanh', | |||||
| mlp_final_act: Optional[str] = None, | |||||
| mlp_dropout: float = 0.1, | |||||
| **kwargs): | |||||
| r"""The UniTE Model which outputs the scalar to describe the corresponding | |||||
| translation quality of hypothesis. The model architecture includes two | |||||
| modules: a pre-trained language model (PLM) to derive representations, | |||||
| and a multi-layer perceptron (MLP) to give predicted score. | |||||
| Args: | |||||
| attention_probs_dropout_prob (:obj:`float`, defaults to 0.1): | |||||
| The dropout ratio for attention weights inside PLM. | |||||
| bos_token_id (:obj:`int`, defaults to 0): | |||||
| The numeric id representing beginning-of-sentence symbol. | |||||
| eos_token_id (:obj:`int`, defaults to 2): | |||||
| The numeric id representing ending-of-sentence symbol. | |||||
| pad_token_id (:obj:`int`, defaults to 1): | |||||
| The numeric id representing padding symbol. | |||||
| hidden_act (:obj:`str`, defaults to :obj:`"gelu"`): | |||||
| Activation inside PLM. | |||||
| hidden_dropout_prob (:obj:`float`, defaults to 0.1): | |||||
| The dropout ratio for activation states inside PLM. | |||||
| hidden_size (:obj:`int`, defaults to 1024): | |||||
| The dimensionality of PLM. | |||||
| initializer_range (:obj:`float`, defaults to 0.02): | |||||
| The hyper-parameter for initializing PLM. | |||||
| intermediate_size (:obj:`int`, defaults to 4096): | |||||
| The dimensionality of PLM inside feed-forward block. | |||||
| layer_norm_eps (:obj:`float`, defaults to 1e-5): | |||||
| The value for setting epsilon to avoid zero-division inside | |||||
| layer normalization. | |||||
| max_position_embeddings: (:obj:`int`, defaults to 512): | |||||
| The maximum value for identifying the length of input sequence. | |||||
| num_attention_heads (:obj:`int`, defaults to 16): | |||||
| The number of attention heads inside multi-head attention layer. | |||||
| num_hidden_layers (:obj:`int`, defaults to 24): | |||||
| The number of layers inside PLM. | |||||
| type_vocab_size (:obj:`int`, defaults to 1): | |||||
| The number of type embeddings. | |||||
| use_cache (:obj:`bool`, defaults to :obj:`True`): | |||||
| Whether to use cached buffer to initialize PLM. | |||||
| vocab_size (:obj:`int`, defaults to 250002): | |||||
| The size of vocabulary. | |||||
| mlp_hidden_sizes (:obj:`List[int]`, defaults to `[3072, 1024]`): | |||||
| The size of hidden states inside MLP. | |||||
| mlp_act (:obj:`str`, defaults to :obj:`"tanh"`): | |||||
| Activation inside MLP. | |||||
| mlp_final_act (:obj:`str`, `optional`, defaults to :obj:`None`): | |||||
| Activation at the end of MLP. | |||||
| mlp_dropout (:obj:`float`, defaults to 0.1): | |||||
| The dropout ratio for MLP. | |||||
| """ | |||||
| super().__init__(**kwargs) | |||||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
| self.bos_token_id = bos_token_id | |||||
| self.eos_token_id = eos_token_id | |||||
| self.pad_token_id = pad_token_id | |||||
| self.hidden_act = hidden_act | |||||
| self.hidden_dropout_prob = hidden_dropout_prob | |||||
| self.hidden_size = hidden_size | |||||
| self.initializer_range = initializer_range | |||||
| self.intermediate_size = intermediate_size | |||||
| self.layer_norm_eps = layer_norm_eps | |||||
| self.max_position_embeddings = max_position_embeddings | |||||
| self.num_attention_heads = num_attention_heads | |||||
| self.num_hidden_layers = num_hidden_layers | |||||
| self.type_vocab_size = type_vocab_size | |||||
| self.use_cache = use_cache | |||||
| self.vocab_size = vocab_size | |||||
| self.mlp_hidden_sizes = mlp_hidden_sizes | |||||
| self.mlp_act = mlp_act | |||||
| self.mlp_final_act = mlp_final_act | |||||
| self.mlp_dropout = mlp_dropout | |||||
| self.encoder_config = XLMRobertaConfig( | |||||
| bos_token_id=self.bos_token_id, | |||||
| eos_token_id=self.eos_token_id, | |||||
| pad_token_id=self.pad_token_id, | |||||
| vocab_size=self.vocab_size, | |||||
| hidden_size=self.hidden_size, | |||||
| num_hidden_layers=self.num_hidden_layers, | |||||
| num_attention_heads=self.num_attention_heads, | |||||
| intermediate_size=self.intermediate_size, | |||||
| hidden_act=self.hidden_act, | |||||
| hidden_dropout_prob=self.hidden_dropout_prob, | |||||
| attention_probs_dropout_prob=self.attention_probs_dropout_prob, | |||||
| max_position_embeddings=self.max_position_embeddings, | |||||
| type_vocab_size=self.type_vocab_size, | |||||
| initializer_range=self.initializer_range, | |||||
| layer_norm_eps=self.layer_norm_eps, | |||||
| use_cache=self.use_cache) | |||||
| self.encoder = XLMRobertaModel( | |||||
| self.encoder_config, add_pooling_layer=False) | |||||
| self.layerwise_attention = LayerwiseAttention( | |||||
| num_layers=self.num_hidden_layers + 1, | |||||
| model_dim=self.hidden_size, | |||||
| dropout=self.mlp_dropout) | |||||
| self.estimator = FeedForward( | |||||
| in_dim=self.hidden_size, | |||||
| out_dim=1, | |||||
| hidden_sizes=self.mlp_hidden_sizes, | |||||
| activations=self.mlp_act, | |||||
| final_activation=self.mlp_final_act, | |||||
| dropout=self.mlp_dropout) | |||||
| return | |||||
| def forward(self, input_sentences: List[torch.Tensor]): | |||||
| input_ids = self.combine_input_sentences(input_sentences) | |||||
| attention_mask = input_ids.ne(self.pad_token_id).long() | |||||
| outputs = self.encoder( | |||||
| input_ids=input_ids, | |||||
| attention_mask=attention_mask, | |||||
| output_hidden_states=True, | |||||
| return_dict=True) | |||||
| mix_states = self.layerwise_attention(outputs['hidden_states'], | |||||
| attention_mask) | |||||
| pred = self.estimator(mix_states) | |||||
| return pred.squeeze(dim=-1) | |||||
| def load_checkpoint(self, path: str): | |||||
| state_dict = torch.load(path) | |||||
| self.load_state_dict(state_dict) | |||||
| logger.info('Loading checkpoint parameters from %s' % path) | |||||
| return | |||||
| def combine_input_sentences(self, input_sent_groups: List[torch.Tensor]): | |||||
| for input_sent_group in input_sent_groups[1:]: | |||||
| input_sent_group[:, 0] = self.eos_token_id | |||||
| if len(input_sent_groups) == 3: | |||||
| cutted_sents = self.cut_long_sequences3(input_sent_groups) | |||||
| else: | |||||
| cutted_sents = self.cut_long_sequences2(input_sent_groups) | |||||
| return cutted_sents | |||||
| @staticmethod | |||||
| def cut_long_sequences2(all_input_concat: List[List[torch.Tensor]], | |||||
| maximum_length: int = 512, | |||||
| pad_idx: int = 1): | |||||
| all_input_concat = list(zip(*all_input_concat)) | |||||
| collected_tuples = list() | |||||
| for tensor_tuple in all_input_concat: | |||||
| all_lens = tuple(len(x) for x in tensor_tuple) | |||||
| if sum(all_lens) > maximum_length: | |||||
| lengths = dict(enumerate(all_lens)) | |||||
| lengths_sorted_idxes = list(x[0] for x in sorted( | |||||
| lengths.items(), key=lambda d: d[1], reverse=True)) | |||||
| offset = ceil((sum(lengths.values()) - maximum_length) / 2) | |||||
| if min(all_lens) > (maximum_length | |||||
| // 2) and min(all_lens) > offset: | |||||
| lengths = dict((k, v - offset) for k, v in lengths.items()) | |||||
| else: | |||||
| lengths[lengths_sorted_idxes[ | |||||
| 0]] = maximum_length - lengths[lengths_sorted_idxes[1]] | |||||
| new_lens = list(lengths[k] | |||||
| for k in range(0, len(tensor_tuple))) | |||||
| new_tensor_tuple = tuple( | |||||
| x[:y] for x, y in zip(tensor_tuple, new_lens)) | |||||
| for x, y in zip(new_tensor_tuple, tensor_tuple): | |||||
| x[-1] = y[-1] | |||||
| collected_tuples.append(new_tensor_tuple) | |||||
| else: | |||||
| collected_tuples.append(tensor_tuple) | |||||
| concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples) | |||||
| all_input_concat_padded = pad_sequence( | |||||
| concat_tensor, batch_first=True, padding_value=pad_idx) | |||||
| return all_input_concat_padded | |||||
| @staticmethod | |||||
| def cut_long_sequences3(all_input_concat: List[List[torch.Tensor]], | |||||
| maximum_length: int = 512, | |||||
| pad_idx: int = 1): | |||||
| all_input_concat = list(zip(*all_input_concat)) | |||||
| collected_tuples = list() | |||||
| for tensor_tuple in all_input_concat: | |||||
| all_lens = tuple(len(x) for x in tensor_tuple) | |||||
| if sum(all_lens) > maximum_length: | |||||
| lengths = dict(enumerate(all_lens)) | |||||
| lengths_sorted_idxes = list(x[0] for x in sorted( | |||||
| lengths.items(), key=lambda d: d[1], reverse=True)) | |||||
| offset = ceil((sum(lengths.values()) - maximum_length) / 3) | |||||
| if min(all_lens) > (maximum_length | |||||
| // 3) and min(all_lens) > offset: | |||||
| lengths = dict((k, v - offset) for k, v in lengths.items()) | |||||
| else: | |||||
| while sum(lengths.values()) > maximum_length: | |||||
| if lengths[lengths_sorted_idxes[0]] > lengths[ | |||||
| lengths_sorted_idxes[1]]: | |||||
| offset = maximum_length - lengths[ | |||||
| lengths_sorted_idxes[1]] - lengths[ | |||||
| lengths_sorted_idxes[2]] | |||||
| if offset > lengths[lengths_sorted_idxes[1]]: | |||||
| lengths[lengths_sorted_idxes[0]] = offset | |||||
| else: | |||||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||||
| lengths_sorted_idxes[1]] | |||||
| elif lengths[lengths_sorted_idxes[0]] == lengths[ | |||||
| lengths_sorted_idxes[1]] > lengths[ | |||||
| lengths_sorted_idxes[2]]: | |||||
| offset = (maximum_length | |||||
| - lengths[lengths_sorted_idxes[2]]) // 2 | |||||
| if offset > lengths[lengths_sorted_idxes[2]]: | |||||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||||
| lengths_sorted_idxes[1]] = offset | |||||
| else: | |||||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||||
| lengths_sorted_idxes[1]] = lengths[ | |||||
| lengths_sorted_idxes[2]] | |||||
| else: | |||||
| lengths[lengths_sorted_idxes[0]] = lengths[ | |||||
| lengths_sorted_idxes[1]] = lengths[ | |||||
| lengths_sorted_idxes[ | |||||
| 2]] = maximum_length // 3 | |||||
| new_lens = list(lengths[k] for k in range(0, len(lengths))) | |||||
| new_tensor_tuple = tuple( | |||||
| x[:y] for x, y in zip(tensor_tuple, new_lens)) | |||||
| for x, y in zip(new_tensor_tuple, tensor_tuple): | |||||
| x[-1] = y[-1] | |||||
| collected_tuples.append(new_tensor_tuple) | |||||
| else: | |||||
| collected_tuples.append(tensor_tuple) | |||||
| concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples) | |||||
| all_input_concat_padded = pad_sequence( | |||||
| concat_tensor, batch_first=True, padding_value=pad_idx) | |||||
| return all_input_concat_padded | |||||
| @@ -801,6 +801,11 @@ TASK_OUTPUTS = { | |||||
| # ] | # ] | ||||
| # } | # } | ||||
| Tasks.product_segmentation: [OutputKeys.MASKS], | Tasks.product_segmentation: [OutputKeys.MASKS], | ||||
| # { | |||||
| # 'scores': [0.1, 0.2, 0.3, ...] | |||||
| # } | |||||
| Tasks.translation_evaluation: [OutputKeys.SCORES] | |||||
| } | } | ||||
| @@ -183,6 +183,11 @@ TASK_INPUTS = { | |||||
| 'query_set': InputType.LIST, | 'query_set': InputType.LIST, | ||||
| 'support_set': InputType.LIST, | 'support_set': InputType.LIST, | ||||
| }, | }, | ||||
| Tasks.translation_evaluation: { | |||||
| 'hyp': InputType.LIST, | |||||
| 'src': InputType.LIST, | |||||
| 'ref': InputType.LIST, | |||||
| }, | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| Tasks.auto_speech_recognition: | Tasks.auto_speech_recognition: | ||||
| @@ -217,6 +217,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_swin-t_referring_video-object-segmentation'), | 'damo/cv_swin-t_referring_video-object-segmentation'), | ||||
| Tasks.video_summarization: (Pipelines.video_summarization, | Tasks.video_summarization: (Pipelines.video_summarization, | ||||
| 'damo/cv_googlenet_pgl-video-summarization'), | 'damo/cv_googlenet_pgl-video-summarization'), | ||||
| Tasks.translation_evaluation: | |||||
| (Pipelines.translation_evaluation, | |||||
| 'damo/nlp_unite_mup_translation_evaluation_multilingual_large'), | |||||
| } | } | ||||
| @@ -32,6 +32,7 @@ if TYPE_CHECKING: | |||||
| from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | ||||
| from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | ||||
| from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline | from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline | ||||
| from .translation_evaluation_pipeline import TranslationEvaluationPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -77,6 +78,7 @@ else: | |||||
| ['CodeGeeXCodeTranslationPipeline'], | ['CodeGeeXCodeTranslationPipeline'], | ||||
| 'codegeex_code_generation_pipeline': | 'codegeex_code_generation_pipeline': | ||||
| ['CodeGeeXCodeGenerationPipeline'], | ['CodeGeeXCodeGenerationPipeline'], | ||||
| 'translation_evaluation_pipeline': ['TranslationEvaluationPipeline'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from enum import Enum | |||||
| from typing import Any, Dict, List, Optional, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.nlp.unite.configuration_unite import EvaluationMode | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import InputModel, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import (Preprocessor, | |||||
| TranslationEvaluationPreprocessor) | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger(__name__) | |||||
| __all__ = ['TranslationEvaluationPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.translation_evaluation, module_name=Pipelines.translation_evaluation) | |||||
| class TranslationEvaluationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: InputModel, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| eval_mode: EvaluationMode = EvaluationMode.SRC_REF, | |||||
| **kwargs): | |||||
| r"""Build a translation pipeline with a model dir or a model id in the model hub. | |||||
| Args: | |||||
| model: A Model instance. | |||||
| eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`, | |||||
| `"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the | |||||
| source/reference/source+reference can be presented during evaluation. | |||||
| """ | |||||
| super().__init__(model=model, preprocessor=preprocessor) | |||||
| self.eval_mode = eval_mode | |||||
| self.checking_eval_mode() | |||||
| self.preprocessor = TranslationEvaluationPreprocessor( | |||||
| self.model.model_dir, | |||||
| self.eval_mode) if preprocessor is None else preprocessor | |||||
| self.model.load_checkpoint( | |||||
| osp.join(self.model.model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||||
| self.model.eval() | |||||
| return | |||||
| def checking_eval_mode(self): | |||||
| if self.eval_mode == EvaluationMode.SRC: | |||||
| logger.info('Evaluation mode: source-only') | |||||
| elif self.eval_mode == EvaluationMode.REF: | |||||
| logger.info('Evaluation mode: reference-only') | |||||
| elif self.eval_mode == EvaluationMode.SRC_REF: | |||||
| logger.info('Evaluation mode: source-reference-combined') | |||||
| else: | |||||
| raise ValueError( | |||||
| 'Evaluation mode should be one choice among' | |||||
| '\'EvaluationMode.SRC\', \'EvaluationMode.REF\', and' | |||||
| '\'EvaluationMode.SRC_REF\'.') | |||||
| def change_eval_mode(self, | |||||
| eval_mode: EvaluationMode = EvaluationMode.SRC_REF): | |||||
| logger.info('Changing the evaluation mode.') | |||||
| self.eval_mode = eval_mode | |||||
| self.checking_eval_mode() | |||||
| self.preprocessor.eval_mode = eval_mode | |||||
| return | |||||
| def __call__(self, input_dict: Dict[str, Union[str, List[str]]], **kwargs): | |||||
| r"""Implementation of __call__ function. | |||||
| Args: | |||||
| input_dict: The formatted dict containing the inputted sentences. | |||||
| An example of the formatted dict: | |||||
| ``` | |||||
| input_dict = { | |||||
| 'hyp': [ | |||||
| 'This is a sentence.', | |||||
| 'This is another sentence.', | |||||
| ], | |||||
| 'src': [ | |||||
| '这是个句子。', | |||||
| '这是另一个句子。', | |||||
| ], | |||||
| 'ref': [ | |||||
| 'It is a sentence.', | |||||
| 'It is another sentence.', | |||||
| ] | |||||
| } | |||||
| ``` | |||||
| """ | |||||
| return super().__call__(input=input_dict, **kwargs) | |||||
| def forward(self, | |||||
| input_ids: List[torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
| return self.model(input_ids) | |||||
| def postprocess(self, output: torch.Tensor) -> Dict[str, Any]: | |||||
| result = {OutputKeys.SCORES: output.cpu().tolist()} | |||||
| return result | |||||
| @@ -33,7 +33,8 @@ if TYPE_CHECKING: | |||||
| DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | ||||
| DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, | DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, | ||||
| TableQuestionAnsweringPreprocessor, NERPreprocessorViet, | TableQuestionAnsweringPreprocessor, NERPreprocessorViet, | ||||
| NERPreprocessorThai, WordSegmentationPreprocessorThai) | |||||
| NERPreprocessorThai, WordSegmentationPreprocessorThai, | |||||
| TranslationEvaluationPreprocessor) | |||||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | ||||
| else: | else: | ||||
| @@ -72,7 +73,8 @@ else: | |||||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | ||||
| 'DialogStateTrackingPreprocessor', | 'DialogStateTrackingPreprocessor', | ||||
| 'ConversationalTextToSqlPreprocessor', | 'ConversationalTextToSqlPreprocessor', | ||||
| 'TableQuestionAnsweringPreprocessor' | |||||
| 'TableQuestionAnsweringPreprocessor', | |||||
| 'TranslationEvaluationPreprocessor' | |||||
| ], | ], | ||||
| } | } | ||||
| @@ -28,6 +28,7 @@ if TYPE_CHECKING: | |||||
| from .space_T_en import ConversationalTextToSqlPreprocessor | from .space_T_en import ConversationalTextToSqlPreprocessor | ||||
| from .space_T_cn import TableQuestionAnsweringPreprocessor | from .space_T_cn import TableQuestionAnsweringPreprocessor | ||||
| from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | ||||
| from .translation_evaluation_preprocessor import TranslationEvaluationPreprocessor | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], | 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], | ||||
| @@ -76,6 +77,8 @@ else: | |||||
| ], | ], | ||||
| 'space_T_en': ['ConversationalTextToSqlPreprocessor'], | 'space_T_en': ['ConversationalTextToSqlPreprocessor'], | ||||
| 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | ||||
| 'translation_evaluation_preprocessor': | |||||
| ['TranslationEvaluationPreprocessor'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, List, Union | |||||
| from transformers import AutoTokenizer | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.models.nlp.unite.configuration_unite import EvaluationMode | |||||
| from modelscope.preprocessors import Preprocessor | |||||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||||
| from modelscope.utils.constant import Fields, ModeKeys | |||||
| from .transformers_tokenizer import NLPTokenizer | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.translation_evaluation) | |||||
| class TranslationEvaluationPreprocessor(Preprocessor): | |||||
| r"""The tokenizer preprocessor used for translation evaluation. | |||||
| """ | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| eval_mode: EvaluationMode, | |||||
| mode=ModeKeys.INFERENCE, | |||||
| *args, | |||||
| **kwargs): | |||||
| r"""preprocess the data via the vocab file from the `model_dir` path | |||||
| Args: | |||||
| model_dir: A Model instance. | |||||
| eval_mode: Evaluation mode, choosing one from `"EvaluationMode.SRC_REF"`, | |||||
| `"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the | |||||
| source/reference/source+reference can be presented during evaluation. | |||||
| """ | |||||
| super().__init__(mode=mode) | |||||
| self.tokenizer = NLPTokenizer( | |||||
| model_dir=model_dir, use_fast=False, tokenize_kwargs=kwargs) | |||||
| self.eval_mode = eval_mode | |||||
| return | |||||
| def __call__(self, input_dict: Dict[str, Any]) -> List[List[str]]: | |||||
| if self.eval_mode == EvaluationMode.SRC and 'src' not in input_dict.keys( | |||||
| ): | |||||
| raise ValueError( | |||||
| 'Source sentences are required for source-only evaluation mode.' | |||||
| ) | |||||
| if self.eval_mode == EvaluationMode.REF and 'ref' not in input_dict.keys( | |||||
| ): | |||||
| raise ValueError( | |||||
| 'Reference sentences are required for reference-only evaluation mode.' | |||||
| ) | |||||
| if self.eval_mode == EvaluationMode.SRC_REF and ( | |||||
| 'src' not in input_dict.keys() | |||||
| or 'ref' not in input_dict.keys()): | |||||
| raise ValueError( | |||||
| 'Source and reference sentences are both required for source-reference-combined evaluation mode.' | |||||
| ) | |||||
| if type(input_dict['hyp']) == str: | |||||
| input_dict['hyp'] = [input_dict['hyp']] | |||||
| if (self.eval_mode == EvaluationMode.SRC or self.eval_mode | |||||
| == EvaluationMode.SRC_REF) and type(input_dict['src']) == str: | |||||
| input_dict['src'] = [input_dict['src']] | |||||
| if (self.eval_mode == EvaluationMode.REF or self.eval_mode | |||||
| == EvaluationMode.SRC_REF) and type(input_dict['ref']) == str: | |||||
| input_dict['ref'] = [input_dict['ref']] | |||||
| output_sents = [ | |||||
| self.tokenizer( | |||||
| input_dict['hyp'], return_tensors='pt', | |||||
| padding=True)['input_ids'] | |||||
| ] | |||||
| if self.eval_mode == EvaluationMode.SRC or self.eval_mode == EvaluationMode.SRC_REF: | |||||
| output_sents += [ | |||||
| self.tokenizer( | |||||
| input_dict['src'], return_tensors='pt', | |||||
| padding=True)['input_ids'] | |||||
| ] | |||||
| if self.eval_mode == EvaluationMode.REF or self.eval_mode == EvaluationMode.SRC_REF: | |||||
| output_sents += [ | |||||
| self.tokenizer( | |||||
| input_dict['ref'], return_tensors='pt', | |||||
| padding=True)['input_ids'] | |||||
| ] | |||||
| return output_sents | |||||
| @@ -133,6 +133,7 @@ class NLPTasks(object): | |||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| extractive_summarization = 'extractive-summarization' | extractive_summarization = 'extractive-summarization' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| translation_evaluation = 'translation-evaluation' | |||||
| class AudioTasks(object): | class AudioTasks(object): | ||||
| @@ -0,0 +1,73 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.models.nlp.unite.configuration_unite import EvaluationMode | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.translation_evaluation | |||||
| self.model_id_large = 'damo/nlp_unite_mup_translation_evaluation_multilingual_large' | |||||
| self.model_id_base = 'damo/nlp_unite_mup_translation_evaluation_multilingual_base' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name_for_unite_large(self): | |||||
| input_dict = { | |||||
| 'hyp': [ | |||||
| 'This is a sentence.', | |||||
| 'This is another sentence.', | |||||
| ], | |||||
| 'src': [ | |||||
| '这是个句子。', | |||||
| '这是另一个句子。', | |||||
| ], | |||||
| 'ref': [ | |||||
| 'It is a sentence.', | |||||
| 'It is another sentence.', | |||||
| ] | |||||
| } | |||||
| pipeline_ins = pipeline(self.task, model=self.model_id_large) | |||||
| print(pipeline_ins(input_dict)) | |||||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) | |||||
| print(pipeline_ins(input_dict)) | |||||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) | |||||
| print(pipeline_ins(input_dict)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name_for_unite_base(self): | |||||
| input_dict = { | |||||
| 'hyp': [ | |||||
| 'This is a sentence.', | |||||
| 'This is another sentence.', | |||||
| ], | |||||
| 'src': [ | |||||
| '这是个句子。', | |||||
| '这是另一个句子。', | |||||
| ], | |||||
| 'ref': [ | |||||
| 'It is a sentence.', | |||||
| 'It is another sentence.', | |||||
| ] | |||||
| } | |||||
| pipeline_ins = pipeline(self.task, model=self.model_id_base) | |||||
| print(pipeline_ins(input_dict)) | |||||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) | |||||
| print(pipeline_ins(input_dict)) | |||||
| pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) | |||||
| print(pipeline_ins(input_dict)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||