Browse Source

[to #42322933] fix: effect tf warmup, and add model preload to warmup when constructing pipeline

修复tf模型推理时预热无效的问题;
增加在pipeline构造时调用preload接口,以提前加载模型和预热tf
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9719779
master
shichen.fsc yingda.chen 3 years ago
parent
commit
c1dea3adf1
2 changed files with 36 additions and 8 deletions
  1. +25
    -2
      modelscope/models/audio/asr/generic_automatic_speech_recognition.py
  2. +11
    -6
      modelscope/pipelines/audio/asr_inference_pipeline.py

+ 25
- 2
modelscope/models/audio/asr/generic_automatic_speech_recognition.py View File

@@ -4,7 +4,7 @@ from typing import Any, Dict
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import Frameworks, Tasks

__all__ = ['GenericAutomaticSpeechRecognition']

@@ -36,6 +36,29 @@ class GenericAutomaticSpeechRecognition(Model):
}

def forward(self) -> Dict[str, Any]:
"""return the info of the model
"""preload model and return the info of the model
"""
if self.model_cfg['model_config']['type'] == Frameworks.tf:
from easyasr import asr_inference_paraformer_tf
if hasattr(asr_inference_paraformer_tf, 'preload'):
model_workspace = self.model_cfg['model_workspace']
model_path = os.path.join(model_workspace,
self.model_cfg['am_model'])
vocab_path = os.path.join(
model_workspace,
self.model_cfg['model_config']['vocab_file'])
sampled_ids = 'seq2seq/sampled_ids'
sampled_lengths = 'seq2seq/sampled_lengths'
if 'sampled_ids' in self.model_cfg['model_config']:
sampled_ids = self.model_cfg['model_config']['sampled_ids']
if 'sampled_lengths' in self.model_cfg['model_config']:
sampled_lengths = self.model_cfg['model_config'][
'sampled_lengths']
asr_inference_paraformer_tf.preload(
ngpu=1,
asr_model_file=model_path,
vocab_file=vocab_path,
sampled_ids=sampled_ids,
sampled_lengths=sampled_lengths)

return self.model_cfg

+ 11
- 6
modelscope/pipelines/audio/asr_inference_pipeline.py View File

@@ -30,6 +30,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
"""use `model` and `preprocessor` to create an asr pipeline for prediction
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model_cfg = self.model.forward()

def __call__(self,
audio_in: Union[str, bytes],
@@ -49,16 +50,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
recog_type=recog_type,
audio_format=audio_format)

if hasattr(asr_utils, 'sample_rate_checking'):
if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None:
self.audio_fs = asr_utils.sample_rate_checking(
self.audio_in, self.audio_format)

if self.preprocessor is None:
self.preprocessor = WavToScp()

output = self.preprocessor.forward(self.model.forward(),
self.recog_type, self.audio_format,
self.audio_in, self.audio_fs)
output = self.preprocessor.forward(self.model_cfg, self.recog_type,
self.audio_format, self.audio_in,
self.audio_fs)
output = self.forward(output)
rst = self.postprocess(output)
return rst
@@ -198,8 +199,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):

for line in lines:
line_item = line.split(None, 1)
item = {'key': line_item[0], 'value': line_item[1].strip('\n')}
ref_list.append(item)
if len(line_item) > 1:
item = {
'key': line_item[0],
'value': line_item[1].strip('\n')
}
ref_list.append(item)

return ref_list



Loading…
Cancel
Save