ofa增加asr任务infer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10761019
master^2
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:46dbc998c9d1d48111267c40741dd3200f2e5bcf4075f8c4c97f4451160dce50 | |||
| size 134570 | |||
| @@ -284,6 +284,7 @@ class Pipelines(object): | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| image_text_retrieval = 'image-text-retrieval' | |||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | |||
| ofa_asr = 'ofa-asr' | |||
| # science tasks | |||
| protein_structure = 'unifold-protein-structure' | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .modeling_mmspeech import MMSpeechModel | |||
| from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | |||
| from .tokenization_ofa import OFATokenizer, OFATokenizerZH | |||
| from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast | |||
| @@ -0,0 +1,260 @@ | |||
| # Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ MMSpeech model configuration""" | |||
| import warnings | |||
| from transformers import PretrainedConfig | |||
| from transformers.utils import logging | |||
| logger = logging.get_logger(__name__) | |||
| class MMSpeechConfig(PretrainedConfig): | |||
| r""" | |||
| This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA | |||
| model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | |||
| defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) | |||
| architecture. | |||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||
| documentation from [`PretrainedConfig`] for more information. | |||
| Args: | |||
| vocab_size (`int`, *optional*, defaults to 50265): | |||
| Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the | |||
| `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. | |||
| d_model (`int`, *optional*, defaults to 1024): | |||
| Dimension of the layers and the pooler layer. | |||
| encoder_layers (`int`, *optional*, defaults to 12): | |||
| Number of encoder layers. | |||
| decoder_layers (`int`, *optional*, defaults to 12): | |||
| Number of decoder layers. | |||
| encoder_attention_heads (`int`, *optional*, defaults to 16): | |||
| Number of attention heads for each attention layer in the Transformer encoder. | |||
| decoder_attention_heads (`int`, *optional*, defaults to 16): | |||
| Number of attention heads for each attention layer in the Transformer decoder. | |||
| decoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
| encoder_ffn_dim (`int`, *optional*, defaults to 4096): | |||
| Dimension of the "intermediate" (often named feed-forward) layer in decoder. | |||
| activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): | |||
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |||
| `"relu"`, `"silu"` and `"gelu_new"` are supported. | |||
| dropout (`float`, *optional*, defaults to 0.1): | |||
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||
| attention_dropout (`float`, *optional*, defaults to 0.0): | |||
| The dropout ratio for the attention probabilities. | |||
| activation_dropout (`float`, *optional*, defaults to 0.0): | |||
| The dropout ratio for activations inside the fully connected layer. | |||
| classifier_dropout (`float`, *optional*, defaults to 0.0): | |||
| The dropout ratio for classifier. | |||
| max_position_embeddings (`int`, *optional*, defaults to 1024): | |||
| The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
| just in case (e.g., 512 or 1024 or 2048). | |||
| init_std (`float`, *optional*, defaults to 0.02): | |||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
| encoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
| The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
| for more details. | |||
| decoder_layerdrop: (`float`, *optional*, defaults to 0.0): | |||
| The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) | |||
| for more details. | |||
| use_cache (`bool`, *optional*, defaults to `True`): | |||
| Whether or not the model should return the last key/values attentions (not used by all models). | |||
| """ | |||
| model_type = 'ofa' | |||
| keys_to_ignore_at_inference = ['past_key_values'] | |||
| attribute_map = { | |||
| 'num_attention_heads': 'encoder_attention_heads', | |||
| 'hidden_size': 'd_model' | |||
| } | |||
| def __init__(self, | |||
| vocab_size=59457, | |||
| max_position_embeddings=1024, | |||
| encoder_layers=4, | |||
| encoder_ffn_dim=512 * 4, | |||
| encoder_attention_heads=8, | |||
| decoder_layers=4, | |||
| decoder_ffn_dim=512 * 4, | |||
| decoder_attention_heads=8, | |||
| encoder_layerdrop=0.0, | |||
| decoder_layerdrop=0.0, | |||
| use_cache=True, | |||
| is_encoder_decoder=True, | |||
| activation_function='gelu', | |||
| d_model=512, | |||
| dropout=0.1, | |||
| attention_dropout=0.0, | |||
| activation_dropout=0.0, | |||
| init_std=0.02, | |||
| classifier_dropout=0.0, | |||
| scale_embedding=False, | |||
| pad_token_id=1, | |||
| bos_token_id=0, | |||
| decoder_start_token_id=0, | |||
| eos_token_id=2, | |||
| forced_eos_token_id=2, | |||
| encoder_normalize_before=True, | |||
| decoder_normalize_before=True, | |||
| normformer=True, | |||
| encoder_drop_path_rate=0.0, | |||
| decoder_drop_path_rate=0.0, | |||
| layernorm_embedding=True, | |||
| patch_layernorm_embedding=True, | |||
| resnet_type='resnet101', | |||
| resnet_model_path=None, | |||
| resnet_drop_path_rate=0.0, | |||
| token_bucket_size=256, | |||
| image_bucket_size=42, | |||
| add_type_embedding=True, | |||
| share_decoder_input_output_embed=True, | |||
| attn_scale_factor=2., | |||
| code_layernorm_embedding=False, | |||
| code_image_size=128, | |||
| entangle_position_embedding=False, | |||
| interpolate_position=False, | |||
| orig_patch_image_size=224, | |||
| share_attn_bias=False, | |||
| use_image_feature=True, | |||
| disable_entangle=False, | |||
| use_ofasys=False, | |||
| vit_type='vit_base', | |||
| vit_drop_path_rate=0.0, | |||
| required_seq_len_multiple=2, | |||
| encoder_pos_conv_depth=5, | |||
| encoder_conv_pos=95, | |||
| encoder_conv_pos_groups=16, | |||
| encoder_max_positions=100000, | |||
| phone_vocab_size=141, | |||
| audio_mask_prob=0.65, | |||
| audio_mask_selection='static', | |||
| audio_mask_other=0, | |||
| audio_mask_length=10, | |||
| audio_no_mask_overlap=False, | |||
| audio_mask_min_space=1, | |||
| audio_mask_channel_prob=0.0, | |||
| audio_mask_channel_before=False, | |||
| audio_mask_channel_selection='static', | |||
| audio_mask_channel_other=0, | |||
| audio_mask_channel_length=10, | |||
| audio_no_mask_channel_overlap=False, | |||
| audio_mask_channel_min_space=1, | |||
| encoder_dropout_input=0.0, | |||
| encoder_dropout_features=0.0, | |||
| phone_dict_size=124, | |||
| **kwargs): | |||
| self.vocab_size = vocab_size | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.d_model = d_model | |||
| self.encoder_ffn_dim = encoder_ffn_dim | |||
| self.encoder_layers = encoder_layers | |||
| self.encoder_attention_heads = encoder_attention_heads | |||
| self.decoder_ffn_dim = decoder_ffn_dim | |||
| self.decoder_layers = decoder_layers | |||
| self.decoder_attention_heads = decoder_attention_heads | |||
| self.dropout = dropout | |||
| self.attention_dropout = attention_dropout | |||
| self.activation_dropout = activation_dropout | |||
| self.activation_function = activation_function | |||
| self.init_std = init_std | |||
| self.encoder_layerdrop = encoder_layerdrop | |||
| self.decoder_layerdrop = decoder_layerdrop | |||
| self.classifier_dropout = classifier_dropout | |||
| self.use_cache = use_cache | |||
| self.num_hidden_layers = encoder_layers | |||
| self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True | |||
| self.encoder_normalize_before = encoder_normalize_before | |||
| self.decoder_normalize_before = decoder_normalize_before | |||
| self.normformer = normformer | |||
| self.encoder_drop_path_rate = encoder_drop_path_rate | |||
| self.decoder_drop_path_rate = decoder_drop_path_rate | |||
| self.layernorm_embedding = layernorm_embedding | |||
| self.patch_layernorm_embedding = patch_layernorm_embedding | |||
| self.resnet_type = resnet_type | |||
| self.resnet_model_path = resnet_model_path | |||
| self.resnet_drop_path_rate = resnet_drop_path_rate | |||
| self.token_bucket_size = token_bucket_size | |||
| self.image_bucket_size = image_bucket_size | |||
| self.add_type_embedding = add_type_embedding | |||
| self.share_decoder_input_output_embed = share_decoder_input_output_embed | |||
| self.attn_scale_factor = attn_scale_factor | |||
| self.code_layernorm_embedding = code_layernorm_embedding | |||
| self.code_image_size = code_image_size | |||
| self.entangle_position_embedding = entangle_position_embedding | |||
| self.interpolate_position = interpolate_position | |||
| self.orig_patch_image_size = orig_patch_image_size | |||
| self.share_attn_bias = share_attn_bias | |||
| self.use_image_feature = use_image_feature | |||
| self.disable_entangle = disable_entangle | |||
| self.use_ofasys = use_ofasys | |||
| self.vit_type = vit_type | |||
| self.vit_drop_path_rate = vit_drop_path_rate | |||
| # FP16 optimization | |||
| self.required_seq_len_multiple = required_seq_len_multiple | |||
| # encoder_pos_conv | |||
| self.encoder_pos_conv_depth = encoder_pos_conv_depth | |||
| self.encoder_conv_pos = encoder_conv_pos | |||
| self.encoder_conv_pos_groups = encoder_conv_pos_groups | |||
| self.encoder_max_positions = encoder_max_positions | |||
| # phone | |||
| self.phone_vocab_size = phone_vocab_size | |||
| # audio_mask | |||
| self.audio_mask_prob = audio_mask_prob | |||
| self.audio_mask_selection = audio_mask_selection | |||
| self.audio_mask_other = audio_mask_other | |||
| self.audio_mask_length = audio_mask_length | |||
| self.audio_no_mask_overlap = audio_no_mask_overlap | |||
| self.audio_mask_min_space = audio_mask_min_space | |||
| self.audio_mask_channel_prob = audio_mask_channel_prob | |||
| self.audio_mask_channel_before = audio_mask_channel_before | |||
| self.audio_mask_channel_selection = audio_mask_channel_selection | |||
| self.audio_mask_channel_other = audio_mask_channel_other | |||
| self.audio_mask_channel_length = audio_mask_channel_length | |||
| self.audio_no_mask_channel_overlap = audio_no_mask_channel_overlap | |||
| self.audio_mask_channel_min_space = audio_mask_channel_min_space | |||
| # audio encoder | |||
| self.encoder_dropout_input = encoder_dropout_input | |||
| self.encoder_dropout_features = encoder_dropout_features | |||
| self.phone_dict_size = phone_dict_size | |||
| super().__init__( | |||
| pad_token_id=pad_token_id, | |||
| bos_token_id=bos_token_id, | |||
| eos_token_id=eos_token_id, | |||
| is_encoder_decoder=is_encoder_decoder, | |||
| decoder_start_token_id=decoder_start_token_id, | |||
| forced_eos_token_id=forced_eos_token_id, | |||
| **kwargs, | |||
| ) | |||
| # ensure backward compatibility for BART CNN models | |||
| if self.forced_bos_token_id is None and kwargs.get( | |||
| 'force_bos_token_to_be_generated', False): | |||
| self.forced_bos_token_id = self.bos_token_id | |||
| warnings.warn( | |||
| f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' | |||
| 'The config can simply be saved and uploaded again to be fixed.' | |||
| ) | |||
| @@ -227,6 +227,9 @@ class SequenceGenerator(nn.Module): | |||
| - net_input['padding_mask'].sum(-1) | |||
| if net_input['padding_mask'] is not None else torch.tensor( | |||
| src_tokens.size(-1)).to(src_tokens)) | |||
| elif 'fbank' in net_input: | |||
| src_tokens = net_input['fbank'] | |||
| src_lengths = net_input['fbank_length'] | |||
| else: | |||
| raise Exception( | |||
| 'expected src_tokens or source in net input. input keys: ' | |||
| @@ -11,4 +11,5 @@ OFA_TASK_KEY_MAPPING = { | |||
| Tasks.text_classification: OutputKeys.LABELS, | |||
| Tasks.image_classification: OutputKeys.LABELS, | |||
| Tasks.visual_entailment: OutputKeys.LABELS, | |||
| Tasks.auto_speech_recognition: OutputKeys.TEXT | |||
| } | |||
| @@ -19,7 +19,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.trie import Trie | |||
| from .ofa import OFAModel, OFATokenizer, OFATokenizerZH | |||
| from .ofa import MMSpeechModel, OFAModel, OFATokenizer, OFATokenizerZH | |||
| from .ofa.generate import sequence_generator as sg | |||
| from .ofa.generate.utils import move_to_device | |||
| from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks | |||
| @@ -37,13 +37,20 @@ __all__ = ['OfaForAllTasks'] | |||
| @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.auto_speech_recognition, module_name=Models.ofa) | |||
| class OfaForAllTasks(TorchModel): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||
| model = OFAModel.from_pretrained(model_dir) | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| multimodal_type = self.cfg.model.get('multimodal_type', 'default') | |||
| if multimodal_type == 'default': | |||
| model = OFAModel.from_pretrained(model_dir) | |||
| elif multimodal_type == 'mmspeech': | |||
| model = MMSpeechModel.from_pretrained(model_dir) | |||
| else: | |||
| raise NotImplementedError | |||
| self.model = model.module if hasattr(model, 'module') else model | |||
| self.language = self.cfg.model.get('language', 'en') | |||
| if self.language == 'en': | |||
| @@ -54,12 +61,20 @@ class OfaForAllTasks(TorchModel): | |||
| raise NotImplementedError | |||
| # there is some diff between here and our ofa code, | |||
| # there will be no need to use param: use_bpe | |||
| if not model.use_ofasys: | |||
| self.tokenizer.add_tokens( | |||
| ['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens( | |||
| ['<bin_{}>'.format(i) for i in range(1000)]) | |||
| self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
| if multimodal_type == 'default': | |||
| self.tokenizer.add_tokens( | |||
| ['<code_{}>'.format(i) for i in range(8192)]) | |||
| self.tokenizer.add_tokens( | |||
| ['<bin_{}>'.format(i) for i in range(1000)]) | |||
| self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
| elif multimodal_type == 'mmspeech': | |||
| self.tokenizer.add_tokens('<blank>') | |||
| self.tokenizer.add_tokens( | |||
| ['<audio_{}>'.format(i) for i in range(30000)]) | |||
| self.cfg.update({'num_bins': 0, 'num_codes': 30000}) | |||
| self.batch_size = self.cfg.model.get('batch_size', 1) | |||
| self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | |||
| self.max_image_size = self.cfg.model.get('max_image_size', 512) | |||
| @@ -110,6 +125,7 @@ class OfaForAllTasks(TorchModel): | |||
| Tasks.visual_question_answering: inference_d[self.gen_type], | |||
| Tasks.text_classification: inference_d[self.gen_type], | |||
| Tasks.image_classification: inference_d[self.gen_type], | |||
| Tasks.auto_speech_recognition: self._text_gen_inference, | |||
| } | |||
| pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' | |||
| self.pattern = re.compile(pattern_str) | |||
| @@ -186,7 +186,10 @@ TASK_INPUTS = { | |||
| # ============ audio tasks =================== | |||
| Tasks.auto_speech_recognition: | |||
| InputType.AUDIO, | |||
| [InputType.AUDIO, { | |||
| 'wav': InputType.AUDIO, | |||
| 'text': InputType.TEXT | |||
| }], | |||
| Tasks.speech_signal_process: | |||
| InputType.AUDIO, | |||
| Tasks.acoustic_echo_cancellation: { | |||
| @@ -13,6 +13,7 @@ if TYPE_CHECKING: | |||
| from .video_multi_modal_embedding_pipeline import \ | |||
| VideoMultiModalEmbeddingPipeline | |||
| from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | |||
| from .asr_pipeline import AutomaticSpeechRecognitionPipeline | |||
| else: | |||
| _import_structure = { | |||
| @@ -26,7 +27,8 @@ else: | |||
| 'video_multi_modal_embedding_pipeline': | |||
| ['VideoMultiModalEmbeddingPipeline'], | |||
| 'generative_multi_modal_embedding_pipeline': | |||
| ['GEMMMultiModalEmbeddingPipeline'] | |||
| ['GEMMMultiModalEmbeddingPipeline'], | |||
| 'asr_pipeline': ['AutomaticSpeechRecognitionPipeline'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Model, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, | |||
| Preprocessor) | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.auto_speech_recognition, module_name=Pipelines.ofa_asr) | |||
| class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create an automatic speech recognition pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||
| 'model must be a single str or OfaForAllTasks' | |||
| if isinstance(model, str): | |||
| pipe_model = Model.from_pretrained(model) | |||
| elif isinstance(model, Model): | |||
| pipe_model = model | |||
| else: | |||
| raise NotImplementedError | |||
| pipe_model.model.eval() | |||
| if preprocessor is None: | |||
| if isinstance(pipe_model, OfaForAllTasks): | |||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||
| elif isinstance(pipe_model, MPlugForAllTasks): | |||
| preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -53,7 +53,8 @@ class OfaPreprocessor(Preprocessor): | |||
| Tasks.image_classification: OfaImageClassificationPreprocessor, | |||
| Tasks.text_classification: OfaTextClassificationPreprocessor, | |||
| Tasks.text_summarization: OfaSummarizationPreprocessor, | |||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | |||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor, | |||
| Tasks.auto_speech_recognition: OfaASRPreprocessor | |||
| } | |||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
| model_dir) | |||
| @@ -1,4 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .asr import OfaASRPreprocessor | |||
| from .image_captioning import OfaImageCaptioningPreprocessor | |||
| from .image_classification import OfaImageClassificationPreprocessor | |||
| from .ocr_recognition import OfaOcrRecognitionPreprocessor | |||
| @@ -0,0 +1,121 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import random | |||
| from pathlib import Path | |||
| from typing import Any, Dict | |||
| import soundfile as sf | |||
| import torch | |||
| from fairseq.data.audio.feature_transforms import \ | |||
| CompositeAudioFeatureTransform | |||
| from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig | |||
| from modelscope.utils.chinese_utils import pre_chinese | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| from .utils.text2phone import Text2Phone | |||
| class OfaASRPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaASRPreprocessor, self).__init__(cfg, model_dir, mode, *args, | |||
| **kwargs) | |||
| # Initialize transform | |||
| self.data_cfg = S2TDataConfig( | |||
| Path(os.path.join(model_dir, 'fbank_config.yaml'))) | |||
| self.train_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( | |||
| self.data_cfg.get_feature_transforms('train', True)) | |||
| self.test_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict( | |||
| self.data_cfg.get_feature_transforms('test', False)) | |||
| self.text2phone_tokenizer = Text2Phone( | |||
| os.path.join(model_dir, 'text2phone_dict.txt')) | |||
| self.phone_to_id, self.id_to_phone = self.build_phone_dict( | |||
| os.path.join(model_dir, 'phone_dict.txt')) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| if self.mode == ModeKeys.TRAIN: | |||
| return self._build_train_sample(data) | |||
| else: | |||
| return self._build_infer_sample(data) | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| speed = random.choice([0.9, 1.0, 1.1]) | |||
| wav, sr = sf.read(self.column_map['wav']) | |||
| fbank = self.prepare_fbank( | |||
| torch.tensor([wav], dtype=torch.float32), sr, speed, is_train=True) | |||
| fbank_mask = torch.tensor([True]) | |||
| sample = { | |||
| 'fbank': fbank, | |||
| 'fbank_mask': fbank_mask, | |||
| 'label': data[self.column_map['text']] | |||
| } | |||
| target = sample['label'] | |||
| if self.language == 'zh': | |||
| target = pre_chinese(target, self.max_tgt_length) | |||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||
| else: | |||
| target = target.translate(self.transtab).strip() | |||
| target_token_list = target.strip().split() | |||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | |||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||
| phone_item = self.to_phone(target) - 3 | |||
| phone_mask = torch.tensor([False]) | |||
| sample['phone_item'] = phone_item | |||
| sample['phone_mask'] = phone_mask | |||
| sample['prev_output_tokens'] = torch.cat( | |||
| [self.bos_item, sample['target'][:-1]]) | |||
| return sample | |||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| speed = 1.0 | |||
| wav, sr = sf.read(data[self.column_map['wav']]) | |||
| fbank = self.prepare_fbank( | |||
| torch.tensor([wav], dtype=torch.float32), | |||
| sr, | |||
| speed, | |||
| is_train=False) | |||
| fbank_mask = torch.tensor([True]) | |||
| sample = {'fbank': fbank, 'fbank_mask': fbank_mask} | |||
| if 'text' in self.column_map and self.column_map['text'] in data: | |||
| sample['label'] = data[self.column_map['text']] | |||
| # mock | |||
| sample['phone_item'] = torch.tensor([6, 6, 6]) | |||
| sample['phone_mask'] = torch.tensor([False]) | |||
| return sample | |||
| def to_phone(self, text): | |||
| phones = self.text2phone_tokenizer.trans(text) | |||
| ids = torch.tensor([self.phone_to_id[x] for x in phones.split(' ')]) | |||
| return ids | |||
| def build_phone_dict(self, phone_dict_path): | |||
| phone_to_id = dict() | |||
| id_to_phone = dict() | |||
| with open(phone_dict_path, 'r') as phone_dict_file: | |||
| for i, line in enumerate(phone_dict_file): | |||
| phone = line.strip().split(' ')[0] | |||
| phone_to_id[phone] = i | |||
| id_to_phone[i] = phone_to_id | |||
| return phone_to_id, id_to_phone | |||
| @@ -6,11 +6,14 @@ from os import path as osp | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torchaudio | |||
| from PIL import Image | |||
| from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.trie import Trie | |||
| from .utils.audio_helper import (_get_kaldi_fbank, _get_torchaudio_fbank, | |||
| convert_waveform) | |||
| from .utils.constant import OFA_TASK_KEY_MAPPING | |||
| from .utils.random_help import set_torch_seed | |||
| @@ -88,6 +91,9 @@ class OfaBasePreprocessor: | |||
| + answer_item.tolist() | |||
| + [tokenizer.eos_token_id]) | |||
| self.train_audio_feature_transforms = None | |||
| self.test_audio_feature_transforms = None | |||
| def tokenize_text(self, text, add_bos=True, add_eos=True): | |||
| if text is None: | |||
| return None | |||
| @@ -163,3 +169,36 @@ class OfaBasePreprocessor: | |||
| image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ | |||
| else load_image(path_or_url_or_pil) | |||
| return image | |||
| def prepare_fbank(self, waveform, sample_rate, speed, is_train): | |||
| waveform, _ = torchaudio.sox_effects.apply_effects_tensor( | |||
| waveform, sample_rate, | |||
| [['speed', str(speed)], ['rate', str(sample_rate)]]) | |||
| _waveform, _ = convert_waveform( | |||
| waveform, sample_rate, to_mono=True, normalize_volume=True) | |||
| # Kaldi compliance: 16-bit signed integers | |||
| _waveform = _waveform * (2**15) | |||
| _waveform = _waveform.numpy() | |||
| fbank = _get_kaldi_fbank(_waveform, sample_rate, 80) | |||
| if fbank is None: | |||
| fbank = _get_torchaudio_fbank(_waveform, sample_rate, 80) | |||
| if fbank is None: | |||
| raise ImportError( | |||
| 'Please install pyKaldi or torchaudio to enable fbank feature extraction' | |||
| ) | |||
| if is_train and self.train_audio_feature_transforms is not None: | |||
| fbank = self.train_audio_feature_transforms(fbank) | |||
| elif ~is_train and self.test_audio_feature_transforms( | |||
| fbank) is not None: | |||
| fbank = self.test_audio_feature_transforms(fbank) | |||
| fbank = torch.from_numpy(fbank).float() | |||
| fbank = self.pack_frames(fbank) | |||
| return fbank | |||
| def pack_frames(self, feature: torch.Tensor): | |||
| if self.cfg.n_frames_per_step == 1: | |||
| return feature | |||
| n_packed_frames = feature.shape[0] // self.cfg.n_frames_per_step | |||
| feature = feature[:self.cfg.n_frames_per_step * n_packed_frames] | |||
| return feature.reshape(n_packed_frames, -1) | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| def convert_waveform( | |||
| waveform: Union[np.ndarray, torch.Tensor], | |||
| sample_rate: int, | |||
| normalize_volume: bool = False, | |||
| to_mono: bool = False, | |||
| to_sample_rate: Optional[int] = None, | |||
| ) -> Tuple[Union[np.ndarray, torch.Tensor], int]: | |||
| """convert a waveform: | |||
| - to a target sample rate | |||
| - from multi-channel to mono channel | |||
| - volume normalization | |||
| Args: | |||
| waveform (numpy.ndarray or torch.Tensor): 2D original waveform | |||
| (channels x length) | |||
| sample_rate (int): original sample rate | |||
| normalize_volume (bool): perform volume normalization | |||
| to_mono (bool): convert to mono channel if having multiple channels | |||
| to_sample_rate (Optional[int]): target sample rate | |||
| Returns: | |||
| waveform (numpy.ndarray): converted 2D waveform (channels x length) | |||
| sample_rate (float): target sample rate | |||
| """ | |||
| try: | |||
| import torchaudio.sox_effects as ta_sox | |||
| except ImportError: | |||
| raise ImportError('Please install torchaudio: pip install torchaudio') | |||
| effects = [] | |||
| if normalize_volume: | |||
| effects.append(['gain', '-n']) | |||
| if to_sample_rate is not None and to_sample_rate != sample_rate: | |||
| effects.append(['rate', f'{to_sample_rate}']) | |||
| if to_mono and waveform.shape[0] > 1: | |||
| effects.append(['channels', '1']) | |||
| if len(effects) > 0: | |||
| is_np_input = isinstance(waveform, np.ndarray) | |||
| _waveform = torch.from_numpy(waveform) if is_np_input else waveform | |||
| converted, converted_sample_rate = ta_sox.apply_effects_tensor( | |||
| _waveform, sample_rate, effects) | |||
| if is_np_input: | |||
| converted = converted.numpy() | |||
| return converted, converted_sample_rate | |||
| return waveform, sample_rate | |||
| def _get_kaldi_fbank(waveform: np.ndarray, | |||
| sample_rate: int, | |||
| n_bins=80) -> Optional[np.ndarray]: | |||
| """Get mel-filter bank features via PyKaldi.""" | |||
| try: | |||
| from kaldi.feat.fbank import Fbank, FbankOptions | |||
| from kaldi.feat.mel import MelBanksOptions | |||
| from kaldi.feat.window import FrameExtractionOptions | |||
| from kaldi.matrix import Vector | |||
| mel_opts = MelBanksOptions() | |||
| mel_opts.num_bins = n_bins | |||
| frame_opts = FrameExtractionOptions() | |||
| frame_opts.samp_freq = sample_rate | |||
| opts = FbankOptions() | |||
| opts.mel_opts = mel_opts | |||
| opts.frame_opts = frame_opts | |||
| fbank = Fbank(opts=opts) | |||
| features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() | |||
| return features | |||
| except ImportError: | |||
| return None | |||
| def _get_torchaudio_fbank(waveform: np.ndarray, | |||
| sample_rate, | |||
| n_bins=80) -> Optional[np.ndarray]: | |||
| """Get mel-filter bank features via TorchAudio.""" | |||
| try: | |||
| import torchaudio.compliance.kaldi as ta_kaldi | |||
| waveform = torch.from_numpy(waveform) | |||
| features = ta_kaldi.fbank( | |||
| waveform, num_mel_bins=n_bins, sample_frequency=sample_rate) | |||
| return features.numpy() | |||
| except ImportError: | |||
| return None | |||
| @@ -1,5 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import List | |||
| import numpy as np | |||
| import torch | |||
| @@ -13,14 +15,12 @@ def collate_fn(samples, pad_idx, eos_idx): | |||
| pad_idx, | |||
| eos_idx=eos_idx) | |||
| src_tokens = merge('source') | |||
| batch = { | |||
| 'nsentences': len(samples), | |||
| 'net_input': { | |||
| 'input_ids': src_tokens, | |||
| }, | |||
| 'net_input': {}, | |||
| } | |||
| if samples[0].get('source', None) is not None: | |||
| batch['net_input']['input_ids'] = merge('source') | |||
| if samples[0].get('id', None) is not None: | |||
| batch['id'] = np.array([s.get['id'] for s in samples]) | |||
| if samples[0].get('target', None) is not None: | |||
| @@ -70,6 +70,20 @@ def collate_fn(samples, pad_idx, eos_idx): | |||
| [s['region_coord'] for s in samples], dim=0) | |||
| if samples[0].get('sample', None) is not None: | |||
| batch['samples'] = [s['sample'] for s in samples] | |||
| # For asr | |||
| if samples[0].get('fbank', None) is not None: | |||
| batch['net_input']['fbank'] = _collate_frames( | |||
| [s['fbank'] for s in samples]) | |||
| batch['net_input']['fbank_length'] = torch.tensor( | |||
| [s['fbank'].size(0) for s in samples], dtype=torch.long) | |||
| if samples[0].get('fbank_mask', None) is not None: | |||
| batch['net_input']['fbank_masks'] = torch.cat( | |||
| [s['fbank_mask'] for s in samples]) | |||
| if samples[0].get('phone_item', None) is not None: | |||
| batch['net_input']['phone_items'] = merge('phone_item') | |||
| batch['net_input']['phone_masks'] = torch.cat( | |||
| [s['phone_mask'] for s in samples]) | |||
| return batch | |||
| @@ -113,3 +127,19 @@ def collate_tokens( | |||
| for i, v in enumerate(values): | |||
| copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) | |||
| return res | |||
| def _collate_frames(frames: List[torch.Tensor]): | |||
| """ | |||
| Convert a list of 2D frames into a padded 3D tensor | |||
| Args: | |||
| frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is | |||
| length of i-th frame and f_dim is static dimension of features | |||
| Returns: | |||
| 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] | |||
| """ | |||
| max_len = max(frame.size(0) for frame in frames) | |||
| out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) | |||
| for i, v in enumerate(frames): | |||
| out[i, :v.size(0)] = v | |||
| return out | |||
| @@ -9,5 +9,6 @@ OFA_TASK_KEY_MAPPING = { | |||
| Tasks.visual_grounding: ['image', 'text'], | |||
| Tasks.visual_question_answering: ['image', 'text'], | |||
| Tasks.visual_entailment: ['image', 'text', 'text2'], | |||
| Tasks.text_to_image_synthesis: ['text'] | |||
| Tasks.text_to_image_synthesis: ['text'], | |||
| Tasks.auto_speech_recognition: ['wav', 'text'], | |||
| } | |||
| @@ -0,0 +1,192 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from modelscope.utils.chinese_utils import normalize_chinese_number | |||
| class TrieNode(object): | |||
| def __init__(self): | |||
| """ | |||
| Initialize your data structure here. | |||
| """ | |||
| self.data = {} | |||
| self.is_word = False | |||
| class Trie(object): | |||
| """ | |||
| trie-tree | |||
| """ | |||
| def __init__(self): | |||
| """ | |||
| Initialize your data structure here. | |||
| """ | |||
| self.root = TrieNode() | |||
| def insert(self, word): | |||
| """ | |||
| Inserts a word into the trie. | |||
| :type word: str | |||
| :rtype: void | |||
| """ | |||
| node = self.root | |||
| for chars in word: | |||
| child = node.data.get(chars) | |||
| if not child: | |||
| node.data[chars] = TrieNode() | |||
| node = node.data[chars] | |||
| node.is_word = True | |||
| def search(self, word): | |||
| """ | |||
| Returns if the word is in the trie. | |||
| :type word: str | |||
| :rtype: bool | |||
| """ | |||
| node = self.root | |||
| for chars in word: | |||
| node = node.data.get(chars) | |||
| if not node: | |||
| return False | |||
| return node.is_word | |||
| def startsWith(self, prefix): | |||
| """ | |||
| Returns if there is any word in the trie that starts with the given prefix. | |||
| :type prefix: str | |||
| :rtype: bool | |||
| """ | |||
| node = self.root | |||
| for chars in prefix: | |||
| node = node.data.get(chars) | |||
| if not node: | |||
| return False | |||
| return True | |||
| def get_start(self, prefix): | |||
| """ | |||
| Returns words started with prefix | |||
| :param prefix: | |||
| :return: words (list) | |||
| """ | |||
| def get_key(pre, pre_node): | |||
| word_list = [] | |||
| if pre_node.is_word: | |||
| word_list.append(pre) | |||
| for x in pre_node.data.keys(): | |||
| word_list.extend(get_key(pre + str(x), pre_node.data.get(x))) | |||
| return word_list | |||
| words = [] | |||
| if not self.startsWith(prefix): | |||
| return words | |||
| if self.search(prefix): | |||
| words.append(prefix) | |||
| return words | |||
| node = self.root | |||
| for chars in prefix: | |||
| node = node.data.get(chars) | |||
| return get_key(prefix, node) | |||
| class TrieTokenizer(Trie): | |||
| """ | |||
| word_split based on trie-tree | |||
| """ | |||
| def __init__(self, dict_path): | |||
| super(TrieTokenizer, self).__init__() | |||
| self.dict_path = dict_path | |||
| self.create_trie_tree() | |||
| def load_dict(self): | |||
| words = [] | |||
| with open(self.dict_path, mode='r', encoding='utf-8') as file: | |||
| for line in file: | |||
| words.append(line.strip().split('\t')[0].encode( | |||
| 'utf-8').decode('utf-8-sig')) | |||
| return words | |||
| def create_trie_tree(self): | |||
| words = self.load_dict() | |||
| for word in words: | |||
| self.insert(word) | |||
| def mine_tree(self, tree, sentence, trace_index): | |||
| if trace_index <= (len(sentence) - 1): | |||
| if sentence[trace_index] in tree.data: | |||
| trace_index = trace_index + 1 | |||
| trace_index = self.mine_tree( | |||
| tree.data[sentence[trace_index - 1]], sentence, | |||
| trace_index) | |||
| return trace_index | |||
| def tokenize(self, sentence): | |||
| tokens = [] | |||
| sentence_len = len(sentence) | |||
| while sentence_len != 0: | |||
| trace_index = 0 | |||
| trace_index = self.mine_tree(self.root, sentence, trace_index) | |||
| if trace_index == 0: | |||
| tokens.append(sentence[0:1]) | |||
| sentence = sentence[1:len(sentence)] | |||
| sentence_len = len(sentence) | |||
| else: | |||
| tokens.append(sentence[0:trace_index]) | |||
| sentence = sentence[trace_index:len(sentence)] | |||
| sentence_len = len(sentence) | |||
| return tokens | |||
| def combine(self, token_list): | |||
| flag = 0 | |||
| output = [] | |||
| temp = [] | |||
| for i in token_list: | |||
| if len(i) != 1: | |||
| if flag == 0: | |||
| output.append(i[::]) | |||
| else: | |||
| output.append(''.join(temp)) | |||
| output.append(i[::]) | |||
| temp = [] | |||
| flag = 0 | |||
| else: | |||
| if flag == 0: | |||
| temp.append(i) | |||
| flag = 1 | |||
| else: | |||
| temp.append(i) | |||
| return output | |||
| class Text2Phone: | |||
| def __init__(self, phone_dict_path): | |||
| self.trie_cws = TrieTokenizer(phone_dict_path) | |||
| self.phone_map = self.get_phone_map(phone_dict_path) | |||
| def get_phone_map(self, phone_dict_path): | |||
| phone_map = dict() | |||
| with open(phone_dict_path, 'r') as phone_map_file_reader: | |||
| for line in phone_map_file_reader: | |||
| key, phone_series = line.strip().split('\t') | |||
| if key not in phone_map: | |||
| phone_map[key] = phone_series | |||
| return phone_map | |||
| def trans(self, text): | |||
| text = normalize_chinese_number(text) | |||
| tokens = self.trie_cws.tokenize(text) | |||
| phones = [] | |||
| for word in tokens: | |||
| if word in self.phone_map: | |||
| phones.append(self.phone_map[word]) | |||
| elif len(word) > 1: | |||
| for char in word: | |||
| if char in self.phone_map: | |||
| phones.append(self.phone_map[char]) | |||
| return ' '.join(phones) | |||
| @@ -113,6 +113,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| self.use_rdrop = args.get('use_rdrop', False) | |||
| self.reg_alpha = args.get('reg_alpha', 1.0) | |||
| self.sample_patch_num = args.get('sample_patch_num', 196) | |||
| self.ctc_weight = args.get('ctc_weight', 0.0) | |||
| self.constraint_start = None | |||
| self.constraint_end = None | |||
| @@ -141,6 +142,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| output = model.model(**sample['net_input']) | |||
| loss, nll_loss, ntokens = self.compute_loss( | |||
| output.logits, sample, update_num, reduce=reduce) | |||
| if self.ctc_weight > 0: | |||
| ctc_loss = self.compute_ctc_loss(model, output, sample) | |||
| loss = nll_loss + ctc_loss | |||
| sample_size = ( | |||
| sample['target'].size(0) if self.sentence_avg else ntokens) | |||
| logging_output = { | |||
| @@ -206,6 +210,32 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| constraint_end=self.constraint_end) | |||
| return loss, nll_loss, ntokens | |||
| def compute_ctc_loss(self, model, output, sample): | |||
| lprobs = model.get_encoder_normalized_probs( | |||
| output, log_probs=True).contiguous() # (T, B, C) from the encoder | |||
| non_padding_mask = ~output.encoder_padding_mask | |||
| input_lengths = non_padding_mask.long().sum(-1) | |||
| target_lengths = sample['ctc_output_lengths'] | |||
| pad_mask = torch.arange(target_lengths.max()).expand([ | |||
| target_lengths.shape[0], -1 | |||
| ]).to(target_lengths) < target_lengths.unsqueeze(1) | |||
| targets_flat = sample['ctc_outputs'].masked_select(pad_mask) | |||
| with torch.backends.cudnn.flags(enabled=False): | |||
| loss = F.ctc_loss( | |||
| lprobs, | |||
| targets_flat, | |||
| input_lengths, | |||
| target_lengths, | |||
| blank=self.blank_idx, | |||
| reduction='sum', | |||
| zero_infinity=True, | |||
| ) | |||
| return loss | |||
| def get_schedule(scheduler): | |||
| @@ -1,5 +1,13 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import re | |||
| import string | |||
| from zhconv import convert | |||
| CHINESE_PUNCTUATION = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。' | |||
| ENGLISH_PUNCTUATION = string.punctuation | |||
| def is_chinese_char(word: str): | |||
| chinese_punctuations = { | |||
| @@ -33,3 +41,28 @@ def rebuild_chinese_str(string: str): | |||
| return ' '.join(''.join([ | |||
| f' {char} ' if is_chinese_char(char) else char for char in string | |||
| ]).split()) | |||
| def normalize_chinese_number(text): | |||
| chinese_number = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九'] | |||
| new_text = '' | |||
| for x in text: | |||
| if x in '0123456789': | |||
| x = chinese_number[0] | |||
| new_text += x | |||
| new_text = convert(new_text, 'zh-hans') | |||
| return new_text | |||
| def pre_chinese(text, max_words): | |||
| text = text.lower().replace(CHINESE_PUNCTUATION, | |||
| ' ').replace(ENGLISH_PUNCTUATION, ' ') | |||
| text = re.sub( | |||
| r'\s{2,}', | |||
| ' ', | |||
| text, | |||
| ) | |||
| text = text.rstrip('\n') | |||
| text = text.strip(' ')[:max_words] | |||
| return text | |||
| @@ -8,6 +8,7 @@ pytorch_lightning<=1.7.7 | |||
| # which introduced compatability issues that are being investigated | |||
| rouge_score<=0.0.4 | |||
| sacrebleu | |||
| soundfile | |||
| taming-transformers-rom1504 | |||
| timm | |||
| tokenizers | |||
| @@ -273,6 +273,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_asr_with_name(self): | |||
| model = 'damo/ofa_asr_pretrain_base_zh' | |||
| ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) | |||
| example = {'wav': 'data/test/audios/asr_example_ofa.wav'} | |||
| result = ofa_pipe(example) | |||
| print(result[OutputKeys.TEXT]) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||