[pipelines] support wenet note: ut failed is due to a run.py enveironment setup issue that is being fixed. nothing to do with the change.master^2
| @@ -1,4 +1,5 @@ | |||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
| pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple | |||
| pip install -r requirements/tests.txt | |||
| git config --global --add safe.directory /Maas-lib | |||
| git config --global user.email tmp | |||
| @@ -92,6 +92,7 @@ class Models(object): | |||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
| kws_kwsbp = 'kws-kwsbp' | |||
| generic_asr = 'generic-asr' | |||
| wenet_asr = 'wenet-asr' | |||
| # multi-modal models | |||
| ofa = 'ofa' | |||
| @@ -267,6 +268,7 @@ class Pipelines(object): | |||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
| kws_kwsbp = 'kws-kwsbp' | |||
| asr_inference = 'asr-inference' | |||
| asr_wenet_inference = 'asr-wenet-inference' | |||
| # multi-modal tasks | |||
| image_captioning = 'image-captioning' | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import wenetruntime as wenet | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['WeNetAutomaticSpeechRecognition'] | |||
| @MODELS.register_module( | |||
| Tasks.auto_speech_recognition, module_name=Models.wenet_asr) | |||
| class WeNetAutomaticSpeechRecognition(Model): | |||
| def __init__(self, model_dir: str, am_model_name: str, | |||
| model_config: Dict[str, Any], *args, **kwargs): | |||
| """initialize the info of model. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, am_model_name, model_config, *args, | |||
| **kwargs) | |||
| self.decoder = wenet.Decoder(model_dir, lang='chs') | |||
| def forward(self, inputs: Dict[str, Any]) -> str: | |||
| if inputs['audio_format'] == 'wav': | |||
| rst = self.decoder.decode_wav(inputs['audio']) | |||
| else: | |||
| rst = self.decoder.decode(inputs['audio']) | |||
| text = json.loads(rst)['nbest'][0]['sentence'] | |||
| return {'text': text} | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import WavToScp | |||
| from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, | |||
| load_bytes_from_url) | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['WeNetAutomaticSpeechRecognitionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference) | |||
| class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): | |||
| """ASR Inference Pipeline | |||
| """ | |||
| def __init__(self, | |||
| model: Union[Model, str] = None, | |||
| preprocessor: WavToScp = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create an asr pipeline for prediction | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def __call__(self, | |||
| audio_in: Union[str, bytes], | |||
| audio_fs: int = None, | |||
| recog_type: str = None, | |||
| audio_format: str = None) -> Dict[str, Any]: | |||
| from easyasr.common import asr_utils | |||
| self.recog_type = recog_type | |||
| self.audio_format = audio_format | |||
| self.audio_fs = audio_fs | |||
| if isinstance(audio_in, str): | |||
| # load pcm data from url if audio_in is url str | |||
| self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) | |||
| elif isinstance(audio_in, bytes): | |||
| # load pcm data from wav data if audio_in is wave format | |||
| self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) | |||
| else: | |||
| self.audio_in = audio_in | |||
| # set the sample_rate of audio_in if checking_audio_fs is valid | |||
| if checking_audio_fs is not None: | |||
| self.audio_fs = checking_audio_fs | |||
| if recog_type is None or audio_format is None: | |||
| self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | |||
| audio_in=self.audio_in, | |||
| recog_type=recog_type, | |||
| audio_format=audio_format) | |||
| if hasattr(asr_utils, 'sample_rate_checking'): | |||
| checking_audio_fs = asr_utils.sample_rate_checking( | |||
| self.audio_in, self.audio_format) | |||
| if checking_audio_fs is not None: | |||
| self.audio_fs = checking_audio_fs | |||
| inputs = { | |||
| 'audio': self.audio_in, | |||
| 'audio_format': self.audio_format, | |||
| 'audio_fs': self.audio_fs | |||
| } | |||
| output = self.forward(inputs) | |||
| rst = self.postprocess(output['asr_result']) | |||
| return rst | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """Decoding | |||
| """ | |||
| inputs['asr_result'] = self.model(inputs) | |||
| return inputs | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the asr results | |||
| """ | |||
| return inputs | |||
| @@ -70,6 +70,11 @@ PYTORCH_IMPORT_ERROR = """ | |||
| installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. | |||
| """ | |||
| WENETRUNTIME_IMPORT_ERROR = """ | |||
| {0} requires the wenetruntime library but it was not found in your environment. You can install it with pip: | |||
| `pip install wenetruntime==TORCH_VER` | |||
| """ | |||
| # docstyle-ignore | |||
| SCIPY_IMPORT_ERROR = """ | |||
| {0} requires the scipy library but it was not found in your environment. You can install it with pip: | |||
| @@ -245,6 +245,10 @@ def is_torch_cuda_available(): | |||
| return False | |||
| def is_wenetruntime_available(): | |||
| return importlib.util.find_spec('wenetruntime') is not None | |||
| def is_tf_available(): | |||
| return _tf_available | |||
| @@ -280,6 +284,9 @@ REQUIREMENTS_MAAPING = OrderedDict([ | |||
| ('timm', (is_timm_available, TIMM_IMPORT_ERROR)), | |||
| ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), | |||
| ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), | |||
| ('wenetruntime', | |||
| (is_wenetruntime_available, | |||
| WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))), | |||
| ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)), | |||
| ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)), | |||
| ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)), | |||
| @@ -0,0 +1,131 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import unittest | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import soundfile | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import ColorCodes, Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import download_and_untar, test_level | |||
| logger = get_logger() | |||
| WAV_FILE = 'data/test/audios/asr_example.wav' | |||
| URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' | |||
| class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase, | |||
| DemoCompatibilityCheck): | |||
| action_info = { | |||
| 'test_run_with_pcm': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'wav_example' | |||
| }, | |||
| 'test_run_with_url': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'wav_example' | |||
| }, | |||
| 'test_run_with_wav': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'wav_example' | |||
| }, | |||
| 'wav_example': { | |||
| 'text': '每一天都要快乐喔' | |||
| } | |||
| } | |||
| def setUp(self) -> None: | |||
| self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online' | |||
| # this temporary workspace dir will store waveform files | |||
| self.workspace = os.path.join(os.getcwd(), '.tmp') | |||
| self.task = Tasks.auto_speech_recognition | |||
| if not os.path.exists(self.workspace): | |||
| os.mkdir(self.workspace) | |||
| def tearDown(self) -> None: | |||
| # remove workspace dir (.tmp) | |||
| shutil.rmtree(self.workspace, ignore_errors=True) | |||
| def run_pipeline(self, | |||
| model_id: str, | |||
| audio_in: Union[str, bytes], | |||
| sr: int = None) -> Dict[str, Any]: | |||
| inference_16k_pipline = pipeline( | |||
| task=Tasks.auto_speech_recognition, model=model_id) | |||
| rec_result = inference_16k_pipline(audio_in, audio_fs=sr) | |||
| return rec_result | |||
| def log_error(self, functions: str, result: Dict[str, Any]) -> None: | |||
| logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||
| + ColorCodes.END) | |||
| logger.error( | |||
| ColorCodes.MAGENTA + functions + ' correct result example:' | |||
| + ColorCodes.YELLOW | |||
| + str(self.action_info[self.action_info[functions]['example']]) | |||
| + ColorCodes.END) | |||
| raise ValueError('asr result is mismatched') | |||
| def check_result(self, functions: str, result: Dict[str, Any]) -> None: | |||
| if result.__contains__(self.action_info[functions]['checking_item']): | |||
| logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||
| + ColorCodes.END) | |||
| logger.info( | |||
| ColorCodes.YELLOW | |||
| + str(result[self.action_info[functions]['checking_item']]) | |||
| + ColorCodes.END) | |||
| else: | |||
| self.log_error(functions, result) | |||
| def wav2bytes(self, wav_file): | |||
| audio, fs = soundfile.read(wav_file) | |||
| # float32 -> int16 | |||
| audio = np.asarray(audio) | |||
| dtype = np.dtype('int16') | |||
| i = np.iinfo(dtype) | |||
| abs_max = 2**(i.bits - 1) | |||
| offset = i.min + abs_max | |||
| audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) | |||
| # int16(PCM_16) -> byte | |||
| audio = audio.tobytes() | |||
| return audio, fs | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_pcm(self): | |||
| """run with wav data | |||
| """ | |||
| logger.info('Run ASR test with wav data (wenet)...') | |||
| audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_model_id, audio_in=audio, sr=sr) | |||
| self.check_result('test_run_with_pcm', rec_result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_wav(self): | |||
| """run with single waveform file | |||
| """ | |||
| logger.info('Run ASR test with waveform file (wenet)...') | |||
| wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_model_id, audio_in=wav_file_path) | |||
| self.check_result('test_run_with_wav', rec_result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_url(self): | |||
| """run with single url file | |||
| """ | |||
| logger.info('Run ASR test with url file (wenet)...') | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_model_id, audio_in=URL_FILE) | |||
| self.check_result('test_run_with_url', rec_result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||