Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9358893master
| @@ -10,7 +10,6 @@ from typing import Any, Optional, Sequence, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from espnet2.asr.frontend.default import DefaultFrontend | |||||
| from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer | from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer | ||||
| from espnet2.asr.transducer.beam_search_transducer import \ | from espnet2.asr.transducer.beam_search_transducer import \ | ||||
| ExtendedHypothesis as ExtTransHypothesis # noqa: H301 | ExtendedHypothesis as ExtTransHypothesis # noqa: H301 | ||||
| @@ -35,6 +34,7 @@ from espnet.nets.scorers.length_bonus import LengthBonus | |||||
| from espnet.utils.cli_utils import get_commandline_args | from espnet.utils.cli_utils import get_commandline_args | ||||
| from typeguard import check_argument_types | from typeguard import check_argument_types | ||||
| from .espnet.asr.frontend.wav_frontend import WavFrontend | |||||
| from .espnet.tasks.asr import ASRTaskNAR as ASRTask | from .espnet.tasks.asr import ASRTaskNAR as ASRTask | ||||
| @@ -70,7 +70,7 @@ class Speech2Text: | |||||
| asr_model, asr_train_args = ASRTask.build_model_from_file( | asr_model, asr_train_args = ASRTask.build_model_from_file( | ||||
| asr_train_config, asr_model_file, device) | asr_train_config, asr_model_file, device) | ||||
| if asr_model.frontend is None and frontend_conf is not None: | if asr_model.frontend is None and frontend_conf is not None: | ||||
| frontend = DefaultFrontend(**frontend_conf) | |||||
| frontend = WavFrontend(**frontend_conf) | |||||
| asr_model.frontend = frontend | asr_model.frontend = frontend | ||||
| asr_model.to(dtype=getattr(torch, dtype)).eval() | asr_model.to(dtype=getattr(torch, dtype)).eval() | ||||
| @@ -0,0 +1,113 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| import copy | |||||
| from typing import Optional, Tuple, Union | |||||
| import humanfriendly | |||||
| import numpy as np | |||||
| import torch | |||||
| import torchaudio | |||||
| import torchaudio.compliance.kaldi as kaldi | |||||
| from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||||
| from espnet2.layers.log_mel import LogMel | |||||
| from espnet2.layers.stft import Stft | |||||
| from espnet2.utils.get_default_kwargs import get_default_kwargs | |||||
| from espnet.nets.pytorch_backend.frontends.frontend import Frontend | |||||
| from typeguard import check_argument_types | |||||
| class WavFrontend(AbsFrontend): | |||||
| """Conventional frontend structure for ASR. | |||||
| Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| fs: Union[int, str] = 16000, | |||||
| n_fft: int = 512, | |||||
| win_length: int = 400, | |||||
| hop_length: int = 160, | |||||
| window: Optional[str] = 'hamming', | |||||
| center: bool = True, | |||||
| normalized: bool = False, | |||||
| onesided: bool = True, | |||||
| n_mels: int = 80, | |||||
| fmin: int = None, | |||||
| fmax: int = None, | |||||
| htk: bool = False, | |||||
| frontend_conf: Optional[dict] = get_default_kwargs(Frontend), | |||||
| apply_stft: bool = True, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__() | |||||
| if isinstance(fs, str): | |||||
| fs = humanfriendly.parse_size(fs) | |||||
| # Deepcopy (In general, dict shouldn't be used as default arg) | |||||
| frontend_conf = copy.deepcopy(frontend_conf) | |||||
| self.hop_length = hop_length | |||||
| self.win_length = win_length | |||||
| self.window = window | |||||
| self.fs = fs | |||||
| if apply_stft: | |||||
| self.stft = Stft( | |||||
| n_fft=n_fft, | |||||
| win_length=win_length, | |||||
| hop_length=hop_length, | |||||
| center=center, | |||||
| window=window, | |||||
| normalized=normalized, | |||||
| onesided=onesided, | |||||
| ) | |||||
| else: | |||||
| self.stft = None | |||||
| self.apply_stft = apply_stft | |||||
| if frontend_conf is not None: | |||||
| self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) | |||||
| else: | |||||
| self.frontend = None | |||||
| self.logmel = LogMel( | |||||
| fs=fs, | |||||
| n_fft=n_fft, | |||||
| n_mels=n_mels, | |||||
| fmin=fmin, | |||||
| fmax=fmax, | |||||
| htk=htk, | |||||
| ) | |||||
| self.n_mels = n_mels | |||||
| self.frontend_type = 'default' | |||||
| def output_size(self) -> int: | |||||
| return self.n_mels | |||||
| def forward( | |||||
| self, input: torch.Tensor, | |||||
| input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| sample_frequency = self.fs | |||||
| num_mel_bins = self.n_mels | |||||
| frame_length = self.win_length * 1000 / sample_frequency | |||||
| frame_shift = self.hop_length * 1000 / sample_frequency | |||||
| waveform = input * (1 << 15) | |||||
| mat = kaldi.fbank( | |||||
| waveform, | |||||
| num_mel_bins=num_mel_bins, | |||||
| frame_length=frame_length, | |||||
| frame_shift=frame_shift, | |||||
| dither=1.0, | |||||
| energy_floor=0.0, | |||||
| window_type=self.window, | |||||
| sample_frequency=sample_frequency) | |||||
| input_feats = mat[None, :] | |||||
| feats_lens = torch.randn(1) | |||||
| feats_lens.fill_(input_feats.shape[1]) | |||||
| return input_feats, feats_lens | |||||
| @@ -11,8 +11,11 @@ from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import WavToScp | from modelscope.preprocessors import WavToScp | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | |||||
| from .asr_engine.common import asr_utils | from .asr_engine.common import asr_utils | ||||
| logger = get_logger() | |||||
| __all__ = ['AutomaticSpeechRecognitionPipeline'] | __all__ = ['AutomaticSpeechRecognitionPipeline'] | ||||
| @@ -76,6 +79,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| """Decoding | """Decoding | ||||
| """ | """ | ||||
| logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||||
| j: int = 0 | j: int = 0 | ||||
| process = [] | process = [] | ||||
| @@ -134,6 +139,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| """process the asr results | """process the asr results | ||||
| """ | """ | ||||
| logger.info('Computing the result of ASR ...') | |||||
| rst = {'rec_result': 'None'} | rst = {'rec_result': 'None'} | ||||
| # single wav task | # single wav task | ||||
| @@ -8,8 +8,11 @@ import requests | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| logger = get_logger() | |||||
| WAV_FILE = 'data/test/audios/asr_example.wav' | WAV_FILE = 'data/test/audios/asr_example.wav' | ||||
| LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' | LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' | ||||
| @@ -38,6 +41,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
| '''run with single waveform file | '''run with single waveform file | ||||
| ''' | ''' | ||||
| logger.info('Run ASR test with waveform file ...') | |||||
| wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | ||||
| inference_16k_pipline = pipeline( | inference_16k_pipline = pipeline( | ||||
| @@ -57,7 +62,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
| 'rec_result': 'None' | 'rec_result': 'None' | ||||
| } | } | ||||
| ''' | ''' | ||||
| print('test_run_with_wav rec result: ' + rec_result['rec_result']) | |||||
| logger.info('test_run_with_wav rec result: ' | |||||
| + rec_result['rec_result']) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_wav_dataset(self): | def test_run_with_wav_dataset(self): | ||||
| @@ -78,6 +84,9 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
| data.text # hypothesis text | data.text # hypothesis text | ||||
| ''' | ''' | ||||
| logger.info('Run ASR test with waveform dataset ...') | |||||
| logger.info('Downloading waveform testsets file ...') | |||||
| # downloading pos_testsets file | # downloading pos_testsets file | ||||
| testsets_file_path = os.path.join(self._workspace, | testsets_file_path = os.path.join(self._workspace, | ||||
| LITTLE_TESTSETS_FILE) | LITTLE_TESTSETS_FILE) | ||||
| @@ -124,8 +133,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
| } | } | ||||
| } | } | ||||
| ''' | ''' | ||||
| print('test_run_with_wav_dataset datasets result: ') | |||||
| print(rec_result['datasets_result']) | |||||
| logger.info('test_run_with_wav_dataset datasets result: ') | |||||
| logger.info(rec_result['datasets_result']) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_ark_dataset(self): | def test_run_with_ark_dataset(self): | ||||
| @@ -146,6 +155,9 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
| data.text | data.text | ||||
| ''' | ''' | ||||
| logger.info('Run ASR test with ark dataset ...') | |||||
| logger.info('Downloading ark testsets file ...') | |||||
| # downloading pos_testsets file | # downloading pos_testsets file | ||||
| testsets_file_path = os.path.join(self._workspace, | testsets_file_path = os.path.join(self._workspace, | ||||
| AISHELL1_TESTSETS_FILE) | AISHELL1_TESTSETS_FILE) | ||||
| @@ -191,8 +203,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
| } | } | ||||
| } | } | ||||
| ''' | ''' | ||||
| print('test_run_with_ark_dataset datasets result: ') | |||||
| print(rec_result['datasets_result']) | |||||
| logger.info('test_run_with_ark_dataset datasets result: ') | |||||
| logger.info(rec_result['datasets_result']) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||