Browse Source

[to #42322933] add foreign language supported, and audio data resample -- asr inference

asr推理增加对其他外文的支持,包括计算wer。
增加对音频重采样,根据传入音频的采样率和当前模型支持的采样率,在easyasr内部完成重采样。注意,输入数据为pcm时,需要同时传入pcm的sample rate,否则当成16K。
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9580609
master
shichen.fsc yingda.chen 3 years ago
parent
commit
e134254c6c
4 changed files with 84 additions and 23 deletions
  1. +6
    -0
      modelscope/outputs.py
  2. +52
    -8
      modelscope/pipelines/audio/asr_inference_pipeline.py
  3. +15
    -6
      modelscope/preprocessors/asr.py
  4. +11
    -9
      tests/pipelines/test_automatic_speech_recognition.py

+ 6
- 0
modelscope/outputs.py View File

@@ -428,6 +428,12 @@ TASK_OUTPUTS = {
# {"text": "this is a text answser. "}
Tasks.visual_question_answering: [OutputKeys.TEXT],

# auto_speech_recognition result for a single sample
# {
# "text": "每天都要快乐喔"
# }
Tasks.auto_speech_recognition: [OutputKeys.TEXT],

# {
# "scores": [0.9, 0.1, 0.1],
# "labels": ["entailment", "contradiction", "neutral"]


+ 52
- 8
modelscope/pipelines/audio/asr_inference_pipeline.py View File

@@ -33,6 +33,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):

def __call__(self,
audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None) -> Dict[str, Any]:
from easyasr.common import asr_utils
@@ -40,17 +41,24 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.recog_type = recog_type
self.audio_format = audio_format
self.audio_in = audio_in
self.audio_fs = audio_fs

if recog_type is None or audio_format is None:
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
audio_in, recog_type, audio_format)
audio_in=audio_in,
recog_type=recog_type,
audio_format=audio_format)

if hasattr(asr_utils, 'sample_rate_checking'):
self.audio_fs = asr_utils.sample_rate_checking(
self.audio_in, self.audio_format)

if self.preprocessor is None:
self.preprocessor = WavToScp()

output = self.preprocessor.forward(self.model.forward(),
self.recog_type, self.audio_format,
self.audio_in)
self.audio_in, self.audio_fs)
output = self.forward(output)
rst = self.postprocess(output)
return rst
@@ -77,7 +85,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
'audio_in': inputs['audio_lists'],
'name_and_type': data_cmd,
'asr_model_file': inputs['am_model_path'],
'idx_text': ''
'idx_text': '',
'sampled_ids': 'seq2seq/sampled_ids',
'sampled_lengths': 'seq2seq/sampled_lengths',
'lang': 'zh-cn',
'fs': {
'audio_fs': inputs['audio_fs'],
'model_fs': 16000
}
}

if self.framework == Frameworks.torch:
@@ -97,16 +112,24 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
cmd['asr_train_config'] = inputs['am_model_config']
cmd['batch_size'] = inputs['model_config']['batch_size']
cmd['frontend_conf'] = frontend_conf
if frontend_conf is not None and 'fs' in frontend_conf:
cmd['fs']['model_fs'] = frontend_conf['fs']

elif self.framework == Frameworks.tf:
cmd['fs'] = inputs['model_config']['fs']
cmd['fs']['model_fs'] = inputs['model_config']['fs']
cmd['hop_length'] = inputs['model_config']['hop_length']
cmd['feature_dims'] = inputs['model_config']['feature_dims']
cmd['predictions_file'] = 'text'
cmd['mvn_file'] = inputs['am_mvn_file']
cmd['vocab_file'] = inputs['vocab_file']
cmd['lang'] = inputs['model_lang']
if 'idx_text' in inputs:
cmd['idx_text'] = inputs['idx_text']
if 'sampled_ids' in inputs['model_config']:
cmd['sampled_ids'] = inputs['model_config']['sampled_ids']
if 'sampled_lengths' in inputs['model_config']:
cmd['sampled_lengths'] = inputs['model_config'][
'sampled_lengths']

else:
raise ValueError('model type is mismatching')
@@ -134,8 +157,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord
elif inputs['recog_type'] != 'wav':
inputs['reference_list'] = self.ref_list_tidy(inputs)

if hasattr(asr_utils, 'set_parameters'):
asr_utils.set_parameters(language=inputs['model_lang'])
inputs['datasets_result'] = asr_utils.compute_wer(
inputs['asr_result'], inputs['reference_list'])
hyp_list=inputs['asr_result'],
ref_list=inputs['reference_list'])

else:
raise ValueError('recog_type and audio_format are mismatching')
@@ -170,8 +197,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
lines = f.readlines()

for line in lines:
line_item = line.split()
item = {'key': line_item[0], 'value': line_item[1]}
line_item = line.split(None, 1)
item = {'key': line_item[0], 'value': line_item[1].strip('\n')}
ref_list.append(item)

return ref_list
@@ -180,6 +207,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
asr_result = []
if self.framework == Frameworks.torch:
from easyasr import asr_inference_paraformer_espnet

if hasattr(asr_inference_paraformer_espnet, 'set_parameters'):
asr_inference_paraformer_espnet.set_parameters(
sample_rate=cmd['fs'])
asr_inference_paraformer_espnet.set_parameters(
language=cmd['lang'])

asr_result = asr_inference_paraformer_espnet.asr_inference(
batch_size=cmd['batch_size'],
maxlenratio=cmd['maxlenratio'],
@@ -195,8 +229,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
asr_train_config=cmd['asr_train_config'],
asr_model_file=cmd['asr_model_file'],
frontend_conf=cmd['frontend_conf'])

elif self.framework == Frameworks.tf:
from easyasr import asr_inference_paraformer_tf
if hasattr(asr_inference_paraformer_tf, 'set_parameters'):
asr_inference_paraformer_tf.set_parameters(
language=cmd['lang'])
else:
# in order to support easyasr-0.0.2
cmd['fs'] = cmd['fs']['model_fs']

asr_result = asr_inference_paraformer_tf.asr_inference(
ngpu=cmd['ngpu'],
name_and_type=cmd['name_and_type'],
@@ -208,6 +250,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
predictions_file=cmd['predictions_file'],
fs=cmd['fs'],
hop_length=cmd['hop_length'],
feature_dims=cmd['feature_dims'])
feature_dims=cmd['feature_dims'],
sampled_ids=cmd['sampled_ids'],
sampled_lengths=cmd['sampled_lengths'])

return asr_result

+ 15
- 6
modelscope/preprocessors/asr.py View File

@@ -23,7 +23,8 @@ class WavToScp(Preprocessor):
model: Model = None,
recog_type: str = None,
audio_format: str = None,
audio_in: Union[str, bytes] = None) -> Dict[str, Any]:
audio_in: Union[str, bytes] = None,
audio_fs: int = None) -> Dict[str, Any]:
assert model is not None, 'preprocess model is empty'
assert recog_type is not None and len(
recog_type) > 0, 'preprocess recog_type is empty'
@@ -32,12 +33,12 @@ class WavToScp(Preprocessor):

self.am_model = model
out = self.forward(self.am_model.forward(), recog_type, audio_format,
audio_in)
audio_in, audio_fs)
return out

def forward(self, model: Dict[str, Any], recog_type: str,
audio_format: str, audio_in: Union[str,
bytes]) -> Dict[str, Any]:
def forward(self, model: Dict[str,
Any], recog_type: str, audio_format: str,
audio_in: Union[str, bytes], audio_fs: int) -> Dict[str, Any]:
assert len(recog_type) > 0, 'preprocess recog_type is empty'
assert len(audio_format) > 0, 'preprocess audio_format is empty'
assert len(
@@ -65,7 +66,9 @@ class WavToScp(Preprocessor):
# the asr audio format setting, eg: wav, pcm, kaldi_ark, tfrecord
'audio_format': audio_format,
# the recognition model config dict
'model_config': model['model_config']
'model_config': model['model_config'],
# the sample rate of audio_in
'audio_fs': audio_fs
}

if isinstance(audio_in, str):
@@ -186,6 +189,12 @@ class WavToScp(Preprocessor):
assert os.path.exists(
inputs['idx_text']), 'idx text does not exist'

# set asr model language
if 'lang' in inputs['model_config']:
inputs['model_lang'] = inputs['model_config']['lang']
else:
inputs['model_lang'] = 'zh-cn'

return inputs

def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]:


+ 11
- 9
tests/pipelines/test_automatic_speech_recognition.py View File

@@ -98,12 +98,14 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
# remove workspace dir (.tmp)
shutil.rmtree(self.workspace, ignore_errors=True)

def run_pipeline(self, model_id: str,
audio_in: Union[str, bytes]) -> Dict[str, Any]:
def run_pipeline(self,
model_id: str,
audio_in: Union[str, bytes],
sr: int = 16000) -> Dict[str, Any]:
inference_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=model_id)

rec_result = inference_16k_pipline(audio_in)
rec_result = inference_16k_pipline(audio_in, audio_fs=sr)

return rec_result

@@ -129,7 +131,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
else:
self.log_error(functions, result)

def wav2bytes(self, wav_file) -> bytes:
def wav2bytes(self, wav_file):
audio, fs = soundfile.read(wav_file)

# float32 -> int16
@@ -142,7 +144,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

# int16(PCM_16) -> byte
audio = audio.tobytes()
return audio
return audio, fs

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_pytorch(self):
@@ -164,10 +166,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

logger.info('Run ASR test with wav data (pytorch)...')

audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))

rec_result = self.run_pipeline(
model_id=self.am_pytorch_model_id, audio_in=audio)
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm_pytorch', rec_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -190,10 +192,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

logger.info('Run ASR test with wav data (tensorflow)...')

audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))

rec_result = self.run_pipeline(
model_id=self.am_tf_model_id, audio_in=audio)
model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm_tf', rec_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


Loading…
Cancel
Save