From 2e30caf1e6dfb6a37e39599449583326aef889ae Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Wed, 23 Nov 2022 17:29:06 +0800 Subject: [PATCH 1/6] [pipelines] add wenetruntime --- modelscope/metainfo.py | 2 + .../asr/wenet_automatic_speech_recognition.py | 45 ++++++++++ .../audio/asr_wenet_inference_pipeline.py | 87 +++++++++++++++++++ requirements/audio.txt | 1 + 4 files changed, 135 insertions(+) create mode 100644 modelscope/models/audio/asr/wenet_automatic_speech_recognition.py create mode 100644 modelscope/pipelines/audio/asr_wenet_inference_pipeline.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ccd36349..b13e7aec 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -92,6 +92,7 @@ class Models(object): speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' kws_kwsbp = 'kws-kwsbp' generic_asr = 'generic-asr' + wenet_asr = 'wenet-asr' # multi-modal models ofa = 'ofa' @@ -267,6 +268,7 @@ class Pipelines(object): speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' kws_kwsbp = 'kws-kwsbp' asr_inference = 'asr-inference' + asr_wenet_inference = 'asr-wenet-inference' # multi-modal tasks image_captioning = 'image-captioning' diff --git a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py new file mode 100644 index 00000000..7db11190 --- /dev/null +++ b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +import wenetruntime as wenet + +__all__ = ['WeNetAutomaticSpeechRecognition'] + + +@MODELS.register_module( + Tasks.auto_speech_recognition, module_name=Models.wenet_asr) +class WeNetAutomaticSpeechRecognition(Model): + + def __init__(self, model_dir: str, am_model_name: str, + model_config: Dict[str, Any], *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + am_model_name (str): the am model name from configuration.json + model_config (Dict[str, Any]): the detail config about model from configuration.json + """ + super().__init__(model_dir, am_model_name, model_config, *args, + **kwargs) + self.model_cfg = { + # the recognition model dir path + 'model_dir': model_dir, + # the recognition model config dict + 'model_config': model_config + } + self.decoder = None + + def forward(self) -> Dict[str, Any]: + """preload model and return the info of the model + """ + model_dir = self.model_cfg['model_dir'] + self.decoder = wenet.Decoder(model_dir, lang='chs') + + return self.model_cfg diff --git a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py new file mode 100644 index 00000000..33e8c617 --- /dev/null +++ b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import WavToScp +from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, + load_bytes_from_url) +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['WeNetAutomaticSpeechRecognitionPipeline'] + + +@PIPELINES.register_module( + Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference) +class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): + """ASR Inference Pipeline + """ + + def __init__(self, + model: Union[Model, str] = None, + preprocessor: WavToScp = None, + **kwargs): + """use `model` and `preprocessor` to create an asr pipeline for prediction + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model_cfg = self.model.forward() + self.decoder = self.model.decoder + + def __call__(self, + audio_in: Union[str, bytes], + audio_fs: int = None, + recog_type: str = None, + audio_format: str = None) -> Dict[str, Any]: + from easyasr.common import asr_utils + + self.recog_type = recog_type + self.audio_format = audio_format + self.audio_fs = audio_fs + + if isinstance(audio_in, str): + # load pcm data from url if audio_in is url str + self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) + elif isinstance(audio_in, bytes): + # load pcm data from wav data if audio_in is wave format + self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) + else: + self.audio_in = audio_in + + # set the sample_rate of audio_in if checking_audio_fs is valid + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + + if recog_type is None or audio_format is None: + self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( + audio_in=self.audio_in, + recog_type=recog_type, + audio_format=audio_format) + + if hasattr(asr_utils, 'sample_rate_checking'): + checking_audio_fs = asr_utils.sample_rate_checking( + self.audio_in, self.audio_format) + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + + self.model_cfg['audio'] = self.audio_in + self.model_cfg['audio_fs'] = self.audio_fs + + output = self.forward(self.model_cfg) + rst = self.postprocess(output['asr_result']) + return rst + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Decoding + """ + inputs['asr_result'] = self.decoder.decode(inputs['audio']) + return inputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the asr results + """ + return inputs diff --git a/requirements/audio.txt b/requirements/audio.txt index bef32121..86c78d3c 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -25,3 +25,4 @@ torchaudio tqdm ttsfrd>=0.0.3 unidecode +wenetruntime From 2605824dea612f2780ccbabb9ba7cf53bc89bfb8 Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Wed, 23 Nov 2022 21:58:03 +0800 Subject: [PATCH 2/6] [tests] add unittest --- .../asr/wenet_automatic_speech_recognition.py | 23 ++- .../audio/asr_wenet_inference_pipeline.py | 14 +- ...test_wenet_automatic_speech_recognition.py | 131 ++++++++++++++++++ 3 files changed, 146 insertions(+), 22 deletions(-) create mode 100644 tests/pipelines/test_wenet_automatic_speech_recognition.py diff --git a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py index 7db11190..1947629f 100644 --- a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py +++ b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py @@ -8,6 +8,7 @@ from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.utils.constant import Tasks +import json import wenetruntime as wenet __all__ = ['WeNetAutomaticSpeechRecognition'] @@ -23,23 +24,15 @@ class WeNetAutomaticSpeechRecognition(Model): Args: model_dir (str): the model path. - am_model_name (str): the am model name from configuration.json - model_config (Dict[str, Any]): the detail config about model from configuration.json """ super().__init__(model_dir, am_model_name, model_config, *args, **kwargs) - self.model_cfg = { - # the recognition model dir path - 'model_dir': model_dir, - # the recognition model config dict - 'model_config': model_config - } - self.decoder = None - - def forward(self) -> Dict[str, Any]: - """preload model and return the info of the model - """ - model_dir = self.model_cfg['model_dir'] self.decoder = wenet.Decoder(model_dir, lang='chs') - return self.model_cfg + def forward(self, inputs: Dict[str, Any]) -> str: + if inputs['audio_format'] == 'wav': + rst = self.decoder.decode_wav(inputs['audio']) + else: + rst = self.decoder.decode(inputs['audio']) + text = json.loads(rst)['nbest'][0]['sentence'] + return {'text': text} diff --git a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py index 33e8c617..6df47bcb 100644 --- a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py @@ -29,8 +29,6 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): """use `model` and `preprocessor` to create an asr pipeline for prediction """ super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.model_cfg = self.model.forward() - self.decoder = self.model.decoder def __call__(self, audio_in: Union[str, bytes], @@ -68,17 +66,19 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): if checking_audio_fs is not None: self.audio_fs = checking_audio_fs - self.model_cfg['audio'] = self.audio_in - self.model_cfg['audio_fs'] = self.audio_fs - - output = self.forward(self.model_cfg) + inputs = { + 'audio': self.audio_in, + 'audio_format': self.audio_format, + 'audio_fs': self.audio_fs + } + output = self.forward(inputs) rst = self.postprocess(output['asr_result']) return rst def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Decoding """ - inputs['asr_result'] = self.decoder.decode(inputs['audio']) + inputs['asr_result'] = self.model(inputs) return inputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/tests/pipelines/test_wenet_automatic_speech_recognition.py b/tests/pipelines/test_wenet_automatic_speech_recognition.py new file mode 100644 index 00000000..4adf8119 --- /dev/null +++ b/tests/pipelines/test_wenet_automatic_speech_recognition.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest +from typing import Any, Dict, Union + +import numpy as np +import soundfile + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ColorCodes, Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import download_and_untar, test_level + +logger = get_logger() + +WAV_FILE = 'data/test/audios/asr_example.wav' +URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' + + +class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase, + DemoCompatibilityCheck): + action_info = { + 'test_run_with_pcm': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_url': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_wav': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'wav_example': { + 'text': '每一天都要快乐喔' + } + } + + def setUp(self) -> None: + self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online' + # this temporary workspace dir will store waveform files + self.workspace = os.path.join(os.getcwd(), '.tmp') + self.task = Tasks.auto_speech_recognition + if not os.path.exists(self.workspace): + os.mkdir(self.workspace) + + def tearDown(self) -> None: + # remove workspace dir (.tmp) + shutil.rmtree(self.workspace, ignore_errors=True) + + def run_pipeline(self, + model_id: str, + audio_in: Union[str, bytes], + sr: int = None) -> Dict[str, Any]: + inference_16k_pipline = pipeline( + task=Tasks.auto_speech_recognition, model=model_id) + rec_result = inference_16k_pipline(audio_in, audio_fs=sr) + return rec_result + + def log_error(self, functions: str, result: Dict[str, Any]) -> None: + logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' + + ColorCodes.END) + logger.error( + ColorCodes.MAGENTA + functions + ' correct result example:' + + ColorCodes.YELLOW + + str(self.action_info[self.action_info[functions]['example']]) + + ColorCodes.END) + raise ValueError('asr result is mismatched') + + def check_result(self, functions: str, result: Dict[str, Any]) -> None: + if result.__contains__(self.action_info[functions]['checking_item']): + logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' + + ColorCodes.END) + logger.info( + ColorCodes.YELLOW + + str(result[self.action_info[functions]['checking_item']]) + + ColorCodes.END) + else: + self.log_error(functions, result) + + def wav2bytes(self, wav_file): + audio, fs = soundfile.read(wav_file) + + # float32 -> int16 + audio = np.asarray(audio) + dtype = np.dtype('int16') + i = np.iinfo(dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) + + # int16(PCM_16) -> byte + audio = audio.tobytes() + return audio, fs + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_pcm(self): + """run with wav data + """ + logger.info('Run ASR test with wav data (wenet)...') + audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) + rec_result = self.run_pipeline( + model_id=self.am_model_id, audio_in=audio, sr=sr) + self.check_result('test_run_with_pcm', rec_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + """run with single waveform file + """ + logger.info('Run ASR test with waveform file (wenet)...') + wav_file_path = os.path.join(os.getcwd(), WAV_FILE) + rec_result = self.run_pipeline( + model_id=self.am_model_id, audio_in=wav_file_path) + self.check_result('test_run_with_wav', rec_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_url(self): + """run with single url file + """ + logger.info('Run ASR test with url file (wenet)...') + rec_result = self.run_pipeline( + model_id=self.am_model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url', rec_result) + + +if __name__ == '__main__': + unittest.main() From eb2ef3a1cfc7ec511e73cc37d7d66a544dc59dfb Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Thu, 24 Nov 2022 19:48:48 +0800 Subject: [PATCH 3/6] [lint] fix lint --- .../models/audio/asr/wenet_automatic_speech_recognition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py index 1947629f..feb822d4 100644 --- a/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py +++ b/modelscope/models/audio/asr/wenet_automatic_speech_recognition.py @@ -3,14 +3,14 @@ import os from typing import Any, Dict +import json +import wenetruntime as wenet + from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.utils.constant import Tasks -import json -import wenetruntime as wenet - __all__ = ['WeNetAutomaticSpeechRecognition'] From b0cf09d7b0bf25e110f6fb52aa77161f6cd1deea Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Thu, 24 Nov 2022 22:12:58 +0800 Subject: [PATCH 4/6] [ci] chang pypi url to tsinghua --- .dev_scripts/ci_container_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.dev_scripts/ci_container_test.sh b/.dev_scripts/ci_container_test.sh index a3f13137..35b43535 100644 --- a/.dev_scripts/ci_container_test.sh +++ b/.dev_scripts/ci_container_test.sh @@ -1,4 +1,5 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then + pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple pip install -r requirements/tests.txt git config --global --add safe.directory /Maas-lib git config --global user.email tmp From a2532210af2712aa87ff0a72065ed84e567779f8 Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Fri, 25 Nov 2022 11:47:25 +0800 Subject: [PATCH 5/6] fix wenetruntime version --- requirements/audio.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/audio.txt b/requirements/audio.txt index 86c78d3c..037bb839 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -25,4 +25,5 @@ torchaudio tqdm ttsfrd>=0.0.3 unidecode -wenetruntime +# wenetruntime version should be the same as torch +wenetruntime==1.11 From 02d2469e55347c95349820caf660f2df1128fb58 Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Fri, 25 Nov 2022 15:37:45 +0800 Subject: [PATCH 6/6] check wenetruntime --- modelscope/utils/error.py | 5 +++++ modelscope/utils/import_utils.py | 7 +++++++ requirements/audio.txt | 2 -- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/modelscope/utils/error.py b/modelscope/utils/error.py index a894063c..8128f7b0 100644 --- a/modelscope/utils/error.py +++ b/modelscope/utils/error.py @@ -70,6 +70,11 @@ PYTORCH_IMPORT_ERROR = """ installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ +WENETRUNTIME_IMPORT_ERROR = """ +{0} requires the wenetruntime library but it was not found in your environment. You can install it with pip: +`pip install wenetruntime==TORCH_VER` +""" + # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 5db5ea98..64072eee 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -245,6 +245,10 @@ def is_torch_cuda_available(): return False +def is_wenetruntime_available(): + return importlib.util.find_spec('wenetruntime') is not None + + def is_tf_available(): return _tf_available @@ -280,6 +284,9 @@ REQUIREMENTS_MAAPING = OrderedDict([ ('timm', (is_timm_available, TIMM_IMPORT_ERROR)), ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), + ('wenetruntime', + (is_wenetruntime_available, + WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))), ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)), ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)), ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)), diff --git a/requirements/audio.txt b/requirements/audio.txt index 037bb839..bef32121 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -25,5 +25,3 @@ torchaudio tqdm ttsfrd>=0.0.3 unidecode -# wenetruntime version should be the same as torch -wenetruntime==1.11