Browse Source

[to #42322933] fix: effect tf warmup, and add model preload to warmup when constructing pipeline

修复tf模型推理时预热无效的问题;
增加在pipeline构造时调用preload接口,以提前加载模型和预热tf
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9719779
master
shichen.fsc yingda.chen 3 years ago
parent
commit
c1dea3adf1
2 changed files with 36 additions and 8 deletions
  1. +25
    -2
      modelscope/models/audio/asr/generic_automatic_speech_recognition.py
  2. +11
    -6
      modelscope/pipelines/audio/asr_inference_pipeline.py

+ 25
- 2
modelscope/models/audio/asr/generic_automatic_speech_recognition.py View File

@@ -4,7 +4,7 @@ from typing import Any, Dict
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import Frameworks, Tasks


__all__ = ['GenericAutomaticSpeechRecognition'] __all__ = ['GenericAutomaticSpeechRecognition']


@@ -36,6 +36,29 @@ class GenericAutomaticSpeechRecognition(Model):
} }


def forward(self) -> Dict[str, Any]: def forward(self) -> Dict[str, Any]:
"""return the info of the model
"""preload model and return the info of the model
""" """
if self.model_cfg['model_config']['type'] == Frameworks.tf:
from easyasr import asr_inference_paraformer_tf
if hasattr(asr_inference_paraformer_tf, 'preload'):
model_workspace = self.model_cfg['model_workspace']
model_path = os.path.join(model_workspace,
self.model_cfg['am_model'])
vocab_path = os.path.join(
model_workspace,
self.model_cfg['model_config']['vocab_file'])
sampled_ids = 'seq2seq/sampled_ids'
sampled_lengths = 'seq2seq/sampled_lengths'
if 'sampled_ids' in self.model_cfg['model_config']:
sampled_ids = self.model_cfg['model_config']['sampled_ids']
if 'sampled_lengths' in self.model_cfg['model_config']:
sampled_lengths = self.model_cfg['model_config'][
'sampled_lengths']
asr_inference_paraformer_tf.preload(
ngpu=1,
asr_model_file=model_path,
vocab_file=vocab_path,
sampled_ids=sampled_ids,
sampled_lengths=sampled_lengths)

return self.model_cfg return self.model_cfg

+ 11
- 6
modelscope/pipelines/audio/asr_inference_pipeline.py View File

@@ -30,6 +30,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
"""use `model` and `preprocessor` to create an asr pipeline for prediction """use `model` and `preprocessor` to create an asr pipeline for prediction
""" """
super().__init__(model=model, preprocessor=preprocessor, **kwargs) super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model_cfg = self.model.forward()


def __call__(self, def __call__(self,
audio_in: Union[str, bytes], audio_in: Union[str, bytes],
@@ -49,16 +50,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
recog_type=recog_type, recog_type=recog_type,
audio_format=audio_format) audio_format=audio_format)


if hasattr(asr_utils, 'sample_rate_checking'):
if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None:
self.audio_fs = asr_utils.sample_rate_checking( self.audio_fs = asr_utils.sample_rate_checking(
self.audio_in, self.audio_format) self.audio_in, self.audio_format)


if self.preprocessor is None: if self.preprocessor is None:
self.preprocessor = WavToScp() self.preprocessor = WavToScp()


output = self.preprocessor.forward(self.model.forward(),
self.recog_type, self.audio_format,
self.audio_in, self.audio_fs)
output = self.preprocessor.forward(self.model_cfg, self.recog_type,
self.audio_format, self.audio_in,
self.audio_fs)
output = self.forward(output) output = self.forward(output)
rst = self.postprocess(output) rst = self.postprocess(output)
return rst return rst
@@ -198,8 +199,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):


for line in lines: for line in lines:
line_item = line.split(None, 1) line_item = line.split(None, 1)
item = {'key': line_item[0], 'value': line_item[1].strip('\n')}
ref_list.append(item)
if len(line_item) > 1:
item = {
'key': line_item[0],
'value': line_item[1].strip('\n')
}
ref_list.append(item)


return ref_list return ref_list




Loading…
Cancel
Save