diff --git a/modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py b/modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py index 8290578a..befb7a01 100755 --- a/modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py +++ b/modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py @@ -10,7 +10,6 @@ from typing import Any, Optional, Sequence, Tuple, Union import numpy as np import torch -from espnet2.asr.frontend.default import DefaultFrontend from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer from espnet2.asr.transducer.beam_search_transducer import \ ExtendedHypothesis as ExtTransHypothesis # noqa: H301 @@ -35,6 +34,7 @@ from espnet.nets.scorers.length_bonus import LengthBonus from espnet.utils.cli_utils import get_commandline_args from typeguard import check_argument_types +from .espnet.asr.frontend.wav_frontend import WavFrontend from .espnet.tasks.asr import ASRTaskNAR as ASRTask @@ -70,7 +70,7 @@ class Speech2Text: asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) if asr_model.frontend is None and frontend_conf is not None: - frontend = DefaultFrontend(**frontend_conf) + frontend = WavFrontend(**frontend_conf) asr_model.frontend = frontend asr_model.to(dtype=getattr(torch, dtype)).eval() diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/decoder/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/decoder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/wav_frontend.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/wav_frontend.py new file mode 100644 index 00000000..1adc24f1 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/frontend/wav_frontend.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. + +import copy +from typing import Optional, Tuple, Union + +import humanfriendly +import numpy as np +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.layers.log_mel import LogMel +from espnet2.layers.stft import Stft +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet.nets.pytorch_backend.frontends.frontend import Frontend +from typeguard import check_argument_types + + +class WavFrontend(AbsFrontend): + """Conventional frontend structure for ASR. + + Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN + """ + + def __init__( + self, + fs: Union[int, str] = 16000, + n_fft: int = 512, + win_length: int = 400, + hop_length: int = 160, + window: Optional[str] = 'hamming', + center: bool = True, + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: int = None, + fmax: int = None, + htk: bool = False, + frontend_conf: Optional[dict] = get_default_kwargs(Frontend), + apply_stft: bool = True, + ): + assert check_argument_types() + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + + # Deepcopy (In general, dict shouldn't be used as default arg) + frontend_conf = copy.deepcopy(frontend_conf) + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.fs = fs + + if apply_stft: + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + center=center, + window=window, + normalized=normalized, + onesided=onesided, + ) + else: + self.stft = None + self.apply_stft = apply_stft + + if frontend_conf is not None: + self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) + else: + self.frontend = None + + self.logmel = LogMel( + fs=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + ) + self.n_mels = n_mels + self.frontend_type = 'default' + + def output_size(self) -> int: + return self.n_mels + + def forward( + self, input: torch.Tensor, + input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + sample_frequency = self.fs + num_mel_bins = self.n_mels + frame_length = self.win_length * 1000 / sample_frequency + frame_shift = self.hop_length * 1000 / sample_frequency + + waveform = input * (1 << 15) + + mat = kaldi.fbank( + waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=1.0, + energy_floor=0.0, + window_type=self.window, + sample_frequency=sample_frequency) + + input_feats = mat[None, :] + feats_lens = torch.randn(1) + feats_lens.fill_(input_feats.shape[1]) + + return input_feats, feats_lens diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/streaming_utilis/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/streaming_utilis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/cif_utils/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/cif_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr/asr_inference_pipeline.py index 4c94c1d2..20e7b6bf 100644 --- a/modelscope/pipelines/audio/asr/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr/asr_inference_pipeline.py @@ -11,8 +11,11 @@ from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import WavToScp from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger from .asr_engine.common import asr_utils +logger = get_logger() + __all__ = ['AutomaticSpeechRecognitionPipeline'] @@ -76,6 +79,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): """Decoding """ + logger.info(f"Decoding with {inputs['audio_format']} files ...") + j: int = 0 process = [] @@ -134,6 +139,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): """process the asr results """ + logger.info('Computing the result of ASR ...') + rst = {'rec_result': 'None'} # single wav task diff --git a/tests/pipelines/test_automatic_speech_recognition.py b/tests/pipelines/test_automatic_speech_recognition.py index 14d33b8f..22d1d777 100644 --- a/tests/pipelines/test_automatic_speech_recognition.py +++ b/tests/pipelines/test_automatic_speech_recognition.py @@ -8,8 +8,11 @@ import requests from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import test_level +logger = get_logger() + WAV_FILE = 'data/test/audios/asr_example.wav' LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' @@ -38,6 +41,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): '''run with single waveform file ''' + logger.info('Run ASR test with waveform file ...') + wav_file_path = os.path.join(os.getcwd(), WAV_FILE) inference_16k_pipline = pipeline( @@ -57,7 +62,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): 'rec_result': 'None' } ''' - print('test_run_with_wav rec result: ' + rec_result['rec_result']) + logger.info('test_run_with_wav rec result: ' + + rec_result['rec_result']) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_wav_dataset(self): @@ -78,6 +84,9 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): data.text # hypothesis text ''' + logger.info('Run ASR test with waveform dataset ...') + logger.info('Downloading waveform testsets file ...') + # downloading pos_testsets file testsets_file_path = os.path.join(self._workspace, LITTLE_TESTSETS_FILE) @@ -124,8 +133,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): } } ''' - print('test_run_with_wav_dataset datasets result: ') - print(rec_result['datasets_result']) + logger.info('test_run_with_wav_dataset datasets result: ') + logger.info(rec_result['datasets_result']) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_ark_dataset(self): @@ -146,6 +155,9 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): data.text ''' + logger.info('Run ASR test with ark dataset ...') + logger.info('Downloading ark testsets file ...') + # downloading pos_testsets file testsets_file_path = os.path.join(self._workspace, AISHELL1_TESTSETS_FILE) @@ -191,8 +203,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): } } ''' - print('test_run_with_ark_dataset datasets result: ') - print(rec_result['datasets_result']) + logger.info('test_run_with_ark_dataset datasets result: ') + logger.info(rec_result['datasets_result']) if __name__ == '__main__':