|
|
@@ -33,6 +33,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
|
|
|
|
|
|
def __call__(self, |
|
|
def __call__(self, |
|
|
audio_in: Union[str, bytes], |
|
|
audio_in: Union[str, bytes], |
|
|
|
|
|
audio_fs: int = None, |
|
|
recog_type: str = None, |
|
|
recog_type: str = None, |
|
|
audio_format: str = None) -> Dict[str, Any]: |
|
|
audio_format: str = None) -> Dict[str, Any]: |
|
|
from easyasr.common import asr_utils |
|
|
from easyasr.common import asr_utils |
|
|
@@ -40,17 +41,24 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
self.recog_type = recog_type |
|
|
self.recog_type = recog_type |
|
|
self.audio_format = audio_format |
|
|
self.audio_format = audio_format |
|
|
self.audio_in = audio_in |
|
|
self.audio_in = audio_in |
|
|
|
|
|
self.audio_fs = audio_fs |
|
|
|
|
|
|
|
|
if recog_type is None or audio_format is None: |
|
|
if recog_type is None or audio_format is None: |
|
|
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( |
|
|
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( |
|
|
audio_in, recog_type, audio_format) |
|
|
|
|
|
|
|
|
audio_in=audio_in, |
|
|
|
|
|
recog_type=recog_type, |
|
|
|
|
|
audio_format=audio_format) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(asr_utils, 'sample_rate_checking'): |
|
|
|
|
|
self.audio_fs = asr_utils.sample_rate_checking( |
|
|
|
|
|
self.audio_in, self.audio_format) |
|
|
|
|
|
|
|
|
if self.preprocessor is None: |
|
|
if self.preprocessor is None: |
|
|
self.preprocessor = WavToScp() |
|
|
self.preprocessor = WavToScp() |
|
|
|
|
|
|
|
|
output = self.preprocessor.forward(self.model.forward(), |
|
|
output = self.preprocessor.forward(self.model.forward(), |
|
|
self.recog_type, self.audio_format, |
|
|
self.recog_type, self.audio_format, |
|
|
self.audio_in) |
|
|
|
|
|
|
|
|
self.audio_in, self.audio_fs) |
|
|
output = self.forward(output) |
|
|
output = self.forward(output) |
|
|
rst = self.postprocess(output) |
|
|
rst = self.postprocess(output) |
|
|
return rst |
|
|
return rst |
|
|
@@ -77,7 +85,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
'audio_in': inputs['audio_lists'], |
|
|
'audio_in': inputs['audio_lists'], |
|
|
'name_and_type': data_cmd, |
|
|
'name_and_type': data_cmd, |
|
|
'asr_model_file': inputs['am_model_path'], |
|
|
'asr_model_file': inputs['am_model_path'], |
|
|
'idx_text': '' |
|
|
|
|
|
|
|
|
'idx_text': '', |
|
|
|
|
|
'sampled_ids': 'seq2seq/sampled_ids', |
|
|
|
|
|
'sampled_lengths': 'seq2seq/sampled_lengths', |
|
|
|
|
|
'lang': 'zh-cn', |
|
|
|
|
|
'fs': { |
|
|
|
|
|
'audio_fs': inputs['audio_fs'], |
|
|
|
|
|
'model_fs': 16000 |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if self.framework == Frameworks.torch: |
|
|
if self.framework == Frameworks.torch: |
|
|
@@ -97,16 +112,24 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
cmd['asr_train_config'] = inputs['am_model_config'] |
|
|
cmd['asr_train_config'] = inputs['am_model_config'] |
|
|
cmd['batch_size'] = inputs['model_config']['batch_size'] |
|
|
cmd['batch_size'] = inputs['model_config']['batch_size'] |
|
|
cmd['frontend_conf'] = frontend_conf |
|
|
cmd['frontend_conf'] = frontend_conf |
|
|
|
|
|
if frontend_conf is not None and 'fs' in frontend_conf: |
|
|
|
|
|
cmd['fs']['model_fs'] = frontend_conf['fs'] |
|
|
|
|
|
|
|
|
elif self.framework == Frameworks.tf: |
|
|
elif self.framework == Frameworks.tf: |
|
|
cmd['fs'] = inputs['model_config']['fs'] |
|
|
|
|
|
|
|
|
cmd['fs']['model_fs'] = inputs['model_config']['fs'] |
|
|
cmd['hop_length'] = inputs['model_config']['hop_length'] |
|
|
cmd['hop_length'] = inputs['model_config']['hop_length'] |
|
|
cmd['feature_dims'] = inputs['model_config']['feature_dims'] |
|
|
cmd['feature_dims'] = inputs['model_config']['feature_dims'] |
|
|
cmd['predictions_file'] = 'text' |
|
|
cmd['predictions_file'] = 'text' |
|
|
cmd['mvn_file'] = inputs['am_mvn_file'] |
|
|
cmd['mvn_file'] = inputs['am_mvn_file'] |
|
|
cmd['vocab_file'] = inputs['vocab_file'] |
|
|
cmd['vocab_file'] = inputs['vocab_file'] |
|
|
|
|
|
cmd['lang'] = inputs['model_lang'] |
|
|
if 'idx_text' in inputs: |
|
|
if 'idx_text' in inputs: |
|
|
cmd['idx_text'] = inputs['idx_text'] |
|
|
cmd['idx_text'] = inputs['idx_text'] |
|
|
|
|
|
if 'sampled_ids' in inputs['model_config']: |
|
|
|
|
|
cmd['sampled_ids'] = inputs['model_config']['sampled_ids'] |
|
|
|
|
|
if 'sampled_lengths' in inputs['model_config']: |
|
|
|
|
|
cmd['sampled_lengths'] = inputs['model_config'][ |
|
|
|
|
|
'sampled_lengths'] |
|
|
|
|
|
|
|
|
else: |
|
|
else: |
|
|
raise ValueError('model type is mismatching') |
|
|
raise ValueError('model type is mismatching') |
|
|
@@ -134,8 +157,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord |
|
|
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord |
|
|
elif inputs['recog_type'] != 'wav': |
|
|
elif inputs['recog_type'] != 'wav': |
|
|
inputs['reference_list'] = self.ref_list_tidy(inputs) |
|
|
inputs['reference_list'] = self.ref_list_tidy(inputs) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(asr_utils, 'set_parameters'): |
|
|
|
|
|
asr_utils.set_parameters(language=inputs['model_lang']) |
|
|
inputs['datasets_result'] = asr_utils.compute_wer( |
|
|
inputs['datasets_result'] = asr_utils.compute_wer( |
|
|
inputs['asr_result'], inputs['reference_list']) |
|
|
|
|
|
|
|
|
hyp_list=inputs['asr_result'], |
|
|
|
|
|
ref_list=inputs['reference_list']) |
|
|
|
|
|
|
|
|
else: |
|
|
else: |
|
|
raise ValueError('recog_type and audio_format are mismatching') |
|
|
raise ValueError('recog_type and audio_format are mismatching') |
|
|
@@ -170,8 +197,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
lines = f.readlines() |
|
|
lines = f.readlines() |
|
|
|
|
|
|
|
|
for line in lines: |
|
|
for line in lines: |
|
|
line_item = line.split() |
|
|
|
|
|
item = {'key': line_item[0], 'value': line_item[1]} |
|
|
|
|
|
|
|
|
line_item = line.split(None, 1) |
|
|
|
|
|
item = {'key': line_item[0], 'value': line_item[1].strip('\n')} |
|
|
ref_list.append(item) |
|
|
ref_list.append(item) |
|
|
|
|
|
|
|
|
return ref_list |
|
|
return ref_list |
|
|
@@ -180,6 +207,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
asr_result = [] |
|
|
asr_result = [] |
|
|
if self.framework == Frameworks.torch: |
|
|
if self.framework == Frameworks.torch: |
|
|
from easyasr import asr_inference_paraformer_espnet |
|
|
from easyasr import asr_inference_paraformer_espnet |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(asr_inference_paraformer_espnet, 'set_parameters'): |
|
|
|
|
|
asr_inference_paraformer_espnet.set_parameters( |
|
|
|
|
|
sample_rate=cmd['fs']) |
|
|
|
|
|
asr_inference_paraformer_espnet.set_parameters( |
|
|
|
|
|
language=cmd['lang']) |
|
|
|
|
|
|
|
|
asr_result = asr_inference_paraformer_espnet.asr_inference( |
|
|
asr_result = asr_inference_paraformer_espnet.asr_inference( |
|
|
batch_size=cmd['batch_size'], |
|
|
batch_size=cmd['batch_size'], |
|
|
maxlenratio=cmd['maxlenratio'], |
|
|
maxlenratio=cmd['maxlenratio'], |
|
|
@@ -195,8 +229,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
asr_train_config=cmd['asr_train_config'], |
|
|
asr_train_config=cmd['asr_train_config'], |
|
|
asr_model_file=cmd['asr_model_file'], |
|
|
asr_model_file=cmd['asr_model_file'], |
|
|
frontend_conf=cmd['frontend_conf']) |
|
|
frontend_conf=cmd['frontend_conf']) |
|
|
|
|
|
|
|
|
elif self.framework == Frameworks.tf: |
|
|
elif self.framework == Frameworks.tf: |
|
|
from easyasr import asr_inference_paraformer_tf |
|
|
from easyasr import asr_inference_paraformer_tf |
|
|
|
|
|
if hasattr(asr_inference_paraformer_tf, 'set_parameters'): |
|
|
|
|
|
asr_inference_paraformer_tf.set_parameters( |
|
|
|
|
|
language=cmd['lang']) |
|
|
|
|
|
else: |
|
|
|
|
|
# in order to support easyasr-0.0.2 |
|
|
|
|
|
cmd['fs'] = cmd['fs']['model_fs'] |
|
|
|
|
|
|
|
|
asr_result = asr_inference_paraformer_tf.asr_inference( |
|
|
asr_result = asr_inference_paraformer_tf.asr_inference( |
|
|
ngpu=cmd['ngpu'], |
|
|
ngpu=cmd['ngpu'], |
|
|
name_and_type=cmd['name_and_type'], |
|
|
name_and_type=cmd['name_and_type'], |
|
|
@@ -208,6 +250,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): |
|
|
predictions_file=cmd['predictions_file'], |
|
|
predictions_file=cmd['predictions_file'], |
|
|
fs=cmd['fs'], |
|
|
fs=cmd['fs'], |
|
|
hop_length=cmd['hop_length'], |
|
|
hop_length=cmd['hop_length'], |
|
|
feature_dims=cmd['feature_dims']) |
|
|
|
|
|
|
|
|
feature_dims=cmd['feature_dims'], |
|
|
|
|
|
sampled_ids=cmd['sampled_ids'], |
|
|
|
|
|
sampled_lengths=cmd['sampled_lengths']) |
|
|
|
|
|
|
|
|
return asr_result |
|
|
return asr_result |