diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py index 5213fdd1..11accf0a 100644 --- a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py +++ b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py @@ -4,7 +4,7 @@ from typing import Any, Dict from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import Frameworks, Tasks __all__ = ['GenericAutomaticSpeechRecognition'] @@ -36,6 +36,29 @@ class GenericAutomaticSpeechRecognition(Model): } def forward(self) -> Dict[str, Any]: - """return the info of the model + """preload model and return the info of the model """ + if self.model_cfg['model_config']['type'] == Frameworks.tf: + from easyasr import asr_inference_paraformer_tf + if hasattr(asr_inference_paraformer_tf, 'preload'): + model_workspace = self.model_cfg['model_workspace'] + model_path = os.path.join(model_workspace, + self.model_cfg['am_model']) + vocab_path = os.path.join( + model_workspace, + self.model_cfg['model_config']['vocab_file']) + sampled_ids = 'seq2seq/sampled_ids' + sampled_lengths = 'seq2seq/sampled_lengths' + if 'sampled_ids' in self.model_cfg['model_config']: + sampled_ids = self.model_cfg['model_config']['sampled_ids'] + if 'sampled_lengths' in self.model_cfg['model_config']: + sampled_lengths = self.model_cfg['model_config'][ + 'sampled_lengths'] + asr_inference_paraformer_tf.preload( + ngpu=1, + asr_model_file=model_path, + vocab_file=vocab_path, + sampled_ids=sampled_ids, + sampled_lengths=sampled_lengths) + return self.model_cfg diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index 353a0d47..b321b770 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -30,6 +30,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): """use `model` and `preprocessor` to create an asr pipeline for prediction """ super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model_cfg = self.model.forward() def __call__(self, audio_in: Union[str, bytes], @@ -49,16 +50,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): recog_type=recog_type, audio_format=audio_format) - if hasattr(asr_utils, 'sample_rate_checking'): + if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None: self.audio_fs = asr_utils.sample_rate_checking( self.audio_in, self.audio_format) if self.preprocessor is None: self.preprocessor = WavToScp() - output = self.preprocessor.forward(self.model.forward(), - self.recog_type, self.audio_format, - self.audio_in, self.audio_fs) + output = self.preprocessor.forward(self.model_cfg, self.recog_type, + self.audio_format, self.audio_in, + self.audio_fs) output = self.forward(output) rst = self.postprocess(output) return rst @@ -198,8 +199,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): for line in lines: line_item = line.split(None, 1) - item = {'key': line_item[0], 'value': line_item[1].strip('\n')} - ref_list.append(item) + if len(line_item) > 1: + item = { + 'key': line_item[0], + 'value': line_item[1].strip('\n') + } + ref_list.append(item) return ref_list