Browse Source

[to #42322933] Bug fix: fix asr runtime error after python setup.py install, and add logger.info to prompt decoding progress

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9358893
master
shichen.fsc 3 years ago
parent
commit
d5affa2e31
15 changed files with 139 additions and 7 deletions
  1. +2
    -2
      modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py
  2. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/__init__.py
  3. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/asr/__init__.py
  4. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/asr/decoder/__init__.py
  5. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/__init__.py
  6. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/__init__.py
  7. +113
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/wav_frontend.py
  8. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/asr/streaming_utilis/__init__.py
  9. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/nets/__init__.py
  10. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/__init__.py
  11. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/cif_utils/__init__.py
  12. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/__init__.py
  13. +0
    -0
      modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/__init__.py
  14. +7
    -0
      modelscope/pipelines/audio/asr/asr_inference_pipeline.py
  15. +17
    -5
      tests/pipelines/test_automatic_speech_recognition.py

+ 2
- 2
modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py View File

@@ -10,7 +10,6 @@ from typing import Any, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from espnet2.asr.frontend.default import DefaultFrontend
from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer
from espnet2.asr.transducer.beam_search_transducer import \
ExtendedHypothesis as ExtTransHypothesis # noqa: H301
@@ -35,6 +34,7 @@ from espnet.nets.scorers.length_bonus import LengthBonus
from espnet.utils.cli_utils import get_commandline_args
from typeguard import check_argument_types

from .espnet.asr.frontend.wav_frontend import WavFrontend
from .espnet.tasks.asr import ASRTaskNAR as ASRTask


@@ -70,7 +70,7 @@ class Speech2Text:
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, device)
if asr_model.frontend is None and frontend_conf is not None:
frontend = DefaultFrontend(**frontend_conf)
frontend = WavFrontend(**frontend_conf)
asr_model.frontend = frontend
asr_model.to(dtype=getattr(torch, dtype)).eval()



+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/decoder/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/__init__.py View File


+ 113
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/wav_frontend.py View File

@@ -0,0 +1,113 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.

import copy
from typing import Optional, Tuple, Union

import humanfriendly
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.layers.log_mel import LogMel
from espnet2.layers.stft import Stft
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet.nets.pytorch_backend.frontends.frontend import Frontend
from typeguard import check_argument_types


class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.

Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""

def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = 400,
hop_length: int = 160,
window: Optional[str] = 'hamming',
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)

# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.fs = fs

if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)
else:
self.stft = None
self.apply_stft = apply_stft

if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None

self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = 'default'

def output_size(self) -> int:
return self.n_mels

def forward(
self, input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

sample_frequency = self.fs
num_mel_bins = self.n_mels
frame_length = self.win_length * 1000 / sample_frequency
frame_shift = self.hop_length * 1000 / sample_frequency

waveform = input * (1 << 15)

mat = kaldi.fbank(
waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=1.0,
energy_floor=0.0,
window_type=self.window,
sample_frequency=sample_frequency)

input_feats = mat[None, :]
feats_lens = torch.randn(1)
feats_lens.fill_(input_feats.shape[1])

return input_feats, feats_lens

+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/streaming_utilis/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/nets/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/cif_utils/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/__init__.py View File


+ 0
- 0
modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/__init__.py View File


+ 7
- 0
modelscope/pipelines/audio/asr/asr_inference_pipeline.py View File

@@ -11,8 +11,11 @@ from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from .asr_engine.common import asr_utils

logger = get_logger()

__all__ = ['AutomaticSpeechRecognitionPipeline']


@@ -76,6 +79,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
"""Decoding
"""

logger.info(f"Decoding with {inputs['audio_format']} files ...")

j: int = 0
process = []

@@ -134,6 +139,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
"""process the asr results
"""

logger.info('Computing the result of ASR ...')

rst = {'rec_result': 'None'}

# single wav task


+ 17
- 5
tests/pipelines/test_automatic_speech_recognition.py View File

@@ -8,8 +8,11 @@ import requests

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

logger = get_logger()

WAV_FILE = 'data/test/audios/asr_example.wav'

LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
@@ -38,6 +41,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
'''run with single waveform file
'''

logger.info('Run ASR test with waveform file ...')

wav_file_path = os.path.join(os.getcwd(), WAV_FILE)

inference_16k_pipline = pipeline(
@@ -57,7 +62,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
'rec_result': 'None'
}
'''
print('test_run_with_wav rec result: ' + rec_result['rec_result'])
logger.info('test_run_with_wav rec result: '
+ rec_result['rec_result'])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset(self):
@@ -78,6 +84,9 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
data.text # hypothesis text
'''

logger.info('Run ASR test with waveform dataset ...')
logger.info('Downloading waveform testsets file ...')

# downloading pos_testsets file
testsets_file_path = os.path.join(self._workspace,
LITTLE_TESTSETS_FILE)
@@ -124,8 +133,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
}
}
'''
print('test_run_with_wav_dataset datasets result: ')
print(rec_result['datasets_result'])
logger.info('test_run_with_wav_dataset datasets result: ')
logger.info(rec_result['datasets_result'])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_ark_dataset(self):
@@ -146,6 +155,9 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
data.text
'''

logger.info('Run ASR test with ark dataset ...')
logger.info('Downloading ark testsets file ...')

# downloading pos_testsets file
testsets_file_path = os.path.join(self._workspace,
AISHELL1_TESTSETS_FILE)
@@ -191,8 +203,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
}
}
'''
print('test_run_with_ark_dataset datasets result: ')
print(rec_result['datasets_result'])
logger.info('test_run_with_ark_dataset datasets result: ')
logger.info(rec_result['datasets_result'])


if __name__ == '__main__':


Loading…
Cancel
Save