From 2e30caf1e6dfb6a37e39599449583326aef889ae Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Wed, 23 Nov 2022 17:29:06 +0800 Subject: [PATCH] [pipelines] add wenetruntime --- modelscope/metainfo.py | 2 + .../asr/wenet_automatic_speech_recognition.py | 45 ++++++++++ .../audio/asr_wenet_inference_pipeline.py | 87 +++++++++++++++++++ requirements/audio.txt | 1 + 4 files changed, 135 insertions(+) create mode 100644 modelscope/models/audio/asr/wenet_automatic_speech_recognition.py create mode 100644 modelscope/pipelines/audio/asr_wenet_inference_pipeline.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ccd36349..b13e7aec 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -92,6 +92,7 @@ class Models(object): speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' kws_kwsbp = 'kws-kwsbp' generic_asr = 'generic-asr' + wenet_asr = 'wenet-asr' # multi-modal models ofa = 'ofa' @@ -267,6 +268,7 @@ class Pipelines(object): speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' kws_kwsbp = 'kws-kwsbp' asr_inference = 'asr-inference' + asr_wenet_inference = 'asr-wenet-inference' # multi-modal tasks image_captioning = 'image-captioning' diff --git a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py new file mode 100644 index 00000000..7db11190 --- /dev/null +++ b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +import wenetruntime as wenet + +__all__ = ['WeNetAutomaticSpeechRecognition'] + + +@MODELS.register_module( + Tasks.auto_speech_recognition, module_name=Models.wenet_asr) +class WeNetAutomaticSpeechRecognition(Model): + + def __init__(self, model_dir: str, am_model_name: str, + model_config: Dict[str, Any], *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + am_model_name (str): the am model name from configuration.json + model_config (Dict[str, Any]): the detail config about model from configuration.json + """ + super().__init__(model_dir, am_model_name, model_config, *args, + **kwargs) + self.model_cfg = { + # the recognition model dir path + 'model_dir': model_dir, + # the recognition model config dict + 'model_config': model_config + } + self.decoder = None + + def forward(self) -> Dict[str, Any]: + """preload model and return the info of the model + """ + model_dir = self.model_cfg['model_dir'] + self.decoder = wenet.Decoder(model_dir, lang='chs') + + return self.model_cfg diff --git a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py new file mode 100644 index 00000000..33e8c617 --- /dev/null +++ b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import WavToScp +from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, + load_bytes_from_url) +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['WeNetAutomaticSpeechRecognitionPipeline'] + + +@PIPELINES.register_module( + Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference) +class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): + """ASR Inference Pipeline + """ + + def __init__(self, + model: Union[Model, str] = None, + preprocessor: WavToScp = None, + **kwargs): + """use `model` and `preprocessor` to create an asr pipeline for prediction + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model_cfg = self.model.forward() + self.decoder = self.model.decoder + + def __call__(self, + audio_in: Union[str, bytes], + audio_fs: int = None, + recog_type: str = None, + audio_format: str = None) -> Dict[str, Any]: + from easyasr.common import asr_utils + + self.recog_type = recog_type + self.audio_format = audio_format + self.audio_fs = audio_fs + + if isinstance(audio_in, str): + # load pcm data from url if audio_in is url str + self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) + elif isinstance(audio_in, bytes): + # load pcm data from wav data if audio_in is wave format + self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) + else: + self.audio_in = audio_in + + # set the sample_rate of audio_in if checking_audio_fs is valid + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + + if recog_type is None or audio_format is None: + self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( + audio_in=self.audio_in, + recog_type=recog_type, + audio_format=audio_format) + + if hasattr(asr_utils, 'sample_rate_checking'): + checking_audio_fs = asr_utils.sample_rate_checking( + self.audio_in, self.audio_format) + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + + self.model_cfg['audio'] = self.audio_in + self.model_cfg['audio_fs'] = self.audio_fs + + output = self.forward(self.model_cfg) + rst = self.postprocess(output['asr_result']) + return rst + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Decoding + """ + inputs['asr_result'] = self.decoder.decode(inputs['audio']) + return inputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the asr results + """ + return inputs diff --git a/requirements/audio.txt b/requirements/audio.txt index bef32121..86c78d3c 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -25,3 +25,4 @@ torchaudio tqdm ttsfrd>=0.0.3 unidecode +wenetruntime