diff --git a/modelscope/outputs.py b/modelscope/outputs.py index e4d7e373..30860f29 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -428,6 +428,12 @@ TASK_OUTPUTS = { # {"text": "this is a text answser. "} Tasks.visual_question_answering: [OutputKeys.TEXT], + # auto_speech_recognition result for a single sample + # { + # "text": "每天都要快乐喔" + # } + Tasks.auto_speech_recognition: [OutputKeys.TEXT], + # { # "scores": [0.9, 0.1, 0.1], # "labels": ["entailment", "contradiction", "neutral"] diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index ac53d12d..353a0d47 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -33,6 +33,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): def __call__(self, audio_in: Union[str, bytes], + audio_fs: int = None, recog_type: str = None, audio_format: str = None) -> Dict[str, Any]: from easyasr.common import asr_utils @@ -40,17 +41,24 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): self.recog_type = recog_type self.audio_format = audio_format self.audio_in = audio_in + self.audio_fs = audio_fs if recog_type is None or audio_format is None: self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( - audio_in, recog_type, audio_format) + audio_in=audio_in, + recog_type=recog_type, + audio_format=audio_format) + + if hasattr(asr_utils, 'sample_rate_checking'): + self.audio_fs = asr_utils.sample_rate_checking( + self.audio_in, self.audio_format) if self.preprocessor is None: self.preprocessor = WavToScp() output = self.preprocessor.forward(self.model.forward(), self.recog_type, self.audio_format, - self.audio_in) + self.audio_in, self.audio_fs) output = self.forward(output) rst = self.postprocess(output) return rst @@ -77,7 +85,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): 'audio_in': inputs['audio_lists'], 'name_and_type': data_cmd, 'asr_model_file': inputs['am_model_path'], - 'idx_text': '' + 'idx_text': '', + 'sampled_ids': 'seq2seq/sampled_ids', + 'sampled_lengths': 'seq2seq/sampled_lengths', + 'lang': 'zh-cn', + 'fs': { + 'audio_fs': inputs['audio_fs'], + 'model_fs': 16000 + } } if self.framework == Frameworks.torch: @@ -97,16 +112,24 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): cmd['asr_train_config'] = inputs['am_model_config'] cmd['batch_size'] = inputs['model_config']['batch_size'] cmd['frontend_conf'] = frontend_conf + if frontend_conf is not None and 'fs' in frontend_conf: + cmd['fs']['model_fs'] = frontend_conf['fs'] elif self.framework == Frameworks.tf: - cmd['fs'] = inputs['model_config']['fs'] + cmd['fs']['model_fs'] = inputs['model_config']['fs'] cmd['hop_length'] = inputs['model_config']['hop_length'] cmd['feature_dims'] = inputs['model_config']['feature_dims'] cmd['predictions_file'] = 'text' cmd['mvn_file'] = inputs['am_mvn_file'] cmd['vocab_file'] = inputs['vocab_file'] + cmd['lang'] = inputs['model_lang'] if 'idx_text' in inputs: cmd['idx_text'] = inputs['idx_text'] + if 'sampled_ids' in inputs['model_config']: + cmd['sampled_ids'] = inputs['model_config']['sampled_ids'] + if 'sampled_lengths' in inputs['model_config']: + cmd['sampled_lengths'] = inputs['model_config'][ + 'sampled_lengths'] else: raise ValueError('model type is mismatching') @@ -134,8 +157,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): # run with datasets, and audio format is waveform or kaldi_ark or tfrecord elif inputs['recog_type'] != 'wav': inputs['reference_list'] = self.ref_list_tidy(inputs) + + if hasattr(asr_utils, 'set_parameters'): + asr_utils.set_parameters(language=inputs['model_lang']) inputs['datasets_result'] = asr_utils.compute_wer( - inputs['asr_result'], inputs['reference_list']) + hyp_list=inputs['asr_result'], + ref_list=inputs['reference_list']) else: raise ValueError('recog_type and audio_format are mismatching') @@ -170,8 +197,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): lines = f.readlines() for line in lines: - line_item = line.split() - item = {'key': line_item[0], 'value': line_item[1]} + line_item = line.split(None, 1) + item = {'key': line_item[0], 'value': line_item[1].strip('\n')} ref_list.append(item) return ref_list @@ -180,6 +207,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): asr_result = [] if self.framework == Frameworks.torch: from easyasr import asr_inference_paraformer_espnet + + if hasattr(asr_inference_paraformer_espnet, 'set_parameters'): + asr_inference_paraformer_espnet.set_parameters( + sample_rate=cmd['fs']) + asr_inference_paraformer_espnet.set_parameters( + language=cmd['lang']) + asr_result = asr_inference_paraformer_espnet.asr_inference( batch_size=cmd['batch_size'], maxlenratio=cmd['maxlenratio'], @@ -195,8 +229,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): asr_train_config=cmd['asr_train_config'], asr_model_file=cmd['asr_model_file'], frontend_conf=cmd['frontend_conf']) + elif self.framework == Frameworks.tf: from easyasr import asr_inference_paraformer_tf + if hasattr(asr_inference_paraformer_tf, 'set_parameters'): + asr_inference_paraformer_tf.set_parameters( + language=cmd['lang']) + else: + # in order to support easyasr-0.0.2 + cmd['fs'] = cmd['fs']['model_fs'] + asr_result = asr_inference_paraformer_tf.asr_inference( ngpu=cmd['ngpu'], name_and_type=cmd['name_and_type'], @@ -208,6 +250,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): predictions_file=cmd['predictions_file'], fs=cmd['fs'], hop_length=cmd['hop_length'], - feature_dims=cmd['feature_dims']) + feature_dims=cmd['feature_dims'], + sampled_ids=cmd['sampled_ids'], + sampled_lengths=cmd['sampled_lengths']) return asr_result diff --git a/modelscope/preprocessors/asr.py b/modelscope/preprocessors/asr.py index de0eb634..d58383d7 100644 --- a/modelscope/preprocessors/asr.py +++ b/modelscope/preprocessors/asr.py @@ -23,7 +23,8 @@ class WavToScp(Preprocessor): model: Model = None, recog_type: str = None, audio_format: str = None, - audio_in: Union[str, bytes] = None) -> Dict[str, Any]: + audio_in: Union[str, bytes] = None, + audio_fs: int = None) -> Dict[str, Any]: assert model is not None, 'preprocess model is empty' assert recog_type is not None and len( recog_type) > 0, 'preprocess recog_type is empty' @@ -32,12 +33,12 @@ class WavToScp(Preprocessor): self.am_model = model out = self.forward(self.am_model.forward(), recog_type, audio_format, - audio_in) + audio_in, audio_fs) return out - def forward(self, model: Dict[str, Any], recog_type: str, - audio_format: str, audio_in: Union[str, - bytes]) -> Dict[str, Any]: + def forward(self, model: Dict[str, + Any], recog_type: str, audio_format: str, + audio_in: Union[str, bytes], audio_fs: int) -> Dict[str, Any]: assert len(recog_type) > 0, 'preprocess recog_type is empty' assert len(audio_format) > 0, 'preprocess audio_format is empty' assert len( @@ -65,7 +66,9 @@ class WavToScp(Preprocessor): # the asr audio format setting, eg: wav, pcm, kaldi_ark, tfrecord 'audio_format': audio_format, # the recognition model config dict - 'model_config': model['model_config'] + 'model_config': model['model_config'], + # the sample rate of audio_in + 'audio_fs': audio_fs } if isinstance(audio_in, str): @@ -186,6 +189,12 @@ class WavToScp(Preprocessor): assert os.path.exists( inputs['idx_text']), 'idx text does not exist' + # set asr model language + if 'lang' in inputs['model_config']: + inputs['model_lang'] = inputs['model_config']['lang'] + else: + inputs['model_lang'] = 'zh-cn' + return inputs def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]: diff --git a/tests/pipelines/test_automatic_speech_recognition.py b/tests/pipelines/test_automatic_speech_recognition.py index 0659720a..9dad7573 100644 --- a/tests/pipelines/test_automatic_speech_recognition.py +++ b/tests/pipelines/test_automatic_speech_recognition.py @@ -98,12 +98,14 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): # remove workspace dir (.tmp) shutil.rmtree(self.workspace, ignore_errors=True) - def run_pipeline(self, model_id: str, - audio_in: Union[str, bytes]) -> Dict[str, Any]: + def run_pipeline(self, + model_id: str, + audio_in: Union[str, bytes], + sr: int = 16000) -> Dict[str, Any]: inference_16k_pipline = pipeline( task=Tasks.auto_speech_recognition, model=model_id) - rec_result = inference_16k_pipline(audio_in) + rec_result = inference_16k_pipline(audio_in, audio_fs=sr) return rec_result @@ -129,7 +131,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): else: self.log_error(functions, result) - def wav2bytes(self, wav_file) -> bytes: + def wav2bytes(self, wav_file): audio, fs = soundfile.read(wav_file) # float32 -> int16 @@ -142,7 +144,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): # int16(PCM_16) -> byte audio = audio.tobytes() - return audio + return audio, fs @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_pytorch(self): @@ -164,10 +166,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): logger.info('Run ASR test with wav data (pytorch)...') - audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) + audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) rec_result = self.run_pipeline( - model_id=self.am_pytorch_model_id, audio_in=audio) + model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr) self.check_result('test_run_with_pcm_pytorch', rec_result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -190,10 +192,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): logger.info('Run ASR test with wav data (tensorflow)...') - audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) + audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) rec_result = self.run_pipeline( - model_id=self.am_tf_model_id, audio_in=audio) + model_id=self.am_tf_model_id, audio_in=audio, sr=sr) self.check_result('test_run_with_pcm_tf', rec_result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')