[pipelines] support wenet note: ut failed is due to a run.py enveironment setup issue that is being fixed. nothing to do with the change.master^2
| @@ -1,4 +1,5 @@ | |||||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | ||||
| pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | |||||
| pip install -r requirements/tests.txt | pip install -r requirements/tests.txt | ||||
| git config --global --add safe.directory /Maas-lib | git config --global --add safe.directory /Maas-lib | ||||
| git config --global user.email tmp | git config --global user.email tmp | ||||
| @@ -92,6 +92,7 @@ class Models(object): | |||||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | ||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| generic_asr = 'generic-asr' | generic_asr = 'generic-asr' | ||||
| wenet_asr = 'wenet-asr' | |||||
| # multi-modal models | # multi-modal models | ||||
| ofa = 'ofa' | ofa = 'ofa' | ||||
| @@ -267,6 +268,7 @@ class Pipelines(object): | |||||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | ||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| asr_inference = 'asr-inference' | asr_inference = 'asr-inference' | ||||
| asr_wenet_inference = 'asr-wenet-inference' | |||||
| # multi-modal tasks | # multi-modal tasks | ||||
| image_captioning = 'image-captioning' | image_captioning = 'image-captioning' | ||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import wenetruntime as wenet | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['WeNetAutomaticSpeechRecognition'] | |||||
| @MODELS.register_module( | |||||
| Tasks.auto_speech_recognition, module_name=Models.wenet_asr) | |||||
| class WeNetAutomaticSpeechRecognition(Model): | |||||
| def __init__(self, model_dir: str, am_model_name: str, | |||||
| model_config: Dict[str, Any], *args, **kwargs): | |||||
| """initialize the info of model. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, am_model_name, model_config, *args, | |||||
| **kwargs) | |||||
| self.decoder = wenet.Decoder(model_dir, lang='chs') | |||||
| def forward(self, inputs: Dict[str, Any]) -> str: | |||||
| if inputs['audio_format'] == 'wav': | |||||
| rst = self.decoder.decode_wav(inputs['audio']) | |||||
| else: | |||||
| rst = self.decoder.decode(inputs['audio']) | |||||
| text = json.loads(rst)['nbest'][0]['sentence'] | |||||
| return {'text': text} | |||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Union | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import WavToScp | |||||
| from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, | |||||
| load_bytes_from_url) | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['WeNetAutomaticSpeechRecognitionPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference) | |||||
| class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| """ASR Inference Pipeline | |||||
| """ | |||||
| def __init__(self, | |||||
| model: Union[Model, str] = None, | |||||
| preprocessor: WavToScp = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create an asr pipeline for prediction | |||||
| """ | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def __call__(self, | |||||
| audio_in: Union[str, bytes], | |||||
| audio_fs: int = None, | |||||
| recog_type: str = None, | |||||
| audio_format: str = None) -> Dict[str, Any]: | |||||
| from easyasr.common import asr_utils | |||||
| self.recog_type = recog_type | |||||
| self.audio_format = audio_format | |||||
| self.audio_fs = audio_fs | |||||
| if isinstance(audio_in, str): | |||||
| # load pcm data from url if audio_in is url str | |||||
| self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) | |||||
| elif isinstance(audio_in, bytes): | |||||
| # load pcm data from wav data if audio_in is wave format | |||||
| self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) | |||||
| else: | |||||
| self.audio_in = audio_in | |||||
| # set the sample_rate of audio_in if checking_audio_fs is valid | |||||
| if checking_audio_fs is not None: | |||||
| self.audio_fs = checking_audio_fs | |||||
| if recog_type is None or audio_format is None: | |||||
| self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | |||||
| audio_in=self.audio_in, | |||||
| recog_type=recog_type, | |||||
| audio_format=audio_format) | |||||
| if hasattr(asr_utils, 'sample_rate_checking'): | |||||
| checking_audio_fs = asr_utils.sample_rate_checking( | |||||
| self.audio_in, self.audio_format) | |||||
| if checking_audio_fs is not None: | |||||
| self.audio_fs = checking_audio_fs | |||||
| inputs = { | |||||
| 'audio': self.audio_in, | |||||
| 'audio_format': self.audio_format, | |||||
| 'audio_fs': self.audio_fs | |||||
| } | |||||
| output = self.forward(inputs) | |||||
| rst = self.postprocess(output['asr_result']) | |||||
| return rst | |||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """Decoding | |||||
| """ | |||||
| inputs['asr_result'] = self.model(inputs) | |||||
| return inputs | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the asr results | |||||
| """ | |||||
| return inputs | |||||
| @@ -70,6 +70,11 @@ PYTORCH_IMPORT_ERROR = """ | |||||
| installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. | installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. | ||||
| """ | """ | ||||
| WENETRUNTIME_IMPORT_ERROR = """ | |||||
| {0} requires the wenetruntime library but it was not found in your environment. You can install it with pip: | |||||
| `pip install wenetruntime==TORCH_VER` | |||||
| """ | |||||
| # docstyle-ignore | # docstyle-ignore | ||||
| SCIPY_IMPORT_ERROR = """ | SCIPY_IMPORT_ERROR = """ | ||||
| {0} requires the scipy library but it was not found in your environment. You can install it with pip: | {0} requires the scipy library but it was not found in your environment. You can install it with pip: | ||||
| @@ -245,6 +245,10 @@ def is_torch_cuda_available(): | |||||
| return False | return False | ||||
| def is_wenetruntime_available(): | |||||
| return importlib.util.find_spec('wenetruntime') is not None | |||||
| def is_tf_available(): | def is_tf_available(): | ||||
| return _tf_available | return _tf_available | ||||
| @@ -280,6 +284,9 @@ REQUIREMENTS_MAAPING = OrderedDict([ | |||||
| ('timm', (is_timm_available, TIMM_IMPORT_ERROR)), | ('timm', (is_timm_available, TIMM_IMPORT_ERROR)), | ||||
| ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), | ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), | ||||
| ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), | ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), | ||||
| ('wenetruntime', | |||||
| (is_wenetruntime_available, | |||||
| WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))), | |||||
| ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)), | ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)), | ||||
| ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)), | ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)), | ||||
| ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)), | ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)), | ||||
| @@ -0,0 +1,131 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import unittest | |||||
| from typing import Any, Dict, Union | |||||
| import numpy as np | |||||
| import soundfile | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import ColorCodes, Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import download_and_untar, test_level | |||||
| logger = get_logger() | |||||
| WAV_FILE = 'data/test/audios/asr_example.wav' | |||||
| URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' | |||||
| class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| DemoCompatibilityCheck): | |||||
| action_info = { | |||||
| 'test_run_with_pcm': { | |||||
| 'checking_item': OutputKeys.TEXT, | |||||
| 'example': 'wav_example' | |||||
| }, | |||||
| 'test_run_with_url': { | |||||
| 'checking_item': OutputKeys.TEXT, | |||||
| 'example': 'wav_example' | |||||
| }, | |||||
| 'test_run_with_wav': { | |||||
| 'checking_item': OutputKeys.TEXT, | |||||
| 'example': 'wav_example' | |||||
| }, | |||||
| 'wav_example': { | |||||
| 'text': '每一天都要快乐喔' | |||||
| } | |||||
| } | |||||
| def setUp(self) -> None: | |||||
| self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online' | |||||
| # this temporary workspace dir will store waveform files | |||||
| self.workspace = os.path.join(os.getcwd(), '.tmp') | |||||
| self.task = Tasks.auto_speech_recognition | |||||
| if not os.path.exists(self.workspace): | |||||
| os.mkdir(self.workspace) | |||||
| def tearDown(self) -> None: | |||||
| # remove workspace dir (.tmp) | |||||
| shutil.rmtree(self.workspace, ignore_errors=True) | |||||
| def run_pipeline(self, | |||||
| model_id: str, | |||||
| audio_in: Union[str, bytes], | |||||
| sr: int = None) -> Dict[str, Any]: | |||||
| inference_16k_pipline = pipeline( | |||||
| task=Tasks.auto_speech_recognition, model=model_id) | |||||
| rec_result = inference_16k_pipline(audio_in, audio_fs=sr) | |||||
| return rec_result | |||||
| def log_error(self, functions: str, result: Dict[str, Any]) -> None: | |||||
| logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||||
| + ColorCodes.END) | |||||
| logger.error( | |||||
| ColorCodes.MAGENTA + functions + ' correct result example:' | |||||
| + ColorCodes.YELLOW | |||||
| + str(self.action_info[self.action_info[functions]['example']]) | |||||
| + ColorCodes.END) | |||||
| raise ValueError('asr result is mismatched') | |||||
| def check_result(self, functions: str, result: Dict[str, Any]) -> None: | |||||
| if result.__contains__(self.action_info[functions]['checking_item']): | |||||
| logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||||
| + ColorCodes.END) | |||||
| logger.info( | |||||
| ColorCodes.YELLOW | |||||
| + str(result[self.action_info[functions]['checking_item']]) | |||||
| + ColorCodes.END) | |||||
| else: | |||||
| self.log_error(functions, result) | |||||
| def wav2bytes(self, wav_file): | |||||
| audio, fs = soundfile.read(wav_file) | |||||
| # float32 -> int16 | |||||
| audio = np.asarray(audio) | |||||
| dtype = np.dtype('int16') | |||||
| i = np.iinfo(dtype) | |||||
| abs_max = 2**(i.bits - 1) | |||||
| offset = i.min + abs_max | |||||
| audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) | |||||
| # int16(PCM_16) -> byte | |||||
| audio = audio.tobytes() | |||||
| return audio, fs | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_pcm(self): | |||||
| """run with wav data | |||||
| """ | |||||
| logger.info('Run ASR test with wav data (wenet)...') | |||||
| audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||||
| rec_result = self.run_pipeline( | |||||
| model_id=self.am_model_id, audio_in=audio, sr=sr) | |||||
| self.check_result('test_run_with_pcm', rec_result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_wav(self): | |||||
| """run with single waveform file | |||||
| """ | |||||
| logger.info('Run ASR test with waveform file (wenet)...') | |||||
| wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||||
| rec_result = self.run_pipeline( | |||||
| model_id=self.am_model_id, audio_in=wav_file_path) | |||||
| self.check_result('test_run_with_wav', rec_result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_url(self): | |||||
| """run with single url file | |||||
| """ | |||||
| logger.info('Run ASR test with url file (wenet)...') | |||||
| rec_result = self.run_pipeline( | |||||
| model_id=self.am_model_id, audio_in=URL_FILE) | |||||
| self.check_result('test_run_with_url', rec_result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||