diff --git a/data/test/audios/asr_example.wav b/data/test/audios/asr_example.wav new file mode 100644 index 00000000..5c61b555 --- /dev/null +++ b/data/test/audios/asr_example.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87bde7feb3b40d75dec27e5824dd1077911f867e3f125c4bf603ec0af954d4db +size 77864 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a1dbc95e..f03c0dab 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -23,6 +23,7 @@ class Models(object): sambert_hifigan = 'sambert-hifigan' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' + generic_asr = 'generic-asr' # multi-modal models ofa = 'ofa' @@ -68,6 +69,7 @@ class Pipelines(object): speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' + asr_inference = 'asr-inference' # multi-modal tasks image_caption = 'image-captioning' @@ -120,6 +122,7 @@ class Preprocessors(object): linear_aec_fbank = 'linear-aec-fbank' text_to_tacotron_symbols = 'text-to-tacotron-symbols' wav_to_lists = 'wav-to-lists' + wav_to_scp = 'wav-to-scp' # multi-modal ofa_image_caption = 'ofa-image-caption' diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index b5913d2c..4767657a 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -5,6 +5,7 @@ from .base import Model from .builder import MODELS, build_model try: + from .audio.asr import GenericAutomaticSpeechRecognition from .audio.tts import SambertHifigan from .audio.kws import GenericKeyWordSpotting from .audio.ans.frcrn import FRCRNModel diff --git a/modelscope/models/audio/asr/__init__.py b/modelscope/models/audio/asr/__init__.py new file mode 100644 index 00000000..08dfa27d --- /dev/null +++ b/modelscope/models/audio/asr/__init__.py @@ -0,0 +1 @@ +from .generic_automatic_speech_recognition import * # noqa F403 diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py new file mode 100644 index 00000000..b057a8b7 --- /dev/null +++ b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py @@ -0,0 +1,39 @@ +import os +from typing import Any, Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +__all__ = ['GenericAutomaticSpeechRecognition'] + + +@MODELS.register_module( + Tasks.auto_speech_recognition, module_name=Models.generic_asr) +class GenericAutomaticSpeechRecognition(Model): + + def __init__(self, model_dir: str, am_model_name: str, + model_config: Dict[str, Any], *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + am_model_name (str): the am model name from configuration.json + """ + + self.model_cfg = { + # the recognition model dir path + 'model_workspace': model_dir, + # the am model name + 'am_model': am_model_name, + # the am model file path + 'am_model_path': os.path.join(model_dir, am_model_name), + # the recognition model config dict + 'model_config': model_config + } + + def forward(self) -> Dict[str, Any]: + """return the info of the model + """ + return self.model_cfg diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index 80c03c23..84a593b8 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -3,6 +3,7 @@ from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR try: + from .asr.asr_inference_pipeline import AutomaticSpeechRecognitionPipeline from .kws_kwsbp_pipeline import * # noqa F403 from .linear_aec_pipeline import LinearAECPipeline except ModuleNotFoundError as e: diff --git a/modelscope/pipelines/audio/asr/__init__.py b/modelscope/pipelines/audio/asr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/asr_env_checking.py b/modelscope/pipelines/audio/asr/asr_engine/asr_env_checking.py new file mode 100644 index 00000000..9d9ba3a1 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/asr_env_checking.py @@ -0,0 +1,12 @@ +import nltk + +try: + nltk.data.find('taggers/averaged_perceptron_tagger') +except LookupError: + nltk.download( + 'averaged_perceptron_tagger', halt_on_error=False, raise_on_error=True) + +try: + nltk.data.find('corpora/cmudict') +except LookupError: + nltk.download('cmudict', halt_on_error=False, raise_on_error=True) diff --git a/modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py b/modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py new file mode 100755 index 00000000..06e79afa --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py @@ -0,0 +1,690 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. + +import argparse +import logging +import sys +import time +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from espnet2.asr.frontend.default import DefaultFrontend +from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer +from espnet2.asr.transducer.beam_search_transducer import \ + ExtendedHypothesis as ExtTransHypothesis # noqa: H301 +from espnet2.asr.transducer.beam_search_transducer import \ + Hypothesis as TransHypothesis +from espnet2.fileio.datadir_writer import DatadirWriter +from espnet2.tasks.lm import LMTask +from espnet2.text.build_tokenizer import build_tokenizer +from espnet2.text.token_id_converter import TokenIDConverter +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.utils import config_argparse +from espnet2.utils.types import str2bool, str2triple_str, str_or_none +from espnet.nets.batch_beam_search import BatchBeamSearch +from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim +from espnet.nets.beam_search import BeamSearch, Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import \ + TooShortUttError +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.ctc import CTCPrefixScorer +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.cli_utils import get_commandline_args +from typeguard import check_argument_types, check_return_type + +from .espnet.tasks.asr import ASRTaskNAR as ASRTask + + +class Speech2Text: + + def __init__(self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + transducer_conf: dict = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + ngram_scorer: str = 'full', + ngram_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = 'cpu', + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = 'float32', + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + frontend_conf: dict = None): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, device) + if asr_model.frontend is None and frontend_conf is not None: + frontend = DefaultFrontend(**frontend_conf) + asr_model.frontend = frontend + asr_model.to(dtype=getattr(torch, dtype)).eval() + + decoder = asr_model.decoder + + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device) + scorers['lm'] = lm.lm + + # 3. Build ngram model + if ngram_file is not None: + if ngram_scorer == 'full': + from espnet.nets.scorers.ngram import NgramFullScorer + + ngram = NgramFullScorer(ngram_file, token_list) + else: + from espnet.nets.scorers.ngram import NgramPartScorer + + ngram = NgramPartScorer(ngram_file, token_list) + else: + ngram = None + scorers['ngram'] = ngram + + # 4. Build BeamSearch object + if asr_model.use_transducer_decoder: + beam_search_transducer = BeamSearchTransducer( + decoder=asr_model.decoder, + joint_network=asr_model.joint_network, + beam_size=beam_size, + lm=scorers['lm'] if 'lm' in scorers else None, + lm_weight=lm_weight, + **transducer_conf, + ) + beam_search = None + else: + beam_search_transducer = None + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else 'full', + ) + + # TODO(karita): make all scorers batchfied + if batch_size == 1: + non_batch = [ + k for k, v in beam_search.full_scorers.items() + if not isinstance(v, BatchScorerInterface) + ] + if len(non_batch) == 0: + if streaming: + beam_search.__class__ = BatchBeamSearchOnlineSim + beam_search.set_streaming_config(asr_train_config) + logging.info( + 'BatchBeamSearchOnlineSim implementation is selected.' + ) + else: + beam_search.__class__ = BatchBeamSearch + else: + logging.warning( + f'As non-batch scorers {non_batch} are found, ' + f'fall back to non-batch implementation.') + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to( + device=device, dtype=getattr(torch, dtype)).eval() + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == 'bpe': + if bpemodel is not None: + tokenizer = build_tokenizer( + token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + + @torch.no_grad() + def __call__(self, speech: Union[torch.Tensor, np.ndarray]): + """Inference + + Args: + data: Input speech data + Returns: + text, token, token_int, hyp + """ + + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + # data: (Nsamples,) -> (1, Nsamples) + speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + # lengths: (1,) + lengths = speech.new_full([1], + dtype=torch.long, + fill_value=speech.size(1)) + batch = {'speech': speech, 'speech_lengths': lengths} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, enc_len = self.asr_model.encode(**batch) + if isinstance(enc, tuple): + enc = enc[0] + assert len(enc) == 1, len(enc) + + predictor_outs = self.asr_model.calc_predictor(enc, enc_len) + pre_acoustic_embeds, pre_token_length = predictor_outs[ + 0], predictor_outs[1] + pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], + device=pre_acoustic_embeds.device) + decoder_outs = self.asr_model.cal_decoder_with_predictor( + enc, enc_len, pre_acoustic_embeds, pre_token_length) + decoder_out = decoder_outs[0] + + yseq = decoder_out.argmax(dim=-1) + score = decoder_out.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.asr_model.sos] + yseq.tolist()[0] + [self.asr_model.eos], + device=yseq.device) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = None if self.asr_model.use_transducer_decoder else -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + + results.append((text, token, token_int, hyp, speech.size(1))) + + return results + + @staticmethod + def from_pretrained( + model_tag: Optional[str] = None, + **kwargs: Optional[Any], + ): + """Build Speech2Text instance from the pretrained model. + + Args: + model_tag (Optional[str]): Model tag of the pretrained models. + Currently, the tags of espnet_model_zoo are supported. + + Returns: + Speech2Text: Speech2Text instance. + + """ + if model_tag is not None: + try: + from espnet_model_zoo.downloader import ModelDownloader + + except ImportError: + logging.error( + '`espnet_model_zoo` is not installed. ' + 'Please install via `pip install -U espnet_model_zoo`.') + raise + d = ModelDownloader() + kwargs.update(**d.download_and_unpack(model_tag)) + + return Speech2Text(**kwargs) + + +def inference( + output_dir: str, + maxlenratio: float, + minlenratio: float, + batch_size: int, + dtype: str, + beam_size: int, + ngpu: int, + seed: int, + ctc_weight: float, + lm_weight: float, + ngram_weight: float, + penalty: float, + nbest: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + key_file: Optional[str], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + lm_train_config: Optional[str], + lm_file: Optional[str], + word_lm_train_config: Optional[str], + word_lm_file: Optional[str], + ngram_file: Optional[str], + model_tag: Optional[str], + token_type: Optional[str], + bpemodel: Optional[str], + allow_variable_data_keys: bool, + transducer_conf: Optional[dict], + streaming: bool, + frontend_conf: dict = None, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError('batch decoding is not implemented') + if word_lm_train_config is not None: + raise NotImplementedError('Word LM is not implemented') + if ngpu > 1: + raise NotImplementedError('only single GPU decoding is supported') + + logging.basicConfig( + level=log_level, + format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', + ) + + if ngpu >= 1: + device = 'cuda' + else: + device = 'cpu' + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + transducer_conf=transducer_conf, + lm_train_config=lm_train_config, + lm_file=lm_file, + ngram_file=ngram_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + streaming=streaming, + frontend_conf=frontend_conf, + ) + speech2text = Speech2Text.from_pretrained( + model_tag=model_tag, + **speech2text_kwargs, + ) + + # 3. Build data-iterator + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, + False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + forward_time_total = 0.0 + length_total = 0.0 + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + with DatadirWriter(output_dir) as writer: + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f'{len(keys)} != {_bs}' + batch = { + k: v[0] + for k, v in batch.items() if not k.endswith('_lengths') + } + + # N-best list of (text, token, token_int, hyp_object) + + try: + time_beg = time.time() + results = speech2text(**batch) + time_end = time.time() + forward_time = time_end - time_beg + length = results[0][-1] + results = [results[0][:-1]] + forward_time_total += forward_time + length_total += length + except TooShortUttError as e: + logging.warning(f'Utterance {keys} {e}') + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[' ', [''], [2], hyp]] * nbest + + # Only supporting batch_size==1 + key = keys[0] + for n, (text, token, token_int, + hyp) in zip(range(1, nbest + 1), results): + # Create a directory: outdir/{n}best_recog + ibest_writer = writer[f'{n}best_recog'] + + # Write the result to each file + ibest_writer['token'][key] = ' '.join(token) + ibest_writer['token_int'][key] = ' '.join(map(str, token_int)) + ibest_writer['score'][key] = str(hyp.score) + + if text is not None: + ibest_writer['text'][key] = text + + logging.info( + 'decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}' + .format(length_total, forward_time_total, + 100 * forward_time_total / length_total)) + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description='ASR Decoding', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + '--log_level', + type=lambda x: x.upper(), + default='INFO', + choices=('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'), + help='The verbose level of logging', + ) + + parser.add_argument('--output_dir', type=str, required=True) + parser.add_argument( + '--ngpu', + type=int, + default=0, + help='The number of gpus. 0 indicates CPU mode', + ) + parser.add_argument('--seed', type=int, default=0, help='Random seed') + parser.add_argument( + '--dtype', + default='float32', + choices=['float16', 'float32', 'float64'], + help='Data type', + ) + parser.add_argument( + '--num_workers', + type=int, + default=1, + help='The number of workers used for DataLoader', + ) + + group = parser.add_argument_group('Input data related') + group.add_argument( + '--data_path_and_name_and_type', + type=str2triple_str, + required=True, + action='append', + ) + group.add_argument('--key_file', type=str_or_none) + group.add_argument( + '--allow_variable_data_keys', type=str2bool, default=False) + + group = parser.add_argument_group('The model configuration related') + group.add_argument( + '--asr_train_config', + type=str, + help='ASR training configuration', + ) + group.add_argument( + '--asr_model_file', + type=str, + help='ASR model parameter file', + ) + group.add_argument( + '--lm_train_config', + type=str, + help='LM training configuration', + ) + group.add_argument( + '--lm_file', + type=str, + help='LM parameter file', + ) + group.add_argument( + '--word_lm_train_config', + type=str, + help='Word LM training configuration', + ) + group.add_argument( + '--word_lm_file', + type=str, + help='Word LM parameter file', + ) + group.add_argument( + '--ngram_file', + type=str, + help='N-gram parameter file', + ) + group.add_argument( + '--model_tag', + type=str, + help='Pretrained model tag. If specify this option, *_train_config and ' + '*_file will be overwritten', + ) + + group = parser.add_argument_group('Beam-search related') + group.add_argument( + '--batch_size', + type=int, + default=1, + help='The batch size for inference', + ) + group.add_argument( + '--nbest', type=int, default=1, help='Output N-best hypotheses') + group.add_argument('--beam_size', type=int, default=20, help='Beam size') + group.add_argument( + '--penalty', type=float, default=0.0, help='Insertion penalty') + group.add_argument( + '--maxlenratio', + type=float, + default=0.0, + help='Input length ratio to obtain max output length. ' + 'If maxlenratio=0.0 (default), it uses a end-detect ' + 'function ' + 'to automatically find maximum hypothesis lengths.' + 'If maxlenratio<0.0, its absolute value is interpreted' + 'as a constant max output length', + ) + group.add_argument( + '--minlenratio', + type=float, + default=0.0, + help='Input length ratio to obtain min output length', + ) + group.add_argument( + '--ctc_weight', + type=float, + default=0.5, + help='CTC weight in joint decoding', + ) + group.add_argument( + '--lm_weight', type=float, default=1.0, help='RNNLM weight') + group.add_argument( + '--ngram_weight', type=float, default=0.9, help='ngram weight') + group.add_argument('--streaming', type=str2bool, default=False) + + group.add_argument( + '--frontend_conf', + default=None, + help='', + ) + + group = parser.add_argument_group('Text converter related') + group.add_argument( + '--token_type', + type=str_or_none, + default=None, + choices=['char', 'bpe', None], + help='The token type for ASR model. ' + 'If not given, refers from the training args', + ) + group.add_argument( + '--bpemodel', + type=str_or_none, + default=None, + help='The model path of sentencepiece. ' + 'If not given, refers from the training args', + ) + group.add_argument( + '--transducer_conf', + default=None, + help='The keyword arguments for transducer beam search.', + ) + + return parser + + +def asr_inference( + output_dir: str, + maxlenratio: float, + minlenratio: float, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + nbest: int = 1, + num_workers: int = 1, + log_level: Union[int, str] = 'INFO', + batch_size: int = 1, + dtype: str = 'float32', + seed: int = 0, + key_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + word_lm_file: Optional[str] = None, + ngram_file: Optional[str] = None, + ngram_weight: float = 0.9, + model_tag: Optional[str] = None, + token_type: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + transducer_conf: Optional[dict] = None, + streaming: bool = False, + frontend_conf: dict = None, +): + inference( + output_dir=output_dir, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + batch_size=batch_size, + dtype=dtype, + beam_size=beam_size, + ngpu=ngpu, + seed=seed, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + num_workers=num_workers, + log_level=log_level, + data_path_and_name_and_type=data_path_and_name_and_type, + key_file=key_file, + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + word_lm_train_config=word_lm_train_config, + word_lm_file=word_lm_file, + ngram_file=ngram_file, + model_tag=model_tag, + token_type=token_type, + bpemodel=bpemodel, + allow_variable_data_keys=allow_variable_data_keys, + transducer_conf=transducer_conf, + streaming=streaming, + frontend_conf=frontend_conf) + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop('config', None) + inference(**kwargs) + + +if __name__ == '__main__': + main() diff --git a/modelscope/pipelines/audio/asr/asr_engine/common/__init__.py b/modelscope/pipelines/audio/asr/asr_engine/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/audio/asr/asr_engine/common/asr_utils.py b/modelscope/pipelines/audio/asr/asr_engine/common/asr_utils.py new file mode 100644 index 00000000..0d9a5f43 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/common/asr_utils.py @@ -0,0 +1,193 @@ +import os +from typing import Any, Dict, List + +import numpy as np + + +def type_checking(wav_path: str, + recog_type: str = None, + audio_format: str = None, + workspace: str = None): + assert os.path.exists(wav_path), f'wav_path:{wav_path} does not exist' + + r_recog_type = recog_type + r_audio_format = audio_format + r_workspace = workspace + r_wav_path = wav_path + + if r_workspace is None or len(r_workspace) == 0: + r_workspace = os.path.join(os.getcwd(), '.tmp') + + if r_recog_type is None: + if os.path.isfile(wav_path): + if wav_path.endswith('.wav') or wav_path.endswith('.WAV'): + r_recog_type = 'wav' + r_audio_format = 'wav' + + elif os.path.isdir(wav_path): + dir_name = os.path.basename(wav_path) + if 'test' in dir_name: + r_recog_type = 'test' + elif 'dev' in dir_name: + r_recog_type = 'dev' + elif 'train' in dir_name: + r_recog_type = 'train' + + if r_audio_format is None: + if find_file_by_ends(wav_path, '.ark'): + r_audio_format = 'kaldi_ark' + elif find_file_by_ends(wav_path, '.wav') or find_file_by_ends( + wav_path, '.WAV'): + r_audio_format = 'wav' + + if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': + # datasets with kaldi_ark file + r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) + elif r_audio_format == 'wav' and r_recog_type != 'wav': + # datasets with waveform files + r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) + + return r_recog_type, r_audio_format, r_workspace, r_wav_path + + +def find_file_by_ends(dir_path: str, ends: str): + dir_files = os.listdir(dir_path) + for file in dir_files: + file_path = os.path.join(dir_path, file) + if os.path.isfile(file_path): + if file_path.endswith(ends): + return True + elif os.path.isdir(file_path): + if find_file_by_ends(file_path, ends): + return True + + return False + + +def compute_wer(hyp_text_path: str, ref_text_path: str) -> Dict[str, Any]: + assert os.path.exists(hyp_text_path), 'hyp_text does not exist' + assert os.path.exists(ref_text_path), 'ref_text does not exist' + + rst = { + 'Wrd': 0, + 'Corr': 0, + 'Ins': 0, + 'Del': 0, + 'Sub': 0, + 'Snt': 0, + 'Err': 0.0, + 'S.Err': 0.0, + 'wrong_words': 0, + 'wrong_sentences': 0 + } + + with open(ref_text_path, 'r', encoding='utf-8') as r: + r_lines = r.readlines() + + with open(hyp_text_path, 'r', encoding='utf-8') as h: + h_lines = h.readlines() + + for r_line in r_lines: + r_line_item = r_line.split() + r_key = r_line_item[0] + r_sentence = r_line_item[1] + for h_line in h_lines: + # find sentence from hyp text + if r_key in h_line: + h_line_item = h_line.split() + h_sentence = h_line_item[1] + out_item = compute_wer_by_line(h_sentence, r_sentence) + rst['Wrd'] += out_item['nwords'] + rst['Corr'] += out_item['cor'] + rst['wrong_words'] += out_item['wrong'] + rst['Ins'] += out_item['ins'] + rst['Del'] += out_item['del'] + rst['Sub'] += out_item['sub'] + rst['Snt'] += 1 + if out_item['wrong'] > 0: + rst['wrong_sentences'] += 1 + + break + + if rst['Wrd'] > 0: + rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) + if rst['Snt'] > 0: + rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) + + return rst + + +def compute_wer_by_line(hyp: list, ref: list) -> Dict[str, Any]: + len_hyp = len(hyp) + len_ref = len(ref) + cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) + + ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) + + for i in range(len_hyp + 1): + cost_matrix[i][0] = i + for j in range(len_ref + 1): + cost_matrix[0][j] = j + + for i in range(1, len_hyp + 1): + for j in range(1, len_ref + 1): + if hyp[i - 1] == ref[j - 1]: + cost_matrix[i][j] = cost_matrix[i - 1][j - 1] + else: + substitution = cost_matrix[i - 1][j - 1] + 1 + insertion = cost_matrix[i - 1][j] + 1 + deletion = cost_matrix[i][j - 1] + 1 + + compare_val = [substitution, insertion, deletion] + + min_val = min(compare_val) + operation_idx = compare_val.index(min_val) + 1 + cost_matrix[i][j] = min_val + ops_matrix[i][j] = operation_idx + + match_idx = [] + i = len_hyp + j = len_ref + rst = { + 'nwords': len_hyp, + 'cor': 0, + 'wrong': 0, + 'ins': 0, + 'del': 0, + 'sub': 0 + } + while i >= 0 or j >= 0: + i_idx = max(0, i) + j_idx = max(0, j) + + if ops_matrix[i_idx][j_idx] == 0: # correct + if i - 1 >= 0 and j - 1 >= 0: + match_idx.append((j - 1, i - 1)) + rst['cor'] += 1 + + i -= 1 + j -= 1 + + elif ops_matrix[i_idx][j_idx] == 2: # insert + i -= 1 + rst['ins'] += 1 + + elif ops_matrix[i_idx][j_idx] == 3: # delete + j -= 1 + rst['del'] += 1 + + elif ops_matrix[i_idx][j_idx] == 1: # substitute + i -= 1 + j -= 1 + rst['sub'] += 1 + + if i < 0 and j >= 0: + rst['del'] += 1 + elif j < 0 and i >= 0: + rst['ins'] += 1 + + match_idx.reverse() + wrong_cnt = cost_matrix[len_hyp][len_ref] + rst['wrong'] = wrong_cnt + + return rst diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/decoder/transformer_decoder.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/decoder/transformer_decoder.py new file mode 100644 index 00000000..e1435db1 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/decoder/transformer_decoder.py @@ -0,0 +1,757 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +"""Decoder definition.""" +from typing import Any, List, Sequence, Tuple + +import torch +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer +from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ + DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ + DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.lightconv import \ + LightweightConvolution +from espnet.nets.pytorch_backend.transformer.lightconv2d import \ + LightweightConvolution2D +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.scorer_interface import BatchScorerInterface +from typeguard import check_argument_types + + +class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): + """Base class of Transfomer decoder module. + + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + ): + assert check_argument_types() + super().__init__() + attention_dim = encoder_output_size + + if input_layer == 'embed': + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == 'linear': + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError( + f"only 'embed' or 'linear' is supported: {input_layer}") + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.output_layer = None + + # Must set by the inheritance + self.decoders = None + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + # m: (1, L, L) + m = subsequent_mask( + tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + + memory = hs_pad + memory_mask = ( + ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( + memory.device) + # Padding for Longformer + if memory_mask.shape[-1] != memory.shape[1]: + padlen = memory.shape[1] - memory_mask.shape[-1] + memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), + 'constant', False) + + x = self.embed(tgt) + x, tgt_mask, memory, memory_mask = self.decoders( + x, tgt_mask, memory, memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.output_layer is not None: + x = self.output_layer(x) + + olens = tgt_mask.sum(1) + return x, olens + + def forward_one_step( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + cache: List[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + + Args: + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + memory: encoded memory, float32 (batch, maxlen_in, feat) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x = self.embed(tgt) + if cache is None: + cache = [None] * len(self.decoders) + new_cache = [] + for c, decoder in zip(cache, self.decoders): + x, tgt_mask, memory, memory_mask = decoder( + x, tgt_mask, memory, None, cache=c) + new_cache.append(x) + + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.output_layer is not None: + y = torch.log_softmax(self.output_layer(y), dim=-1) + + return y, new_cache + + def score(self, ys, state, x): + """Score.""" + ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) + logp, state = self.forward_one_step( + ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) + return logp.squeeze(0), state + + def batch_score(self, ys: torch.Tensor, states: List[Any], + xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.decoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + torch.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + # batch decoding + ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) + logp, states = self.forward_one_step( + ys, ys_mask, xs, cache=batch_state) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] + for b in range(n_batch)] + return logp, state_list + + +class TransformerDecoder(BaseTransformerDecoder): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class ParaformerDecoder(TransformerDecoder): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + # m: (1, L, L) + # m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + # tgt_mask = tgt_mask & m + + memory = hs_pad + memory_mask = ( + ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( + memory.device) + # Padding for Longformer + if memory_mask.shape[-1] != memory.shape[1]: + padlen = memory.shape[1] - memory_mask.shape[-1] + memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), + 'constant', False) + + # x = self.embed(tgt) + x = tgt + x, tgt_mask, memory, memory_mask = self.decoders( + x, tgt_mask, memory, memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.output_layer is not None: + x = self.output_layer(x) + + olens = tgt_mask.sum(1) + return x, olens + + +class ParaformerDecoderBertEmbed(TransformerDecoder): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + embeds_id: int = 2, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + embeds_id, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if embeds_id == num_blocks: + self.decoders2 = None + else: + self.decoders2 = repeat( + num_blocks - embeds_id, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + # m: (1, L, L) + # m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + # tgt_mask = tgt_mask & m + + memory = hs_pad + memory_mask = ( + ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( + memory.device) + # Padding for Longformer + if memory_mask.shape[-1] != memory.shape[1]: + padlen = memory.shape[1] - memory_mask.shape[-1] + memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), + 'constant', False) + + # x = self.embed(tgt) + x = tgt + x, tgt_mask, memory, memory_mask = self.decoders( + x, tgt_mask, memory, memory_mask) + embeds_outputs = x + if self.decoders2 is not None: + x, tgt_mask, memory, memory_mask = self.decoders2( + x, tgt_mask, memory, memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.output_layer is not None: + x = self.output_layer(x) + + olens = tgt_mask.sum(1) + return x, olens, embeds_outputs + + +class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + 'conv_kernel_length must have equal number of values to num_blocks: ' + f'{len(conv_kernel_length)} != {num_blocks}') + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + LightweightConvolution( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + 'conv_kernel_length must have equal number of values to num_blocks: ' + f'{len(conv_kernel_length)} != {num_blocks}') + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + LightweightConvolution2D( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + 'conv_kernel_length must have equal number of values to num_blocks: ' + f'{len(conv_kernel_length)} != {num_blocks}') + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + attention_dim = encoder_output_size + + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + DynamicConvolution( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = 'embed', + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + 'conv_kernel_length must have equal number of values to num_blocks: ' + f'{len(conv_kernel_length)} != {num_blocks}') + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + attention_dim = encoder_output_size + + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + DynamicConvolution2D( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/conformer_encoder.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/conformer_encoder.py new file mode 100644 index 00000000..463852e9 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/conformer_encoder.py @@ -0,0 +1,710 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +"""Conformer encoder definition.""" + +import logging +from typing import List, Optional, Tuple, Union + +import torch +from espnet2.asr.ctc import CTC +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule +from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.nets_utils import (get_activation, + make_pad_mask) +from espnet.nets.pytorch_backend.transformer.embedding import \ + LegacyRelPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + RelPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.embedding import \ + ScaledPositionalEncoding # noqa: H301 +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import ( + Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, + Conv2dSubsampling8, TooShortUttError, check_short_utt) +from typeguard import check_argument_types + +from ...nets.pytorch_backend.transformer.attention import \ + LegacyRelPositionMultiHeadedAttention # noqa: H301 +from ...nets.pytorch_backend.transformer.attention import \ + MultiHeadedAttention # noqa: H301 +from ...nets.pytorch_backend.transformer.attention import \ + RelPositionMultiHeadedAttention # noqa: H301 +from ...nets.pytorch_backend.transformer.attention import ( + LegacyRelPositionMultiHeadedAttentionSANM, + RelPositionMultiHeadedAttentionSANM) + + +class ConformerEncoder(AbsEncoder): + """Conformer encoder module. + + Args: + input_size (int): Input dimension. + output_size (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + attention_dropout_rate (float): Dropout rate in attention. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + If True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + If False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + rel_pos_type (str): Whether to use the latest relative positional encoding or + the legacy one. The legacy relative positional encoding will be deprecated + in the future. More Details can be found in + https://github.com/espnet/espnet/pull/2816. + encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. + encoder_attn_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + macaron_style (bool): Whether to use macaron style for positionwise layer. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = 'conv2d', + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = 'linear', + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = 'legacy', + pos_enc_layer_type: str = 'rel_pos', + selfattention_layer_type: str = 'rel_selfattn', + activation_type: str = 'swish', + use_cnn_module: bool = True, + zero_triu: bool = False, + cnn_module_kernel: int = 31, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + stochastic_depth_rate: Union[float, List[float]] = 0.0, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if rel_pos_type == 'legacy': + if pos_enc_layer_type == 'rel_pos': + pos_enc_layer_type = 'legacy_rel_pos' + if selfattention_layer_type == 'rel_selfattn': + selfattention_layer_type = 'legacy_rel_selfattn' + elif rel_pos_type == 'latest': + assert selfattention_layer_type != 'legacy_rel_selfattn' + assert pos_enc_layer_type != 'legacy_rel_pos' + else: + raise ValueError('unknown rel_pos_type: ' + rel_pos_type) + + activation = get_activation(activation_type) + if pos_enc_layer_type == 'abs_pos': + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == 'scaled_abs_pos': + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == 'rel_pos': + assert selfattention_layer_type == 'rel_selfattn' + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == 'legacy_rel_pos': + assert selfattention_layer_type == 'legacy_rel_selfattn' + pos_enc_class = LegacyRelPositionalEncoding + else: + raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) + + if input_layer == 'linear': + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d': + self.embed = Conv2dSubsampling( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d2': + self.embed = Conv2dSubsampling2( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d6': + self.embed = Conv2dSubsampling6( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d8': + self.embed = Conv2dSubsampling8( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'embed': + self.embed = torch.nn.Sequential( + torch.nn.Embedding( + input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(output_size, positional_dropout_rate)) + else: + raise ValueError('unknown input_layer: ' + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == 'linear': + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == 'conv1d': + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == 'conv1d-linear': + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError('Support only linear or conv1d.') + + if selfattention_layer_type == 'selfattn': + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == 'legacy_rel_selfattn': + assert pos_enc_layer_type == 'legacy_rel_pos' + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == 'rel_selfattn': + assert pos_enc_layer_type == 'rel_pos' + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + zero_triu, + ) + else: + raise ValueError('unknown encoder_attn_layer: ' + + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation) + + if isinstance(stochastic_depth_rate, float): + stochastic_depth_rate = [stochastic_depth_rate] * num_blocks + + if len(stochastic_depth_rate) != num_blocks: + raise ValueError( + f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' + f'should be equal to num_blocks ({num_blocks})') + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) + if macaron_style else None, + convolution_layer(*convolution_layer_args) + if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate[lnum], + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max( + interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Calculate forward propagation. + + Args: + xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). + ilens (torch.Tensor): Input length (#batch). + prev_states (torch.Tensor): Not to be used now. + + Returns: + torch.Tensor: Output tensor (#batch, L, output_size). + torch.Tensor: Output length (#batch). + torch.Tensor: Not to be used now. + + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if (isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8)): + short_status, limit_size = check_short_utt(self.embed, + xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f'has {xs_pad.size(1)} frames and is too short for subsampling ' + + # noqa: * + f'(it needs more than {limit_size} frames), return empty results', # noqa: * + xs_pad.size(1), + limit_size) # noqa: * + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + xs_pad, masks = self.encoders(xs_pad, masks) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + + if isinstance(xs_pad, tuple): + x, pos_emb = xs_pad + x = x + self.conditioning_layer(ctc_out) + xs_pad = (x, pos_emb) + else: + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None + + +class SANMEncoder_v2(AbsEncoder): + """Conformer encoder module. + + Args: + input_size (int): Input dimension. + output_size (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + attention_dropout_rate (float): Dropout rate in attention. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + If True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + If False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + rel_pos_type (str): Whether to use the latest relative positional encoding or + the legacy one. The legacy relative positional encoding will be deprecated + in the future. More Details can be found in + https://github.com/espnet/espnet/pull/2816. + encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. + encoder_attn_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + macaron_style (bool): Whether to use macaron style for positionwise layer. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = 'conv2d', + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = 'linear', + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = 'legacy', + pos_enc_layer_type: str = 'rel_pos', + selfattention_layer_type: str = 'rel_selfattn', + activation_type: str = 'swish', + use_cnn_module: bool = False, + sanm_shfit: int = 0, + zero_triu: bool = False, + cnn_module_kernel: int = 31, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + stochastic_depth_rate: Union[float, List[float]] = 0.0, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if rel_pos_type == 'legacy': + if pos_enc_layer_type == 'rel_pos': + pos_enc_layer_type = 'legacy_rel_pos' + if selfattention_layer_type == 'rel_selfattn': + selfattention_layer_type = 'legacy_rel_selfattn' + if selfattention_layer_type == 'rel_selfattnsanm': + selfattention_layer_type = 'legacy_rel_selfattnsanm' + + elif rel_pos_type == 'latest': + assert selfattention_layer_type != 'legacy_rel_selfattn' + assert pos_enc_layer_type != 'legacy_rel_pos' + else: + raise ValueError('unknown rel_pos_type: ' + rel_pos_type) + + activation = get_activation(activation_type) + if pos_enc_layer_type == 'abs_pos': + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == 'scaled_abs_pos': + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == 'rel_pos': + # assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == 'legacy_rel_pos': + # assert selfattention_layer_type == "legacy_rel_selfattn" + pos_enc_class = LegacyRelPositionalEncoding + logging.warning( + 'Using legacy_rel_pos and it will be deprecated in the future.' + ) + else: + raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) + + if input_layer == 'linear': + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d': + self.embed = Conv2dSubsampling( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d2': + self.embed = Conv2dSubsampling2( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d6': + self.embed = Conv2dSubsampling6( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d8': + self.embed = Conv2dSubsampling8( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'embed': + self.embed = torch.nn.Sequential( + torch.nn.Embedding( + input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(output_size, positional_dropout_rate)) + else: + raise ValueError('unknown input_layer: ' + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == 'linear': + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == 'conv1d': + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == 'conv1d-linear': + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError('Support only linear or conv1d.') + + if selfattention_layer_type == 'selfattn': + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == 'legacy_rel_selfattn': + assert pos_enc_layer_type == 'legacy_rel_pos' + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + logging.warning( + 'Using legacy_rel_selfattn and it will be deprecated in the future.' + ) + + elif selfattention_layer_type == 'legacy_rel_selfattnsanm': + assert pos_enc_layer_type == 'legacy_rel_pos' + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttentionSANM + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + logging.warning( + 'Using legacy_rel_selfattn and it will be deprecated in the future.' + ) + + elif selfattention_layer_type == 'rel_selfattn': + assert pos_enc_layer_type == 'rel_pos' + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + zero_triu, + ) + elif selfattention_layer_type == 'rel_selfattnsanm': + assert pos_enc_layer_type == 'rel_pos' + encoder_selfattn_layer = RelPositionMultiHeadedAttentionSANM + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + zero_triu, + cnn_module_kernel, + sanm_shfit, + ) + else: + raise ValueError('unknown encoder_attn_layer: ' + + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation) + + if isinstance(stochastic_depth_rate, float): + stochastic_depth_rate = [stochastic_depth_rate] * num_blocks + + if len(stochastic_depth_rate) != num_blocks: + raise ValueError( + f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' + f'should be equal to num_blocks ({num_blocks})') + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) + if macaron_style else None, + convolution_layer(*convolution_layer_args) + if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate[lnum], + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max( + interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Calculate forward propagation. + + Args: + xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). + ilens (torch.Tensor): Input length (#batch). + prev_states (torch.Tensor): Not to be used now. + + Returns: + torch.Tensor: Output tensor (#batch, L, output_size). + torch.Tensor: Output length (#batch). + torch.Tensor: Not to be used now. + + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if (isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8)): + short_status, limit_size = check_short_utt(self.embed, + xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f'has {xs_pad.size(1)} frames and is too short for subsampling ' + + # noqa: * + f'(it needs more than {limit_size} frames), return empty results', + xs_pad.size(1), + limit_size) # noqa: * + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + xs_pad, masks = self.encoders(xs_pad, masks) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + + if isinstance(xs_pad, tuple): + x, pos_emb = xs_pad + x = x + self.conditioning_layer(ctc_out) + xs_pad = (x, pos_emb) + else: + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/sanm_encoder.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/sanm_encoder.py new file mode 100644 index 00000000..92e51b2e --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/encoder/sanm_encoder.py @@ -0,0 +1,500 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +"""Transformer encoder definition.""" + +import logging +from typing import List, Optional, Sequence, Tuple, Union + +import torch +from espnet2.asr.ctc import CTC +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.embedding import \ + PositionalEncoding +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, MultiLayeredConv1d) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ + PositionwiseFeedForward # noqa: H301 +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import ( + Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, + Conv2dSubsampling8, TooShortUttError, check_short_utt) +from typeguard import check_argument_types + +from ...asr.streaming_utilis.chunk_utilis import overlap_chunk +from ...nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, MultiHeadedAttentionSANM) +from ...nets.pytorch_backend.transformer.encoder_layer import ( + EncoderLayer, EncoderLayerChunk) + + +class SANMEncoder(AbsEncoder): + """Transformer encoder module. + + Args: + input_size: input dim + output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + attention_dropout_rate: dropout rate in attention + positional_dropout_rate: dropout rate after adding positional encoding + input_layer: input layer type + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + positionwise_layer_type: linear of conv1d + positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = 'conv2d', + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = 'linear', + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + kernel_size: int = 11, + sanm_shfit: int = 0, + selfattention_layer_type: str = 'sanm', + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if input_layer == 'linear': + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d': + self.embed = Conv2dSubsampling(input_size, output_size, + dropout_rate) + elif input_layer == 'conv2d2': + self.embed = Conv2dSubsampling2(input_size, output_size, + dropout_rate) + elif input_layer == 'conv2d6': + self.embed = Conv2dSubsampling6(input_size, output_size, + dropout_rate) + elif input_layer == 'conv2d8': + self.embed = Conv2dSubsampling8(input_size, output_size, + dropout_rate) + elif input_layer == 'embed': + self.embed = torch.nn.Sequential( + torch.nn.Embedding( + input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + else: + raise ValueError('unknown input_layer: ' + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == 'linear': + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == 'conv1d': + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == 'conv1d-linear': + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError('Support only linear or conv1d.') + + if selfattention_layer_type == 'selfattn': + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == 'sanm': + encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max( + interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if self.embed is None: + xs_pad = xs_pad + elif (isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8)): + short_status, limit_size = check_short_utt(self.embed, + xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f'has {xs_pad.size(1)} frames and is too short for subsampling ' + + # noqa: * + f'(it needs more than {limit_size} frames), return empty results', + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + xs_pad, masks = self.encoders(xs_pad, masks) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None + + +class SANMEncoderChunk(AbsEncoder): + """Transformer encoder module. + + Args: + input_size: input dim + output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + attention_dropout_rate: dropout rate in attention + positional_dropout_rate: dropout rate after adding positional encoding + input_layer: input layer type + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + positionwise_layer_type: linear of conv1d + positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = 'conv2d', + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = 'linear', + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + kernel_size: int = 11, + sanm_shfit: int = 0, + selfattention_layer_type: str = 'sanm', + chunk_size: Union[int, Sequence[int]] = (16, ), + stride: Union[int, Sequence[int]] = (10, ), + pad_left: Union[int, Sequence[int]] = (0, ), + encoder_att_look_back_factor: Union[int, Sequence[int]] = (1, ), + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if input_layer == 'linear': + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == 'conv2d': + self.embed = Conv2dSubsampling(input_size, output_size, + dropout_rate) + elif input_layer == 'conv2d2': + self.embed = Conv2dSubsampling2(input_size, output_size, + dropout_rate) + elif input_layer == 'conv2d6': + self.embed = Conv2dSubsampling6(input_size, output_size, + dropout_rate) + elif input_layer == 'conv2d8': + self.embed = Conv2dSubsampling8(input_size, output_size, + dropout_rate) + elif input_layer == 'embed': + self.embed = torch.nn.Sequential( + torch.nn.Embedding( + input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + else: + raise ValueError('unknown input_layer: ' + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == 'linear': + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == 'conv1d': + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == 'conv1d-linear': + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError('Support only linear or conv1d.') + + if selfattention_layer_type == 'selfattn': + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == 'sanm': + encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayerChunk( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max( + interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + shfit_fsmn = (kernel_size - 1) // 2 + self.overlap_chunk_cls = overlap_chunk( + chunk_size=chunk_size, + stride=stride, + pad_left=pad_left, + shfit_fsmn=shfit_fsmn, + encoder_att_look_back_factor=encoder_att_look_back_factor, + ) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ind: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if self.embed is None: + xs_pad = xs_pad + elif (isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8)): + short_status, limit_size = check_short_utt(self.embed, + xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f'has {xs_pad.size(1)} frames and is too short for subsampling ' + + # noqa: * + f'(it needs more than {limit_size} frames), return empty results', + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + mask_shfit_chunk, mask_att_chunk_encoder = None, None + if self.overlap_chunk_cls is not None: + ilens = masks.squeeze(1).sum(1) + chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) + xs_pad, ilens = self.overlap_chunk_cls.split_chunk( + xs_pad, ilens, chunk_outs=chunk_outs) + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk( + chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) + mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder( + chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + xs_pad, masks, _, _, _ = self.encoders(xs_pad, masks, None, + mask_shfit_chunk, + mask_att_chunk_encoder) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks, _, _, _ = encoder_layer(xs_pad, masks, None, + mask_shfit_chunk, + mask_att_chunk_encoder) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + if self.overlap_chunk_cls is not None: + xs_pad, olens = self.overlap_chunk_cls.remove_chunk( + xs_pad, ilens, chunk_outs) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model.py new file mode 100644 index 00000000..6f5b3688 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model.py @@ -0,0 +1,1131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +import logging +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict, List, Optional, Tuple, Union + +import torch +from espnet2.asr.ctc import CTC +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder +from espnet2.asr.specaug.abs_specaug import AbsSpecAug +from espnet2.asr.transducer.error_calculator import ErrorCalculatorTransducer +from espnet2.asr.transducer.utils import get_transducer_task_io +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 +from typeguard import check_argument_types + +from .streaming_utilis.chunk_utilis import sequence_mask + +if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetASRModel(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + joint_network: Optional[torch.nn.Module], + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = '', + sym_blank: str = '', + extract_feats_in_collect_stats: bool = True, + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.blank_id = 0 + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.interctc_weight = interctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder + self.encoder = encoder + + if not hasattr(self.encoder, 'interctc_use_conditioning'): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size()) + + self.use_transducer_decoder = joint_network is not None + + self.error_calculator = None + + if self.use_transducer_decoder: + # from warprnnt_pytorch import RNNTLoss + from warp_rnnt import rnnt_loss as RNNTLoss + + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = RNNTLoss + + if report_cer or report_wer: + self.error_calculator_trans = ErrorCalculatorTransducer( + decoder, + joint_network, + token_list, + sym_space, + sym_blank, + report_cer=report_cer, + report_wer=report_wer, + ) + else: + self.error_calculator_trans = None + + if self.ctc_weight != 0: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, + report_wer) + else: + # we set self.decoder = None in the CTC mode since + # self.decoder parameters were never used and PyTorch complained + # and threw an Exception in the multi-GPU experiment. + # thanks Jeff Farris for pointing out the issue. + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] + == # noqa: * + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + batch_size = speech.shape[0] + + # for data-parallel + text = text[:, :text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + loss_transducer, cer_transducer, wer_transducer = None, None, None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, + encoder_out_lens, text, + text_lengths) + + # Collect CTC branch stats + stats['loss_ctc'] = loss_ctc.detach( + ) if loss_ctc is not None else None + stats['cer_ctc'] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss(intermediate_out, + encoder_out_lens, text, + text_lengths) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats['loss_interctc_layer{}'.format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None) + stats['cer_interctc_layer{}'.format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = (1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + if self.use_transducer_decoder: + # 2a. Transducer decoder branch + ( + loss_transducer, + cer_transducer, + wer_transducer, + ) = self._calc_transducer_loss( + encoder_out, + encoder_out_lens, + text, + ) + + if loss_ctc is not None: + loss = loss_transducer + (self.ctc_weight * loss_ctc) + else: + loss = loss_transducer + + # Collect Transducer branch stats + stats['loss_transducer'] = ( + loss_transducer.detach() + if loss_transducer is not None else None) + stats['cer_transducer'] = cer_transducer + stats['wer_transducer'] = wer_transducer + + else: + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight) * loss_att + + # Collect Attn branch stats + stats['loss_att'] = loss_att.detach( + ) if loss_att is not None else None + stats['acc'] = acc_att + stats['cer'] = cer_att + stats['wer'] = wer_att + + # Collect total loss stats + # TODO(wjm): needed to be checked + # TODO(wjm): same problem: https://github.com/espnet/espnet/issues/4136 + # FIXME(wjm): for logger error when accum_grad > 1 + # stats["loss"] = loss.detach() + stats['loss'] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), + loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + 'Generating dummy stats for feats and feats_lengths, ' + 'because encoder_conf.extract_feats_in_collect_stats is ' + f'{self.extract_feats_in_collect_stats}') + feats, feats_lengths = speech, speech_lengths + return {'feats': feats, 'feats_lengths': feats_lengths} + + def encode( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths, ctc=self.ctc) + else: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, :speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from transformer-decoder + + Normally, this function is called in batchify_nll. + + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) # [batch, seqlen, dim] + batch_size = decoder_out.size(0) + decoder_num_class = decoder_out.size(2) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + decoder_out.view(-1, decoder_num_class), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction='none', + ) + nll = nll.view(batch_size, -1) + nll = nll.sum(dim=1) + assert nll.size(0) == batch_size + return nll + + def batchify_nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + batch_size: int = 100, + ): + """Compute negative log likelihood(nll) from transformer-decoder + + To avoid OOM, this fuction seperate the input into batches. + Then call nll for each batch and combine and return results. + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + batch_size: int, samples each batch contain when computing nll, + you may change this to avoid OOM or increase + GPU memory usage + """ + total_num = encoder_out.size(0) + if total_num <= batch_size: + nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + else: + nll = [] + start_idx = 0 + while True: + end_idx = min(start_idx + batch_size, total_num) + batch_encoder_out = encoder_out[start_idx:end_idx, :, :] + batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] + batch_ys_pad = ys_pad[start_idx:end_idx, :] + batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] + batch_nll = self.nll( + batch_encoder_out, + batch_encoder_out_lens, + batch_ys_pad, + batch_ys_pad_lens, + ) + nll.append(batch_nll) + start_idx = end_idx + if start_idx == total_num: + break + nll = torch.cat(nll) + assert nll.size(0) == total_num + return nll + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator( + ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + labels: torch.Tensor, + ): + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + encoder_out_lens: Encoder output sequences lengths. (B,) + labels: Label ID sequences. (B, L) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + decoder_in, target, t_len, u_len = get_transducer_task_io( + labels, + encoder_out_lens, + ignore_id=self.ignore_id, + blank_id=self.blank_id, + ) + + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in) + + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)) + + loss_transducer = self.criterion_transducer( + joint_out, + target, + t_len, + u_len, + reduction='sum', + ) + + cer_transducer, wer_transducer = None, None + if not self.training and self.error_calculator_trans is not None: + cer_transducer, wer_transducer = self.error_calculator_trans( + encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer + + +class AEDStreaming(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + joint_network: Optional[torch.nn.Module], + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = '', + sym_blank: str = '', + extract_feats_in_collect_stats: bool = True, + predictor=None, + predictor_weight: float = 0.0, + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.blank_id = 0 + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.interctc_weight = interctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder + self.encoder = encoder + + if not hasattr(self.encoder, 'interctc_use_conditioning'): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size()) + + self.use_transducer_decoder = joint_network is not None + + self.error_calculator = None + + if self.use_transducer_decoder: + # from warprnnt_pytorch import RNNTLoss + from warp_rnnt import rnnt_loss as RNNTLoss + + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = RNNTLoss + + if report_cer or report_wer: + self.error_calculator_trans = ErrorCalculatorTransducer( + decoder, + joint_network, + token_list, + sym_space, + sym_blank, + report_cer=report_cer, + report_wer=report_wer, + ) + else: + self.error_calculator_trans = None + + if self.ctc_weight != 0: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, + report_wer) + else: + # we set self.decoder = None in the CTC mode since + # self.decoder parameters were never used and PyTorch complained + # and threw an Exception in the multi-GPU experiment. + # thanks Jeff Farris for pointing out the issue. + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + self.predictor = predictor + self.predictor_weight = predictor_weight + self.criterion_pre = torch.nn.L1Loss() + self.step_cur = 0 + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] + == # noqa: * + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + batch_size = speech.shape[0] + + # for data-parallel + text = text[:, :text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + loss_transducer, cer_transducer, wer_transducer = None, None, None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, + encoder_out_lens, text, + text_lengths) + + # Collect CTC branch stats + stats['loss_ctc'] = loss_ctc.detach( + ) if loss_ctc is not None else None + stats['cer_ctc'] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss(intermediate_out, + encoder_out_lens, text, + text_lengths) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats['loss_interctc_layer{}'.format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None) + stats['cer_interctc_layer{}'.format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = (1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + if self.use_transducer_decoder: + # 2a. Transducer decoder branch + ( + loss_transducer, + cer_transducer, + wer_transducer, + ) = self._calc_transducer_loss( + encoder_out, + encoder_out_lens, + text, + ) + + if loss_ctc is not None: + loss = loss_transducer + (self.ctc_weight * loss_ctc) + else: + loss = loss_transducer + + # Collect Transducer branch stats + stats['loss_transducer'] = ( + loss_transducer.detach() + if loss_transducer is not None else None) + stats['cer_transducer'] = cer_transducer + stats['wer_transducer'] = wer_transducer + + else: + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight) * loss_att + + # Collect Attn branch stats + stats['loss_att'] = loss_att.detach( + ) if loss_att is not None else None + stats['acc'] = acc_att + stats['cer'] = cer_att + stats['wer'] = wer_att + + # Collect total loss stats + # TODO(wjm): needed to be checked + # TODO(wjm): same problem: https://github.com/espnet/espnet/issues/4136 + # FIXME(wjm): for logger error when accum_grad > 1 + # stats["loss"] = loss.detach() + stats['loss'] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), + loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + 'Generating dummy stats for feats and feats_lengths, ' + 'because encoder_conf.extract_feats_in_collect_stats is ' + f'{self.extract_feats_in_collect_stats}') + feats, feats_lengths = speech, speech_lengths + return {'feats': feats, 'feats_lengths': feats_lengths} + + def encode( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths, ctc=self.ctc) + else: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, :speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from transformer-decoder + + Normally, this function is called in batchify_nll. + + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) # [batch, seqlen, dim] + batch_size = decoder_out.size(0) + decoder_num_class = decoder_out.size(2) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + decoder_out.view(-1, decoder_num_class), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction='none', + ) + nll = nll.view(batch_size, -1) + nll = nll.sum(dim=1) + assert nll.size(0) == batch_size + return nll + + def batchify_nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + batch_size: int = 100, + ): + """Compute negative log likelihood(nll) from transformer-decoder + + To avoid OOM, this fuction seperate the input into batches. + Then call nll for each batch and combine and return results. + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + batch_size: int, samples each batch contain when computing nll, + you may change this to avoid OOM or increase + GPU memory usage + """ + total_num = encoder_out.size(0) + if total_num <= batch_size: + nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + else: + nll = [] + start_idx = 0 + while True: + end_idx = min(start_idx + batch_size, total_num) + batch_encoder_out = encoder_out[start_idx:end_idx, :, :] + batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] + batch_ys_pad = ys_pad[start_idx:end_idx, :] + batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] + batch_nll = self.nll( + batch_encoder_out, + batch_encoder_out_lens, + batch_ys_pad, + batch_ys_pad_lens, + ) + nll.append(batch_nll) + start_idx = end_idx + if start_idx == total_num: + break + nll = torch.cat(nll) + assert nll.size(0) == total_num + return nll + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_att_predictor_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + encoder_out_mask = sequence_mask( + encoder_out_lens, + maxlen=encoder_out.size(1), + dtype=encoder_out.dtype, + device=encoder_out.device)[:, None, :] + # logging.info( + # "encoder_out_mask size: {}".format(encoder_out_mask.size())) + pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor( + encoder_out, + ys_out_pad, + encoder_out_mask, + ignore_id=self.ignore_id, + target_label_length=ys_in_lens) + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator( + ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + labels: torch.Tensor, + ): + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + encoder_out_lens: Encoder output sequences lengths. (B,) + labels: Label ID sequences. (B, L) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + decoder_in, target, t_len, u_len = get_transducer_task_io( + labels, + encoder_out_lens, + ignore_id=self.ignore_id, + blank_id=self.blank_id, + ) + + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in) + + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)) + + loss_transducer = self.criterion_transducer( + joint_out, + target, + t_len, + u_len, + reduction='sum', + ) + + cer_transducer, wer_transducer = None, None + if not self.training and self.error_calculator_trans is not None: + cer_transducer, wer_transducer = self.error_calculator_trans( + encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model_paraformer.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model_paraformer.py new file mode 100644 index 00000000..9b3ac624 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model_paraformer.py @@ -0,0 +1,1444 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +import logging +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict, List, Optional, Tuple, Union + +import torch +from espnet2.asr.ctc import CTC +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.espnet_model import ESPnetASRModel +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder +from espnet2.asr.specaug.abs_specaug import AbsSpecAug +from espnet2.asr.transducer.error_calculator import ErrorCalculatorTransducer +from espnet2.asr.transducer.utils import get_transducer_task_io +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, th_accuracy +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import \ + LabelSmoothingLoss # noqa: H301 +from typeguard import check_argument_types + +from ...espnet.nets.pytorch_backend.cif_utils.cif import \ + CIF_Model as cif_predictor + +if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class Paraformer(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + joint_network: Optional[torch.nn.Module], + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = '', + sym_blank: str = '', + extract_feats_in_collect_stats: bool = True, + predictor=None, + predictor_weight: float = 0.0, + glat_context_p: float = 0.2, + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.blank_id = 0 + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.interctc_weight = interctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder + self.encoder = encoder + + if not hasattr(self.encoder, 'interctc_use_conditioning'): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size()) + + self.use_transducer_decoder = joint_network is not None + + self.error_calculator = None + + if self.use_transducer_decoder: + # from warprnnt_pytorch import RNNTLoss + from warp_rnnt import rnnt_loss as RNNTLoss + + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = RNNTLoss + + if report_cer or report_wer: + self.error_calculator_trans = ErrorCalculatorTransducer( + decoder, + joint_network, + token_list, + sym_space, + sym_blank, + report_cer=report_cer, + report_wer=report_wer, + ) + else: + self.error_calculator_trans = None + + if self.ctc_weight != 0: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, + report_wer) + else: + # we set self.decoder = None in the CTC mode since + # self.decoder parameters were never used and PyTorch complained + # and threw an Exception in the multi-GPU experiment. + # thanks Jeff Farris for pointing out the issue. + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + self.predictor = predictor + self.predictor_weight = predictor_weight + self.glat_context_p = glat_context_p + self.criterion_pre = torch.nn.L1Loss() + self.step_cur = 0 + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), \ + (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + self.step_cur += 1 + # for data-parallel + text = text[:, :text_lengths.max()] + speech = speech[:, :speech_lengths.max(), :] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + loss_transducer, cer_transducer, wer_transducer = None, None, None + loss_pre = None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, + encoder_out_lens, text, + text_lengths) + + # Collect CTC branch stats + stats['loss_ctc'] = loss_ctc.detach( + ) if loss_ctc is not None else None + stats['cer_ctc'] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss(intermediate_out, + encoder_out_lens, text, + text_lengths) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats['loss_interctc_layer{}'.format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None) + stats['cer_interctc_layer{}'.format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = (1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + if self.use_transducer_decoder: + # 2a. Transducer decoder branch + ( + loss_transducer, + cer_transducer, + wer_transducer, + ) = self._calc_transducer_loss( + encoder_out, + encoder_out_lens, + text, + ) + + if loss_ctc is not None: + loss = loss_transducer + (self.ctc_weight * loss_ctc) + else: + loss = loss_transducer + + # Collect Transducer branch stats + stats['loss_transducer'] = ( + loss_transducer.detach() + if loss_transducer is not None else None) + stats['cer_transducer'] = cer_transducer + stats['wer_transducer'] = wer_transducer + + else: + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight + ) * loss_att + loss_pre * self.predictor_weight + + # Collect Attn branch stats + stats['loss_att'] = loss_att.detach( + ) if loss_att is not None else None + stats['acc'] = acc_att + stats['cer'] = cer_att + stats['wer'] = wer_att + stats['loss_pre'] = loss_pre.detach().cpu( + ) if loss_pre is not None else None + + # Collect total loss stats + # TODO(wjm): needed to be checked + # TODO(wjm): same problem: https://github.com/espnet/espnet/issues/4136 + # FIXME(wjm): for logger error when accum_grad > 1 + # stats["loss"] = loss.detach() + stats['loss'] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), + loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + 'Generating dummy stats for feats and feats_lengths, ' + 'because encoder_conf.extract_feats_in_collect_stats is ' + f'{self.extract_feats_in_collect_stats}') + feats, feats_lengths = speech, speech_lengths + return {'feats': feats, 'feats_lengths': feats_lengths} + + def encode( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths, ctc=self.ctc) + else: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def calc_predictor(self, encoder_out, encoder_out_lens): + + encoder_out_mask = (~make_pad_mask( + encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor( + encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id) + return pre_acoustic_embeds, pre_token_length + + def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, + sematic_embeds, ys_pad_lens): + + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, + sematic_embeds, ys_pad_lens) + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out, ys_pad_lens + + def _extract_feats( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, :speech_lengths.max()] + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from transformer-decoder + + Normally, this function is called in batchify_nll. + + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) # [batch, seqlen, dim] + batch_size = decoder_out.size(0) + decoder_num_class = decoder_out.size(2) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + decoder_out.view(-1, decoder_num_class), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction='none', + ) + nll = nll.view(batch_size, -1) + nll = nll.sum(dim=1) + assert nll.size(0) == batch_size + return nll + + def batchify_nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + batch_size: int = 100, + ): + """Compute negative log likelihood(nll) from transformer-decoder + + To avoid OOM, this fuction seperate the input into batches. + Then call nll for each batch and combine and return results. + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + batch_size: int, samples each batch contain when computing nll, + you may change this to avoid OOM or increase + GPU memory usage + """ + total_num = encoder_out.size(0) + if total_num <= batch_size: + nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + else: + nll = [] + start_idx = 0 + while True: + end_idx = min(start_idx + batch_size, total_num) + batch_encoder_out = encoder_out[start_idx:end_idx, :, :] + batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] + batch_ys_pad = ys_pad[start_idx:end_idx, :] + batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] + batch_nll = self.nll( + batch_encoder_out, + batch_encoder_out_lens, + batch_ys_pad, + batch_ys_pad_lens, + ) + nll.append(batch_nll) + start_idx = end_idx + if start_idx == total_num: + break + nll = torch.cat(nll) + assert nll.size(0) == total_num + return nll + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask( + encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor( + encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) + + # 0. sampler + decoder_out_1st = None + if self.glat_context_p > 0.0: + if self.step_cur < 2: + logging.info( + 'enable sampler in paraformer, glat_context_p: {}'.format( + self.glat_context_p)) + sematic_embeds, decoder_out_1st = self.sampler( + encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds) + else: + if self.step_cur < 2: + logging.info( + 'disable sampler in paraformer, glat_context_p: {}'.format( + self.glat_context_p)) + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder(encoder_out, encoder_out_lens, + sematic_embeds, ys_pad_lens) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre( + ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre + + def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds): + + tgt_mask = (~make_pad_mask(ys_pad_lens, + maxlen=ys_pad_lens.max())[:, :, None]).to( + ys_pad.device) + ys_pad *= tgt_mask[:, :, 0] + ys_pad_embed = self.decoder.embed(ys_pad) + with torch.no_grad(): + decoder_outs = self.decoder(encoder_out, encoder_out_lens, + pre_acoustic_embeds, ys_pad_lens) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + pred_tokens = decoder_out.argmax(-1) + nonpad_positions = ys_pad.ne(self.ignore_id) + seq_lens = (nonpad_positions).sum(1) + same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) + input_mask = torch.ones_like(nonpad_positions) + bsz, seq_len = ys_pad.size() + for li in range(bsz): + target_num = (((seq_lens[li] - same_num[li].sum()).float()) + * self.glat_context_p).long() + if target_num > 0: + input_mask[li].scatter_( + dim=0, + index=torch.randperm(seq_lens[li])[:target_num].cuda(), + value=0) + input_mask = input_mask.eq(1) + input_mask = input_mask.masked_fill(~nonpad_positions, False) + input_mask_expand_dim = input_mask.unsqueeze(2).to( + pre_acoustic_embeds.device) + + sematic_embeds = pre_acoustic_embeds.masked_fill( + ~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill( + input_mask_expand_dim, 0) + return sematic_embeds * tgt_mask, decoder_out * tgt_mask + + def _calc_att_loss_ar( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator( + ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + labels: torch.Tensor, + ): + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + encoder_out_lens: Encoder output sequences lengths. (B,) + labels: Label ID sequences. (B, L) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + decoder_in, target, t_len, u_len = get_transducer_task_io( + labels, + encoder_out_lens, + ignore_id=self.ignore_id, + blank_id=self.blank_id, + ) + + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in) + + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)) + + loss_transducer = self.criterion_transducer( + joint_out, + target, + t_len, + u_len, + reduction='sum', + ) + + cer_transducer, wer_transducer = None, None + if not self.training and self.error_calculator_trans is not None: + cer_transducer, wer_transducer = self.error_calculator_trans( + encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer + + +class ParaformerBertEmbed(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + joint_network: Optional[torch.nn.Module], + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = '', + sym_blank: str = '', + extract_feats_in_collect_stats: bool = True, + predictor: cif_predictor = None, + predictor_weight: float = 0.0, + glat_context_p: float = 0.2, + embed_dims: int = 768, + embeds_loss_weight: float = 0.0, + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.blank_id = 0 + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.interctc_weight = interctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder + self.encoder = encoder + + if not hasattr(self.encoder, 'interctc_use_conditioning'): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size()) + + self.use_transducer_decoder = joint_network is not None + + self.error_calculator = None + + if self.use_transducer_decoder: + # from warprnnt_pytorch import RNNTLoss + from warp_rnnt import rnnt_loss as RNNTLoss + + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = RNNTLoss + + if report_cer or report_wer: + self.error_calculator_trans = ErrorCalculatorTransducer( + decoder, + joint_network, + token_list, + sym_space, + sym_blank, + report_cer=report_cer, + report_wer=report_wer, + ) + else: + self.error_calculator_trans = None + + if self.ctc_weight != 0: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, + report_wer) + else: + # we set self.decoder = None in the CTC mode since + # self.decoder parameters were never used and PyTorch complained + # and threw an Exception in the multi-GPU experiment. + # thanks Jeff Farris for pointing out the issue. + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + self.predictor = predictor + self.predictor_weight = predictor_weight + self.glat_context_p = glat_context_p + self.criterion_pre = torch.nn.L1Loss() + self.step_cur = 0 + self.pro_nn = torch.nn.Linear(encoder.output_size(), embed_dims) + self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + self.embeds_loss_weight = embeds_loss_weight + self.length_normalized_loss = length_normalized_loss + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + embed: torch.Tensor = None, + embed_lengths: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), \ + (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + self.step_cur += 1 + # for data-parallel + text = text[:, :text_lengths.max()] + speech = speech[:, :speech_lengths.max(), :] + if embed is not None: + embed = embed[:, :embed_lengths.max(), :] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + loss_transducer, cer_transducer, wer_transducer = None, None, None + loss_pre = None + cos_loss = None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, + encoder_out_lens, text, + text_lengths) + + # Collect CTC branch stats + stats['loss_ctc'] = loss_ctc.detach( + ) if loss_ctc is not None else None + stats['cer_ctc'] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss(intermediate_out, + encoder_out_lens, text, + text_lengths) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats['loss_interctc_layer{}'.format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None) + stats['cer_interctc_layer{}'.format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = (1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + if self.use_transducer_decoder: + # 2a. Transducer decoder branch + ( + loss_transducer, + cer_transducer, + wer_transducer, + ) = self._calc_transducer_loss( + encoder_out, + encoder_out_lens, + text, + ) + + if loss_ctc is not None: + loss = loss_transducer + (self.ctc_weight * loss_ctc) + else: + loss = loss_transducer + + # Collect Transducer branch stats + stats['loss_transducer'] = ( + loss_transducer.detach() + if loss_transducer is not None else None) + stats['cer_transducer'] = cer_transducer + stats['wer_transducer'] = wer_transducer + + else: + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + + if embed is None or self.embeds_loss_weight <= 0.0: + loss_ret = self._calc_att_loss(encoder_out, + encoder_out_lens, text, + text_lengths) + loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[ + 0], loss_ret[1], loss_ret[2], loss_ret[3], loss_ret[4] + else: + loss_ret = self._calc_att_loss_embed( + encoder_out, encoder_out_lens, text, text_lengths, + embed, embed_lengths) + loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[ + 0], loss_ret[1], loss_ret[2], loss_ret[3], loss_ret[4] + embeds_outputs = None + if len(loss_ret) > 5: + embeds_outputs = loss_ret[5] + if embeds_outputs is not None: + cos_loss = self._calc_embed_loss( + text, text_lengths, embed, embed_lengths, + embeds_outputs) + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + elif self.embeds_loss_weight > 0.0: + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight + ) * loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight + else: + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight + ) * loss_att + loss_pre * self.predictor_weight + + # Collect Attn branch stats + stats['loss_att'] = loss_att.detach( + ) if loss_att is not None else None + stats['acc'] = acc_att + stats['cer'] = cer_att + stats['wer'] = wer_att + stats['loss_pre'] = loss_pre.detach().cpu( + ) if loss_pre is not None else None + stats['cos_loss'] = cos_loss.detach().cpu( + ) if cos_loss is not None else None + + # Collect total loss stats + # TODO(wjm): needed to be checked + # TODO(wjm): same problem: https://github.com/espnet/espnet/issues/4136 + # FIXME(wjm): for logger error when accum_grad > 1 + # stats["loss"] = loss.detach() + stats['loss'] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), + loss.device) + return loss, stats, weight + + def _calc_embed_loss( + self, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + embed: torch.Tensor = None, + embed_lengths: torch.Tensor = None, + embeds_outputs: torch.Tensor = None, + ): + embeds_outputs = self.pro_nn(embeds_outputs) + tgt_mask = (~make_pad_mask(ys_pad_lens, + maxlen=ys_pad_lens.max())[:, :, None]).to( + ys_pad.device) + embeds_outputs *= tgt_mask # b x l x d + embed *= tgt_mask # b x l x d + cos_loss = 1.0 - self.cos(embeds_outputs, embed) + cos_loss *= tgt_mask.squeeze(2) + if self.length_normalized_loss: + token_num_total = torch.sum(tgt_mask) + else: + token_num_total = tgt_mask.size()[0] + cos_loss_total = torch.sum(cos_loss) + cos_loss = cos_loss_total / token_num_total + return cos_loss + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + 'Generating dummy stats for feats and feats_lengths, ' + 'because encoder_conf.extract_feats_in_collect_stats is ' + f'{self.extract_feats_in_collect_stats}') + feats, feats_lengths = speech, speech_lengths + return {'feats': feats, 'feats_lengths': feats_lengths} + + def encode( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths, ctc=self.ctc) + else: + encoder_out, encoder_out_lens, _ = self.encoder( + feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def calc_predictor(self, encoder_out, encoder_out_lens): + + encoder_out_mask = (~make_pad_mask( + encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + # logging.info( + # "encoder_out_mask size: {}".format(encoder_out_mask.size())) + pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor( + encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id) + return pre_acoustic_embeds, pre_token_length + + def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, + sematic_embeds, ys_pad_lens): + + decoder_outs = self.decoder(encoder_out, encoder_out_lens, + sematic_embeds, ys_pad_lens) + decoder_out = decoder_outs[0] + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out, ys_pad_lens + + def _extract_feats( + self, speech: torch.Tensor, + speech_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, :speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from transformer-decoder + + Normally, this function is called in batchify_nll. + + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) # [batch, seqlen, dim] + batch_size = decoder_out.size(0) + decoder_num_class = decoder_out.size(2) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + decoder_out.view(-1, decoder_num_class), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction='none', + ) + nll = nll.view(batch_size, -1) + nll = nll.sum(dim=1) + assert nll.size(0) == batch_size + return nll + + def batchify_nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + batch_size: int = 100, + ): + """Compute negative log likelihood(nll) from transformer-decoder + + To avoid OOM, this fuction seperate the input into batches. + Then call nll for each batch and combine and return results. + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + batch_size: int, samples each batch contain when computing nll, + you may change this to avoid OOM or increase + GPU memory usage + """ + total_num = encoder_out.size(0) + if total_num <= batch_size: + nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + else: + nll = [] + start_idx = 0 + while True: + end_idx = min(start_idx + batch_size, total_num) + batch_encoder_out = encoder_out[start_idx:end_idx, :, :] + batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] + batch_ys_pad = ys_pad[start_idx:end_idx, :] + batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] + batch_nll = self.nll( + batch_encoder_out, + batch_encoder_out_lens, + batch_ys_pad, + batch_ys_pad_lens, + ) + nll.append(batch_nll) + start_idx = end_idx + if start_idx == total_num: + break + nll = torch.cat(nll) + assert nll.size(0) == total_num + return nll + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + encoder_out_mask = (~make_pad_mask( + encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor( + encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) + + # 0. sampler + decoder_out_1st = None + if self.glat_context_p > 0.0: + if self.step_cur < 2: + logging.info( + 'enable sampler in paraformer, glat_context_p: {}'.format( + self.glat_context_p)) + sematic_embeds, decoder_out_1st = self.sampler( + encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds) + else: + if self.step_cur < 2: + logging.info( + 'disable sampler in paraformer, glat_context_p: {}'.format( + self.glat_context_p)) + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder(encoder_out, encoder_out_lens, + sematic_embeds, ys_pad_lens) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre( + ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre + + def _calc_att_loss_embed( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + embed: torch.Tensor = None, + embed_lengths: torch.Tensor = None, + ): + encoder_out_mask = (~make_pad_mask( + encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor( + encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) + + # 0. sampler + decoder_out_1st = None + if self.glat_context_p > 0.0: + if self.step_cur < 2: + logging.info( + 'enable sampler in paraformer, glat_context_p: {}'.format( + self.glat_context_p)) + sematic_embeds, decoder_out_1st = self.sampler( + encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds) + else: + if self.step_cur < 2: + logging.info( + 'disable sampler in paraformer, glat_context_p: {}'.format( + self.glat_context_p)) + sematic_embeds = pre_acoustic_embeds + + # 1. Forward decoder + decoder_outs = self.decoder(encoder_out, encoder_out_lens, + sematic_embeds, ys_pad_lens) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + if len(decoder_outs) > 2: + embeds_outputs = decoder_outs[2] + if decoder_out_1st is None: + decoder_out_1st = decoder_out + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_pad) + acc_att = th_accuracy( + decoder_out_1st.view(-1, self.vocab_size), + ys_pad, + ignore_label=self.ignore_id, + ) + loss_pre = self.criterion_pre( + ys_pad_lens.type_as(pre_token_length), pre_token_length) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out_1st.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre, embeds_outputs + + def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, + pre_acoustic_embeds): + + tgt_mask = (~make_pad_mask(ys_pad_lens, + maxlen=ys_pad_lens.max())[:, :, None]).to( + ys_pad.device) + ys_pad *= tgt_mask[:, :, 0] + ys_pad_embed = self.decoder.embed(ys_pad) + with torch.no_grad(): + decoder_outs = self.decoder(encoder_out, encoder_out_lens, + pre_acoustic_embeds, ys_pad_lens) + decoder_out, _ = decoder_outs[0], decoder_outs[1] + pred_tokens = decoder_out.argmax(-1) + nonpad_positions = ys_pad.ne(self.ignore_id) + seq_lens = (nonpad_positions).sum(1) + same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) + input_mask = torch.ones_like(nonpad_positions) + bsz, seq_len = ys_pad.size() + for li in range(bsz): + target_num = (((seq_lens[li] - same_num[li].sum()).float()) + * self.glat_context_p).long() + if target_num > 0: + input_mask[li].scatter_( + dim=0, + index=torch.randperm(seq_lens[li])[:target_num].cuda(), + value=0) + input_mask = input_mask.eq(1) + input_mask = input_mask.masked_fill(~nonpad_positions, False) + input_mask_expand_dim = input_mask.unsqueeze(2).to( + pre_acoustic_embeds.device) + + sematic_embeds = pre_acoustic_embeds.masked_fill( + ~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill( + input_mask_expand_dim, 0) + return sematic_embeds * tgt_mask, decoder_out * tgt_mask + + def _calc_att_loss_ar( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), + ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator( + ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + labels: torch.Tensor, + ): + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + encoder_out_lens: Encoder output sequences lengths. (B,) + labels: Label ID sequences. (B, L) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + decoder_in, target, t_len, u_len = get_transducer_task_io( + labels, + encoder_out_lens, + ignore_id=self.ignore_id, + blank_id=self.blank_id, + ) + + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in) + + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)) + + loss_transducer = self.criterion_transducer( + joint_out, + target, + t_len, + u_len, + reduction='sum', + ) + + cer_transducer, wer_transducer = None, None + if not self.training and self.error_calculator_trans is not None: + cer_transducer, wer_transducer = self.error_calculator_trans( + encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/streaming_utilis/chunk_utilis.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/streaming_utilis/chunk_utilis.py new file mode 100644 index 00000000..e9f1b785 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/asr/streaming_utilis/chunk_utilis.py @@ -0,0 +1,321 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +import logging +import math + +import numpy as np +import torch +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask + +from ...nets.pytorch_backend.cif_utils.cif import \ + cif_predictor as cif_predictor + +np.set_printoptions(threshold=np.inf) +torch.set_printoptions(profile='full', precision=100000, linewidth=None) + + +def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device='cpu'): + if maxlen is None: + maxlen = lengths.max() + row_vector = torch.arange(0, maxlen, 1) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + + return mask.type(dtype).to(device) + + +class overlap_chunk(): + + def __init__( + self, + chunk_size: tuple = (16, ), + stride: tuple = (10, ), + pad_left: tuple = (0, ), + encoder_att_look_back_factor: tuple = (1, ), + shfit_fsmn: int = 0, + ): + self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor \ + = chunk_size, stride, pad_left, encoder_att_look_back_factor + self.shfit_fsmn = shfit_fsmn + self.x_add_mask = None + self.x_rm_mask = None + self.x_len = None + self.mask_shfit_chunk = None + self.mask_chunk_predictor = None + self.mask_att_chunk_encoder = None + self.mask_shift_att_chunk_decoder = None + self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur \ + = None, None, None, None + + def get_chunk_size(self, ind: int = 0): + # with torch.no_grad: + chunk_size, stride, pad_left, encoder_att_look_back_factor = self.chunk_size[ + ind], self.stride[ind], self.pad_left[ + ind], self.encoder_att_look_back_factor[ind] + self.chunk_size_cur, self.stride_cur, self.pad_left_cur, + self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \ + = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn + return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur + + def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1): + + with torch.no_grad(): + x_len = x_len.cpu().numpy() + x_len_max = x_len.max() + + chunk_size, stride, pad_left, encoder_att_look_back_factor = self.get_chunk_size( + ind) + shfit_fsmn = self.shfit_fsmn + chunk_size_pad_shift = chunk_size + shfit_fsmn + + chunk_num_batch = np.ceil(x_len / stride).astype(np.int32) + x_len_chunk = ( + chunk_num_batch - 1 + ) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - ( + chunk_num_batch - 1) * stride + x_len_chunk = x_len_chunk.astype(x_len.dtype) + x_len_chunk_max = x_len_chunk.max() + + chunk_num = int(math.ceil(x_len_max / stride)) + dtype = np.int32 + max_len_for_x_mask_tmp = max(chunk_size, x_len_max) + x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) + x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) + mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) + mask_chunk_predictor = np.zeros([0, num_units_predictor], + dtype=dtype) + mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) + mask_att_chunk_encoder = np.zeros( + [0, chunk_num * chunk_size_pad_shift], dtype=dtype) + for chunk_ids in range(chunk_num): + # x_mask add + fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), + dtype=dtype) + x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) + x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), + dtype=dtype) + x_mask_pad_right = np.zeros( + (chunk_size, max_len_for_x_mask_tmp), dtype=dtype) + x_cur_pad = np.concatenate( + [x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1) + x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp] + x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], + axis=0) + x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], + axis=0) + + # x_mask rm + fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), + dtype=dtype) + x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) + x_mask_right = np.zeros((stride, chunk_size - stride), + dtype=dtype) + x_mask_cur = np.concatenate([x_mask_cur, x_mask_right], axis=1) + x_mask_cur_pad_top = np.zeros((chunk_ids * stride, chunk_size), + dtype=dtype) + x_mask_cur_pad_bottom = np.zeros( + (max_len_for_x_mask_tmp, chunk_size), dtype=dtype) + x_rm_mask_cur = np.concatenate( + [x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], + axis=0) + x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, : + chunk_size] + x_rm_mask_cur_fsmn = np.concatenate( + [fsmn_padding, x_rm_mask_cur], axis=1) + x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], + axis=1) + + # fsmn_padding_mask + pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) + ones_1 = np.ones([chunk_size, num_units], dtype=dtype) + mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], + axis=0) + mask_shfit_chunk = np.concatenate( + [mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) + + # predictor mask + zeros_1 = np.zeros( + [shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) + ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) + zeros_3 = np.zeros( + [chunk_size - stride - pad_left, num_units_predictor], + dtype=dtype) + ones_zeros = np.concatenate([ones_2, zeros_3], axis=0) + mask_chunk_predictor_cur = np.concatenate( + [zeros_1, ones_zeros], axis=0) + mask_chunk_predictor = np.concatenate( + [mask_chunk_predictor, mask_chunk_predictor_cur], axis=0) + + # encoder att mask + zeros_1_top = np.zeros( + [shfit_fsmn, chunk_num * chunk_size_pad_shift], + dtype=dtype) + + zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) + zeros_2 = np.zeros( + [chunk_size, zeros_2_num * chunk_size_pad_shift], + dtype=dtype) + + encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) + zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) + ones_2_mid = np.ones([stride, stride], dtype=dtype) + zeros_2_bottom = np.zeros([chunk_size - stride, stride], + dtype=dtype) + zeros_2_right = np.zeros([chunk_size, chunk_size - stride], + dtype=dtype) + ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0) + ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], + axis=1) + ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) + + zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) + ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) + ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) + + zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0) + zeros_remain = np.zeros( + [chunk_size, zeros_remain_num * chunk_size_pad_shift], + dtype=dtype) + + ones2_bottom = np.concatenate( + [zeros_2, ones_2, ones_3, zeros_remain], axis=1) + mask_att_chunk_encoder_cur = np.concatenate( + [zeros_1_top, ones2_bottom], axis=0) + mask_att_chunk_encoder = np.concatenate( + [mask_att_chunk_encoder, mask_att_chunk_encoder_cur], + axis=0) + + # decoder fsmn_shift_att_mask + zeros_1 = np.zeros([shfit_fsmn, 1]) + ones_1 = np.ones([chunk_size, 1]) + mask_shift_att_chunk_decoder_cur = np.concatenate( + [zeros_1, ones_1], axis=0) + mask_shift_att_chunk_decoder = np.concatenate( + [ + mask_shift_att_chunk_decoder, + mask_shift_att_chunk_decoder_cur + ], + vaxis=0) # noqa: * + + self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max] + self.x_len_chunk = x_len_chunk + self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] + self.x_len = x_len + self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] + self.mask_chunk_predictor = mask_chunk_predictor[: + x_len_chunk_max, :] + self.mask_att_chunk_encoder = mask_att_chunk_encoder[: + x_len_chunk_max, : + x_len_chunk_max] + self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[: + x_len_chunk_max, :] + + return (self.x_add_mask, self.x_len_chunk, self.x_rm_mask, self.x_len, + self.mask_shfit_chunk, self.mask_chunk_predictor, + self.mask_att_chunk_encoder, self.mask_shift_att_chunk_decoder) + + def split_chunk(self, x, x_len, chunk_outs): + """ + :param x: (b, t, d) + :param x_length: (b) + :param ind: int + :return: + """ + x = x[:, :x_len.max(), :] + b, t, d = x.size() + x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(x.device) + x *= x_len_mask[:, :, None] + + x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype) + x_len_chunk = self.get_x_len_chunk( + chunk_outs, x_len.device, dtype=x_len.dtype) + x = torch.transpose(x, 1, 0) + x = torch.reshape(x, [t, -1]) + x_chunk = torch.mm(x_add_mask, x) + x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0) + + return x_chunk, x_len_chunk + + def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs): + x_chunk = x_chunk[:, :x_len_chunk.max(), :] + b, t, d = x_chunk.size() + x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to( + x_chunk.device) + x_chunk *= x_len_chunk_mask[:, :, None] + + x_rm_mask = self.get_x_rm_mask( + chunk_outs, x_chunk.device, dtype=x_chunk.dtype) + x_len = self.get_x_len( + chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype) + x_chunk = torch.transpose(x_chunk, 1, 0) + x_chunk = torch.reshape(x_chunk, [t, -1]) + x = torch.mm(x_rm_mask, x_chunk) + x = torch.reshape(x, [-1, b, d]).transpose(1, 0) + + return x, x_len + + def get_x_add_mask(self, chunk_outs, device, idx=0, dtype=torch.float32): + x = chunk_outs[idx] + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() + + def get_x_len_chunk(self, chunk_outs, device, idx=1, dtype=torch.float32): + x = chunk_outs[idx] + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() + + def get_x_rm_mask(self, chunk_outs, device, idx=2, dtype=torch.float32): + x = chunk_outs[idx] + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() + + def get_x_len(self, chunk_outs, device, idx=3, dtype=torch.float32): + x = chunk_outs[idx] + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() + + def get_mask_shfit_chunk(self, + chunk_outs, + device, + batch_size=1, + num_units=1, + idx=4, + dtype=torch.float32): + x = chunk_outs[idx] + x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() + + def get_mask_chunk_predictor(self, + chunk_outs, + device, + batch_size=1, + num_units=1, + idx=5, + dtype=torch.float32): + x = chunk_outs[idx] + x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() + + def get_mask_att_chunk_encoder(self, + chunk_outs, + device, + batch_size=1, + idx=6, + dtype=torch.float32): + x = chunk_outs[idx] + x = np.tile(x[None, :, :, ], [batch_size, 1, 1]) + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() + + def get_mask_shift_att_chunk_decoder(self, + chunk_outs, + device, + batch_size=1, + idx=7, + dtype=torch.float32): + x = chunk_outs[idx] + x = np.tile(x[None, None, :, 0], [batch_size, 1, 1]) + x = torch.from_numpy(x).type(dtype).to(device) + return x.detach() diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/cif_utils/cif.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/cif_utils/cif.py new file mode 100644 index 00000000..9381fb98 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/cif_utils/cif.py @@ -0,0 +1,250 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +import logging + +import numpy as np +import torch +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from torch import nn + + +class CIF_Model(nn.Module): + + def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): + super(CIF_Model, self).__init__() + + self.pad = nn.ConstantPad1d((l_order, r_order), 0) + self.cif_conv1d = nn.Conv1d( + idim, idim, l_order + r_order + 1, groups=idim) + self.cif_output = nn.Linear(idim, 1) + self.dropout = torch.nn.Dropout(p=dropout) + self.threshold = threshold + + def forward(self, hidden, target_label=None, mask=None, ignore_id=-1): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + memory = self.cif_conv1d(queries) + output = memory + context + output = self.dropout(output) + output = output.transpose(1, 2) + output = torch.relu(output) + output = self.cif_output(output) + alphas = torch.sigmoid(output) + if mask is not None: + alphas = alphas * mask.transpose(-1, -2).float() + alphas = alphas.squeeze(-1) + if target_label is not None: + target_length = (target_label != ignore_id).float().sum(-1) + else: + target_length = None + cif_length = alphas.sum(-1) + if target_label is not None: + alphas *= (target_length / cif_length)[:, None].repeat( + 1, alphas.size(1)) + cif_output, cif_peak = cif(hidden, alphas, self.threshold) + return cif_output, cif_length, target_length, cif_peak + + def gen_frame_alignments(self, + alphas: torch.Tensor = None, + memory_sequence_length: torch.Tensor = None, + is_training: bool = True, + dtype: torch.dtype = torch.float32): + batch_size, maximum_length = alphas.size() + int_type = torch.int32 + token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) + + max_token_num = torch.max(token_num).item() + + alphas_cumsum = torch.cumsum(alphas, dim=1) + alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) + alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], + [1, max_token_num, 1]) + + index = torch.ones([batch_size, max_token_num], dtype=int_type) + index = torch.cumsum(index, dim=1) + index = torch.tile(index[:, :, None], [1, 1, maximum_length]) + + index_div = torch.floor(torch.divide(alphas_cumsum, + index)).type(int_type) + index_div_bool_zeros = index_div.eq(0) + index_div_bool_zeros_count = torch.sum( + index_div_bool_zeros, dim=-1) + 1 + index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, + memory_sequence_length.max()) + token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( + token_num.device) + index_div_bool_zeros_count *= token_num_mask + + index_div_bool_zeros_count_tile = torch.tile( + index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) + ones = torch.ones_like(index_div_bool_zeros_count_tile) + zeros = torch.zeros_like(index_div_bool_zeros_count_tile) + ones = torch.cumsum(ones, dim=2) + cond = index_div_bool_zeros_count_tile == ones + index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) + + index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( + torch.bool) + index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( + int_type) + index_div_bool_zeros_count_tile_out = torch.sum( + index_div_bool_zeros_count_tile, dim=1) + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( + int_type) + predictor_mask = (~make_pad_mask( + memory_sequence_length, + maxlen=memory_sequence_length.max())).type(int_type).to( + memory_sequence_length.device) # noqa: * + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask + return index_div_bool_zeros_count_tile_out.detach( + ), index_div_bool_zeros_count.detach() + + +class cif_predictor(nn.Module): + + def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): + super(cif_predictor, self).__init__() + + self.pad = nn.ConstantPad1d((l_order, r_order), 0) + self.cif_conv1d = nn.Conv1d( + idim, idim, l_order + r_order + 1, groups=idim) + self.cif_output = nn.Linear(idim, 1) + self.dropout = torch.nn.Dropout(p=dropout) + self.threshold = threshold + + def forward(self, + hidden, + target_label=None, + mask=None, + ignore_id=-1, + mask_chunk_predictor=None, + target_label_length=None): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + memory = self.cif_conv1d(queries) + output = memory + context + output = self.dropout(output) + output = output.transpose(1, 2) + output = torch.relu(output) + output = self.cif_output(output) + alphas = torch.sigmoid(output) + if mask is not None: + alphas = alphas * mask.transpose(-1, -2).float() + if mask_chunk_predictor is not None: + alphas = alphas * mask_chunk_predictor + alphas = alphas.squeeze(-1) + if target_label_length is not None: + target_length = target_label_length + elif target_label is not None: + target_length = (target_label != ignore_id).float().sum(-1) + else: + target_length = None + token_num = alphas.sum(-1) + if target_length is not None: + alphas *= (target_length / token_num)[:, None].repeat( + 1, alphas.size(1)) + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + return acoustic_embeds, token_num, alphas, cif_peak + + def gen_frame_alignments(self, + alphas: torch.Tensor = None, + memory_sequence_length: torch.Tensor = None, + is_training: bool = True, + dtype: torch.dtype = torch.float32): + batch_size, maximum_length = alphas.size() + int_type = torch.int32 + token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) + + max_token_num = torch.max(token_num).item() + + alphas_cumsum = torch.cumsum(alphas, dim=1) + alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) + alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], + [1, max_token_num, 1]) + + index = torch.ones([batch_size, max_token_num], dtype=int_type) + index = torch.cumsum(index, dim=1) + index = torch.tile(index[:, :, None], [1, 1, maximum_length]) + + index_div = torch.floor(torch.divide(alphas_cumsum, + index)).type(int_type) + index_div_bool_zeros = index_div.eq(0) + index_div_bool_zeros_count = torch.sum( + index_div_bool_zeros, dim=-1) + 1 + index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, + memory_sequence_length.max()) + token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( + token_num.device) + index_div_bool_zeros_count *= token_num_mask + + index_div_bool_zeros_count_tile = torch.tile( + index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) + ones = torch.ones_like(index_div_bool_zeros_count_tile) + zeros = torch.zeros_like(index_div_bool_zeros_count_tile) + ones = torch.cumsum(ones, dim=2) + cond = index_div_bool_zeros_count_tile == ones + index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) + + index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( + torch.bool) + index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( + int_type) + index_div_bool_zeros_count_tile_out = torch.sum( + index_div_bool_zeros_count_tile, dim=1) + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( + int_type) + predictor_mask = (~make_pad_mask( + memory_sequence_length, + maxlen=memory_sequence_length.max())).type(int_type).to( + memory_sequence_length.device) # noqa: * + index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask + return index_div_bool_zeros_count_tile_out.detach( + ), index_div_bool_zeros_count.detach() + + +def cif(hidden, alphas, threshold): + batch_size, len_time, hidden_size = hidden.size() + + # loop varss + integrate = torch.zeros([batch_size], device=hidden.device) + frame = torch.zeros([batch_size, hidden_size], device=hidden.device) + # intermediate vars along time + list_fires = [] + list_frames = [] + + for t in range(len_time): + alpha = alphas[:, t] + distribution_completion = torch.ones([batch_size], + device=hidden.device) - integrate + + integrate += alpha + list_fires.append(integrate) + + fire_place = integrate >= threshold + integrate = torch.where( + fire_place, + integrate - torch.ones([batch_size], device=hidden.device), + integrate) + cur = torch.where(fire_place, distribution_completion, alpha) + remainds = alpha - cur + + frame += cur[:, None] * hidden[:, t, :] + list_frames.append(frame) + frame = torch.where(fire_place[:, None].repeat(1, hidden_size), + remainds[:, None] * hidden[:, t, :], frame) + + fires = torch.stack(list_fires, 1) + frames = torch.stack(list_frames, 1) + list_ls = [] + len_labels = torch.round(alphas.sum(-1)).int() + max_label_len = len_labels.max() + for b in range(batch_size): + fire = fires[b, :] + ls = torch.index_select(frames[b, :, :], 0, + torch.nonzero(fire >= threshold).squeeze()) + pad_l = torch.zeros([max_label_len - ls.size(0), hidden_size], + device=hidden.device) + list_ls.append(torch.cat([ls, pad_l], 0)) + return torch.stack(list_ls, 0), fires diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/attention.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/attention.py new file mode 100644 index 00000000..53766246 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/attention.py @@ -0,0 +1,680 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +"""Multi-Head Attention layer definition.""" + +import logging +import math + +import numpy +import torch +from torch import nn + +torch.set_printoptions(profile='full', precision=1) + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float( + numpy.finfo(torch.tensor( + 0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax( + scores, dim=-1).masked_fill(mask, + 0.0) # (batch, head, time1, time2) + else: + self.attn = torch.softmax( + scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class MultiHeadedAttentionSANM(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, + n_head, + n_feat, + dropout_rate, + kernel_size, + sanm_shfit=0): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttentionSANM, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + self.fsmn_block = nn.Conv1d( + n_feat, + n_feat, + kernel_size, + stride=1, + padding=0, + groups=n_feat, + bias=False) + # padding + left_padding = (kernel_size - 1) // 2 + if sanm_shfit > 0: + left_padding = left_padding + sanm_shfit + right_padding = kernel_size - 1 - left_padding + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + + def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): + ''' + :param x: (#batch, time1, size). + :param mask: Mask tensor (#batch, 1, time) + :return: + ''' + # b, t, d = inputs.size() + mask = mask[:, 0, :, None] + if mask_shfit_chunk is not None: + mask = mask * mask_shfit_chunk + inputs *= mask + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + x += inputs + x = self.dropout(x) + return x * mask + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, + value, + scores, + mask, + mask_att_chunk_encoder=None): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + if mask_att_chunk_encoder is not None: + mask = mask * mask_att_chunk_encoder + + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + + min_value = float( + numpy.finfo(torch.tensor( + 0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax( + scores, dim=-1).masked_fill(mask, + 0.0) # (batch, head, time1, time2) + else: + self.attn = torch.softmax( + scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, + query, + key, + value, + mask, + mask_shfit_chunk=None, + mask_att_chunk_encoder=None): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + fsmn_memory = self.forward_fsmn(value, mask, mask_shfit_chunk) + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + att_outs = self.forward_attention(v, scores, mask, + mask_att_chunk_encoder) + return att_outs + fsmn_memory + + +class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, time2). + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class LegacyRelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): + """Multi-Head Attention layer with relative position encoding (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, + n_head, + n_feat, + dropout_rate, + zero_triu=False, + kernel_size=15, + sanm_shfit=0): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, time2). + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + fsmn_memory = self.forward_fsmn(value, mask) + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + att_outs = self.forward_attention(v, scores, mask) + return att_outs + fsmn_memory + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as( + x)[:, :, :, :x.size(-1) // 2 + + 1] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, 2*time1-1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): + """Multi-Head Attention layer with relative position encoding (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, + n_head, + n_feat, + dropout_rate, + zero_triu=False, + kernel_size=15, + sanm_shfit=0): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as( + x)[:, :, :, :x.size(-1) // 2 + + 1] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, 2*time1-1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + fsmn_memory = self.forward_fsmn(value, mask) + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + att_outs = self.forward_attention(v, scores, mask) + return att_outs + fsmn_memory diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/encoder_layer.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/encoder_layer.py new file mode 100644 index 00000000..91466b05 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/nets/pytorch_backend/transformer/encoder_layer.py @@ -0,0 +1,239 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +"""Encoder self-attention layer definition.""" + +import torch +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from torch import nn + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + stochastic_depth_rate (float): Proability to skip this layer. + During training, the layer may skip residual computation and return input + as-is with given probability. + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + + def forward(self, x, mask, cache=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn(x_q, x, x, mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, mask + + +class EncoderLayerChunk(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + stochastic_depth_rate (float): Proability to skip this layer. + During training, the layer may skip residual computation and return input + as-is with given probability. + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayerChunk, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + + def forward(self, + x, + mask, + cache=None, + mask_shfit_chunk=None, + mask_att_chunk_encoder=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat( + (x, + self.self_attn( + x_q, + x, + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder)), + dim=-1) + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn( + x_q, + x, + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, mask, None, mask_shfit_chunk, mask_att_chunk_encoder diff --git a/modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/asr.py b/modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/asr.py new file mode 100644 index 00000000..7419abd4 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/asr.py @@ -0,0 +1,890 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +import argparse +import logging +import os +from pathlib import Path +from typing import Callable, Collection, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import yaml +from espnet2.asr.ctc import CTC +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet2.asr.decoder.mlm_decoder import MLMDecoder +from espnet2.asr.decoder.rnn_decoder import RNNDecoder +from espnet2.asr.decoder.transformer_decoder import \ + DynamicConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolution2DTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import \ + LightweightConvolutionTransformerDecoder # noqa: H301 +from espnet2.asr.decoder.transformer_decoder import ( + DynamicConvolutionTransformerDecoder, TransformerDecoder) +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.encoder.contextual_block_conformer_encoder import \ + ContextualBlockConformerEncoder # noqa: H301 +from espnet2.asr.encoder.contextual_block_transformer_encoder import \ + ContextualBlockTransformerEncoder # noqa: H301 +from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder, + FairseqHubertPretrainEncoder) +from espnet2.asr.encoder.longformer_encoder import LongformerEncoder +from espnet2.asr.encoder.rnn_encoder import RNNEncoder +from espnet2.asr.encoder.transformer_encoder import TransformerEncoder +from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder +from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder +from espnet2.asr.espnet_model import ESPnetASRModel +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.frontend.default import DefaultFrontend +from espnet2.asr.frontend.fused import FusedFrontends +from espnet2.asr.frontend.s3prl import S3prlFrontend +from espnet2.asr.frontend.windowing import SlidingWindow +from espnet2.asr.maskctc_model import MaskCTCModel +from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder +from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ + HuggingFaceTransformersPostEncoder # noqa: H301 +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder +from espnet2.asr.preencoder.linear import LinearProjection +from espnet2.asr.preencoder.sinc import LightweightSincConvs +from espnet2.asr.specaug.abs_specaug import AbsSpecAug +from espnet2.asr.specaug.specaug import SpecAug +from espnet2.asr.transducer.joint_network import JointNetwork +from espnet2.asr.transducer.transducer_decoder import TransducerDecoder +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.layers.global_mvn import GlobalMVN +from espnet2.layers.utterance_mvn import UtteranceMVN +from espnet2.tasks.abs_task import AbsTask +from espnet2.text.phoneme_tokenizer import g2p_choices +from espnet2.torch_utils.initialize import initialize +from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet2.train.class_choices import ClassChoices +from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.preprocessor import CommonPreprocessor +from espnet2.train.trainer import Trainer +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import (float_or_none, int_or_none, str2bool, + str_or_none) +from typeguard import check_argument_types, check_return_type + +from ..asr.decoder.transformer_decoder import (ParaformerDecoder, + ParaformerDecoderBertEmbed) +from ..asr.encoder.conformer_encoder import ConformerEncoder, SANMEncoder_v2 +from ..asr.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunk +from ..asr.espnet_model import AEDStreaming +from ..asr.espnet_model_paraformer import Paraformer, ParaformerBertEmbed +from ..nets.pytorch_backend.cif_utils.cif import cif_predictor + +# FIXME(wjm): suggested by fairseq, We need to setup root logger before importing any fairseq libraries. +logging.basicConfig( + level='INFO', + format=f"[{os.uname()[1].split('.')[0]}]" + f' %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', +) +# FIXME(wjm): create logger to set level, unset __name__ for different files to share the same logger +logger = logging.getLogger() + +frontend_choices = ClassChoices( + name='frontend', + classes=dict( + default=DefaultFrontend, + sliding_window=SlidingWindow, + s3prl=S3prlFrontend, + fused=FusedFrontends, + ), + type_check=AbsFrontend, + default='default', +) +specaug_choices = ClassChoices( + name='specaug', + classes=dict(specaug=SpecAug, ), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + 'normalize', + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default='utterance_mvn', + optional=True, +) +model_choices = ClassChoices( + 'model', + classes=dict( + espnet=ESPnetASRModel, + maskctc=MaskCTCModel, + paraformer=Paraformer, + paraformer_bert_embed=ParaformerBertEmbed, + aedstreaming=AEDStreaming, + ), + type_check=AbsESPnetModel, + default='espnet', +) +preencoder_choices = ClassChoices( + name='preencoder', + classes=dict( + sinc=LightweightSincConvs, + linear=LinearProjection, + ), + type_check=AbsPreEncoder, + default=None, + optional=True, +) +encoder_choices = ClassChoices( + 'encoder', + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + contextual_block_transformer=ContextualBlockTransformerEncoder, + contextual_block_conformer=ContextualBlockConformerEncoder, + vgg_rnn=VGGRNNEncoder, + rnn=RNNEncoder, + wav2vec2=FairSeqWav2Vec2Encoder, + hubert=FairseqHubertEncoder, + hubert_pretrain=FairseqHubertPretrainEncoder, + longformer=LongformerEncoder, + sanm=SANMEncoder, + sanm_v2=SANMEncoder_v2, + sanm_chunk=SANMEncoderChunk, + ), + type_check=AbsEncoder, + default='rnn', +) +postencoder_choices = ClassChoices( + name='postencoder', + classes=dict( + hugging_face_transformers=HuggingFaceTransformersPostEncoder, ), + type_check=AbsPostEncoder, + default=None, + optional=True, +) +decoder_choices = ClassChoices( + 'decoder', + classes=dict( + transformer=TransformerDecoder, + lightweight_conv=LightweightConvolutionTransformerDecoder, + lightweight_conv2d=LightweightConvolution2DTransformerDecoder, + dynamic_conv=DynamicConvolutionTransformerDecoder, + dynamic_conv2d=DynamicConvolution2DTransformerDecoder, + rnn=RNNDecoder, + transducer=TransducerDecoder, + mlm=MLMDecoder, + paraformer_decoder=ParaformerDecoder, + paraformer_decoder_bert_embed=ParaformerDecoderBertEmbed, + ), + type_check=AbsDecoder, + default='rnn', +) + +predictor_choices = ClassChoices( + name='predictor', + classes=dict( + cif_predictor=cif_predictor, + ctc_predictor=None, + ), + type_check=None, + default='cif_predictor', + optional=True, +) + + +class ASRTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --model and --model_conf + model_choices, + # --preencoder and --preencoder_conf + preencoder_choices, + # --encoder and --encoder_conf + encoder_choices, + # --postencoder and --postencoder_conf + postencoder_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description='Task related') + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + required = parser.get_default('required') + required += ['token_list'] + + group.add_argument( + '--token_list', + type=str_or_none, + default=None, + help='A text mapping int-id to token', + ) + group.add_argument( + '--init', + type=lambda x: str_or_none(x.lower()), + default=None, + help='The initialization method', + choices=[ + 'chainer', + 'xavier_uniform', + 'xavier_normal', + 'kaiming_uniform', + 'kaiming_normal', + None, + ], + ) + + group.add_argument( + '--input_size', + type=int_or_none, + default=None, + help='The number of input dimension of the feature', + ) + + group.add_argument( + '--ctc_conf', + action=NestedDictAction, + default=get_default_kwargs(CTC), + help='The keyword arguments for CTC class.', + ) + group.add_argument( + '--joint_net_conf', + action=NestedDictAction, + default=None, + help='The keyword arguments for joint network class.', + ) + + group = parser.add_argument_group(description='Preprocess related') + group.add_argument( + '--use_preprocessor', + type=str2bool, + default=True, + help='Apply preprocessing to data or not', + ) + group.add_argument( + '--token_type', + type=str, + default='bpe', + choices=['bpe', 'char', 'word', 'phn'], + help='The text will be tokenized ' + 'in the specified level token', + ) + group.add_argument( + '--bpemodel', + type=str_or_none, + default=None, + help='The model file of sentencepiece', + ) + parser.add_argument( + '--non_linguistic_symbols', + type=str_or_none, + help='non_linguistic_symbols file path', + ) + parser.add_argument( + '--cleaner', + type=str_or_none, + choices=[None, 'tacotron', 'jaconv', 'vietnamese'], + default=None, + help='Apply text cleaning', + ) + parser.add_argument( + '--g2p', + type=str_or_none, + choices=g2p_choices, + default=None, + help='Specify g2p method if --token_type=phn', + ) + parser.add_argument( + '--speech_volume_normalize', + type=float_or_none, + default=None, + help='Scale the maximum amplitude to the given value.', + ) + parser.add_argument( + '--rir_scp', + type=str_or_none, + default=None, + help='The file path of rir scp file.', + ) + parser.add_argument( + '--rir_apply_prob', + type=float, + default=1.0, + help='THe probability for applying RIR convolution.', + ) + parser.add_argument( + '--noise_scp', + type=str_or_none, + default=None, + help='The file path of noise scp file.', + ) + parser.add_argument( + '--noise_apply_prob', + type=float, + default=1.0, + help='The probability applying Noise adding.', + ) + parser.add_argument( + '--noise_db_range', + type=str, + default='13_15', + help='The range of noise decibel level.', + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ + List[str], Dict[str, torch.Tensor]], ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, + rir_apply_prob=args.rir_apply_prob if hasattr( + args, 'rir_apply_prob') else 1.0, + noise_scp=args.noise_scp + if hasattr(args, 'noise_scp') else None, + noise_apply_prob=args.noise_apply_prob if hasattr( + args, 'noise_apply_prob') else 1.0, + noise_db_range=args.noise_db_range if hasattr( + args, 'noise_db_range') else '13_15', + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, 'rir_scp') else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names(cls, + train: bool = True, + inference: bool = False) -> Tuple[str, ...]: + if not inference: + retval = ('speech', 'text') + else: + # Recognition mode + retval = ('speech', ) + return retval + + @classmethod + def optional_data_names(cls, + train: bool = True, + inference: bool = False) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding='utf-8') as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError('token_list must be str or list') + vocab_size = len(token_list) + logger.info(f'Vocabulary size: {vocab_size }') + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Pre-encoder input block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + if getattr(args, 'preencoder', None) is not None: + preencoder_class = preencoder_choices.get_class(args.preencoder) + preencoder = preencoder_class(**args.preencoder_conf) + input_size = preencoder.output_size() + else: + preencoder = None + + # 4. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 5. Post-encoder block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + encoder_output_size = encoder.output_size() + if getattr(args, 'postencoder', None) is not None: + postencoder_class = postencoder_choices.get_class(args.postencoder) + postencoder = postencoder_class( + input_size=encoder_output_size, **args.postencoder_conf) + encoder_output_size = postencoder.output_size() + else: + postencoder = None + + # 5. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + + if args.decoder == 'transducer': + decoder = decoder_class( + vocab_size, + embed_pad=0, + **args.decoder_conf, + ) + + joint_network = JointNetwork( + vocab_size, + encoder.output_size(), + decoder.dunits, + **args.joint_net_conf, + ) + else: + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **args.decoder_conf, + ) + + joint_network = None + + # 6. CTC + ctc = CTC( + odim=vocab_size, + encoder_output_size=encoder_output_size, + **args.ctc_conf) + + # 7. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class('espnet') + model = model_class( + vocab_size=vocab_size, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + encoder=encoder, + postencoder=postencoder, + decoder=decoder, + ctc=ctc, + joint_network=joint_network, + token_list=token_list, + **args.model_conf, + ) + + # FIXME(kamo): Should be done in model? + # 8. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model + + +class ASRTaskNAR(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --model and --model_conf + model_choices, + # --preencoder and --preencoder_conf + preencoder_choices, + # --encoder and --encoder_conf + encoder_choices, + # --postencoder and --postencoder_conf + postencoder_choices, + # --decoder and --decoder_conf + decoder_choices, + # --predictor and --predictor_conf + predictor_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description='Task related') + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + required = parser.get_default('required') + required += ['token_list'] + + group.add_argument( + '--token_list', + type=str_or_none, + default=None, + help='A text mapping int-id to token', + ) + group.add_argument( + '--init', + type=lambda x: str_or_none(x.lower()), + default=None, + help='The initialization method', + choices=[ + 'chainer', + 'xavier_uniform', + 'xavier_normal', + 'kaiming_uniform', + 'kaiming_normal', + None, + ], + ) + + group.add_argument( + '--input_size', + type=int_or_none, + default=None, + help='The number of input dimension of the feature', + ) + + group.add_argument( + '--ctc_conf', + action=NestedDictAction, + default=get_default_kwargs(CTC), + help='The keyword arguments for CTC class.', + ) + group.add_argument( + '--joint_net_conf', + action=NestedDictAction, + default=None, + help='The keyword arguments for joint network class.', + ) + + group = parser.add_argument_group(description='Preprocess related') + group.add_argument( + '--use_preprocessor', + type=str2bool, + default=True, + help='Apply preprocessing to data or not', + ) + group.add_argument( + '--token_type', + type=str, + default='bpe', + choices=['bpe', 'char', 'word', 'phn'], + help='The text will be tokenized ' + 'in the specified level token', + ) + group.add_argument( + '--bpemodel', + type=str_or_none, + default=None, + help='The model file of sentencepiece', + ) + parser.add_argument( + '--non_linguistic_symbols', + type=str_or_none, + help='non_linguistic_symbols file path', + ) + parser.add_argument( + '--cleaner', + type=str_or_none, + choices=[None, 'tacotron', 'jaconv', 'vietnamese'], + default=None, + help='Apply text cleaning', + ) + parser.add_argument( + '--g2p', + type=str_or_none, + choices=g2p_choices, + default=None, + help='Specify g2p method if --token_type=phn', + ) + parser.add_argument( + '--speech_volume_normalize', + type=float_or_none, + default=None, + help='Scale the maximum amplitude to the given value.', + ) + parser.add_argument( + '--rir_scp', + type=str_or_none, + default=None, + help='The file path of rir scp file.', + ) + parser.add_argument( + '--rir_apply_prob', + type=float, + default=1.0, + help='THe probability for applying RIR convolution.', + ) + parser.add_argument( + '--noise_scp', + type=str_or_none, + default=None, + help='The file path of noise scp file.', + ) + parser.add_argument( + '--noise_apply_prob', + type=float, + default=1.0, + help='The probability applying Noise adding.', + ) + parser.add_argument( + '--noise_db_range', + type=str, + default='13_15', + help='The range of noise decibel level.', + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ + List[str], Dict[str, torch.Tensor]], ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, + rir_apply_prob=args.rir_apply_prob if hasattr( + args, 'rir_apply_prob') else 1.0, + noise_scp=args.noise_scp + if hasattr(args, 'noise_scp') else None, + noise_apply_prob=args.noise_apply_prob if hasattr( + args, 'noise_apply_prob') else 1.0, + noise_db_range=args.noise_db_range if hasattr( + args, 'noise_db_range') else '13_15', + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, 'rir_scp') else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names(cls, + train: bool = True, + inference: bool = False) -> Tuple[str, ...]: + if not inference: + retval = ('speech', 'text') + else: + # Recognition mode + retval = ('speech', ) + return retval + + @classmethod + def optional_data_names(cls, + train: bool = True, + inference: bool = False) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding='utf-8') as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError('token_list must be str or list') + vocab_size = len(token_list) + # logger.info(f'Vocabulary size: {vocab_size }') + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Pre-encoder input block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + if getattr(args, 'preencoder', None) is not None: + preencoder_class = preencoder_choices.get_class(args.preencoder) + preencoder = preencoder_class(**args.preencoder_conf) + input_size = preencoder.output_size() + else: + preencoder = None + + # 4. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 5. Post-encoder block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + encoder_output_size = encoder.output_size() + if getattr(args, 'postencoder', None) is not None: + postencoder_class = postencoder_choices.get_class(args.postencoder) + postencoder = postencoder_class( + input_size=encoder_output_size, **args.postencoder_conf) + encoder_output_size = postencoder.output_size() + else: + postencoder = None + + # 5. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + + if args.decoder == 'transducer': + decoder = decoder_class( + vocab_size, + embed_pad=0, + **args.decoder_conf, + ) + + joint_network = JointNetwork( + vocab_size, + encoder.output_size(), + decoder.dunits, + **args.joint_net_conf, + ) + else: + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **args.decoder_conf, + ) + + joint_network = None + + # 6. CTC + ctc = CTC( + odim=vocab_size, + encoder_output_size=encoder_output_size, + **args.ctc_conf) + + predictor_class = predictor_choices.get_class(args.predictor) + predictor = predictor_class(**args.predictor_conf) + + # 7. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class('espnet') + model = model_class( + vocab_size=vocab_size, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + encoder=encoder, + postencoder=postencoder, + decoder=decoder, + ctc=ctc, + joint_network=joint_network, + token_list=token_list, + predictor=predictor, + **args.model_conf, + ) + + # FIXME(kamo): Should be done in model? + # 8. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/modelscope/pipelines/audio/asr/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr/asr_inference_pipeline.py new file mode 100644 index 00000000..2bb1c0a0 --- /dev/null +++ b/modelscope/pipelines/audio/asr/asr_inference_pipeline.py @@ -0,0 +1,217 @@ +import io +import os +import shutil +import threading +from typing import Any, Dict, List, Union + +import yaml + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import WavToScp +from modelscope.utils.constant import Tasks +from .asr_engine import asr_env_checking, asr_inference_paraformer_espnet +from .asr_engine.common import asr_utils + +__all__ = ['AutomaticSpeechRecognitionPipeline'] + + +@PIPELINES.register_module( + Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) +class AutomaticSpeechRecognitionPipeline(Pipeline): + """ASR Pipeline + """ + + def __init__(self, + model: Union[List[Model], List[str]] = None, + preprocessor: WavToScp = None, + **kwargs): + """use `model` and `preprocessor` to create an asr pipeline for prediction + """ + + assert model is not None, 'asr model should be provided' + + model_list: List = [] + if isinstance(model[0], Model): + model_list = model + else: + model_list.append(Model.from_pretrained(model[0])) + if len(model) == 2 and model[1] is not None: + model_list.append(Model.from_pretrained(model[1])) + + super().__init__(model=model_list, preprocessor=preprocessor, **kwargs) + + self._preprocessor = preprocessor + self._am_model = model_list[0] + if len(model_list) == 2 and model_list[1] is not None: + self._lm_model = model_list[1] + + def __call__(self, + wav_path: str, + recog_type: str = None, + audio_format: str = None, + workspace: str = None) -> Dict[str, Any]: + assert len(wav_path) > 0, 'wav_path should be provided' + + self._recog_type = recog_type + self._audio_format = audio_format + self._workspace = workspace + self._wav_path = wav_path + + if recog_type is None or audio_format is None or workspace is None: + self._recog_type, self._audio_format, self._workspace, self._wav_path = asr_utils.type_checking( + wav_path, recog_type, audio_format, workspace) + + if self._preprocessor is None: + self._preprocessor = WavToScp(workspace=self._workspace) + + output = self._preprocessor.forward(self._am_model.forward(), + self._recog_type, + self._audio_format, self._wav_path) + output = self.forward(output) + rst = self.postprocess(output) + return rst + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Decoding + """ + + j: int = 0 + process = [] + + while j < inputs['thread_count']: + data_cmd: Sequence[Tuple[str, str, str]] + if inputs['audio_format'] == 'wav': + data_cmd = [(os.path.join(inputs['workspace'], + 'data.' + str(j) + '.scp'), 'speech', + 'sound')] + elif inputs['audio_format'] == 'kaldi_ark': + data_cmd = [(os.path.join(inputs['workspace'], + 'data.' + str(j) + '.scp'), 'speech', + 'kaldi_ark')] + + output_dir: str = os.path.join(inputs['output'], + 'output.' + str(j)) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + config_file = open(inputs['asr_model_config']) + root = yaml.full_load(config_file) + config_file.close() + frontend_conf = None + if 'frontend_conf' in root: + frontend_conf = root['frontend_conf'] + + cmd = { + 'model_type': inputs['model_type'], + 'beam_size': root['beam_size'], + 'penalty': root['penalty'], + 'maxlenratio': root['maxlenratio'], + 'minlenratio': root['minlenratio'], + 'ctc_weight': root['ctc_weight'], + 'lm_weight': root['lm_weight'], + 'output_dir': output_dir, + 'ngpu': 0, + 'log_level': 'ERROR', + 'data_path_and_name_and_type': data_cmd, + 'asr_train_config': inputs['am_model_config'], + 'asr_model_file': inputs['am_model_path'], + 'batch_size': inputs['model_config']['batch_size'], + 'frontend_conf': frontend_conf + } + + thread = AsrInferenceThread(j, cmd) + thread.start() + j += 1 + process.append(thread) + + for p in process: + p.join() + + return inputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the asr results + """ + + rst = {'rec_result': 'None'} + + # single wav task + if inputs['recog_type'] == 'wav' and inputs['audio_format'] == 'wav': + text_file: str = os.path.join(inputs['output'], 'output.0', + '1best_recog', 'text') + + if os.path.exists(text_file): + f = open(text_file, 'r') + result_str: str = f.readline() + f.close() + if len(result_str) > 0: + result_list = result_str.split() + if len(result_list) >= 2: + rst['rec_result'] = result_list[1] + + # run with datasets, and audio format is waveform or kaldi_ark + elif inputs['recog_type'] != 'wav': + inputs['reference_text'] = self._ref_text_tidy(inputs) + inputs['datasets_result'] = asr_utils.compute_wer( + inputs['hypothesis_text'], inputs['reference_text']) + + else: + raise ValueError('recog_type and audio_format are mismatching') + + if 'datasets_result' in inputs: + rst['datasets_result'] = inputs['datasets_result'] + + # remove workspace dir (.tmp) + if os.path.exists(self._workspace): + shutil.rmtree(self._workspace) + + return rst + + def _ref_text_tidy(self, inputs: Dict[str, Any]) -> str: + ref_text: str = os.path.join(inputs['output'], 'text.ref') + k: int = 0 + + while k < inputs['thread_count']: + output_text = os.path.join(inputs['output'], 'output.' + str(k), + '1best_recog', 'text') + if os.path.exists(output_text): + with open(output_text, 'r', encoding='utf-8') as i: + lines = i.readlines() + + with open(ref_text, 'a', encoding='utf-8') as o: + for line in lines: + o.write(line) + + k += 1 + + return ref_text + + +class AsrInferenceThread(threading.Thread): + + def __init__(self, threadID, cmd): + threading.Thread.__init__(self) + self._threadID = threadID + self._cmd = cmd + + def run(self): + if self._cmd['model_type'] == 'pytorch': + asr_inference_paraformer_espnet.asr_inference( + batch_size=self._cmd['batch_size'], + output_dir=self._cmd['output_dir'], + maxlenratio=self._cmd['maxlenratio'], + minlenratio=self._cmd['minlenratio'], + beam_size=self._cmd['beam_size'], + ngpu=self._cmd['ngpu'], + ctc_weight=self._cmd['ctc_weight'], + lm_weight=self._cmd['lm_weight'], + penalty=self._cmd['penalty'], + log_level=self._cmd['log_level'], + data_path_and_name_and_type=self. + _cmd['data_path_and_name_and_type'], + asr_train_config=self._cmd['asr_train_config'], + asr_model_file=self._cmd['asr_model_file'], + frontend_conf=self._cmd['frontend_conf']) diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 95d1f3b2..a2e3ee42 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from modelscope.utils.error import AUDIO_IMPORT_ERROR, TENSORFLOW_IMPORT_ERROR +from .asr import WavToScp from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose diff --git a/modelscope/preprocessors/asr.py b/modelscope/preprocessors/asr.py new file mode 100644 index 00000000..f13cc2e7 --- /dev/null +++ b/modelscope/preprocessors/asr.py @@ -0,0 +1,254 @@ +import io +import os +import shutil +from pathlib import Path +from typing import Any, Dict, List + +import yaml + +from modelscope.metainfo import Preprocessors +from modelscope.models.base import Model +from modelscope.utils.constant import Fields +from .base import Preprocessor +from .builder import PREPROCESSORS + +__all__ = ['WavToScp'] + + +@PREPROCESSORS.register_module( + Fields.audio, module_name=Preprocessors.wav_to_scp) +class WavToScp(Preprocessor): + """generate audio scp from wave or ark + + Args: + workspace (str): + """ + + def __init__(self, workspace: str = None): + # the workspace path + if workspace is None or len(workspace) == 0: + self._workspace = os.path.join(os.getcwd(), '.tmp') + else: + self._workspace = workspace + + if not os.path.exists(self._workspace): + os.mkdir(self._workspace) + + def __call__(self, + model: List[Model] = None, + recog_type: str = None, + audio_format: str = None, + wav_path: str = None) -> Dict[str, Any]: + assert len(model) > 0, 'preprocess model is invalid' + assert len(recog_type) > 0, 'preprocess recog_type is empty' + assert len(audio_format) > 0, 'preprocess audio_format is empty' + assert len(wav_path) > 0, 'preprocess wav_path is empty' + + self._am_model = model[0] + if len(model) == 2 and model[1] is not None: + self._lm_model = model[1] + out = self.forward(self._am_model.forward(), recog_type, audio_format, + wav_path) + return out + + def forward(self, model: Dict[str, Any], recog_type: str, + audio_format: str, wav_path: str) -> Dict[str, Any]: + assert len(recog_type) > 0, 'preprocess recog_type is empty' + assert len(audio_format) > 0, 'preprocess audio_format is empty' + assert len(wav_path) > 0, 'preprocess wav_path is empty' + assert os.path.exists(wav_path), 'preprocess wav_path does not exist' + assert len( + model['am_model']) > 0, 'preprocess model[am_model] is empty' + assert len(model['am_model_path'] + ) > 0, 'preprocess model[am_model_path] is empty' + assert os.path.exists( + model['am_model_path']), 'preprocess am_model_path does not exist' + assert len(model['model_workspace'] + ) > 0, 'preprocess model[model_workspace] is empty' + assert os.path.exists(model['model_workspace'] + ), 'preprocess model_workspace does not exist' + assert len(model['model_config'] + ) > 0, 'preprocess model[model_config] is empty' + + # the am model name + am_model: str = model['am_model'] + # the am model file path + am_model_path: str = model['am_model_path'] + # the recognition model dir path + model_workspace: str = model['model_workspace'] + # the recognition model config dict + global_model_config_dict: str = model['model_config'] + + rst = { + 'workspace': os.path.join(self._workspace, recog_type), + 'am_model': am_model, + 'am_model_path': am_model_path, + 'model_workspace': model_workspace, + # the asr type setting, eg: test dev train wav + 'recog_type': recog_type, + # the asr audio format setting, eg: wav, kaldi_ark + 'audio_format': audio_format, + # the test wav file path or the dataset path + 'wav_path': wav_path, + 'model_config': global_model_config_dict + } + + out = self._config_checking(rst) + out = self._env_setting(out) + if audio_format == 'wav': + out = self._scp_generation_from_wav(out) + elif audio_format == 'kaldi_ark': + out = self._scp_generation_from_ark(out) + + return out + + def _config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """config checking + """ + + assert inputs['model_config'].__contains__( + 'type'), 'model type does not exist' + assert inputs['model_config'].__contains__( + 'batch_size'), 'batch_size does not exist' + assert inputs['model_config'].__contains__( + 'am_model_config'), 'am_model_config does not exist' + assert inputs['model_config'].__contains__( + 'asr_model_config'), 'asr_model_config does not exist' + assert inputs['model_config'].__contains__( + 'asr_model_wav_config'), 'asr_model_wav_config does not exist' + + am_model_config: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['am_model_config']) + assert os.path.exists( + am_model_config), 'am_model_config does not exist' + inputs['am_model_config'] = am_model_config + + asr_model_config: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['asr_model_config']) + assert os.path.exists( + asr_model_config), 'asr_model_config does not exist' + + asr_model_wav_config: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['asr_model_wav_config']) + assert os.path.exists( + asr_model_wav_config), 'asr_model_wav_config does not exist' + + inputs['model_type'] = inputs['model_config']['type'] + + if inputs['audio_format'] == 'wav': + inputs['asr_model_config'] = asr_model_wav_config + else: + inputs['asr_model_config'] = asr_model_config + + return inputs + + def _env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + if not os.path.exists(inputs['workspace']): + os.mkdir(inputs['workspace']) + + inputs['output'] = os.path.join(inputs['workspace'], 'logdir') + if not os.path.exists(inputs['output']): + os.mkdir(inputs['output']) + + # run with datasets, should set datasets_path and text_path + if inputs['recog_type'] != 'wav': + inputs['datasets_path'] = inputs['wav_path'] + + # run with datasets, and audio format is waveform + if inputs['audio_format'] == 'wav': + inputs['wav_path'] = os.path.join(inputs['datasets_path'], + 'wav', inputs['recog_type']) + inputs['hypothesis_text'] = os.path.join( + inputs['datasets_path'], 'transcript', 'data.text') + assert os.path.exists(inputs['hypothesis_text'] + ), 'hypothesis text does not exist' + + elif inputs['audio_format'] == 'kaldi_ark': + inputs['wav_path'] = os.path.join(inputs['datasets_path'], + inputs['recog_type']) + inputs['hypothesis_text'] = os.path.join( + inputs['wav_path'], 'data.text') + assert os.path.exists(inputs['hypothesis_text'] + ), 'hypothesis text does not exist' + + return inputs + + def _scp_generation_from_wav(self, inputs: Dict[str, + Any]) -> Dict[str, Any]: + """scp generation from waveform files + """ + + # find all waveform files + wav_list = [] + if inputs['recog_type'] == 'wav': + file_path = inputs['wav_path'] + if os.path.isfile(file_path): + if file_path.endswith('.wav') or file_path.endswith('.WAV'): + wav_list.append(file_path) + else: + wav_dir: str = inputs['wav_path'] + wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) + + list_count: int = len(wav_list) + inputs['wav_count'] = list_count + + # store all wav into data.0.scp + inputs['thread_count'] = 1 + j: int = 0 + wav_list_path = os.path.join(inputs['workspace'], 'data.0.scp') + with open(wav_list_path, 'a') as f: + while j < list_count: + wav_file = wav_list[j] + wave_scp_content: str = os.path.splitext( + os.path.basename(wav_file))[0] + wave_scp_content += ' ' + wav_file + '\n' + f.write(wave_scp_content) + j += 1 + + return inputs + + def _scp_generation_from_ark(self, inputs: Dict[str, + Any]) -> Dict[str, Any]: + """scp generation from kaldi ark file + """ + + inputs['thread_count'] = 1 + ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp') + ark_file_path = os.path.join(inputs['wav_path'], 'data.ark') + assert os.path.exists(ark_scp_path), 'data.scp does not exist' + assert os.path.exists(ark_file_path), 'data.ark does not exist' + + new_ark_scp_path = os.path.join(inputs['workspace'], 'data.0.scp') + + with open(ark_scp_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + with open(new_ark_scp_path, 'w', encoding='utf-8') as n: + for line in lines: + outs = line.strip().split(' ') + if len(outs) == 2: + key = outs[0] + sub = outs[1].split(':') + if len(sub) == 2: + nums = sub[1] + content = key + ' ' + ark_file_path + ':' + nums + '\n' + n.write(content) + + return inputs + + def _recursion_dir_all_wave(self, wav_list, + dir_path: str) -> Dict[str, Any]: + dir_files = os.listdir(dir_path) + for file in dir_files: + file_path = os.path.join(dir_path, file) + if os.path.isfile(file_path): + if file_path.endswith('.wav') or file_path.endswith('.WAV'): + wav_list.append(file_path) + elif os.path.isdir(file_path): + self._recursion_dir_all_wave(wav_list, file_path) + + return wav_list diff --git a/requirements/audio.txt b/requirements/audio.txt index feb4eb82..255b478a 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,3 +1,4 @@ +espnet==202204 #tts h5py inflect diff --git a/tests/pipelines/test_automatic_speech_recognition.py b/tests/pipelines/test_automatic_speech_recognition.py new file mode 100644 index 00000000..14d33b8f --- /dev/null +++ b/tests/pipelines/test_automatic_speech_recognition.py @@ -0,0 +1,199 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tarfile +import unittest + +import requests + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + +WAV_FILE = 'data/test/audios/asr_example.wav' + +LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' +LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz' + +AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz' +AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz' + + +def un_tar_gz(fname, dirs): + t = tarfile.open(fname) + t.extractall(path=dirs) + + +class AutomaticSpeechRecognitionTest(unittest.TestCase): + + def setUp(self) -> None: + self._am_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' + # this temporary workspace dir will store waveform files + self._workspace = os.path.join(os.getcwd(), '.tmp') + if not os.path.exists(self._workspace): + os.mkdir(self._workspace) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + '''run with single waveform file + ''' + + wav_file_path = os.path.join(os.getcwd(), WAV_FILE) + + inference_16k_pipline = pipeline( + task=Tasks.auto_speech_recognition, model=[self._am_model_id]) + self.assertTrue(inference_16k_pipline is not None) + + rec_result = inference_16k_pipline(wav_file_path) + self.assertTrue(len(rec_result['rec_result']) > 0) + self.assertTrue(rec_result['rec_result'] != 'None') + ''' + result structure: + { + 'rec_result': '每一天都要快乐喔' + } + or + { + 'rec_result': 'None' + } + ''' + print('test_run_with_wav rec result: ' + rec_result['rec_result']) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_wav_dataset(self): + '''run with datasets, and audio format is waveform + datasets directory: + + wav + test # testsets + xx.wav + ... + dev # devsets + yy.wav + ... + train # trainsets + zz.wav + ... + transcript + data.text # hypothesis text + ''' + + # downloading pos_testsets file + testsets_file_path = os.path.join(self._workspace, + LITTLE_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(LITTLE_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename( + os.path.splitext( + os.path.basename(LITTLE_TESTSETS_FILE))[0]))[0] + # dataset_path = /.tmp/data_aishell/wav/test + dataset_path = os.path.join(self._workspace, testsets_dir_name, 'wav', + 'test') + + # untar the dataset_path file + if not os.path.exists(dataset_path): + un_tar_gz(testsets_file_path, self._workspace) + + inference_16k_pipline = pipeline( + task=Tasks.auto_speech_recognition, model=[self._am_model_id]) + self.assertTrue(inference_16k_pipline is not None) + + rec_result = inference_16k_pipline(wav_path=dataset_path) + self.assertTrue(len(rec_result['datasets_result']) > 0) + self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) + ''' + result structure: + { + 'rec_result': 'None', + 'datasets_result': + { + 'Wrd': 1654, # the number of words + 'Snt': 128, # the number of sentences + 'Corr': 1573, # the number of correct words + 'Ins': 1, # the number of insert words + 'Del': 1, # the number of delete words + 'Sub': 80, # the number of substitution words + 'wrong_words': 82, # the number of wrong words + 'wrong_sentences': 47, # the number of wrong sentences + 'Err': 4.96, # WER/CER + 'S.Err': 36.72 # SER + } + } + ''' + print('test_run_with_wav_dataset datasets result: ') + print(rec_result['datasets_result']) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_ark_dataset(self): + '''run with datasets, and audio format is kaldi_ark + datasets directory: + + test # testsets + data.ark + data.scp + data.text + dev # devsets + data.ark + data.scp + data.text + train # trainsets + data.ark + data.scp + data.text + ''' + + # downloading pos_testsets file + testsets_file_path = os.path.join(self._workspace, + AISHELL1_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(AISHELL1_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename( + os.path.splitext( + os.path.basename(AISHELL1_TESTSETS_FILE))[0]))[0] + # dataset_path = /.tmp/aishell1/test + dataset_path = os.path.join(self._workspace, testsets_dir_name, 'test') + + # untar the dataset_path file + if not os.path.exists(dataset_path): + un_tar_gz(testsets_file_path, self._workspace) + + inference_16k_pipline = pipeline( + task=Tasks.auto_speech_recognition, model=[self._am_model_id]) + self.assertTrue(inference_16k_pipline is not None) + + rec_result = inference_16k_pipline(wav_path=dataset_path) + self.assertTrue(len(rec_result['datasets_result']) > 0) + self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) + ''' + result structure: + { + 'rec_result': 'None', + 'datasets_result': + { + 'Wrd': 104816, # the number of words + 'Snt': 7176, # the number of sentences + 'Corr': 99327, # the number of correct words + 'Ins': 104, # the number of insert words + 'Del': 155, # the number of delete words + 'Sub': 5334, # the number of substitution words + 'wrong_words': 5593, # the number of wrong words + 'wrong_sentences': 2898, # the number of wrong sentences + 'Err': 5.34, # WER/CER + 'S.Err': 40.38 # SER + } + } + ''' + print('test_run_with_ark_dataset datasets result: ') + print(rec_result['datasets_result']) + + +if __name__ == '__main__': + unittest.main()