Browse Source

[pipelines] add wenetruntime

master^2
pengzhendong 3 years ago
parent
commit
2e30caf1e6
4 changed files with 135 additions and 0 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +45
    -0
      modelscope/models/audio/asr/wenet_automatic_speech_recognition.py
  3. +87
    -0
      modelscope/pipelines/audio/asr_wenet_inference_pipeline.py
  4. +1
    -0
      requirements/audio.txt

+ 2
- 0
modelscope/metainfo.py View File

@@ -92,6 +92,7 @@ class Models(object):
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'
wenet_asr = 'wenet-asr'

# multi-modal models
ofa = 'ofa'
@@ -267,6 +268,7 @@ class Pipelines(object):
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
asr_inference = 'asr-inference'
asr_wenet_inference = 'asr-wenet-inference'

# multi-modal tasks
image_captioning = 'image-captioning'


+ 45
- 0
modelscope/models/audio/asr/wenet_automatic_speech_recognition.py View File

@@ -0,0 +1,45 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks

import wenetruntime as wenet

__all__ = ['WeNetAutomaticSpeechRecognition']


@MODELS.register_module(
Tasks.auto_speech_recognition, module_name=Models.wenet_asr)
class WeNetAutomaticSpeechRecognition(Model):

def __init__(self, model_dir: str, am_model_name: str,
model_config: Dict[str, Any], *args, **kwargs):
"""initialize the info of model.

Args:
model_dir (str): the model path.
am_model_name (str): the am model name from configuration.json
model_config (Dict[str, Any]): the detail config about model from configuration.json
"""
super().__init__(model_dir, am_model_name, model_config, *args,
**kwargs)
self.model_cfg = {
# the recognition model dir path
'model_dir': model_dir,
# the recognition model config dict
'model_config': model_config
}
self.decoder = None

def forward(self) -> Dict[str, Any]:
"""preload model and return the info of the model
"""
model_dir = self.model_cfg['model_dir']
self.decoder = wenet.Decoder(model_dir, lang='chs')

return self.model_cfg

+ 87
- 0
modelscope/pipelines/audio/asr_wenet_inference_pipeline.py View File

@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
load_bytes_from_url)
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['WeNetAutomaticSpeechRecognitionPipeline']


@PIPELINES.register_module(
Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference)
class WeNetAutomaticSpeechRecognitionPipeline(Pipeline):
"""ASR Inference Pipeline
"""

def __init__(self,
model: Union[Model, str] = None,
preprocessor: WavToScp = None,
**kwargs):
"""use `model` and `preprocessor` to create an asr pipeline for prediction
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model_cfg = self.model.forward()
self.decoder = self.model.decoder

def __call__(self,
audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None) -> Dict[str, Any]:
from easyasr.common import asr_utils

self.recog_type = recog_type
self.audio_format = audio_format
self.audio_fs = audio_fs

if isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
else:
self.audio_in = audio_in

# set the sample_rate of audio_in if checking_audio_fs is valid
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs

if recog_type is None or audio_format is None:
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
audio_in=self.audio_in,
recog_type=recog_type,
audio_format=audio_format)

if hasattr(asr_utils, 'sample_rate_checking'):
checking_audio_fs = asr_utils.sample_rate_checking(
self.audio_in, self.audio_format)
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs

self.model_cfg['audio'] = self.audio_in
self.model_cfg['audio_fs'] = self.audio_fs

output = self.forward(self.model_cfg)
rst = self.postprocess(output['asr_result'])
return rst

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Decoding
"""
inputs['asr_result'] = self.decoder.decode(inputs['audio'])
return inputs

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the asr results
"""
return inputs

+ 1
- 0
requirements/audio.txt View File

@@ -25,3 +25,4 @@ torchaudio
tqdm
ttsfrd>=0.0.3
unidecode
wenetruntime

Loading…
Cancel
Save