From 2b620841466baa1316d0850ae4cd150a3888c3b5 Mon Sep 17 00:00:00 2001 From: "jiangyu.xzy" Date: Fri, 25 Nov 2022 17:49:24 +0800 Subject: [PATCH] add funasr based asr inference Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10868583 --- .../pipelines/audio/asr_inference_pipeline.py | 36 ++++++++++++++++--- modelscope/preprocessors/asr.py | 27 ++++++++++++++ modelscope/utils/import_utils.py | 1 + requirements/audio.txt | 1 + 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index da339083..c788e783 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -39,7 +39,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): audio_fs: int = None, recog_type: str = None, audio_format: str = None) -> Dict[str, Any]: - from easyasr.common import asr_utils + from funasr.utils import asr_utils self.recog_type = recog_type self.audio_format = audio_format @@ -109,6 +109,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): 'sampled_ids': 'seq2seq/sampled_ids', 'sampled_lengths': 'seq2seq/sampled_lengths', 'lang': 'zh-cn', + 'code_base': inputs['code_base'], 'fs': { 'audio_fs': inputs['audio_fs'], 'model_fs': 16000 @@ -130,6 +131,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): cmd['ctc_weight'] = root['ctc_weight'] cmd['lm_weight'] = root['lm_weight'] cmd['asr_train_config'] = inputs['am_model_config'] + cmd['lm_file'] = inputs['lm_model_path'] + cmd['lm_train_config'] = inputs['lm_model_config'] cmd['batch_size'] = inputs['model_config']['batch_size'] cmd['frontend_conf'] = frontend_conf if frontend_conf is not None and 'fs' in frontend_conf: @@ -161,7 +164,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """process the asr results """ - from easyasr.common import asr_utils + from funasr.utils import asr_utils logger.info('Computing the result of ASR ...') @@ -229,7 +232,33 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): def run_inference(self, cmd): asr_result = [] - if self.framework == Frameworks.torch: + if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr': + from funasr.bin import asr_inference_paraformer_modelscope + + if hasattr(asr_inference_paraformer_modelscope, 'set_parameters'): + asr_inference_paraformer_modelscope.set_parameters( + sample_rate=cmd['fs']) + asr_inference_paraformer_modelscope.set_parameters( + language=cmd['lang']) + + asr_result = asr_inference_paraformer_modelscope.asr_inference( + batch_size=cmd['batch_size'], + maxlenratio=cmd['maxlenratio'], + minlenratio=cmd['minlenratio'], + beam_size=cmd['beam_size'], + ngpu=cmd['ngpu'], + ctc_weight=cmd['ctc_weight'], + lm_weight=cmd['lm_weight'], + penalty=cmd['penalty'], + log_level=cmd['log_level'], + name_and_type=cmd['name_and_type'], + audio_lists=cmd['audio_in'], + asr_train_config=cmd['asr_train_config'], + asr_model_file=cmd['asr_model_file'], + lm_file=cmd['lm_file'], + lm_train_config=cmd['lm_train_config'], + frontend_conf=cmd['frontend_conf']) + elif self.framework == Frameworks.torch: from easyasr import asr_inference_paraformer_espnet if hasattr(asr_inference_paraformer_espnet, 'set_parameters'): @@ -253,7 +282,6 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): asr_train_config=cmd['asr_train_config'], asr_model_file=cmd['asr_model_file'], frontend_conf=cmd['frontend_conf']) - elif self.framework == Frameworks.tf: from easyasr import asr_inference_paraformer_tf if hasattr(asr_inference_paraformer_tf, 'set_parameters'): diff --git a/modelscope/preprocessors/asr.py b/modelscope/preprocessors/asr.py index 91bf5860..1537b137 100644 --- a/modelscope/preprocessors/asr.py +++ b/modelscope/preprocessors/asr.py @@ -97,6 +97,12 @@ class WavToScp(Preprocessor): assert inputs['model_config'].__contains__( 'type'), 'model type does not exist' inputs['model_type'] = inputs['model_config']['type'] + # code base + if 'code_base' in inputs['model_config']: + code_base = inputs['model_config']['code_base'] + else: + code_base = None + inputs['code_base'] = code_base if inputs['model_type'] == Frameworks.torch: assert inputs['model_config'].__contains__( @@ -127,6 +133,27 @@ class WavToScp(Preprocessor): assert os.path.exists( asr_model_wav_config), 'asr_model_wav_config does not exist' + # the lm model file path + if 'lm_model_name' in inputs['model_config']: + lm_model_path = os.path.join( + inputs['model_workspace'], + inputs['model_config']['lm_model_name']) + else: + lm_model_path = None + # the lm config file path + if 'lm_model_config' in inputs['model_config']: + lm_model_config = os.path.join( + inputs['model_workspace'], + inputs['model_config']['lm_model_config']) + else: + lm_model_config = None + if lm_model_path and lm_model_config and os.path.exists( + lm_model_path) and os.path.exists(lm_model_config): + inputs['lm_model_path'] = lm_model_path + inputs['lm_model_config'] = lm_model_config + else: + inputs['lm_model_path'] = None + inputs['lm_model_config'] = None if inputs['audio_format'] == 'wav' or inputs[ 'audio_format'] == 'pcm': inputs['asr_model_config'] = asr_model_wav_config diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 5db5ea98..f817b7a5 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -288,6 +288,7 @@ REQUIREMENTS_MAAPING = OrderedDict([ ('espnet', (is_espnet_available, GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))), ('easyasr', (is_package_available('easyasr'), AUDIO_IMPORT_ERROR)), + ('funasr', (is_package_available('funasr'), AUDIO_IMPORT_ERROR)), ('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)), ('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)), ('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)), diff --git a/requirements/audio.txt b/requirements/audio.txt index bef32121..bef3764b 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,5 +1,6 @@ easyasr>=0.0.2 espnet==202204 +funasr>=0.1.0 h5py inflect keras