ofa增加asr任务infer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10761019
master^2
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:46dbc998c9d1d48111267c40741dd3200f2e5bcf4075f8c4c97f4451160dce50 | |||||
| size 134570 | |||||
| @@ -284,6 +284,7 @@ class Pipelines(object): | |||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
| image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | ofa_ocr_recognition = 'ofa-ocr-recognition' | ||||
| ofa_asr = 'ofa-asr' | |||||
| # science tasks | # science tasks | ||||
| protein_structure = 'unifold-protein-structure' | protein_structure = 'unifold-protein-structure' | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .modeling_mmspeech import MMSpeechModel | |||||
| from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | ||||
| from .tokenization_ofa import OFATokenizer, OFATokenizerZH | from .tokenization_ofa import OFATokenizer, OFATokenizerZH | ||||
| from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast | from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast | ||||
| @@ -0,0 +1,260 @@ | |||||
| # Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ MMSpeech model configuration""" | |||||
| import warnings | |||||
| from transformers import PretrainedConfig | |||||
| from transformers.utils import logging | |||||
| logger = logging.get_logger(__name__) | |||||
| class MMSpeechConfig(PretrainedConfig): | |||||
| r""" | |||||
| This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA | |||||
| model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |||||
| defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) | |||||
| architecture. | |||||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||||
| documentation from [`PretrainedConfig`] for more information. | |||||
| Args: | |||||
| vocab_size (`int`, *optional*, defaults to 50265): | |||||
| Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the | |||||
| `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. | |||||
| d_model (`int`, *optional*, defaults to 1024): | |||||
| Dimension of the layers and the pooler layer. | |||||
| encoder_layers (`int`, *optional*, defaults to 12): | |||||
| Number of encoder layers. | |||||
| decoder_layers (`int`, *optional*, defaults to 12): | |||||
| Number of decoder layers. | |||||
| encoder_attention_heads (`int`, *optional*, defaults to 16): | |||||
| Number of attention heads for each attention layer in the Transformer encoder. | |||||
| decoder_attention_heads (`int`, *optional*, defaults to 16): | |||||
| Number of attention heads for each attention layer in the Transformer decoder. | |||||
| decoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||||
| encoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||||
| activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): | |||||
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |||||
| `"relu"`, `"silu"` and `"gelu_new"` are supported. | |||||
| dropout (`float`, *optional*, defaults to 0.1): | |||||
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||||
| attention_dropout (`float`, *optional*, defaults to 0.0): | |||||
| The dropout ratio for the attention probabilities. | |||||
| activation_dropout (`float`, *optional*, defaults to 0.0): | |||||
| The dropout ratio for activations inside the fully connected layer. | |||||
| classifier_dropout (`float`, *optional*, defaults to 0.0): | |||||
| The dropout ratio for classifier. | |||||
| max_position_embeddings (`int`, *optional*, defaults to 1024): | |||||
| The maximum sequence length that this model might ever be used with. Typically set this to something large | |||||
| just in case (e.g., 512 or 1024 or 2048). | |||||
| init_std (`float`, *optional*, defaults to 0.02): | |||||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||||
| encoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||||
| The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||||
| for more details. | |||||
| decoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||||
| The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||||
| for more details. | |||||
| use_cache (`bool`, *optional*, defaults to `True`): | |||||
| Whether or not the model should return the last key/values attentions (not used by all models). | |||||
| """ | |||||
| model_type = 'ofa' | |||||
| keys_to_ignore_at_inference = ['past_key_values'] | |||||
| attribute_map = { | |||||
| 'num_attention_heads': 'encoder_attention_heads', | |||||
| 'hidden_size': 'd_model' | |||||
| } | |||||
| def __init__(self, | |||||
| vocab_size=59457, | |||||
| max_position_embeddings=1024, | |||||
| encoder_layers=4, | |||||
| encoder_ffn_dim=512 * 4, | |||||
| encoder_attention_heads=8, | |||||
| decoder_layers=4, | |||||
| decoder_ffn_dim=512 * 4, | |||||
| decoder_attention_heads=8, | |||||
| encoder_layerdrop=0.0, | |||||
| decoder_layerdrop=0.0, | |||||
| use_cache=True, | |||||
| is_encoder_decoder=True, | |||||
| activation_function='gelu', | |||||
| d_model=512, | |||||
| dropout=0.1, | |||||
| attention_dropout=0.0, | |||||
| activation_dropout=0.0, | |||||
| init_std=0.02, | |||||
| classifier_dropout=0.0, | |||||
| scale_embedding=False, | |||||
| pad_token_id=1, | |||||
| bos_token_id=0, | |||||
| decoder_start_token_id=0, | |||||
| eos_token_id=2, | |||||
| forced_eos_token_id=2, | |||||
| encoder_normalize_before=True, | |||||
| decoder_normalize_before=True, | |||||
| normformer=True, | |||||
| encoder_drop_path_rate=0.0, | |||||
| decoder_drop_path_rate=0.0, | |||||
| layernorm_embedding=True, | |||||
| patch_layernorm_embedding=True, | |||||
| resnet_type='resnet101', | |||||
| resnet_model_path=None, | |||||
| resnet_drop_path_rate=0.0, | |||||
| token_bucket_size=256, | |||||
| image_bucket_size=42, | |||||
| add_type_embedding=True, | |||||
| share_decoder_input_output_embed=True, | |||||
| attn_scale_factor=2., | |||||
| code_layernorm_embedding=False, | |||||
| code_image_size=128, | |||||
| entangle_position_embedding=False, | |||||
| interpolate_position=False, | |||||
| orig_patch_image_size=224, | |||||
| share_attn_bias=False, | |||||
| use_image_feature=True, | |||||
| disable_entangle=False, | |||||
| use_ofasys=False, | |||||
| vit_type='vit_base', | |||||
| vit_drop_path_rate=0.0, | |||||
| required_seq_len_multiple=2, | |||||
| encoder_pos_conv_depth=5, | |||||
| encoder_conv_pos=95, | |||||
| encoder_conv_pos_groups=16, | |||||
| encoder_max_positions=100000, | |||||
| phone_vocab_size=141, | |||||
| audio_mask_prob=0.65, | |||||
| audio_mask_selection='static', | |||||
| audio_mask_other=0, | |||||
| audio_mask_length=10, | |||||
| audio_no_mask_overlap=False, | |||||
| audio_mask_min_space=1, | |||||
| audio_mask_channel_prob=0.0, | |||||
| audio_mask_channel_before=False, | |||||
| audio_mask_channel_selection='static', | |||||
| audio_mask_channel_other=0, | |||||
| audio_mask_channel_length=10, | |||||
| audio_no_mask_channel_overlap=False, | |||||
| audio_mask_channel_min_space=1, | |||||
| encoder_dropout_input=0.0, | |||||
| encoder_dropout_features=0.0, | |||||
| phone_dict_size=124, | |||||
| **kwargs): | |||||
| self.vocab_size = vocab_size | |||||
| self.max_position_embeddings = max_position_embeddings | |||||
| self.d_model = d_model | |||||
| self.encoder_ffn_dim = encoder_ffn_dim | |||||
| self.encoder_layers = encoder_layers | |||||
| self.encoder_attention_heads = encoder_attention_heads | |||||
| self.decoder_ffn_dim = decoder_ffn_dim | |||||
| self.decoder_layers = decoder_layers | |||||
| self.decoder_attention_heads = decoder_attention_heads | |||||
| self.dropout = dropout | |||||
| self.attention_dropout = attention_dropout | |||||
| self.activation_dropout = activation_dropout | |||||
| self.activation_function = activation_function | |||||
| self.init_std = init_std | |||||
| self.encoder_layerdrop = encoder_layerdrop | |||||
| self.decoder_layerdrop = decoder_layerdrop | |||||
| self.classifier_dropout = classifier_dropout | |||||
| self.use_cache = use_cache | |||||
| self.num_hidden_layers = encoder_layers | |||||
| self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True | |||||
| self.encoder_normalize_before = encoder_normalize_before | |||||
| self.decoder_normalize_before = decoder_normalize_before | |||||
| self.normformer = normformer | |||||
| self.encoder_drop_path_rate = encoder_drop_path_rate | |||||
| self.decoder_drop_path_rate = decoder_drop_path_rate | |||||
| self.layernorm_embedding = layernorm_embedding | |||||
| self.patch_layernorm_embedding = patch_layernorm_embedding | |||||
| self.resnet_type = resnet_type | |||||
| self.resnet_model_path = resnet_model_path | |||||
| self.resnet_drop_path_rate = resnet_drop_path_rate | |||||
| self.token_bucket_size = token_bucket_size | |||||
| self.image_bucket_size = image_bucket_size | |||||
| self.add_type_embedding = add_type_embedding | |||||
| self.share_decoder_input_output_embed = share_decoder_input_output_embed | |||||
| self.attn_scale_factor = attn_scale_factor | |||||
| self.code_layernorm_embedding = code_layernorm_embedding | |||||
| self.code_image_size = code_image_size | |||||
| self.entangle_position_embedding = entangle_position_embedding | |||||
| self.interpolate_position = interpolate_position | |||||
| self.orig_patch_image_size = orig_patch_image_size | |||||
| self.share_attn_bias = share_attn_bias | |||||
| self.use_image_feature = use_image_feature | |||||
| self.disable_entangle = disable_entangle | |||||
| self.use_ofasys = use_ofasys | |||||
| self.vit_type = vit_type | |||||
| self.vit_drop_path_rate = vit_drop_path_rate | |||||
| # FP16 optimization | |||||
| self.required_seq_len_multiple = required_seq_len_multiple | |||||
| # encoder_pos_conv | |||||
| self.encoder_pos_conv_depth = encoder_pos_conv_depth | |||||
| self.encoder_conv_pos = encoder_conv_pos | |||||
| self.encoder_conv_pos_groups = encoder_conv_pos_groups | |||||
| self.encoder_max_positions = encoder_max_positions | |||||
| # phone | |||||
| self.phone_vocab_size = phone_vocab_size | |||||
| # audio_mask | |||||
| self.audio_mask_prob = audio_mask_prob | |||||
| self.audio_mask_selection = audio_mask_selection | |||||
| self.audio_mask_other = audio_mask_other | |||||
| self.audio_mask_length = audio_mask_length | |||||
| self.audio_no_mask_overlap = audio_no_mask_overlap | |||||
| self.audio_mask_min_space = audio_mask_min_space | |||||
| self.audio_mask_channel_prob = audio_mask_channel_prob | |||||
| self.audio_mask_channel_before = audio_mask_channel_before | |||||
| self.audio_mask_channel_selection = audio_mask_channel_selection | |||||
| self.audio_mask_channel_other = audio_mask_channel_other | |||||
| self.audio_mask_channel_length = audio_mask_channel_length | |||||
| self.audio_no_mask_channel_overlap = audio_no_mask_channel_overlap | |||||
| self.audio_mask_channel_min_space = audio_mask_channel_min_space | |||||
| # audio encoder | |||||
| self.encoder_dropout_input = encoder_dropout_input | |||||
| self.encoder_dropout_features = encoder_dropout_features | |||||
| self.phone_dict_size = phone_dict_size | |||||
| super().__init__( | |||||
| pad_token_id=pad_token_id, | |||||
| bos_token_id=bos_token_id, | |||||
| eos_token_id=eos_token_id, | |||||
| is_encoder_decoder=is_encoder_decoder, | |||||
| decoder_start_token_id=decoder_start_token_id, | |||||
| forced_eos_token_id=forced_eos_token_id, | |||||
| **kwargs, | |||||
| ) | |||||
| # ensure backward compatibility for BART CNN models | |||||
| if self.forced_bos_token_id is None and kwargs.get( | |||||
| 'force_bos_token_to_be_generated', False): | |||||
| self.forced_bos_token_id = self.bos_token_id | |||||
| warnings.warn( | |||||
| f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' | |||||
| 'The config can simply be saved and uploaded again to be fixed.' | |||||
| ) | |||||
| @@ -227,6 +227,9 @@ class SequenceGenerator(nn.Module): | |||||
| - net_input['padding_mask'].sum(-1) | - net_input['padding_mask'].sum(-1) | ||||
| if net_input['padding_mask'] is not None else torch.tensor( | if net_input['padding_mask'] is not None else torch.tensor( | ||||
| src_tokens.size(-1)).to(src_tokens)) | src_tokens.size(-1)).to(src_tokens)) | ||||
| elif 'fbank' in net_input: | |||||
| src_tokens = net_input['fbank'] | |||||
| src_lengths = net_input['fbank_length'] | |||||
| else: | else: | ||||
| raise Exception( | raise Exception( | ||||
| 'expected src_tokens or source in net input. input keys: ' | 'expected src_tokens or source in net input. input keys: ' | ||||
| @@ -11,4 +11,5 @@ OFA_TASK_KEY_MAPPING = { | |||||
| Tasks.text_classification: OutputKeys.LABELS, | Tasks.text_classification: OutputKeys.LABELS, | ||||
| Tasks.image_classification: OutputKeys.LABELS, | Tasks.image_classification: OutputKeys.LABELS, | ||||
| Tasks.visual_entailment: OutputKeys.LABELS, | Tasks.visual_entailment: OutputKeys.LABELS, | ||||
| Tasks.auto_speech_recognition: OutputKeys.TEXT | |||||
| } | } | ||||
| @@ -19,7 +19,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.trie import Trie | from modelscope.utils.trie import Trie | ||||
| from .ofa import OFAModel, OFATokenizer, OFATokenizerZH | |||||
| from .ofa import MMSpeechModel, OFAModel, OFATokenizer, OFATokenizerZH | |||||
| from .ofa.generate import sequence_generator as sg | from .ofa.generate import sequence_generator as sg | ||||
| from .ofa.generate.utils import move_to_device | from .ofa.generate.utils import move_to_device | ||||
| from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks | from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks | ||||
| @@ -37,13 +37,20 @@ __all__ = ['OfaForAllTasks'] | |||||
| @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) | @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) | @MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) | @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.auto_speech_recognition, module_name=Models.ofa) | |||||
| class OfaForAllTasks(TorchModel): | class OfaForAllTasks(TorchModel): | ||||
| def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
| super().__init__(model_dir=model_dir, *args, **kwargs) | super().__init__(model_dir=model_dir, *args, **kwargs) | ||||
| model = OFAModel.from_pretrained(model_dir) | |||||
| self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | osp.join(model_dir, ModelFile.CONFIGURATION)) | ||||
| multimodal_type = self.cfg.model.get('multimodal_type', 'default') | |||||
| if multimodal_type == 'default': | |||||
| model = OFAModel.from_pretrained(model_dir) | |||||
| elif multimodal_type == 'mmspeech': | |||||
| model = MMSpeechModel.from_pretrained(model_dir) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| self.model = model.module if hasattr(model, 'module') else model | self.model = model.module if hasattr(model, 'module') else model | ||||
| self.language = self.cfg.model.get('language', 'en') | self.language = self.cfg.model.get('language', 'en') | ||||
| if self.language == 'en': | if self.language == 'en': | ||||
| @@ -54,12 +61,20 @@ class OfaForAllTasks(TorchModel): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| # there is some diff between here and our ofa code, | # there is some diff between here and our ofa code, | ||||
| # there will be no need to use param: use_bpe | # there will be no need to use param: use_bpe | ||||
| if not model.use_ofasys: | if not model.use_ofasys: | ||||
| self.tokenizer.add_tokens( | |||||
| ['<code_{}>'.format(i) for i in range(8192)]) | |||||
| self.tokenizer.add_tokens( | |||||
| ['<bin_{}>'.format(i) for i in range(1000)]) | |||||
| self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||||
| if multimodal_type == 'default': | |||||
| self.tokenizer.add_tokens( | |||||
| ['<code_{}>'.format(i) for i in range(8192)]) | |||||
| self.tokenizer.add_tokens( | |||||
| ['<bin_{}>'.format(i) for i in range(1000)]) | |||||
| self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||||
| elif multimodal_type == 'mmspeech': | |||||
| self.tokenizer.add_tokens('<blank>') | |||||
| self.tokenizer.add_tokens( | |||||
| ['<audio_{}>'.format(i) for i in range(30000)]) | |||||
| self.cfg.update({'num_bins': 0, 'num_codes': 30000}) | |||||
| self.batch_size = self.cfg.model.get('batch_size', 1) | self.batch_size = self.cfg.model.get('batch_size', 1) | ||||
| self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | ||||
| self.max_image_size = self.cfg.model.get('max_image_size', 512) | self.max_image_size = self.cfg.model.get('max_image_size', 512) | ||||
| @@ -110,6 +125,7 @@ class OfaForAllTasks(TorchModel): | |||||
| Tasks.visual_question_answering: inference_d[self.gen_type], | Tasks.visual_question_answering: inference_d[self.gen_type], | ||||
| Tasks.text_classification: inference_d[self.gen_type], | Tasks.text_classification: inference_d[self.gen_type], | ||||
| Tasks.image_classification: inference_d[self.gen_type], | Tasks.image_classification: inference_d[self.gen_type], | ||||
| Tasks.auto_speech_recognition: self._text_gen_inference, | |||||
| } | } | ||||
| pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' | pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' | ||||
| self.pattern = re.compile(pattern_str) | self.pattern = re.compile(pattern_str) | ||||
| @@ -186,7 +186,10 @@ TASK_INPUTS = { | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| Tasks.auto_speech_recognition: | Tasks.auto_speech_recognition: | ||||
| InputType.AUDIO, | |||||
| [InputType.AUDIO, { | |||||
| 'wav': InputType.AUDIO, | |||||
| 'text': InputType.TEXT | |||||
| }], | |||||
| Tasks.speech_signal_process: | Tasks.speech_signal_process: | ||||
| InputType.AUDIO, | InputType.AUDIO, | ||||
| Tasks.acoustic_echo_cancellation: { | Tasks.acoustic_echo_cancellation: { | ||||
| @@ -13,6 +13,7 @@ if TYPE_CHECKING: | |||||
| from .video_multi_modal_embedding_pipeline import \ | from .video_multi_modal_embedding_pipeline import \ | ||||
| VideoMultiModalEmbeddingPipeline | VideoMultiModalEmbeddingPipeline | ||||
| from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | ||||
| from .asr_pipeline import AutomaticSpeechRecognitionPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -26,7 +27,8 @@ else: | |||||
| 'video_multi_modal_embedding_pipeline': | 'video_multi_modal_embedding_pipeline': | ||||
| ['VideoMultiModalEmbeddingPipeline'], | ['VideoMultiModalEmbeddingPipeline'], | ||||
| 'generative_multi_modal_embedding_pipeline': | 'generative_multi_modal_embedding_pipeline': | ||||
| ['GEMMMultiModalEmbeddingPipeline'] | |||||
| ['GEMMMultiModalEmbeddingPipeline'], | |||||
| 'asr_pipeline': ['AutomaticSpeechRecognitionPipeline'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Model, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, | |||||
| Preprocessor) | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.auto_speech_recognition, module_name=Pipelines.ofa_asr) | |||||
| class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """ | |||||
| use `model` and `preprocessor` to create an automatic speech recognition pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| if preprocessor is None: | |||||
| if isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
| elif isinstance(pipe_model, MPlugForAllTasks): | |||||
| preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -53,7 +53,8 @@ class OfaPreprocessor(Preprocessor): | |||||
| Tasks.image_classification: OfaImageClassificationPreprocessor, | Tasks.image_classification: OfaImageClassificationPreprocessor, | ||||
| Tasks.text_classification: OfaTextClassificationPreprocessor, | Tasks.text_classification: OfaTextClassificationPreprocessor, | ||||
| Tasks.text_summarization: OfaSummarizationPreprocessor, | Tasks.text_summarization: OfaSummarizationPreprocessor, | ||||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | |||||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor, | |||||
| Tasks.auto_speech_recognition: OfaASRPreprocessor | |||||
| } | } | ||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | ||||
| model_dir) | model_dir) | ||||
| @@ -1,4 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .asr import OfaASRPreprocessor | |||||
| from .image_captioning import OfaImageCaptioningPreprocessor | from .image_captioning import OfaImageCaptioningPreprocessor | ||||
| from .image_classification import OfaImageClassificationPreprocessor | from .image_classification import OfaImageClassificationPreprocessor | ||||
| from .ocr_recognition import OfaOcrRecognitionPreprocessor | from .ocr_recognition import OfaOcrRecognitionPreprocessor | ||||
| @@ -0,0 +1,121 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import random | |||||
| from pathlib import Path | |||||
| from typing import Any, Dict | |||||
| import soundfile as sf | |||||
| import torch | |||||
| from fairseq.data.audio.feature_transforms import \ | |||||
| CompositeAudioFeatureTransform | |||||
| from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig | |||||
| from modelscope.utils.chinese_utils import pre_chinese | |||||
| from modelscope.utils.constant import ModeKeys | |||||
| from .base import OfaBasePreprocessor | |||||
| from .utils.text2phone import Text2Phone | |||||
| class OfaASRPreprocessor(OfaBasePreprocessor): | |||||
| def __init__(self, | |||||
| cfg, | |||||
| model_dir, | |||||
| mode=ModeKeys.INFERENCE, | |||||
| *args, | |||||
| **kwargs): | |||||
| """preprocess the data | |||||
| Args: | |||||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||||
| model_dir (str): model path, | |||||
| mode: preprocessor mode (model mode) | |||||
| """ | |||||
| super(OfaASRPreprocessor, self).__init__(cfg, model_dir, mode, *args, | |||||
| **kwargs) | |||||
| # Initialize transform | |||||
| self.data_cfg = S2TDataConfig( | |||||
| Path(os.path.join(model_dir, 'fbank_config.yaml'))) | |||||
| self.train_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( | |||||
| self.data_cfg.get_feature_transforms('train', True)) | |||||
| self.test_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( | |||||
| self.data_cfg.get_feature_transforms('test', False)) | |||||
| self.text2phone_tokenizer = Text2Phone( | |||||
| os.path.join(model_dir, 'text2phone_dict.txt')) | |||||
| self.phone_to_id, self.id_to_phone = self.build_phone_dict( | |||||
| os.path.join(model_dir, 'phone_dict.txt')) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| return self._build_train_sample(data) | |||||
| else: | |||||
| return self._build_infer_sample(data) | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| speed = random.choice([0.9, 1.0, 1.1]) | |||||
| wav, sr = sf.read(self.column_map['wav']) | |||||
| fbank = self.prepare_fbank( | |||||
| torch.tensor([wav], dtype=torch.float32), sr, speed, is_train=True) | |||||
| fbank_mask = torch.tensor([True]) | |||||
| sample = { | |||||
| 'fbank': fbank, | |||||
| 'fbank_mask': fbank_mask, | |||||
| 'label': data[self.column_map['text']] | |||||
| } | |||||
| target = sample['label'] | |||||
| if self.language == 'zh': | |||||
| target = pre_chinese(target, self.max_tgt_length) | |||||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||||
| else: | |||||
| target = target.translate(self.transtab).strip() | |||||
| target_token_list = target.strip().split() | |||||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | |||||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||||
| phone_item = self.to_phone(target) - 3 | |||||
| phone_mask = torch.tensor([False]) | |||||
| sample['phone_item'] = phone_item | |||||
| sample['phone_mask'] = phone_mask | |||||
| sample['prev_output_tokens'] = torch.cat( | |||||
| [self.bos_item, sample['target'][:-1]]) | |||||
| return sample | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| speed = 1.0 | |||||
| wav, sr = sf.read(data[self.column_map['wav']]) | |||||
| fbank = self.prepare_fbank( | |||||
| torch.tensor([wav], dtype=torch.float32), | |||||
| sr, | |||||
| speed, | |||||
| is_train=False) | |||||
| fbank_mask = torch.tensor([True]) | |||||
| sample = {'fbank': fbank, 'fbank_mask': fbank_mask} | |||||
| if 'text' in self.column_map and self.column_map['text'] in data: | |||||
| sample['label'] = data[self.column_map['text']] | |||||
| # mock | |||||
| sample['phone_item'] = torch.tensor([6, 6, 6]) | |||||
| sample['phone_mask'] = torch.tensor([False]) | |||||
| return sample | |||||
| def to_phone(self, text): | |||||
| phones = self.text2phone_tokenizer.trans(text) | |||||
| ids = torch.tensor([self.phone_to_id[x] for x in phones.split(' ')]) | |||||
| return ids | |||||
| def build_phone_dict(self, phone_dict_path): | |||||
| phone_to_id = dict() | |||||
| id_to_phone = dict() | |||||
| with open(phone_dict_path, 'r') as phone_dict_file: | |||||
| for i, line in enumerate(phone_dict_file): | |||||
| phone = line.strip().split(' ')[0] | |||||
| phone_to_id[phone] = i | |||||
| id_to_phone[i] = phone_to_id | |||||
| return phone_to_id, id_to_phone | |||||
| @@ -6,11 +6,14 @@ from os import path as osp | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import torchaudio | |||||
| from PIL import Image | from PIL import Image | ||||
| from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | ||||
| from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
| from modelscope.utils.trie import Trie | from modelscope.utils.trie import Trie | ||||
| from .utils.audio_helper import (_get_kaldi_fbank, _get_torchaudio_fbank, | |||||
| convert_waveform) | |||||
| from .utils.constant import OFA_TASK_KEY_MAPPING | from .utils.constant import OFA_TASK_KEY_MAPPING | ||||
| from .utils.random_help import set_torch_seed | from .utils.random_help import set_torch_seed | ||||
| @@ -88,6 +91,9 @@ class OfaBasePreprocessor: | |||||
| + answer_item.tolist() | + answer_item.tolist() | ||||
| + [tokenizer.eos_token_id]) | + [tokenizer.eos_token_id]) | ||||
| self.train_audio_feature_transforms = None | |||||
| self.test_audio_feature_transforms = None | |||||
| def tokenize_text(self, text, add_bos=True, add_eos=True): | def tokenize_text(self, text, add_bos=True, add_eos=True): | ||||
| if text is None: | if text is None: | ||||
| return None | return None | ||||
| @@ -163,3 +169,36 @@ class OfaBasePreprocessor: | |||||
| image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ | image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ | ||||
| else load_image(path_or_url_or_pil) | else load_image(path_or_url_or_pil) | ||||
| return image | return image | ||||
| def prepare_fbank(self, waveform, sample_rate, speed, is_train): | |||||
| waveform, _ = torchaudio.sox_effects.apply_effects_tensor( | |||||
| waveform, sample_rate, | |||||
| [['speed', str(speed)], ['rate', str(sample_rate)]]) | |||||
| _waveform, _ = convert_waveform( | |||||
| waveform, sample_rate, to_mono=True, normalize_volume=True) | |||||
| # Kaldi compliance: 16-bit signed integers | |||||
| _waveform = _waveform * (2**15) | |||||
| _waveform = _waveform.numpy() | |||||
| fbank = _get_kaldi_fbank(_waveform, sample_rate, 80) | |||||
| if fbank is None: | |||||
| fbank = _get_torchaudio_fbank(_waveform, sample_rate, 80) | |||||
| if fbank is None: | |||||
| raise ImportError( | |||||
| 'Please install pyKaldi or torchaudio to enable fbank feature extraction' | |||||
| ) | |||||
| if is_train and self.train_audio_feature_transforms is not None: | |||||
| fbank = self.train_audio_feature_transforms(fbank) | |||||
| elif ~is_train and self.test_audio_feature_transforms( | |||||
| fbank) is not None: | |||||
| fbank = self.test_audio_feature_transforms(fbank) | |||||
| fbank = torch.from_numpy(fbank).float() | |||||
| fbank = self.pack_frames(fbank) | |||||
| return fbank | |||||
| def pack_frames(self, feature: torch.Tensor): | |||||
| if self.cfg.n_frames_per_step == 1: | |||||
| return feature | |||||
| n_packed_frames = feature.shape[0] // self.cfg.n_frames_per_step | |||||
| feature = feature[:self.cfg.n_frames_per_step * n_packed_frames] | |||||
| return feature.reshape(n_packed_frames, -1) | |||||
| @@ -0,0 +1,91 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Optional, Tuple, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| def convert_waveform( | |||||
| waveform: Union[np.ndarray, torch.Tensor], | |||||
| sample_rate: int, | |||||
| normalize_volume: bool = False, | |||||
| to_mono: bool = False, | |||||
| to_sample_rate: Optional[int] = None, | |||||
| ) -> Tuple[Union[np.ndarray, torch.Tensor], int]: | |||||
| """convert a waveform: | |||||
| - to a target sample rate | |||||
| - from multi-channel to mono channel | |||||
| - volume normalization | |||||
| Args: | |||||
| waveform (numpy.ndarray or torch.Tensor): 2D original waveform | |||||
| (channels x length) | |||||
| sample_rate (int): original sample rate | |||||
| normalize_volume (bool): perform volume normalization | |||||
| to_mono (bool): convert to mono channel if having multiple channels | |||||
| to_sample_rate (Optional[int]): target sample rate | |||||
| Returns: | |||||
| waveform (numpy.ndarray): converted 2D waveform (channels x length) | |||||
| sample_rate (float): target sample rate | |||||
| """ | |||||
| try: | |||||
| import torchaudio.sox_effects as ta_sox | |||||
| except ImportError: | |||||
| raise ImportError('Please install torchaudio: pip install torchaudio') | |||||
| effects = [] | |||||
| if normalize_volume: | |||||
| effects.append(['gain', '-n']) | |||||
| if to_sample_rate is not None and to_sample_rate != sample_rate: | |||||
| effects.append(['rate', f'{to_sample_rate}']) | |||||
| if to_mono and waveform.shape[0] > 1: | |||||
| effects.append(['channels', '1']) | |||||
| if len(effects) > 0: | |||||
| is_np_input = isinstance(waveform, np.ndarray) | |||||
| _waveform = torch.from_numpy(waveform) if is_np_input else waveform | |||||
| converted, converted_sample_rate = ta_sox.apply_effects_tensor( | |||||
| _waveform, sample_rate, effects) | |||||
| if is_np_input: | |||||
| converted = converted.numpy() | |||||
| return converted, converted_sample_rate | |||||
| return waveform, sample_rate | |||||
| def _get_kaldi_fbank(waveform: np.ndarray, | |||||
| sample_rate: int, | |||||
| n_bins=80) -> Optional[np.ndarray]: | |||||
| """Get mel-filter bank features via PyKaldi.""" | |||||
| try: | |||||
| from kaldi.feat.fbank import Fbank, FbankOptions | |||||
| from kaldi.feat.mel import MelBanksOptions | |||||
| from kaldi.feat.window import FrameExtractionOptions | |||||
| from kaldi.matrix import Vector | |||||
| mel_opts = MelBanksOptions() | |||||
| mel_opts.num_bins = n_bins | |||||
| frame_opts = FrameExtractionOptions() | |||||
| frame_opts.samp_freq = sample_rate | |||||
| opts = FbankOptions() | |||||
| opts.mel_opts = mel_opts | |||||
| opts.frame_opts = frame_opts | |||||
| fbank = Fbank(opts=opts) | |||||
| features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() | |||||
| return features | |||||
| except ImportError: | |||||
| return None | |||||
| def _get_torchaudio_fbank(waveform: np.ndarray, | |||||
| sample_rate, | |||||
| n_bins=80) -> Optional[np.ndarray]: | |||||
| """Get mel-filter bank features via TorchAudio.""" | |||||
| try: | |||||
| import torchaudio.compliance.kaldi as ta_kaldi | |||||
| waveform = torch.from_numpy(waveform) | |||||
| features = ta_kaldi.fbank( | |||||
| waveform, num_mel_bins=n_bins, sample_frequency=sample_rate) | |||||
| return features.numpy() | |||||
| except ImportError: | |||||
| return None | |||||
| @@ -1,5 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import List | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| @@ -13,14 +15,12 @@ def collate_fn(samples, pad_idx, eos_idx): | |||||
| pad_idx, | pad_idx, | ||||
| eos_idx=eos_idx) | eos_idx=eos_idx) | ||||
| src_tokens = merge('source') | |||||
| batch = { | batch = { | ||||
| 'nsentences': len(samples), | 'nsentences': len(samples), | ||||
| 'net_input': { | |||||
| 'input_ids': src_tokens, | |||||
| }, | |||||
| 'net_input': {}, | |||||
| } | } | ||||
| if samples[0].get('source', None) is not None: | |||||
| batch['net_input']['input_ids'] = merge('source') | |||||
| if samples[0].get('id', None) is not None: | if samples[0].get('id', None) is not None: | ||||
| batch['id'] = np.array([s.get['id'] for s in samples]) | batch['id'] = np.array([s.get['id'] for s in samples]) | ||||
| if samples[0].get('target', None) is not None: | if samples[0].get('target', None) is not None: | ||||
| @@ -70,6 +70,20 @@ def collate_fn(samples, pad_idx, eos_idx): | |||||
| [s['region_coord'] for s in samples], dim=0) | [s['region_coord'] for s in samples], dim=0) | ||||
| if samples[0].get('sample', None) is not None: | if samples[0].get('sample', None) is not None: | ||||
| batch['samples'] = [s['sample'] for s in samples] | batch['samples'] = [s['sample'] for s in samples] | ||||
| # For asr | |||||
| if samples[0].get('fbank', None) is not None: | |||||
| batch['net_input']['fbank'] = _collate_frames( | |||||
| [s['fbank'] for s in samples]) | |||||
| batch['net_input']['fbank_length'] = torch.tensor( | |||||
| [s['fbank'].size(0) for s in samples], dtype=torch.long) | |||||
| if samples[0].get('fbank_mask', None) is not None: | |||||
| batch['net_input']['fbank_masks'] = torch.cat( | |||||
| [s['fbank_mask'] for s in samples]) | |||||
| if samples[0].get('phone_item', None) is not None: | |||||
| batch['net_input']['phone_items'] = merge('phone_item') | |||||
| batch['net_input']['phone_masks'] = torch.cat( | |||||
| [s['phone_mask'] for s in samples]) | |||||
| return batch | return batch | ||||
| @@ -113,3 +127,19 @@ def collate_tokens( | |||||
| for i, v in enumerate(values): | for i, v in enumerate(values): | ||||
| copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | ||||
| return res | return res | ||||
| def _collate_frames(frames: List[torch.Tensor]): | |||||
| """ | |||||
| Convert a list of 2D frames into a padded 3D tensor | |||||
| Args: | |||||
| frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is | |||||
| length of i-th frame and f_dim is static dimension of features | |||||
| Returns: | |||||
| 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] | |||||
| """ | |||||
| max_len = max(frame.size(0) for frame in frames) | |||||
| out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) | |||||
| for i, v in enumerate(frames): | |||||
| out[i, :v.size(0)] = v | |||||
| return out | |||||
| @@ -9,5 +9,6 @@ OFA_TASK_KEY_MAPPING = { | |||||
| Tasks.visual_grounding: ['image', 'text'], | Tasks.visual_grounding: ['image', 'text'], | ||||
| Tasks.visual_question_answering: ['image', 'text'], | Tasks.visual_question_answering: ['image', 'text'], | ||||
| Tasks.visual_entailment: ['image', 'text', 'text2'], | Tasks.visual_entailment: ['image', 'text', 'text2'], | ||||
| Tasks.text_to_image_synthesis: ['text'] | |||||
| Tasks.text_to_image_synthesis: ['text'], | |||||
| Tasks.auto_speech_recognition: ['wav', 'text'], | |||||
| } | } | ||||
| @@ -0,0 +1,192 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from modelscope.utils.chinese_utils import normalize_chinese_number | |||||
| class TrieNode(object): | |||||
| def __init__(self): | |||||
| """ | |||||
| Initialize your data structure here. | |||||
| """ | |||||
| self.data = {} | |||||
| self.is_word = False | |||||
| class Trie(object): | |||||
| """ | |||||
| trie-tree | |||||
| """ | |||||
| def __init__(self): | |||||
| """ | |||||
| Initialize your data structure here. | |||||
| """ | |||||
| self.root = TrieNode() | |||||
| def insert(self, word): | |||||
| """ | |||||
| Inserts a word into the trie. | |||||
| :type word: str | |||||
| :rtype: void | |||||
| """ | |||||
| node = self.root | |||||
| for chars in word: | |||||
| child = node.data.get(chars) | |||||
| if not child: | |||||
| node.data[chars] = TrieNode() | |||||
| node = node.data[chars] | |||||
| node.is_word = True | |||||
| def search(self, word): | |||||
| """ | |||||
| Returns if the word is in the trie. | |||||
| :type word: str | |||||
| :rtype: bool | |||||
| """ | |||||
| node = self.root | |||||
| for chars in word: | |||||
| node = node.data.get(chars) | |||||
| if not node: | |||||
| return False | |||||
| return node.is_word | |||||
| def startsWith(self, prefix): | |||||
| """ | |||||
| Returns if there is any word in the trie that starts with the given prefix. | |||||
| :type prefix: str | |||||
| :rtype: bool | |||||
| """ | |||||
| node = self.root | |||||
| for chars in prefix: | |||||
| node = node.data.get(chars) | |||||
| if not node: | |||||
| return False | |||||
| return True | |||||
| def get_start(self, prefix): | |||||
| """ | |||||
| Returns words started with prefix | |||||
| :param prefix: | |||||
| :return: words (list) | |||||
| """ | |||||
| def get_key(pre, pre_node): | |||||
| word_list = [] | |||||
| if pre_node.is_word: | |||||
| word_list.append(pre) | |||||
| for x in pre_node.data.keys(): | |||||
| word_list.extend(get_key(pre + str(x), pre_node.data.get(x))) | |||||
| return word_list | |||||
| words = [] | |||||
| if not self.startsWith(prefix): | |||||
| return words | |||||
| if self.search(prefix): | |||||
| words.append(prefix) | |||||
| return words | |||||
| node = self.root | |||||
| for chars in prefix: | |||||
| node = node.data.get(chars) | |||||
| return get_key(prefix, node) | |||||
| class TrieTokenizer(Trie): | |||||
| """ | |||||
| word_split based on trie-tree | |||||
| """ | |||||
| def __init__(self, dict_path): | |||||
| super(TrieTokenizer, self).__init__() | |||||
| self.dict_path = dict_path | |||||
| self.create_trie_tree() | |||||
| def load_dict(self): | |||||
| words = [] | |||||
| with open(self.dict_path, mode='r', encoding='utf-8') as file: | |||||
| for line in file: | |||||
| words.append(line.strip().split('\t')[0].encode( | |||||
| 'utf-8').decode('utf-8-sig')) | |||||
| return words | |||||
| def create_trie_tree(self): | |||||
| words = self.load_dict() | |||||
| for word in words: | |||||
| self.insert(word) | |||||
| def mine_tree(self, tree, sentence, trace_index): | |||||
| if trace_index <= (len(sentence) - 1): | |||||
| if sentence[trace_index] in tree.data: | |||||
| trace_index = trace_index + 1 | |||||
| trace_index = self.mine_tree( | |||||
| tree.data[sentence[trace_index - 1]], sentence, | |||||
| trace_index) | |||||
| return trace_index | |||||
| def tokenize(self, sentence): | |||||
| tokens = [] | |||||
| sentence_len = len(sentence) | |||||
| while sentence_len != 0: | |||||
| trace_index = 0 | |||||
| trace_index = self.mine_tree(self.root, sentence, trace_index) | |||||
| if trace_index == 0: | |||||
| tokens.append(sentence[0:1]) | |||||
| sentence = sentence[1:len(sentence)] | |||||
| sentence_len = len(sentence) | |||||
| else: | |||||
| tokens.append(sentence[0:trace_index]) | |||||
| sentence = sentence[trace_index:len(sentence)] | |||||
| sentence_len = len(sentence) | |||||
| return tokens | |||||
| def combine(self, token_list): | |||||
| flag = 0 | |||||
| output = [] | |||||
| temp = [] | |||||
| for i in token_list: | |||||
| if len(i) != 1: | |||||
| if flag == 0: | |||||
| output.append(i[::]) | |||||
| else: | |||||
| output.append(''.join(temp)) | |||||
| output.append(i[::]) | |||||
| temp = [] | |||||
| flag = 0 | |||||
| else: | |||||
| if flag == 0: | |||||
| temp.append(i) | |||||
| flag = 1 | |||||
| else: | |||||
| temp.append(i) | |||||
| return output | |||||
| class Text2Phone: | |||||
| def __init__(self, phone_dict_path): | |||||
| self.trie_cws = TrieTokenizer(phone_dict_path) | |||||
| self.phone_map = self.get_phone_map(phone_dict_path) | |||||
| def get_phone_map(self, phone_dict_path): | |||||
| phone_map = dict() | |||||
| with open(phone_dict_path, 'r') as phone_map_file_reader: | |||||
| for line in phone_map_file_reader: | |||||
| key, phone_series = line.strip().split('\t') | |||||
| if key not in phone_map: | |||||
| phone_map[key] = phone_series | |||||
| return phone_map | |||||
| def trans(self, text): | |||||
| text = normalize_chinese_number(text) | |||||
| tokens = self.trie_cws.tokenize(text) | |||||
| phones = [] | |||||
| for word in tokens: | |||||
| if word in self.phone_map: | |||||
| phones.append(self.phone_map[word]) | |||||
| elif len(word) > 1: | |||||
| for char in word: | |||||
| if char in self.phone_map: | |||||
| phones.append(self.phone_map[char]) | |||||
| return ' '.join(phones) | |||||
| @@ -113,6 +113,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| self.use_rdrop = args.get('use_rdrop', False) | self.use_rdrop = args.get('use_rdrop', False) | ||||
| self.reg_alpha = args.get('reg_alpha', 1.0) | self.reg_alpha = args.get('reg_alpha', 1.0) | ||||
| self.sample_patch_num = args.get('sample_patch_num', 196) | self.sample_patch_num = args.get('sample_patch_num', 196) | ||||
| self.ctc_weight = args.get('ctc_weight', 0.0) | |||||
| self.constraint_start = None | self.constraint_start = None | ||||
| self.constraint_end = None | self.constraint_end = None | ||||
| @@ -141,6 +142,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| output = model.model(**sample['net_input']) | output = model.model(**sample['net_input']) | ||||
| loss, nll_loss, ntokens = self.compute_loss( | loss, nll_loss, ntokens = self.compute_loss( | ||||
| output.logits, sample, update_num, reduce=reduce) | output.logits, sample, update_num, reduce=reduce) | ||||
| if self.ctc_weight > 0: | |||||
| ctc_loss = self.compute_ctc_loss(model, output, sample) | |||||
| loss = nll_loss + ctc_loss | |||||
| sample_size = ( | sample_size = ( | ||||
| sample['target'].size(0) if self.sentence_avg else ntokens) | sample['target'].size(0) if self.sentence_avg else ntokens) | ||||
| logging_output = { | logging_output = { | ||||
| @@ -206,6 +210,32 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| constraint_end=self.constraint_end) | constraint_end=self.constraint_end) | ||||
| return loss, nll_loss, ntokens | return loss, nll_loss, ntokens | ||||
| def compute_ctc_loss(self, model, output, sample): | |||||
| lprobs = model.get_encoder_normalized_probs( | |||||
| output, log_probs=True).contiguous() # (T, B, C) from the encoder | |||||
| non_padding_mask = ~output.encoder_padding_mask | |||||
| input_lengths = non_padding_mask.long().sum(-1) | |||||
| target_lengths = sample['ctc_output_lengths'] | |||||
| pad_mask = torch.arange(target_lengths.max()).expand([ | |||||
| target_lengths.shape[0], -1 | |||||
| ]).to(target_lengths) < target_lengths.unsqueeze(1) | |||||
| targets_flat = sample['ctc_outputs'].masked_select(pad_mask) | |||||
| with torch.backends.cudnn.flags(enabled=False): | |||||
| loss = F.ctc_loss( | |||||
| lprobs, | |||||
| targets_flat, | |||||
| input_lengths, | |||||
| target_lengths, | |||||
| blank=self.blank_idx, | |||||
| reduction='sum', | |||||
| zero_infinity=True, | |||||
| ) | |||||
| return loss | |||||
| def get_schedule(scheduler): | def get_schedule(scheduler): | ||||
| @@ -1,5 +1,13 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import re | |||||
| import string | |||||
| from zhconv import convert | |||||
| CHINESE_PUNCTUATION = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。' | |||||
| ENGLISH_PUNCTUATION = string.punctuation | |||||
| def is_chinese_char(word: str): | def is_chinese_char(word: str): | ||||
| chinese_punctuations = { | chinese_punctuations = { | ||||
| @@ -33,3 +41,28 @@ def rebuild_chinese_str(string: str): | |||||
| return ' '.join(''.join([ | return ' '.join(''.join([ | ||||
| f' {char} ' if is_chinese_char(char) else char for char in string | f' {char} ' if is_chinese_char(char) else char for char in string | ||||
| ]).split()) | ]).split()) | ||||
| def normalize_chinese_number(text): | |||||
| chinese_number = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九'] | |||||
| new_text = '' | |||||
| for x in text: | |||||
| if x in '0123456789': | |||||
| x = chinese_number[0] | |||||
| new_text += x | |||||
| new_text = convert(new_text, 'zh-hans') | |||||
| return new_text | |||||
| def pre_chinese(text, max_words): | |||||
| text = text.lower().replace(CHINESE_PUNCTUATION, | |||||
| ' ').replace(ENGLISH_PUNCTUATION, ' ') | |||||
| text = re.sub( | |||||
| r'\s{2,}', | |||||
| ' ', | |||||
| text, | |||||
| ) | |||||
| text = text.rstrip('\n') | |||||
| text = text.strip(' ')[:max_words] | |||||
| return text | |||||
| @@ -8,6 +8,7 @@ pytorch_lightning<=1.7.7 | |||||
| # which introduced compatability issues that are being investigated | # which introduced compatability issues that are being investigated | ||||
| rouge_score<=0.0.4 | rouge_score<=0.0.4 | ||||
| sacrebleu | sacrebleu | ||||
| soundfile | |||||
| taming-transformers-rom1504 | taming-transformers-rom1504 | ||||
| timm | timm | ||||
| tokenizers | tokenizers | ||||
| @@ -273,6 +273,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | result[OutputKeys.OUTPUT_IMG].save('result.png') | ||||
| print(f'Output written to {osp.abspath("result.png")}') | print(f'Output written to {osp.abspath("result.png")}') | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_asr_with_name(self): | |||||
| model = 'damo/ofa_asr_pretrain_base_zh' | |||||
| ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) | |||||
| example = {'wav': 'data/test/audios/asr_example_ofa.wav'} | |||||
| result = ofa_pipe(example) | |||||
| print(result[OutputKeys.TEXT]) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | @unittest.skip('demo compatibility test is only enabled on a needed-basis') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.compatibility_check() | self.compatibility_check() | ||||