Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9273537master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:87bde7feb3b40d75dec27e5824dd1077911f867e3f125c4bf603ec0af954d4db | |||||
| size 77864 | |||||
| @@ -23,6 +23,7 @@ class Models(object): | |||||
| sambert_hifigan = 'sambert-hifigan' | sambert_hifigan = 'sambert-hifigan' | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| generic_asr = 'generic-asr' | |||||
| # multi-modal models | # multi-modal models | ||||
| ofa = 'ofa' | ofa = 'ofa' | ||||
| @@ -68,6 +69,7 @@ class Pipelines(object): | |||||
| speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| asr_inference = 'asr-inference' | |||||
| # multi-modal tasks | # multi-modal tasks | ||||
| image_caption = 'image-captioning' | image_caption = 'image-captioning' | ||||
| @@ -120,6 +122,7 @@ class Preprocessors(object): | |||||
| linear_aec_fbank = 'linear-aec-fbank' | linear_aec_fbank = 'linear-aec-fbank' | ||||
| text_to_tacotron_symbols = 'text-to-tacotron-symbols' | text_to_tacotron_symbols = 'text-to-tacotron-symbols' | ||||
| wav_to_lists = 'wav-to-lists' | wav_to_lists = 'wav-to-lists' | ||||
| wav_to_scp = 'wav-to-scp' | |||||
| # multi-modal | # multi-modal | ||||
| ofa_image_caption = 'ofa-image-caption' | ofa_image_caption = 'ofa-image-caption' | ||||
| @@ -5,6 +5,7 @@ from .base import Model | |||||
| from .builder import MODELS, build_model | from .builder import MODELS, build_model | ||||
| try: | try: | ||||
| from .audio.asr import GenericAutomaticSpeechRecognition | |||||
| from .audio.tts import SambertHifigan | from .audio.tts import SambertHifigan | ||||
| from .audio.kws import GenericKeyWordSpotting | from .audio.kws import GenericKeyWordSpotting | ||||
| from .audio.ans.frcrn import FRCRNModel | from .audio.ans.frcrn import FRCRNModel | ||||
| @@ -0,0 +1 @@ | |||||
| from .generic_automatic_speech_recognition import * # noqa F403 | |||||
| @@ -0,0 +1,39 @@ | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['GenericAutomaticSpeechRecognition'] | |||||
| @MODELS.register_module( | |||||
| Tasks.auto_speech_recognition, module_name=Models.generic_asr) | |||||
| class GenericAutomaticSpeechRecognition(Model): | |||||
| def __init__(self, model_dir: str, am_model_name: str, | |||||
| model_config: Dict[str, Any], *args, **kwargs): | |||||
| """initialize the info of model. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| am_model_name (str): the am model name from configuration.json | |||||
| """ | |||||
| self.model_cfg = { | |||||
| # the recognition model dir path | |||||
| 'model_workspace': model_dir, | |||||
| # the am model name | |||||
| 'am_model': am_model_name, | |||||
| # the am model file path | |||||
| 'am_model_path': os.path.join(model_dir, am_model_name), | |||||
| # the recognition model config dict | |||||
| 'model_config': model_config | |||||
| } | |||||
| def forward(self) -> Dict[str, Any]: | |||||
| """return the info of the model | |||||
| """ | |||||
| return self.model_cfg | |||||
| @@ -3,6 +3,7 @@ | |||||
| from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR | from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR | ||||
| try: | try: | ||||
| from .asr.asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||||
| from .kws_kwsbp_pipeline import * # noqa F403 | from .kws_kwsbp_pipeline import * # noqa F403 | ||||
| from .linear_aec_pipeline import LinearAECPipeline | from .linear_aec_pipeline import LinearAECPipeline | ||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| @@ -0,0 +1,12 @@ | |||||
| import nltk | |||||
| try: | |||||
| nltk.data.find('taggers/averaged_perceptron_tagger') | |||||
| except LookupError: | |||||
| nltk.download( | |||||
| 'averaged_perceptron_tagger', halt_on_error=False, raise_on_error=True) | |||||
| try: | |||||
| nltk.data.find('corpora/cmudict') | |||||
| except LookupError: | |||||
| nltk.download('cmudict', halt_on_error=False, raise_on_error=True) | |||||
| @@ -0,0 +1,690 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| import argparse | |||||
| import logging | |||||
| import sys | |||||
| import time | |||||
| from pathlib import Path | |||||
| from typing import Any, List, Optional, Sequence, Tuple, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| from espnet2.asr.frontend.default import DefaultFrontend | |||||
| from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer | |||||
| from espnet2.asr.transducer.beam_search_transducer import \ | |||||
| ExtendedHypothesis as ExtTransHypothesis # noqa: H301 | |||||
| from espnet2.asr.transducer.beam_search_transducer import \ | |||||
| Hypothesis as TransHypothesis | |||||
| from espnet2.fileio.datadir_writer import DatadirWriter | |||||
| from espnet2.tasks.lm import LMTask | |||||
| from espnet2.text.build_tokenizer import build_tokenizer | |||||
| from espnet2.text.token_id_converter import TokenIDConverter | |||||
| from espnet2.torch_utils.device_funcs import to_device | |||||
| from espnet2.torch_utils.set_all_random_seed import set_all_random_seed | |||||
| from espnet2.utils import config_argparse | |||||
| from espnet2.utils.types import str2bool, str2triple_str, str_or_none | |||||
| from espnet.nets.batch_beam_search import BatchBeamSearch | |||||
| from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim | |||||
| from espnet.nets.beam_search import BeamSearch, Hypothesis | |||||
| from espnet.nets.pytorch_backend.transformer.subsampling import \ | |||||
| TooShortUttError | |||||
| from espnet.nets.scorer_interface import BatchScorerInterface | |||||
| from espnet.nets.scorers.ctc import CTCPrefixScorer | |||||
| from espnet.nets.scorers.length_bonus import LengthBonus | |||||
| from espnet.utils.cli_utils import get_commandline_args | |||||
| from typeguard import check_argument_types, check_return_type | |||||
| from .espnet.tasks.asr import ASRTaskNAR as ASRTask | |||||
| class Speech2Text: | |||||
| def __init__(self, | |||||
| asr_train_config: Union[Path, str] = None, | |||||
| asr_model_file: Union[Path, str] = None, | |||||
| transducer_conf: dict = None, | |||||
| lm_train_config: Union[Path, str] = None, | |||||
| lm_file: Union[Path, str] = None, | |||||
| ngram_scorer: str = 'full', | |||||
| ngram_file: Union[Path, str] = None, | |||||
| token_type: str = None, | |||||
| bpemodel: str = None, | |||||
| device: str = 'cpu', | |||||
| maxlenratio: float = 0.0, | |||||
| minlenratio: float = 0.0, | |||||
| batch_size: int = 1, | |||||
| dtype: str = 'float32', | |||||
| beam_size: int = 20, | |||||
| ctc_weight: float = 0.5, | |||||
| lm_weight: float = 1.0, | |||||
| ngram_weight: float = 0.9, | |||||
| penalty: float = 0.0, | |||||
| nbest: int = 1, | |||||
| streaming: bool = False, | |||||
| frontend_conf: dict = None): | |||||
| assert check_argument_types() | |||||
| # 1. Build ASR model | |||||
| scorers = {} | |||||
| asr_model, asr_train_args = ASRTask.build_model_from_file( | |||||
| asr_train_config, asr_model_file, device) | |||||
| if asr_model.frontend is None and frontend_conf is not None: | |||||
| frontend = DefaultFrontend(**frontend_conf) | |||||
| asr_model.frontend = frontend | |||||
| asr_model.to(dtype=getattr(torch, dtype)).eval() | |||||
| decoder = asr_model.decoder | |||||
| ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) | |||||
| token_list = asr_model.token_list | |||||
| scorers.update( | |||||
| decoder=decoder, | |||||
| ctc=ctc, | |||||
| length_bonus=LengthBonus(len(token_list)), | |||||
| ) | |||||
| # 2. Build Language model | |||||
| if lm_train_config is not None: | |||||
| lm, lm_train_args = LMTask.build_model_from_file( | |||||
| lm_train_config, lm_file, device) | |||||
| scorers['lm'] = lm.lm | |||||
| # 3. Build ngram model | |||||
| if ngram_file is not None: | |||||
| if ngram_scorer == 'full': | |||||
| from espnet.nets.scorers.ngram import NgramFullScorer | |||||
| ngram = NgramFullScorer(ngram_file, token_list) | |||||
| else: | |||||
| from espnet.nets.scorers.ngram import NgramPartScorer | |||||
| ngram = NgramPartScorer(ngram_file, token_list) | |||||
| else: | |||||
| ngram = None | |||||
| scorers['ngram'] = ngram | |||||
| # 4. Build BeamSearch object | |||||
| if asr_model.use_transducer_decoder: | |||||
| beam_search_transducer = BeamSearchTransducer( | |||||
| decoder=asr_model.decoder, | |||||
| joint_network=asr_model.joint_network, | |||||
| beam_size=beam_size, | |||||
| lm=scorers['lm'] if 'lm' in scorers else None, | |||||
| lm_weight=lm_weight, | |||||
| **transducer_conf, | |||||
| ) | |||||
| beam_search = None | |||||
| else: | |||||
| beam_search_transducer = None | |||||
| weights = dict( | |||||
| decoder=1.0 - ctc_weight, | |||||
| ctc=ctc_weight, | |||||
| lm=lm_weight, | |||||
| ngram=ngram_weight, | |||||
| length_bonus=penalty, | |||||
| ) | |||||
| beam_search = BeamSearch( | |||||
| beam_size=beam_size, | |||||
| weights=weights, | |||||
| scorers=scorers, | |||||
| sos=asr_model.sos, | |||||
| eos=asr_model.eos, | |||||
| vocab_size=len(token_list), | |||||
| token_list=token_list, | |||||
| pre_beam_score_key=None if ctc_weight == 1.0 else 'full', | |||||
| ) | |||||
| # TODO(karita): make all scorers batchfied | |||||
| if batch_size == 1: | |||||
| non_batch = [ | |||||
| k for k, v in beam_search.full_scorers.items() | |||||
| if not isinstance(v, BatchScorerInterface) | |||||
| ] | |||||
| if len(non_batch) == 0: | |||||
| if streaming: | |||||
| beam_search.__class__ = BatchBeamSearchOnlineSim | |||||
| beam_search.set_streaming_config(asr_train_config) | |||||
| logging.info( | |||||
| 'BatchBeamSearchOnlineSim implementation is selected.' | |||||
| ) | |||||
| else: | |||||
| beam_search.__class__ = BatchBeamSearch | |||||
| else: | |||||
| logging.warning( | |||||
| f'As non-batch scorers {non_batch} are found, ' | |||||
| f'fall back to non-batch implementation.') | |||||
| beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() | |||||
| for scorer in scorers.values(): | |||||
| if isinstance(scorer, torch.nn.Module): | |||||
| scorer.to( | |||||
| device=device, dtype=getattr(torch, dtype)).eval() | |||||
| # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text | |||||
| if token_type is None: | |||||
| token_type = asr_train_args.token_type | |||||
| if bpemodel is None: | |||||
| bpemodel = asr_train_args.bpemodel | |||||
| if token_type is None: | |||||
| tokenizer = None | |||||
| elif token_type == 'bpe': | |||||
| if bpemodel is not None: | |||||
| tokenizer = build_tokenizer( | |||||
| token_type=token_type, bpemodel=bpemodel) | |||||
| else: | |||||
| tokenizer = None | |||||
| else: | |||||
| tokenizer = build_tokenizer(token_type=token_type) | |||||
| converter = TokenIDConverter(token_list=token_list) | |||||
| self.asr_model = asr_model | |||||
| self.asr_train_args = asr_train_args | |||||
| self.converter = converter | |||||
| self.tokenizer = tokenizer | |||||
| self.beam_search = beam_search | |||||
| self.beam_search_transducer = beam_search_transducer | |||||
| self.maxlenratio = maxlenratio | |||||
| self.minlenratio = minlenratio | |||||
| self.device = device | |||||
| self.dtype = dtype | |||||
| self.nbest = nbest | |||||
| @torch.no_grad() | |||||
| def __call__(self, speech: Union[torch.Tensor, np.ndarray]): | |||||
| """Inference | |||||
| Args: | |||||
| data: Input speech data | |||||
| Returns: | |||||
| text, token, token_int, hyp | |||||
| """ | |||||
| assert check_argument_types() | |||||
| # Input as audio signal | |||||
| if isinstance(speech, np.ndarray): | |||||
| speech = torch.tensor(speech) | |||||
| # data: (Nsamples,) -> (1, Nsamples) | |||||
| speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) | |||||
| # lengths: (1,) | |||||
| lengths = speech.new_full([1], | |||||
| dtype=torch.long, | |||||
| fill_value=speech.size(1)) | |||||
| batch = {'speech': speech, 'speech_lengths': lengths} | |||||
| # a. To device | |||||
| batch = to_device(batch, device=self.device) | |||||
| # b. Forward Encoder | |||||
| enc, enc_len = self.asr_model.encode(**batch) | |||||
| if isinstance(enc, tuple): | |||||
| enc = enc[0] | |||||
| assert len(enc) == 1, len(enc) | |||||
| predictor_outs = self.asr_model.calc_predictor(enc, enc_len) | |||||
| pre_acoustic_embeds, pre_token_length = predictor_outs[ | |||||
| 0], predictor_outs[1] | |||||
| pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], | |||||
| device=pre_acoustic_embeds.device) | |||||
| decoder_outs = self.asr_model.cal_decoder_with_predictor( | |||||
| enc, enc_len, pre_acoustic_embeds, pre_token_length) | |||||
| decoder_out = decoder_outs[0] | |||||
| yseq = decoder_out.argmax(dim=-1) | |||||
| score = decoder_out.max(dim=-1)[0] | |||||
| score = torch.sum(score, dim=-1) | |||||
| # pad with mask tokens to ensure compatibility with sos/eos tokens | |||||
| yseq = torch.tensor( | |||||
| [self.asr_model.sos] + yseq.tolist()[0] + [self.asr_model.eos], | |||||
| device=yseq.device) | |||||
| nbest_hyps = [Hypothesis(yseq=yseq, score=score)] | |||||
| results = [] | |||||
| for hyp in nbest_hyps: | |||||
| assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp) | |||||
| # remove sos/eos and get results | |||||
| last_pos = None if self.asr_model.use_transducer_decoder else -1 | |||||
| if isinstance(hyp.yseq, list): | |||||
| token_int = hyp.yseq[1:last_pos] | |||||
| else: | |||||
| token_int = hyp.yseq[1:last_pos].tolist() | |||||
| # remove blank symbol id, which is assumed to be 0 | |||||
| token_int = list(filter(lambda x: x != 0, token_int)) | |||||
| # Change integer-ids to tokens | |||||
| token = self.converter.ids2tokens(token_int) | |||||
| if self.tokenizer is not None: | |||||
| text = self.tokenizer.tokens2text(token) | |||||
| else: | |||||
| text = None | |||||
| results.append((text, token, token_int, hyp, speech.size(1))) | |||||
| return results | |||||
| @staticmethod | |||||
| def from_pretrained( | |||||
| model_tag: Optional[str] = None, | |||||
| **kwargs: Optional[Any], | |||||
| ): | |||||
| """Build Speech2Text instance from the pretrained model. | |||||
| Args: | |||||
| model_tag (Optional[str]): Model tag of the pretrained models. | |||||
| Currently, the tags of espnet_model_zoo are supported. | |||||
| Returns: | |||||
| Speech2Text: Speech2Text instance. | |||||
| """ | |||||
| if model_tag is not None: | |||||
| try: | |||||
| from espnet_model_zoo.downloader import ModelDownloader | |||||
| except ImportError: | |||||
| logging.error( | |||||
| '`espnet_model_zoo` is not installed. ' | |||||
| 'Please install via `pip install -U espnet_model_zoo`.') | |||||
| raise | |||||
| d = ModelDownloader() | |||||
| kwargs.update(**d.download_and_unpack(model_tag)) | |||||
| return Speech2Text(**kwargs) | |||||
| def inference( | |||||
| output_dir: str, | |||||
| maxlenratio: float, | |||||
| minlenratio: float, | |||||
| batch_size: int, | |||||
| dtype: str, | |||||
| beam_size: int, | |||||
| ngpu: int, | |||||
| seed: int, | |||||
| ctc_weight: float, | |||||
| lm_weight: float, | |||||
| ngram_weight: float, | |||||
| penalty: float, | |||||
| nbest: int, | |||||
| num_workers: int, | |||||
| log_level: Union[int, str], | |||||
| data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||||
| key_file: Optional[str], | |||||
| asr_train_config: Optional[str], | |||||
| asr_model_file: Optional[str], | |||||
| lm_train_config: Optional[str], | |||||
| lm_file: Optional[str], | |||||
| word_lm_train_config: Optional[str], | |||||
| word_lm_file: Optional[str], | |||||
| ngram_file: Optional[str], | |||||
| model_tag: Optional[str], | |||||
| token_type: Optional[str], | |||||
| bpemodel: Optional[str], | |||||
| allow_variable_data_keys: bool, | |||||
| transducer_conf: Optional[dict], | |||||
| streaming: bool, | |||||
| frontend_conf: dict = None, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| if batch_size > 1: | |||||
| raise NotImplementedError('batch decoding is not implemented') | |||||
| if word_lm_train_config is not None: | |||||
| raise NotImplementedError('Word LM is not implemented') | |||||
| if ngpu > 1: | |||||
| raise NotImplementedError('only single GPU decoding is supported') | |||||
| logging.basicConfig( | |||||
| level=log_level, | |||||
| format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||||
| ) | |||||
| if ngpu >= 1: | |||||
| device = 'cuda' | |||||
| else: | |||||
| device = 'cpu' | |||||
| # 1. Set random-seed | |||||
| set_all_random_seed(seed) | |||||
| # 2. Build speech2text | |||||
| speech2text_kwargs = dict( | |||||
| asr_train_config=asr_train_config, | |||||
| asr_model_file=asr_model_file, | |||||
| transducer_conf=transducer_conf, | |||||
| lm_train_config=lm_train_config, | |||||
| lm_file=lm_file, | |||||
| ngram_file=ngram_file, | |||||
| token_type=token_type, | |||||
| bpemodel=bpemodel, | |||||
| device=device, | |||||
| maxlenratio=maxlenratio, | |||||
| minlenratio=minlenratio, | |||||
| dtype=dtype, | |||||
| beam_size=beam_size, | |||||
| ctc_weight=ctc_weight, | |||||
| lm_weight=lm_weight, | |||||
| ngram_weight=ngram_weight, | |||||
| penalty=penalty, | |||||
| nbest=nbest, | |||||
| streaming=streaming, | |||||
| frontend_conf=frontend_conf, | |||||
| ) | |||||
| speech2text = Speech2Text.from_pretrained( | |||||
| model_tag=model_tag, | |||||
| **speech2text_kwargs, | |||||
| ) | |||||
| # 3. Build data-iterator | |||||
| loader = ASRTask.build_streaming_iterator( | |||||
| data_path_and_name_and_type, | |||||
| dtype=dtype, | |||||
| batch_size=batch_size, | |||||
| key_file=key_file, | |||||
| num_workers=num_workers, | |||||
| preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, | |||||
| False), | |||||
| collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), | |||||
| allow_variable_data_keys=allow_variable_data_keys, | |||||
| inference=True, | |||||
| ) | |||||
| forward_time_total = 0.0 | |||||
| length_total = 0.0 | |||||
| # 7 .Start for-loop | |||||
| # FIXME(kamo): The output format should be discussed about | |||||
| with DatadirWriter(output_dir) as writer: | |||||
| for keys, batch in loader: | |||||
| assert isinstance(batch, dict), type(batch) | |||||
| assert all(isinstance(s, str) for s in keys), keys | |||||
| _bs = len(next(iter(batch.values()))) | |||||
| assert len(keys) == _bs, f'{len(keys)} != {_bs}' | |||||
| batch = { | |||||
| k: v[0] | |||||
| for k, v in batch.items() if not k.endswith('_lengths') | |||||
| } | |||||
| # N-best list of (text, token, token_int, hyp_object) | |||||
| try: | |||||
| time_beg = time.time() | |||||
| results = speech2text(**batch) | |||||
| time_end = time.time() | |||||
| forward_time = time_end - time_beg | |||||
| length = results[0][-1] | |||||
| results = [results[0][:-1]] | |||||
| forward_time_total += forward_time | |||||
| length_total += length | |||||
| except TooShortUttError as e: | |||||
| logging.warning(f'Utterance {keys} {e}') | |||||
| hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) | |||||
| results = [[' ', ['<space>'], [2], hyp]] * nbest | |||||
| # Only supporting batch_size==1 | |||||
| key = keys[0] | |||||
| for n, (text, token, token_int, | |||||
| hyp) in zip(range(1, nbest + 1), results): | |||||
| # Create a directory: outdir/{n}best_recog | |||||
| ibest_writer = writer[f'{n}best_recog'] | |||||
| # Write the result to each file | |||||
| ibest_writer['token'][key] = ' '.join(token) | |||||
| ibest_writer['token_int'][key] = ' '.join(map(str, token_int)) | |||||
| ibest_writer['score'][key] = str(hyp.score) | |||||
| if text is not None: | |||||
| ibest_writer['text'][key] = text | |||||
| logging.info( | |||||
| 'decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}' | |||||
| .format(length_total, forward_time_total, | |||||
| 100 * forward_time_total / length_total)) | |||||
| def get_parser(): | |||||
| parser = config_argparse.ArgumentParser( | |||||
| description='ASR Decoding', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| # Note(kamo): Use '_' instead of '-' as separator. | |||||
| # '-' is confusing if written in yaml. | |||||
| parser.add_argument( | |||||
| '--log_level', | |||||
| type=lambda x: x.upper(), | |||||
| default='INFO', | |||||
| choices=('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'), | |||||
| help='The verbose level of logging', | |||||
| ) | |||||
| parser.add_argument('--output_dir', type=str, required=True) | |||||
| parser.add_argument( | |||||
| '--ngpu', | |||||
| type=int, | |||||
| default=0, | |||||
| help='The number of gpus. 0 indicates CPU mode', | |||||
| ) | |||||
| parser.add_argument('--seed', type=int, default=0, help='Random seed') | |||||
| parser.add_argument( | |||||
| '--dtype', | |||||
| default='float32', | |||||
| choices=['float16', 'float32', 'float64'], | |||||
| help='Data type', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--num_workers', | |||||
| type=int, | |||||
| default=1, | |||||
| help='The number of workers used for DataLoader', | |||||
| ) | |||||
| group = parser.add_argument_group('Input data related') | |||||
| group.add_argument( | |||||
| '--data_path_and_name_and_type', | |||||
| type=str2triple_str, | |||||
| required=True, | |||||
| action='append', | |||||
| ) | |||||
| group.add_argument('--key_file', type=str_or_none) | |||||
| group.add_argument( | |||||
| '--allow_variable_data_keys', type=str2bool, default=False) | |||||
| group = parser.add_argument_group('The model configuration related') | |||||
| group.add_argument( | |||||
| '--asr_train_config', | |||||
| type=str, | |||||
| help='ASR training configuration', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--asr_model_file', | |||||
| type=str, | |||||
| help='ASR model parameter file', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--lm_train_config', | |||||
| type=str, | |||||
| help='LM training configuration', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--lm_file', | |||||
| type=str, | |||||
| help='LM parameter file', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--word_lm_train_config', | |||||
| type=str, | |||||
| help='Word LM training configuration', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--word_lm_file', | |||||
| type=str, | |||||
| help='Word LM parameter file', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--ngram_file', | |||||
| type=str, | |||||
| help='N-gram parameter file', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--model_tag', | |||||
| type=str, | |||||
| help='Pretrained model tag. If specify this option, *_train_config and ' | |||||
| '*_file will be overwritten', | |||||
| ) | |||||
| group = parser.add_argument_group('Beam-search related') | |||||
| group.add_argument( | |||||
| '--batch_size', | |||||
| type=int, | |||||
| default=1, | |||||
| help='The batch size for inference', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--nbest', type=int, default=1, help='Output N-best hypotheses') | |||||
| group.add_argument('--beam_size', type=int, default=20, help='Beam size') | |||||
| group.add_argument( | |||||
| '--penalty', type=float, default=0.0, help='Insertion penalty') | |||||
| group.add_argument( | |||||
| '--maxlenratio', | |||||
| type=float, | |||||
| default=0.0, | |||||
| help='Input length ratio to obtain max output length. ' | |||||
| 'If maxlenratio=0.0 (default), it uses a end-detect ' | |||||
| 'function ' | |||||
| 'to automatically find maximum hypothesis lengths.' | |||||
| 'If maxlenratio<0.0, its absolute value is interpreted' | |||||
| 'as a constant max output length', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--minlenratio', | |||||
| type=float, | |||||
| default=0.0, | |||||
| help='Input length ratio to obtain min output length', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--ctc_weight', | |||||
| type=float, | |||||
| default=0.5, | |||||
| help='CTC weight in joint decoding', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--lm_weight', type=float, default=1.0, help='RNNLM weight') | |||||
| group.add_argument( | |||||
| '--ngram_weight', type=float, default=0.9, help='ngram weight') | |||||
| group.add_argument('--streaming', type=str2bool, default=False) | |||||
| group.add_argument( | |||||
| '--frontend_conf', | |||||
| default=None, | |||||
| help='', | |||||
| ) | |||||
| group = parser.add_argument_group('Text converter related') | |||||
| group.add_argument( | |||||
| '--token_type', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| choices=['char', 'bpe', None], | |||||
| help='The token type for ASR model. ' | |||||
| 'If not given, refers from the training args', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--bpemodel', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='The model path of sentencepiece. ' | |||||
| 'If not given, refers from the training args', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--transducer_conf', | |||||
| default=None, | |||||
| help='The keyword arguments for transducer beam search.', | |||||
| ) | |||||
| return parser | |||||
| def asr_inference( | |||||
| output_dir: str, | |||||
| maxlenratio: float, | |||||
| minlenratio: float, | |||||
| beam_size: int, | |||||
| ngpu: int, | |||||
| ctc_weight: float, | |||||
| lm_weight: float, | |||||
| penalty: float, | |||||
| data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||||
| asr_train_config: Optional[str], | |||||
| asr_model_file: Optional[str], | |||||
| nbest: int = 1, | |||||
| num_workers: int = 1, | |||||
| log_level: Union[int, str] = 'INFO', | |||||
| batch_size: int = 1, | |||||
| dtype: str = 'float32', | |||||
| seed: int = 0, | |||||
| key_file: Optional[str] = None, | |||||
| lm_train_config: Optional[str] = None, | |||||
| lm_file: Optional[str] = None, | |||||
| word_lm_train_config: Optional[str] = None, | |||||
| word_lm_file: Optional[str] = None, | |||||
| ngram_file: Optional[str] = None, | |||||
| ngram_weight: float = 0.9, | |||||
| model_tag: Optional[str] = None, | |||||
| token_type: Optional[str] = None, | |||||
| bpemodel: Optional[str] = None, | |||||
| allow_variable_data_keys: bool = False, | |||||
| transducer_conf: Optional[dict] = None, | |||||
| streaming: bool = False, | |||||
| frontend_conf: dict = None, | |||||
| ): | |||||
| inference( | |||||
| output_dir=output_dir, | |||||
| maxlenratio=maxlenratio, | |||||
| minlenratio=minlenratio, | |||||
| batch_size=batch_size, | |||||
| dtype=dtype, | |||||
| beam_size=beam_size, | |||||
| ngpu=ngpu, | |||||
| seed=seed, | |||||
| ctc_weight=ctc_weight, | |||||
| lm_weight=lm_weight, | |||||
| ngram_weight=ngram_weight, | |||||
| penalty=penalty, | |||||
| nbest=nbest, | |||||
| num_workers=num_workers, | |||||
| log_level=log_level, | |||||
| data_path_and_name_and_type=data_path_and_name_and_type, | |||||
| key_file=key_file, | |||||
| asr_train_config=asr_train_config, | |||||
| asr_model_file=asr_model_file, | |||||
| lm_train_config=lm_train_config, | |||||
| lm_file=lm_file, | |||||
| word_lm_train_config=word_lm_train_config, | |||||
| word_lm_file=word_lm_file, | |||||
| ngram_file=ngram_file, | |||||
| model_tag=model_tag, | |||||
| token_type=token_type, | |||||
| bpemodel=bpemodel, | |||||
| allow_variable_data_keys=allow_variable_data_keys, | |||||
| transducer_conf=transducer_conf, | |||||
| streaming=streaming, | |||||
| frontend_conf=frontend_conf) | |||||
| def main(cmd=None): | |||||
| print(get_commandline_args(), file=sys.stderr) | |||||
| parser = get_parser() | |||||
| args = parser.parse_args(cmd) | |||||
| kwargs = vars(args) | |||||
| kwargs.pop('config', None) | |||||
| inference(**kwargs) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,193 @@ | |||||
| import os | |||||
| from typing import Any, Dict, List | |||||
| import numpy as np | |||||
| def type_checking(wav_path: str, | |||||
| recog_type: str = None, | |||||
| audio_format: str = None, | |||||
| workspace: str = None): | |||||
| assert os.path.exists(wav_path), f'wav_path:{wav_path} does not exist' | |||||
| r_recog_type = recog_type | |||||
| r_audio_format = audio_format | |||||
| r_workspace = workspace | |||||
| r_wav_path = wav_path | |||||
| if r_workspace is None or len(r_workspace) == 0: | |||||
| r_workspace = os.path.join(os.getcwd(), '.tmp') | |||||
| if r_recog_type is None: | |||||
| if os.path.isfile(wav_path): | |||||
| if wav_path.endswith('.wav') or wav_path.endswith('.WAV'): | |||||
| r_recog_type = 'wav' | |||||
| r_audio_format = 'wav' | |||||
| elif os.path.isdir(wav_path): | |||||
| dir_name = os.path.basename(wav_path) | |||||
| if 'test' in dir_name: | |||||
| r_recog_type = 'test' | |||||
| elif 'dev' in dir_name: | |||||
| r_recog_type = 'dev' | |||||
| elif 'train' in dir_name: | |||||
| r_recog_type = 'train' | |||||
| if r_audio_format is None: | |||||
| if find_file_by_ends(wav_path, '.ark'): | |||||
| r_audio_format = 'kaldi_ark' | |||||
| elif find_file_by_ends(wav_path, '.wav') or find_file_by_ends( | |||||
| wav_path, '.WAV'): | |||||
| r_audio_format = 'wav' | |||||
| if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': | |||||
| # datasets with kaldi_ark file | |||||
| r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) | |||||
| elif r_audio_format == 'wav' and r_recog_type != 'wav': | |||||
| # datasets with waveform files | |||||
| r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) | |||||
| return r_recog_type, r_audio_format, r_workspace, r_wav_path | |||||
| def find_file_by_ends(dir_path: str, ends: str): | |||||
| dir_files = os.listdir(dir_path) | |||||
| for file in dir_files: | |||||
| file_path = os.path.join(dir_path, file) | |||||
| if os.path.isfile(file_path): | |||||
| if file_path.endswith(ends): | |||||
| return True | |||||
| elif os.path.isdir(file_path): | |||||
| if find_file_by_ends(file_path, ends): | |||||
| return True | |||||
| return False | |||||
| def compute_wer(hyp_text_path: str, ref_text_path: str) -> Dict[str, Any]: | |||||
| assert os.path.exists(hyp_text_path), 'hyp_text does not exist' | |||||
| assert os.path.exists(ref_text_path), 'ref_text does not exist' | |||||
| rst = { | |||||
| 'Wrd': 0, | |||||
| 'Corr': 0, | |||||
| 'Ins': 0, | |||||
| 'Del': 0, | |||||
| 'Sub': 0, | |||||
| 'Snt': 0, | |||||
| 'Err': 0.0, | |||||
| 'S.Err': 0.0, | |||||
| 'wrong_words': 0, | |||||
| 'wrong_sentences': 0 | |||||
| } | |||||
| with open(ref_text_path, 'r', encoding='utf-8') as r: | |||||
| r_lines = r.readlines() | |||||
| with open(hyp_text_path, 'r', encoding='utf-8') as h: | |||||
| h_lines = h.readlines() | |||||
| for r_line in r_lines: | |||||
| r_line_item = r_line.split() | |||||
| r_key = r_line_item[0] | |||||
| r_sentence = r_line_item[1] | |||||
| for h_line in h_lines: | |||||
| # find sentence from hyp text | |||||
| if r_key in h_line: | |||||
| h_line_item = h_line.split() | |||||
| h_sentence = h_line_item[1] | |||||
| out_item = compute_wer_by_line(h_sentence, r_sentence) | |||||
| rst['Wrd'] += out_item['nwords'] | |||||
| rst['Corr'] += out_item['cor'] | |||||
| rst['wrong_words'] += out_item['wrong'] | |||||
| rst['Ins'] += out_item['ins'] | |||||
| rst['Del'] += out_item['del'] | |||||
| rst['Sub'] += out_item['sub'] | |||||
| rst['Snt'] += 1 | |||||
| if out_item['wrong'] > 0: | |||||
| rst['wrong_sentences'] += 1 | |||||
| break | |||||
| if rst['Wrd'] > 0: | |||||
| rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) | |||||
| if rst['Snt'] > 0: | |||||
| rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) | |||||
| return rst | |||||
| def compute_wer_by_line(hyp: list, ref: list) -> Dict[str, Any]: | |||||
| len_hyp = len(hyp) | |||||
| len_ref = len(ref) | |||||
| cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) | |||||
| ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) | |||||
| for i in range(len_hyp + 1): | |||||
| cost_matrix[i][0] = i | |||||
| for j in range(len_ref + 1): | |||||
| cost_matrix[0][j] = j | |||||
| for i in range(1, len_hyp + 1): | |||||
| for j in range(1, len_ref + 1): | |||||
| if hyp[i - 1] == ref[j - 1]: | |||||
| cost_matrix[i][j] = cost_matrix[i - 1][j - 1] | |||||
| else: | |||||
| substitution = cost_matrix[i - 1][j - 1] + 1 | |||||
| insertion = cost_matrix[i - 1][j] + 1 | |||||
| deletion = cost_matrix[i][j - 1] + 1 | |||||
| compare_val = [substitution, insertion, deletion] | |||||
| min_val = min(compare_val) | |||||
| operation_idx = compare_val.index(min_val) + 1 | |||||
| cost_matrix[i][j] = min_val | |||||
| ops_matrix[i][j] = operation_idx | |||||
| match_idx = [] | |||||
| i = len_hyp | |||||
| j = len_ref | |||||
| rst = { | |||||
| 'nwords': len_hyp, | |||||
| 'cor': 0, | |||||
| 'wrong': 0, | |||||
| 'ins': 0, | |||||
| 'del': 0, | |||||
| 'sub': 0 | |||||
| } | |||||
| while i >= 0 or j >= 0: | |||||
| i_idx = max(0, i) | |||||
| j_idx = max(0, j) | |||||
| if ops_matrix[i_idx][j_idx] == 0: # correct | |||||
| if i - 1 >= 0 and j - 1 >= 0: | |||||
| match_idx.append((j - 1, i - 1)) | |||||
| rst['cor'] += 1 | |||||
| i -= 1 | |||||
| j -= 1 | |||||
| elif ops_matrix[i_idx][j_idx] == 2: # insert | |||||
| i -= 1 | |||||
| rst['ins'] += 1 | |||||
| elif ops_matrix[i_idx][j_idx] == 3: # delete | |||||
| j -= 1 | |||||
| rst['del'] += 1 | |||||
| elif ops_matrix[i_idx][j_idx] == 1: # substitute | |||||
| i -= 1 | |||||
| j -= 1 | |||||
| rst['sub'] += 1 | |||||
| if i < 0 and j >= 0: | |||||
| rst['del'] += 1 | |||||
| elif j < 0 and i >= 0: | |||||
| rst['ins'] += 1 | |||||
| match_idx.reverse() | |||||
| wrong_cnt = cost_matrix[len_hyp][len_ref] | |||||
| rst['wrong'] = wrong_cnt | |||||
| return rst | |||||
| @@ -0,0 +1,757 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| """Decoder definition.""" | |||||
| from typing import Any, List, Sequence, Tuple | |||||
| import torch | |||||
| from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
| from espnet.nets.pytorch_backend.transformer.attention import \ | |||||
| MultiHeadedAttention | |||||
| from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer | |||||
| from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ | |||||
| DynamicConvolution | |||||
| from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ | |||||
| DynamicConvolution2D | |||||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
| PositionalEncoding | |||||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
| from espnet.nets.pytorch_backend.transformer.lightconv import \ | |||||
| LightweightConvolution | |||||
| from espnet.nets.pytorch_backend.transformer.lightconv2d import \ | |||||
| LightweightConvolution2D | |||||
| from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask | |||||
| from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||||
| PositionwiseFeedForward # noqa: H301 | |||||
| from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||||
| from espnet.nets.scorer_interface import BatchScorerInterface | |||||
| from typeguard import check_argument_types | |||||
| class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): | |||||
| """Base class of Transfomer decoder module. | |||||
| Args: | |||||
| vocab_size: output dim | |||||
| encoder_output_size: dimension of attention | |||||
| attention_heads: the number of heads of multi head attention | |||||
| linear_units: the number of units of position-wise feed forward | |||||
| num_blocks: the number of decoder blocks | |||||
| dropout_rate: dropout rate | |||||
| self_attention_dropout_rate: dropout rate for attention | |||||
| input_layer: input layer type | |||||
| use_output_layer: whether to use output layer | |||||
| pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||||
| normalize_before: whether to use layer_norm before the first block | |||||
| concat_after: whether to concat attention layer's input and output | |||||
| if True, additional linear will be applied. | |||||
| i.e. x -> x + linear(concat(x, att(x))) | |||||
| if False, no additional linear will be applied. | |||||
| i.e. x -> x + att(x) | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__() | |||||
| attention_dim = encoder_output_size | |||||
| if input_layer == 'embed': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Embedding(vocab_size, attention_dim), | |||||
| pos_enc_class(attention_dim, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'linear': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Linear(vocab_size, attention_dim), | |||||
| torch.nn.LayerNorm(attention_dim), | |||||
| torch.nn.Dropout(dropout_rate), | |||||
| torch.nn.ReLU(), | |||||
| pos_enc_class(attention_dim, positional_dropout_rate), | |||||
| ) | |||||
| else: | |||||
| raise ValueError( | |||||
| f"only 'embed' or 'linear' is supported: {input_layer}") | |||||
| self.normalize_before = normalize_before | |||||
| if self.normalize_before: | |||||
| self.after_norm = LayerNorm(attention_dim) | |||||
| if use_output_layer: | |||||
| self.output_layer = torch.nn.Linear(attention_dim, vocab_size) | |||||
| else: | |||||
| self.output_layer = None | |||||
| # Must set by the inheritance | |||||
| self.decoders = None | |||||
| def forward( | |||||
| self, | |||||
| hs_pad: torch.Tensor, | |||||
| hlens: torch.Tensor, | |||||
| ys_in_pad: torch.Tensor, | |||||
| ys_in_lens: torch.Tensor, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| """Forward decoder. | |||||
| Args: | |||||
| hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||||
| hlens: (batch) | |||||
| ys_in_pad: | |||||
| input token ids, int64 (batch, maxlen_out) | |||||
| if input_layer == "embed" | |||||
| input tensor (batch, maxlen_out, #mels) in the other cases | |||||
| ys_in_lens: (batch) | |||||
| Returns: | |||||
| (tuple): tuple containing: | |||||
| x: decoded token score before softmax (batch, maxlen_out, token) | |||||
| if use_output_layer is True, | |||||
| olens: (batch, ) | |||||
| """ | |||||
| tgt = ys_in_pad | |||||
| # tgt_mask: (B, 1, L) | |||||
| tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||||
| # m: (1, L, L) | |||||
| m = subsequent_mask( | |||||
| tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||||
| # tgt_mask: (B, L, L) | |||||
| tgt_mask = tgt_mask & m | |||||
| memory = hs_pad | |||||
| memory_mask = ( | |||||
| ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||||
| memory.device) | |||||
| # Padding for Longformer | |||||
| if memory_mask.shape[-1] != memory.shape[1]: | |||||
| padlen = memory.shape[1] - memory_mask.shape[-1] | |||||
| memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||||
| 'constant', False) | |||||
| x = self.embed(tgt) | |||||
| x, tgt_mask, memory, memory_mask = self.decoders( | |||||
| x, tgt_mask, memory, memory_mask) | |||||
| if self.normalize_before: | |||||
| x = self.after_norm(x) | |||||
| if self.output_layer is not None: | |||||
| x = self.output_layer(x) | |||||
| olens = tgt_mask.sum(1) | |||||
| return x, olens | |||||
| def forward_one_step( | |||||
| self, | |||||
| tgt: torch.Tensor, | |||||
| tgt_mask: torch.Tensor, | |||||
| memory: torch.Tensor, | |||||
| cache: List[torch.Tensor] = None, | |||||
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |||||
| """Forward one step. | |||||
| Args: | |||||
| tgt: input token ids, int64 (batch, maxlen_out) | |||||
| tgt_mask: input token mask, (batch, maxlen_out) | |||||
| dtype=torch.uint8 in PyTorch 1.2- | |||||
| dtype=torch.bool in PyTorch 1.2+ (include 1.2) | |||||
| memory: encoded memory, float32 (batch, maxlen_in, feat) | |||||
| cache: cached output list of (batch, max_time_out-1, size) | |||||
| Returns: | |||||
| y, cache: NN output value and cache per `self.decoders`. | |||||
| y.shape` is (batch, maxlen_out, token) | |||||
| """ | |||||
| x = self.embed(tgt) | |||||
| if cache is None: | |||||
| cache = [None] * len(self.decoders) | |||||
| new_cache = [] | |||||
| for c, decoder in zip(cache, self.decoders): | |||||
| x, tgt_mask, memory, memory_mask = decoder( | |||||
| x, tgt_mask, memory, None, cache=c) | |||||
| new_cache.append(x) | |||||
| if self.normalize_before: | |||||
| y = self.after_norm(x[:, -1]) | |||||
| else: | |||||
| y = x[:, -1] | |||||
| if self.output_layer is not None: | |||||
| y = torch.log_softmax(self.output_layer(y), dim=-1) | |||||
| return y, new_cache | |||||
| def score(self, ys, state, x): | |||||
| """Score.""" | |||||
| ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) | |||||
| logp, state = self.forward_one_step( | |||||
| ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) | |||||
| return logp.squeeze(0), state | |||||
| def batch_score(self, ys: torch.Tensor, states: List[Any], | |||||
| xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: | |||||
| """Score new token batch. | |||||
| Args: | |||||
| ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |||||
| states (List[Any]): Scorer states for prefix tokens. | |||||
| xs (torch.Tensor): | |||||
| The encoder feature that generates ys (n_batch, xlen, n_feat). | |||||
| Returns: | |||||
| tuple[torch.Tensor, List[Any]]: Tuple of | |||||
| batchfied scores for next token with shape of `(n_batch, n_vocab)` | |||||
| and next state list for ys. | |||||
| """ | |||||
| # merge states | |||||
| n_batch = len(ys) | |||||
| n_layers = len(self.decoders) | |||||
| if states[0] is None: | |||||
| batch_state = None | |||||
| else: | |||||
| # transpose state of [batch, layer] into [layer, batch] | |||||
| batch_state = [ | |||||
| torch.stack([states[b][i] for b in range(n_batch)]) | |||||
| for i in range(n_layers) | |||||
| ] | |||||
| # batch decoding | |||||
| ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) | |||||
| logp, states = self.forward_one_step( | |||||
| ys, ys_mask, xs, cache=batch_state) | |||||
| # transpose state of [layer, batch] into [batch, layer] | |||||
| state_list = [[states[i][b] for i in range(n_layers)] | |||||
| for b in range(n_batch)] | |||||
| return logp, state_list | |||||
| class TransformerDecoder(BaseTransformerDecoder): | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| self_attention_dropout_rate: float = 0.0, | |||||
| src_attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| dropout_rate=dropout_rate, | |||||
| positional_dropout_rate=positional_dropout_rate, | |||||
| input_layer=input_layer, | |||||
| use_output_layer=use_output_layer, | |||||
| pos_enc_class=pos_enc_class, | |||||
| normalize_before=normalize_before, | |||||
| ) | |||||
| attention_dim = encoder_output_size | |||||
| self.decoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| self_attention_dropout_rate), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| class ParaformerDecoder(TransformerDecoder): | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| self_attention_dropout_rate: float = 0.0, | |||||
| src_attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| dropout_rate=dropout_rate, | |||||
| positional_dropout_rate=positional_dropout_rate, | |||||
| input_layer=input_layer, | |||||
| use_output_layer=use_output_layer, | |||||
| pos_enc_class=pos_enc_class, | |||||
| normalize_before=normalize_before, | |||||
| ) | |||||
| attention_dim = encoder_output_size | |||||
| self.decoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| self_attention_dropout_rate), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| hs_pad: torch.Tensor, | |||||
| hlens: torch.Tensor, | |||||
| ys_in_pad: torch.Tensor, | |||||
| ys_in_lens: torch.Tensor, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| """Forward decoder. | |||||
| Args: | |||||
| hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||||
| hlens: (batch) | |||||
| ys_in_pad: | |||||
| input token ids, int64 (batch, maxlen_out) | |||||
| if input_layer == "embed" | |||||
| input tensor (batch, maxlen_out, #mels) in the other cases | |||||
| ys_in_lens: (batch) | |||||
| Returns: | |||||
| (tuple): tuple containing: | |||||
| x: decoded token score before softmax (batch, maxlen_out, token) | |||||
| if use_output_layer is True, | |||||
| olens: (batch, ) | |||||
| """ | |||||
| tgt = ys_in_pad | |||||
| # tgt_mask: (B, 1, L) | |||||
| tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||||
| # m: (1, L, L) | |||||
| # m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||||
| # tgt_mask: (B, L, L) | |||||
| # tgt_mask = tgt_mask & m | |||||
| memory = hs_pad | |||||
| memory_mask = ( | |||||
| ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||||
| memory.device) | |||||
| # Padding for Longformer | |||||
| if memory_mask.shape[-1] != memory.shape[1]: | |||||
| padlen = memory.shape[1] - memory_mask.shape[-1] | |||||
| memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||||
| 'constant', False) | |||||
| # x = self.embed(tgt) | |||||
| x = tgt | |||||
| x, tgt_mask, memory, memory_mask = self.decoders( | |||||
| x, tgt_mask, memory, memory_mask) | |||||
| if self.normalize_before: | |||||
| x = self.after_norm(x) | |||||
| if self.output_layer is not None: | |||||
| x = self.output_layer(x) | |||||
| olens = tgt_mask.sum(1) | |||||
| return x, olens | |||||
| class ParaformerDecoderBertEmbed(TransformerDecoder): | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| self_attention_dropout_rate: float = 0.0, | |||||
| src_attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| embeds_id: int = 2, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| dropout_rate=dropout_rate, | |||||
| positional_dropout_rate=positional_dropout_rate, | |||||
| input_layer=input_layer, | |||||
| use_output_layer=use_output_layer, | |||||
| pos_enc_class=pos_enc_class, | |||||
| normalize_before=normalize_before, | |||||
| ) | |||||
| attention_dim = encoder_output_size | |||||
| self.decoders = repeat( | |||||
| embeds_id, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| self_attention_dropout_rate), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| if embeds_id == num_blocks: | |||||
| self.decoders2 = None | |||||
| else: | |||||
| self.decoders2 = repeat( | |||||
| num_blocks - embeds_id, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| self_attention_dropout_rate), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| hs_pad: torch.Tensor, | |||||
| hlens: torch.Tensor, | |||||
| ys_in_pad: torch.Tensor, | |||||
| ys_in_lens: torch.Tensor, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| """Forward decoder. | |||||
| Args: | |||||
| hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||||
| hlens: (batch) | |||||
| ys_in_pad: | |||||
| input token ids, int64 (batch, maxlen_out) | |||||
| if input_layer == "embed" | |||||
| input tensor (batch, maxlen_out, #mels) in the other cases | |||||
| ys_in_lens: (batch) | |||||
| Returns: | |||||
| (tuple): tuple containing: | |||||
| x: decoded token score before softmax (batch, maxlen_out, token) | |||||
| if use_output_layer is True, | |||||
| olens: (batch, ) | |||||
| """ | |||||
| tgt = ys_in_pad | |||||
| # tgt_mask: (B, 1, L) | |||||
| tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||||
| # m: (1, L, L) | |||||
| # m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||||
| # tgt_mask: (B, L, L) | |||||
| # tgt_mask = tgt_mask & m | |||||
| memory = hs_pad | |||||
| memory_mask = ( | |||||
| ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||||
| memory.device) | |||||
| # Padding for Longformer | |||||
| if memory_mask.shape[-1] != memory.shape[1]: | |||||
| padlen = memory.shape[1] - memory_mask.shape[-1] | |||||
| memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||||
| 'constant', False) | |||||
| # x = self.embed(tgt) | |||||
| x = tgt | |||||
| x, tgt_mask, memory, memory_mask = self.decoders( | |||||
| x, tgt_mask, memory, memory_mask) | |||||
| embeds_outputs = x | |||||
| if self.decoders2 is not None: | |||||
| x, tgt_mask, memory, memory_mask = self.decoders2( | |||||
| x, tgt_mask, memory, memory_mask) | |||||
| if self.normalize_before: | |||||
| x = self.after_norm(x) | |||||
| if self.output_layer is not None: | |||||
| x = self.output_layer(x) | |||||
| olens = tgt_mask.sum(1) | |||||
| return x, olens, embeds_outputs | |||||
| class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| self_attention_dropout_rate: float = 0.0, | |||||
| src_attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| conv_wshare: int = 4, | |||||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
| conv_usebias: int = False, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| if len(conv_kernel_length) != num_blocks: | |||||
| raise ValueError( | |||||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||||
| super().__init__( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| dropout_rate=dropout_rate, | |||||
| positional_dropout_rate=positional_dropout_rate, | |||||
| input_layer=input_layer, | |||||
| use_output_layer=use_output_layer, | |||||
| pos_enc_class=pos_enc_class, | |||||
| normalize_before=normalize_before, | |||||
| ) | |||||
| attention_dim = encoder_output_size | |||||
| self.decoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| LightweightConvolution( | |||||
| wshare=conv_wshare, | |||||
| n_feat=attention_dim, | |||||
| dropout_rate=self_attention_dropout_rate, | |||||
| kernel_size=conv_kernel_length[lnum], | |||||
| use_kernel_mask=True, | |||||
| use_bias=conv_usebias, | |||||
| ), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| self_attention_dropout_rate: float = 0.0, | |||||
| src_attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| conv_wshare: int = 4, | |||||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
| conv_usebias: int = False, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| if len(conv_kernel_length) != num_blocks: | |||||
| raise ValueError( | |||||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||||
| super().__init__( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| dropout_rate=dropout_rate, | |||||
| positional_dropout_rate=positional_dropout_rate, | |||||
| input_layer=input_layer, | |||||
| use_output_layer=use_output_layer, | |||||
| pos_enc_class=pos_enc_class, | |||||
| normalize_before=normalize_before, | |||||
| ) | |||||
| attention_dim = encoder_output_size | |||||
| self.decoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| LightweightConvolution2D( | |||||
| wshare=conv_wshare, | |||||
| n_feat=attention_dim, | |||||
| dropout_rate=self_attention_dropout_rate, | |||||
| kernel_size=conv_kernel_length[lnum], | |||||
| use_kernel_mask=True, | |||||
| use_bias=conv_usebias, | |||||
| ), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| self_attention_dropout_rate: float = 0.0, | |||||
| src_attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| conv_wshare: int = 4, | |||||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
| conv_usebias: int = False, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| if len(conv_kernel_length) != num_blocks: | |||||
| raise ValueError( | |||||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||||
| super().__init__( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| dropout_rate=dropout_rate, | |||||
| positional_dropout_rate=positional_dropout_rate, | |||||
| input_layer=input_layer, | |||||
| use_output_layer=use_output_layer, | |||||
| pos_enc_class=pos_enc_class, | |||||
| normalize_before=normalize_before, | |||||
| ) | |||||
| attention_dim = encoder_output_size | |||||
| self.decoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| DynamicConvolution( | |||||
| wshare=conv_wshare, | |||||
| n_feat=attention_dim, | |||||
| dropout_rate=self_attention_dropout_rate, | |||||
| kernel_size=conv_kernel_length[lnum], | |||||
| use_kernel_mask=True, | |||||
| use_bias=conv_usebias, | |||||
| ), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||||
| def __init__( | |||||
| self, | |||||
| vocab_size: int, | |||||
| encoder_output_size: int, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| self_attention_dropout_rate: float = 0.0, | |||||
| src_attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'embed', | |||||
| use_output_layer: bool = True, | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| conv_wshare: int = 4, | |||||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||||
| conv_usebias: int = False, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| if len(conv_kernel_length) != num_blocks: | |||||
| raise ValueError( | |||||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||||
| super().__init__( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| dropout_rate=dropout_rate, | |||||
| positional_dropout_rate=positional_dropout_rate, | |||||
| input_layer=input_layer, | |||||
| use_output_layer=use_output_layer, | |||||
| pos_enc_class=pos_enc_class, | |||||
| normalize_before=normalize_before, | |||||
| ) | |||||
| attention_dim = encoder_output_size | |||||
| self.decoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: DecoderLayer( | |||||
| attention_dim, | |||||
| DynamicConvolution2D( | |||||
| wshare=conv_wshare, | |||||
| n_feat=attention_dim, | |||||
| dropout_rate=self_attention_dropout_rate, | |||||
| kernel_size=conv_kernel_length[lnum], | |||||
| use_kernel_mask=True, | |||||
| use_bias=conv_usebias, | |||||
| ), | |||||
| MultiHeadedAttention(attention_heads, attention_dim, | |||||
| src_attention_dropout_rate), | |||||
| PositionwiseFeedForward(attention_dim, linear_units, | |||||
| dropout_rate), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| @@ -0,0 +1,710 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| """Conformer encoder definition.""" | |||||
| import logging | |||||
| from typing import List, Optional, Tuple, Union | |||||
| import torch | |||||
| from espnet2.asr.ctc import CTC | |||||
| from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||||
| from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule | |||||
| from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer | |||||
| from espnet.nets.pytorch_backend.nets_utils import (get_activation, | |||||
| make_pad_mask) | |||||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
| LegacyRelPositionalEncoding # noqa: H301 | |||||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
| PositionalEncoding # noqa: H301 | |||||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
| RelPositionalEncoding # noqa: H301 | |||||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
| ScaledPositionalEncoding # noqa: H301 | |||||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
| from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||||
| Conv1dLinear, MultiLayeredConv1d) | |||||
| from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||||
| PositionwiseFeedForward # noqa: H301 | |||||
| from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||||
| from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||||
| Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||||
| Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||||
| from typeguard import check_argument_types | |||||
| from ...nets.pytorch_backend.transformer.attention import \ | |||||
| LegacyRelPositionMultiHeadedAttention # noqa: H301 | |||||
| from ...nets.pytorch_backend.transformer.attention import \ | |||||
| MultiHeadedAttention # noqa: H301 | |||||
| from ...nets.pytorch_backend.transformer.attention import \ | |||||
| RelPositionMultiHeadedAttention # noqa: H301 | |||||
| from ...nets.pytorch_backend.transformer.attention import ( | |||||
| LegacyRelPositionMultiHeadedAttentionSANM, | |||||
| RelPositionMultiHeadedAttentionSANM) | |||||
| class ConformerEncoder(AbsEncoder): | |||||
| """Conformer encoder module. | |||||
| Args: | |||||
| input_size (int): Input dimension. | |||||
| output_size (int): Dimension of attention. | |||||
| attention_heads (int): The number of heads of multi head attention. | |||||
| linear_units (int): The number of units of position-wise feed forward. | |||||
| num_blocks (int): The number of decoder blocks. | |||||
| dropout_rate (float): Dropout rate. | |||||
| attention_dropout_rate (float): Dropout rate in attention. | |||||
| positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||||
| input_layer (Union[str, torch.nn.Module]): Input layer type. | |||||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||||
| concat_after (bool): Whether to concat attention layer's input and output. | |||||
| If True, additional linear will be applied. | |||||
| i.e. x -> x + linear(concat(x, att(x))) | |||||
| If False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
| positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||||
| positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||||
| rel_pos_type (str): Whether to use the latest relative positional encoding or | |||||
| the legacy one. The legacy relative positional encoding will be deprecated | |||||
| in the future. More Details can be found in | |||||
| https://github.com/espnet/espnet/pull/2816. | |||||
| encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||||
| encoder_attn_layer_type (str): Encoder attention layer type. | |||||
| activation_type (str): Encoder activation function type. | |||||
| macaron_style (bool): Whether to use macaron style for positionwise layer. | |||||
| use_cnn_module (bool): Whether to use convolution module. | |||||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
| cnn_module_kernel (int): Kernerl size of convolution module. | |||||
| padding_idx (int): Padding idx for input_layer=embed. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| input_size: int, | |||||
| output_size: int = 256, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'conv2d', | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| positionwise_layer_type: str = 'linear', | |||||
| positionwise_conv_kernel_size: int = 3, | |||||
| macaron_style: bool = False, | |||||
| rel_pos_type: str = 'legacy', | |||||
| pos_enc_layer_type: str = 'rel_pos', | |||||
| selfattention_layer_type: str = 'rel_selfattn', | |||||
| activation_type: str = 'swish', | |||||
| use_cnn_module: bool = True, | |||||
| zero_triu: bool = False, | |||||
| cnn_module_kernel: int = 31, | |||||
| padding_idx: int = -1, | |||||
| interctc_layer_idx: List[int] = [], | |||||
| interctc_use_conditioning: bool = False, | |||||
| stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__() | |||||
| self._output_size = output_size | |||||
| if rel_pos_type == 'legacy': | |||||
| if pos_enc_layer_type == 'rel_pos': | |||||
| pos_enc_layer_type = 'legacy_rel_pos' | |||||
| if selfattention_layer_type == 'rel_selfattn': | |||||
| selfattention_layer_type = 'legacy_rel_selfattn' | |||||
| elif rel_pos_type == 'latest': | |||||
| assert selfattention_layer_type != 'legacy_rel_selfattn' | |||||
| assert pos_enc_layer_type != 'legacy_rel_pos' | |||||
| else: | |||||
| raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||||
| activation = get_activation(activation_type) | |||||
| if pos_enc_layer_type == 'abs_pos': | |||||
| pos_enc_class = PositionalEncoding | |||||
| elif pos_enc_layer_type == 'scaled_abs_pos': | |||||
| pos_enc_class = ScaledPositionalEncoding | |||||
| elif pos_enc_layer_type == 'rel_pos': | |||||
| assert selfattention_layer_type == 'rel_selfattn' | |||||
| pos_enc_class = RelPositionalEncoding | |||||
| elif pos_enc_layer_type == 'legacy_rel_pos': | |||||
| assert selfattention_layer_type == 'legacy_rel_selfattn' | |||||
| pos_enc_class = LegacyRelPositionalEncoding | |||||
| else: | |||||
| raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||||
| if input_layer == 'linear': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Linear(input_size, output_size), | |||||
| torch.nn.LayerNorm(output_size), | |||||
| torch.nn.Dropout(dropout_rate), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d': | |||||
| self.embed = Conv2dSubsampling( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d2': | |||||
| self.embed = Conv2dSubsampling2( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d6': | |||||
| self.embed = Conv2dSubsampling6( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d8': | |||||
| self.embed = Conv2dSubsampling8( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'embed': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Embedding( | |||||
| input_size, output_size, padding_idx=padding_idx), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif isinstance(input_layer, torch.nn.Module): | |||||
| self.embed = torch.nn.Sequential( | |||||
| input_layer, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer is None: | |||||
| self.embed = torch.nn.Sequential( | |||||
| pos_enc_class(output_size, positional_dropout_rate)) | |||||
| else: | |||||
| raise ValueError('unknown input_layer: ' + input_layer) | |||||
| self.normalize_before = normalize_before | |||||
| if positionwise_layer_type == 'linear': | |||||
| positionwise_layer = PositionwiseFeedForward | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| dropout_rate, | |||||
| activation, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d': | |||||
| positionwise_layer = MultiLayeredConv1d | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d-linear': | |||||
| positionwise_layer = Conv1dLinear | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| else: | |||||
| raise NotImplementedError('Support only linear or conv1d.') | |||||
| if selfattention_layer_type == 'selfattn': | |||||
| encoder_selfattn_layer = MultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| ) | |||||
| elif selfattention_layer_type == 'legacy_rel_selfattn': | |||||
| assert pos_enc_layer_type == 'legacy_rel_pos' | |||||
| encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| ) | |||||
| elif selfattention_layer_type == 'rel_selfattn': | |||||
| assert pos_enc_layer_type == 'rel_pos' | |||||
| encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| zero_triu, | |||||
| ) | |||||
| else: | |||||
| raise ValueError('unknown encoder_attn_layer: ' | |||||
| + selfattention_layer_type) | |||||
| convolution_layer = ConvolutionModule | |||||
| convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||||
| if isinstance(stochastic_depth_rate, float): | |||||
| stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||||
| if len(stochastic_depth_rate) != num_blocks: | |||||
| raise ValueError( | |||||
| f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||||
| f'should be equal to num_blocks ({num_blocks})') | |||||
| self.encoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: EncoderLayer( | |||||
| output_size, | |||||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
| positionwise_layer(*positionwise_layer_args), | |||||
| positionwise_layer(*positionwise_layer_args) | |||||
| if macaron_style else None, | |||||
| convolution_layer(*convolution_layer_args) | |||||
| if use_cnn_module else None, | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| stochastic_depth_rate[lnum], | |||||
| ), | |||||
| ) | |||||
| if self.normalize_before: | |||||
| self.after_norm = LayerNorm(output_size) | |||||
| self.interctc_layer_idx = interctc_layer_idx | |||||
| if len(interctc_layer_idx) > 0: | |||||
| assert 0 < min(interctc_layer_idx) and max( | |||||
| interctc_layer_idx) < num_blocks | |||||
| self.interctc_use_conditioning = interctc_use_conditioning | |||||
| self.conditioning_layer = None | |||||
| def output_size(self) -> int: | |||||
| return self._output_size | |||||
| def forward( | |||||
| self, | |||||
| xs_pad: torch.Tensor, | |||||
| ilens: torch.Tensor, | |||||
| prev_states: torch.Tensor = None, | |||||
| ctc: CTC = None, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
| """Calculate forward propagation. | |||||
| Args: | |||||
| xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||||
| ilens (torch.Tensor): Input length (#batch). | |||||
| prev_states (torch.Tensor): Not to be used now. | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, L, output_size). | |||||
| torch.Tensor: Output length (#batch). | |||||
| torch.Tensor: Not to be used now. | |||||
| """ | |||||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
| if (isinstance(self.embed, Conv2dSubsampling) | |||||
| or isinstance(self.embed, Conv2dSubsampling2) | |||||
| or isinstance(self.embed, Conv2dSubsampling6) | |||||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||||
| short_status, limit_size = check_short_utt(self.embed, | |||||
| xs_pad.size(1)) | |||||
| if short_status: | |||||
| raise TooShortUttError( | |||||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
| + # noqa: * | |||||
| f'(it needs more than {limit_size} frames), return empty results', # noqa: * | |||||
| xs_pad.size(1), | |||||
| limit_size) # noqa: * | |||||
| xs_pad, masks = self.embed(xs_pad, masks) | |||||
| else: | |||||
| xs_pad = self.embed(xs_pad) | |||||
| intermediate_outs = [] | |||||
| if len(self.interctc_layer_idx) == 0: | |||||
| xs_pad, masks = self.encoders(xs_pad, masks) | |||||
| else: | |||||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
| xs_pad, masks = encoder_layer(xs_pad, masks) | |||||
| if layer_idx + 1 in self.interctc_layer_idx: | |||||
| encoder_out = xs_pad | |||||
| if isinstance(encoder_out, tuple): | |||||
| encoder_out = encoder_out[0] | |||||
| # intermediate outputs are also normalized | |||||
| if self.normalize_before: | |||||
| encoder_out = self.after_norm(encoder_out) | |||||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
| if self.interctc_use_conditioning: | |||||
| ctc_out = ctc.softmax(encoder_out) | |||||
| if isinstance(xs_pad, tuple): | |||||
| x, pos_emb = xs_pad | |||||
| x = x + self.conditioning_layer(ctc_out) | |||||
| xs_pad = (x, pos_emb) | |||||
| else: | |||||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
| if isinstance(xs_pad, tuple): | |||||
| xs_pad = xs_pad[0] | |||||
| if self.normalize_before: | |||||
| xs_pad = self.after_norm(xs_pad) | |||||
| olens = masks.squeeze(1).sum(1) | |||||
| if len(intermediate_outs) > 0: | |||||
| return (xs_pad, intermediate_outs), olens, None | |||||
| return xs_pad, olens, None | |||||
| class SANMEncoder_v2(AbsEncoder): | |||||
| """Conformer encoder module. | |||||
| Args: | |||||
| input_size (int): Input dimension. | |||||
| output_size (int): Dimension of attention. | |||||
| attention_heads (int): The number of heads of multi head attention. | |||||
| linear_units (int): The number of units of position-wise feed forward. | |||||
| num_blocks (int): The number of decoder blocks. | |||||
| dropout_rate (float): Dropout rate. | |||||
| attention_dropout_rate (float): Dropout rate in attention. | |||||
| positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||||
| input_layer (Union[str, torch.nn.Module]): Input layer type. | |||||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||||
| concat_after (bool): Whether to concat attention layer's input and output. | |||||
| If True, additional linear will be applied. | |||||
| i.e. x -> x + linear(concat(x, att(x))) | |||||
| If False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
| positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||||
| positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||||
| rel_pos_type (str): Whether to use the latest relative positional encoding or | |||||
| the legacy one. The legacy relative positional encoding will be deprecated | |||||
| in the future. More Details can be found in | |||||
| https://github.com/espnet/espnet/pull/2816. | |||||
| encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||||
| encoder_attn_layer_type (str): Encoder attention layer type. | |||||
| activation_type (str): Encoder activation function type. | |||||
| macaron_style (bool): Whether to use macaron style for positionwise layer. | |||||
| use_cnn_module (bool): Whether to use convolution module. | |||||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
| cnn_module_kernel (int): Kernerl size of convolution module. | |||||
| padding_idx (int): Padding idx for input_layer=embed. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| input_size: int, | |||||
| output_size: int = 256, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| attention_dropout_rate: float = 0.0, | |||||
| input_layer: str = 'conv2d', | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| positionwise_layer_type: str = 'linear', | |||||
| positionwise_conv_kernel_size: int = 3, | |||||
| macaron_style: bool = False, | |||||
| rel_pos_type: str = 'legacy', | |||||
| pos_enc_layer_type: str = 'rel_pos', | |||||
| selfattention_layer_type: str = 'rel_selfattn', | |||||
| activation_type: str = 'swish', | |||||
| use_cnn_module: bool = False, | |||||
| sanm_shfit: int = 0, | |||||
| zero_triu: bool = False, | |||||
| cnn_module_kernel: int = 31, | |||||
| padding_idx: int = -1, | |||||
| interctc_layer_idx: List[int] = [], | |||||
| interctc_use_conditioning: bool = False, | |||||
| stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__() | |||||
| self._output_size = output_size | |||||
| if rel_pos_type == 'legacy': | |||||
| if pos_enc_layer_type == 'rel_pos': | |||||
| pos_enc_layer_type = 'legacy_rel_pos' | |||||
| if selfattention_layer_type == 'rel_selfattn': | |||||
| selfattention_layer_type = 'legacy_rel_selfattn' | |||||
| if selfattention_layer_type == 'rel_selfattnsanm': | |||||
| selfattention_layer_type = 'legacy_rel_selfattnsanm' | |||||
| elif rel_pos_type == 'latest': | |||||
| assert selfattention_layer_type != 'legacy_rel_selfattn' | |||||
| assert pos_enc_layer_type != 'legacy_rel_pos' | |||||
| else: | |||||
| raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||||
| activation = get_activation(activation_type) | |||||
| if pos_enc_layer_type == 'abs_pos': | |||||
| pos_enc_class = PositionalEncoding | |||||
| elif pos_enc_layer_type == 'scaled_abs_pos': | |||||
| pos_enc_class = ScaledPositionalEncoding | |||||
| elif pos_enc_layer_type == 'rel_pos': | |||||
| # assert selfattention_layer_type == "rel_selfattn" | |||||
| pos_enc_class = RelPositionalEncoding | |||||
| elif pos_enc_layer_type == 'legacy_rel_pos': | |||||
| # assert selfattention_layer_type == "legacy_rel_selfattn" | |||||
| pos_enc_class = LegacyRelPositionalEncoding | |||||
| logging.warning( | |||||
| 'Using legacy_rel_pos and it will be deprecated in the future.' | |||||
| ) | |||||
| else: | |||||
| raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||||
| if input_layer == 'linear': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Linear(input_size, output_size), | |||||
| torch.nn.LayerNorm(output_size), | |||||
| torch.nn.Dropout(dropout_rate), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d': | |||||
| self.embed = Conv2dSubsampling( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d2': | |||||
| self.embed = Conv2dSubsampling2( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d6': | |||||
| self.embed = Conv2dSubsampling6( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d8': | |||||
| self.embed = Conv2dSubsampling8( | |||||
| input_size, | |||||
| output_size, | |||||
| dropout_rate, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'embed': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Embedding( | |||||
| input_size, output_size, padding_idx=padding_idx), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif isinstance(input_layer, torch.nn.Module): | |||||
| self.embed = torch.nn.Sequential( | |||||
| input_layer, | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer is None: | |||||
| self.embed = torch.nn.Sequential( | |||||
| pos_enc_class(output_size, positional_dropout_rate)) | |||||
| else: | |||||
| raise ValueError('unknown input_layer: ' + input_layer) | |||||
| self.normalize_before = normalize_before | |||||
| if positionwise_layer_type == 'linear': | |||||
| positionwise_layer = PositionwiseFeedForward | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| dropout_rate, | |||||
| activation, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d': | |||||
| positionwise_layer = MultiLayeredConv1d | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d-linear': | |||||
| positionwise_layer = Conv1dLinear | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| else: | |||||
| raise NotImplementedError('Support only linear or conv1d.') | |||||
| if selfattention_layer_type == 'selfattn': | |||||
| encoder_selfattn_layer = MultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| ) | |||||
| elif selfattention_layer_type == 'legacy_rel_selfattn': | |||||
| assert pos_enc_layer_type == 'legacy_rel_pos' | |||||
| encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| ) | |||||
| logging.warning( | |||||
| 'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||||
| ) | |||||
| elif selfattention_layer_type == 'legacy_rel_selfattnsanm': | |||||
| assert pos_enc_layer_type == 'legacy_rel_pos' | |||||
| encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttentionSANM | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| ) | |||||
| logging.warning( | |||||
| 'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||||
| ) | |||||
| elif selfattention_layer_type == 'rel_selfattn': | |||||
| assert pos_enc_layer_type == 'rel_pos' | |||||
| encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| zero_triu, | |||||
| ) | |||||
| elif selfattention_layer_type == 'rel_selfattnsanm': | |||||
| assert pos_enc_layer_type == 'rel_pos' | |||||
| encoder_selfattn_layer = RelPositionMultiHeadedAttentionSANM | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| zero_triu, | |||||
| cnn_module_kernel, | |||||
| sanm_shfit, | |||||
| ) | |||||
| else: | |||||
| raise ValueError('unknown encoder_attn_layer: ' | |||||
| + selfattention_layer_type) | |||||
| convolution_layer = ConvolutionModule | |||||
| convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||||
| if isinstance(stochastic_depth_rate, float): | |||||
| stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||||
| if len(stochastic_depth_rate) != num_blocks: | |||||
| raise ValueError( | |||||
| f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||||
| f'should be equal to num_blocks ({num_blocks})') | |||||
| self.encoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: EncoderLayer( | |||||
| output_size, | |||||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
| positionwise_layer(*positionwise_layer_args), | |||||
| positionwise_layer(*positionwise_layer_args) | |||||
| if macaron_style else None, | |||||
| convolution_layer(*convolution_layer_args) | |||||
| if use_cnn_module else None, | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| stochastic_depth_rate[lnum], | |||||
| ), | |||||
| ) | |||||
| if self.normalize_before: | |||||
| self.after_norm = LayerNorm(output_size) | |||||
| self.interctc_layer_idx = interctc_layer_idx | |||||
| if len(interctc_layer_idx) > 0: | |||||
| assert 0 < min(interctc_layer_idx) and max( | |||||
| interctc_layer_idx) < num_blocks | |||||
| self.interctc_use_conditioning = interctc_use_conditioning | |||||
| self.conditioning_layer = None | |||||
| def output_size(self) -> int: | |||||
| return self._output_size | |||||
| def forward( | |||||
| self, | |||||
| xs_pad: torch.Tensor, | |||||
| ilens: torch.Tensor, | |||||
| prev_states: torch.Tensor = None, | |||||
| ctc: CTC = None, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
| """Calculate forward propagation. | |||||
| Args: | |||||
| xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||||
| ilens (torch.Tensor): Input length (#batch). | |||||
| prev_states (torch.Tensor): Not to be used now. | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, L, output_size). | |||||
| torch.Tensor: Output length (#batch). | |||||
| torch.Tensor: Not to be used now. | |||||
| """ | |||||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
| if (isinstance(self.embed, Conv2dSubsampling) | |||||
| or isinstance(self.embed, Conv2dSubsampling2) | |||||
| or isinstance(self.embed, Conv2dSubsampling6) | |||||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||||
| short_status, limit_size = check_short_utt(self.embed, | |||||
| xs_pad.size(1)) | |||||
| if short_status: | |||||
| raise TooShortUttError( | |||||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
| + # noqa: * | |||||
| f'(it needs more than {limit_size} frames), return empty results', | |||||
| xs_pad.size(1), | |||||
| limit_size) # noqa: * | |||||
| xs_pad, masks = self.embed(xs_pad, masks) | |||||
| else: | |||||
| xs_pad = self.embed(xs_pad) | |||||
| intermediate_outs = [] | |||||
| if len(self.interctc_layer_idx) == 0: | |||||
| xs_pad, masks = self.encoders(xs_pad, masks) | |||||
| else: | |||||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
| xs_pad, masks = encoder_layer(xs_pad, masks) | |||||
| if layer_idx + 1 in self.interctc_layer_idx: | |||||
| encoder_out = xs_pad | |||||
| if isinstance(encoder_out, tuple): | |||||
| encoder_out = encoder_out[0] | |||||
| # intermediate outputs are also normalized | |||||
| if self.normalize_before: | |||||
| encoder_out = self.after_norm(encoder_out) | |||||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
| if self.interctc_use_conditioning: | |||||
| ctc_out = ctc.softmax(encoder_out) | |||||
| if isinstance(xs_pad, tuple): | |||||
| x, pos_emb = xs_pad | |||||
| x = x + self.conditioning_layer(ctc_out) | |||||
| xs_pad = (x, pos_emb) | |||||
| else: | |||||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
| if isinstance(xs_pad, tuple): | |||||
| xs_pad = xs_pad[0] | |||||
| if self.normalize_before: | |||||
| xs_pad = self.after_norm(xs_pad) | |||||
| olens = masks.squeeze(1).sum(1) | |||||
| if len(intermediate_outs) > 0: | |||||
| return (xs_pad, intermediate_outs), olens, None | |||||
| return xs_pad, olens, None | |||||
| @@ -0,0 +1,500 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| """Transformer encoder definition.""" | |||||
| import logging | |||||
| from typing import List, Optional, Sequence, Tuple, Union | |||||
| import torch | |||||
| from espnet2.asr.ctc import CTC | |||||
| from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||||
| PositionalEncoding | |||||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
| from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||||
| Conv1dLinear, MultiLayeredConv1d) | |||||
| from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||||
| PositionwiseFeedForward # noqa: H301 | |||||
| from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||||
| from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||||
| Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||||
| Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||||
| from typeguard import check_argument_types | |||||
| from ...asr.streaming_utilis.chunk_utilis import overlap_chunk | |||||
| from ...nets.pytorch_backend.transformer.attention import ( | |||||
| MultiHeadedAttention, MultiHeadedAttentionSANM) | |||||
| from ...nets.pytorch_backend.transformer.encoder_layer import ( | |||||
| EncoderLayer, EncoderLayerChunk) | |||||
| class SANMEncoder(AbsEncoder): | |||||
| """Transformer encoder module. | |||||
| Args: | |||||
| input_size: input dim | |||||
| output_size: dimension of attention | |||||
| attention_heads: the number of heads of multi head attention | |||||
| linear_units: the number of units of position-wise feed forward | |||||
| num_blocks: the number of decoder blocks | |||||
| dropout_rate: dropout rate | |||||
| attention_dropout_rate: dropout rate in attention | |||||
| positional_dropout_rate: dropout rate after adding positional encoding | |||||
| input_layer: input layer type | |||||
| pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||||
| normalize_before: whether to use layer_norm before the first block | |||||
| concat_after: whether to concat attention layer's input and output | |||||
| if True, additional linear will be applied. | |||||
| i.e. x -> x + linear(concat(x, att(x))) | |||||
| if False, no additional linear will be applied. | |||||
| i.e. x -> x + att(x) | |||||
| positionwise_layer_type: linear of conv1d | |||||
| positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||||
| padding_idx: padding_idx for input_layer=embed | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| input_size: int, | |||||
| output_size: int = 256, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| attention_dropout_rate: float = 0.0, | |||||
| input_layer: Optional[str] = 'conv2d', | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| positionwise_layer_type: str = 'linear', | |||||
| positionwise_conv_kernel_size: int = 1, | |||||
| padding_idx: int = -1, | |||||
| interctc_layer_idx: List[int] = [], | |||||
| interctc_use_conditioning: bool = False, | |||||
| kernel_size: int = 11, | |||||
| sanm_shfit: int = 0, | |||||
| selfattention_layer_type: str = 'sanm', | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__() | |||||
| self._output_size = output_size | |||||
| if input_layer == 'linear': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Linear(input_size, output_size), | |||||
| torch.nn.LayerNorm(output_size), | |||||
| torch.nn.Dropout(dropout_rate), | |||||
| torch.nn.ReLU(), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d': | |||||
| self.embed = Conv2dSubsampling(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'conv2d2': | |||||
| self.embed = Conv2dSubsampling2(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'conv2d6': | |||||
| self.embed = Conv2dSubsampling6(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'conv2d8': | |||||
| self.embed = Conv2dSubsampling8(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'embed': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Embedding( | |||||
| input_size, output_size, padding_idx=padding_idx), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer is None: | |||||
| if input_size == output_size: | |||||
| self.embed = None | |||||
| else: | |||||
| self.embed = torch.nn.Linear(input_size, output_size) | |||||
| else: | |||||
| raise ValueError('unknown input_layer: ' + input_layer) | |||||
| self.normalize_before = normalize_before | |||||
| if positionwise_layer_type == 'linear': | |||||
| positionwise_layer = PositionwiseFeedForward | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| dropout_rate, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d': | |||||
| positionwise_layer = MultiLayeredConv1d | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d-linear': | |||||
| positionwise_layer = Conv1dLinear | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| else: | |||||
| raise NotImplementedError('Support only linear or conv1d.') | |||||
| if selfattention_layer_type == 'selfattn': | |||||
| encoder_selfattn_layer = MultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| ) | |||||
| elif selfattention_layer_type == 'sanm': | |||||
| encoder_selfattn_layer = MultiHeadedAttentionSANM | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| kernel_size, | |||||
| sanm_shfit, | |||||
| ) | |||||
| self.encoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: EncoderLayer( | |||||
| output_size, | |||||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
| positionwise_layer(*positionwise_layer_args), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| if self.normalize_before: | |||||
| self.after_norm = LayerNorm(output_size) | |||||
| self.interctc_layer_idx = interctc_layer_idx | |||||
| if len(interctc_layer_idx) > 0: | |||||
| assert 0 < min(interctc_layer_idx) and max( | |||||
| interctc_layer_idx) < num_blocks | |||||
| self.interctc_use_conditioning = interctc_use_conditioning | |||||
| self.conditioning_layer = None | |||||
| def output_size(self) -> int: | |||||
| return self._output_size | |||||
| def forward( | |||||
| self, | |||||
| xs_pad: torch.Tensor, | |||||
| ilens: torch.Tensor, | |||||
| prev_states: torch.Tensor = None, | |||||
| ctc: CTC = None, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
| """Embed positions in tensor. | |||||
| Args: | |||||
| xs_pad: input tensor (B, L, D) | |||||
| ilens: input length (B) | |||||
| prev_states: Not to be used now. | |||||
| Returns: | |||||
| position embedded tensor and mask | |||||
| """ | |||||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
| if self.embed is None: | |||||
| xs_pad = xs_pad | |||||
| elif (isinstance(self.embed, Conv2dSubsampling) | |||||
| or isinstance(self.embed, Conv2dSubsampling2) | |||||
| or isinstance(self.embed, Conv2dSubsampling6) | |||||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||||
| short_status, limit_size = check_short_utt(self.embed, | |||||
| xs_pad.size(1)) | |||||
| if short_status: | |||||
| raise TooShortUttError( | |||||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
| + # noqa: * | |||||
| f'(it needs more than {limit_size} frames), return empty results', | |||||
| xs_pad.size(1), | |||||
| limit_size, | |||||
| ) | |||||
| xs_pad, masks = self.embed(xs_pad, masks) | |||||
| else: | |||||
| xs_pad = self.embed(xs_pad) | |||||
| intermediate_outs = [] | |||||
| if len(self.interctc_layer_idx) == 0: | |||||
| xs_pad, masks = self.encoders(xs_pad, masks) | |||||
| else: | |||||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
| xs_pad, masks = encoder_layer(xs_pad, masks) | |||||
| if layer_idx + 1 in self.interctc_layer_idx: | |||||
| encoder_out = xs_pad | |||||
| # intermediate outputs are also normalized | |||||
| if self.normalize_before: | |||||
| encoder_out = self.after_norm(encoder_out) | |||||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
| if self.interctc_use_conditioning: | |||||
| ctc_out = ctc.softmax(encoder_out) | |||||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
| if self.normalize_before: | |||||
| xs_pad = self.after_norm(xs_pad) | |||||
| olens = masks.squeeze(1).sum(1) | |||||
| if len(intermediate_outs) > 0: | |||||
| return (xs_pad, intermediate_outs), olens, None | |||||
| return xs_pad, olens, None | |||||
| class SANMEncoderChunk(AbsEncoder): | |||||
| """Transformer encoder module. | |||||
| Args: | |||||
| input_size: input dim | |||||
| output_size: dimension of attention | |||||
| attention_heads: the number of heads of multi head attention | |||||
| linear_units: the number of units of position-wise feed forward | |||||
| num_blocks: the number of decoder blocks | |||||
| dropout_rate: dropout rate | |||||
| attention_dropout_rate: dropout rate in attention | |||||
| positional_dropout_rate: dropout rate after adding positional encoding | |||||
| input_layer: input layer type | |||||
| pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||||
| normalize_before: whether to use layer_norm before the first block | |||||
| concat_after: whether to concat attention layer's input and output | |||||
| if True, additional linear will be applied. | |||||
| i.e. x -> x + linear(concat(x, att(x))) | |||||
| if False, no additional linear will be applied. | |||||
| i.e. x -> x + att(x) | |||||
| positionwise_layer_type: linear of conv1d | |||||
| positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||||
| padding_idx: padding_idx for input_layer=embed | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| input_size: int, | |||||
| output_size: int = 256, | |||||
| attention_heads: int = 4, | |||||
| linear_units: int = 2048, | |||||
| num_blocks: int = 6, | |||||
| dropout_rate: float = 0.1, | |||||
| positional_dropout_rate: float = 0.1, | |||||
| attention_dropout_rate: float = 0.0, | |||||
| input_layer: Optional[str] = 'conv2d', | |||||
| pos_enc_class=PositionalEncoding, | |||||
| normalize_before: bool = True, | |||||
| concat_after: bool = False, | |||||
| positionwise_layer_type: str = 'linear', | |||||
| positionwise_conv_kernel_size: int = 1, | |||||
| padding_idx: int = -1, | |||||
| interctc_layer_idx: List[int] = [], | |||||
| interctc_use_conditioning: bool = False, | |||||
| kernel_size: int = 11, | |||||
| sanm_shfit: int = 0, | |||||
| selfattention_layer_type: str = 'sanm', | |||||
| chunk_size: Union[int, Sequence[int]] = (16, ), | |||||
| stride: Union[int, Sequence[int]] = (10, ), | |||||
| pad_left: Union[int, Sequence[int]] = (0, ), | |||||
| encoder_att_look_back_factor: Union[int, Sequence[int]] = (1, ), | |||||
| ): | |||||
| assert check_argument_types() | |||||
| super().__init__() | |||||
| self._output_size = output_size | |||||
| if input_layer == 'linear': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Linear(input_size, output_size), | |||||
| torch.nn.LayerNorm(output_size), | |||||
| torch.nn.Dropout(dropout_rate), | |||||
| torch.nn.ReLU(), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer == 'conv2d': | |||||
| self.embed = Conv2dSubsampling(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'conv2d2': | |||||
| self.embed = Conv2dSubsampling2(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'conv2d6': | |||||
| self.embed = Conv2dSubsampling6(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'conv2d8': | |||||
| self.embed = Conv2dSubsampling8(input_size, output_size, | |||||
| dropout_rate) | |||||
| elif input_layer == 'embed': | |||||
| self.embed = torch.nn.Sequential( | |||||
| torch.nn.Embedding( | |||||
| input_size, output_size, padding_idx=padding_idx), | |||||
| pos_enc_class(output_size, positional_dropout_rate), | |||||
| ) | |||||
| elif input_layer is None: | |||||
| if input_size == output_size: | |||||
| self.embed = None | |||||
| else: | |||||
| self.embed = torch.nn.Linear(input_size, output_size) | |||||
| else: | |||||
| raise ValueError('unknown input_layer: ' + input_layer) | |||||
| self.normalize_before = normalize_before | |||||
| if positionwise_layer_type == 'linear': | |||||
| positionwise_layer = PositionwiseFeedForward | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| dropout_rate, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d': | |||||
| positionwise_layer = MultiLayeredConv1d | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| elif positionwise_layer_type == 'conv1d-linear': | |||||
| positionwise_layer = Conv1dLinear | |||||
| positionwise_layer_args = ( | |||||
| output_size, | |||||
| linear_units, | |||||
| positionwise_conv_kernel_size, | |||||
| dropout_rate, | |||||
| ) | |||||
| else: | |||||
| raise NotImplementedError('Support only linear or conv1d.') | |||||
| if selfattention_layer_type == 'selfattn': | |||||
| encoder_selfattn_layer = MultiHeadedAttention | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| ) | |||||
| elif selfattention_layer_type == 'sanm': | |||||
| encoder_selfattn_layer = MultiHeadedAttentionSANM | |||||
| encoder_selfattn_layer_args = ( | |||||
| attention_heads, | |||||
| output_size, | |||||
| attention_dropout_rate, | |||||
| kernel_size, | |||||
| sanm_shfit, | |||||
| ) | |||||
| self.encoders = repeat( | |||||
| num_blocks, | |||||
| lambda lnum: EncoderLayerChunk( | |||||
| output_size, | |||||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||||
| positionwise_layer(*positionwise_layer_args), | |||||
| dropout_rate, | |||||
| normalize_before, | |||||
| concat_after, | |||||
| ), | |||||
| ) | |||||
| if self.normalize_before: | |||||
| self.after_norm = LayerNorm(output_size) | |||||
| self.interctc_layer_idx = interctc_layer_idx | |||||
| if len(interctc_layer_idx) > 0: | |||||
| assert 0 < min(interctc_layer_idx) and max( | |||||
| interctc_layer_idx) < num_blocks | |||||
| self.interctc_use_conditioning = interctc_use_conditioning | |||||
| self.conditioning_layer = None | |||||
| shfit_fsmn = (kernel_size - 1) // 2 | |||||
| self.overlap_chunk_cls = overlap_chunk( | |||||
| chunk_size=chunk_size, | |||||
| stride=stride, | |||||
| pad_left=pad_left, | |||||
| shfit_fsmn=shfit_fsmn, | |||||
| encoder_att_look_back_factor=encoder_att_look_back_factor, | |||||
| ) | |||||
| def output_size(self) -> int: | |||||
| return self._output_size | |||||
| def forward( | |||||
| self, | |||||
| xs_pad: torch.Tensor, | |||||
| ilens: torch.Tensor, | |||||
| prev_states: torch.Tensor = None, | |||||
| ctc: CTC = None, | |||||
| ind: int = 0, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
| """Embed positions in tensor. | |||||
| Args: | |||||
| xs_pad: input tensor (B, L, D) | |||||
| ilens: input length (B) | |||||
| prev_states: Not to be used now. | |||||
| Returns: | |||||
| position embedded tensor and mask | |||||
| """ | |||||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
| if self.embed is None: | |||||
| xs_pad = xs_pad | |||||
| elif (isinstance(self.embed, Conv2dSubsampling) | |||||
| or isinstance(self.embed, Conv2dSubsampling2) | |||||
| or isinstance(self.embed, Conv2dSubsampling6) | |||||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||||
| short_status, limit_size = check_short_utt(self.embed, | |||||
| xs_pad.size(1)) | |||||
| if short_status: | |||||
| raise TooShortUttError( | |||||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||||
| + # noqa: * | |||||
| f'(it needs more than {limit_size} frames), return empty results', | |||||
| xs_pad.size(1), | |||||
| limit_size, | |||||
| ) | |||||
| xs_pad, masks = self.embed(xs_pad, masks) | |||||
| else: | |||||
| xs_pad = self.embed(xs_pad) | |||||
| mask_shfit_chunk, mask_att_chunk_encoder = None, None | |||||
| if self.overlap_chunk_cls is not None: | |||||
| ilens = masks.squeeze(1).sum(1) | |||||
| chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) | |||||
| xs_pad, ilens = self.overlap_chunk_cls.split_chunk( | |||||
| xs_pad, ilens, chunk_outs=chunk_outs) | |||||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||||
| mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk( | |||||
| chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||||
| mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder( | |||||
| chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||||
| intermediate_outs = [] | |||||
| if len(self.interctc_layer_idx) == 0: | |||||
| xs_pad, masks, _, _, _ = self.encoders(xs_pad, masks, None, | |||||
| mask_shfit_chunk, | |||||
| mask_att_chunk_encoder) | |||||
| else: | |||||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||||
| xs_pad, masks, _, _, _ = encoder_layer(xs_pad, masks, None, | |||||
| mask_shfit_chunk, | |||||
| mask_att_chunk_encoder) | |||||
| if layer_idx + 1 in self.interctc_layer_idx: | |||||
| encoder_out = xs_pad | |||||
| # intermediate outputs are also normalized | |||||
| if self.normalize_before: | |||||
| encoder_out = self.after_norm(encoder_out) | |||||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||||
| if self.interctc_use_conditioning: | |||||
| ctc_out = ctc.softmax(encoder_out) | |||||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||||
| if self.normalize_before: | |||||
| xs_pad = self.after_norm(xs_pad) | |||||
| if self.overlap_chunk_cls is not None: | |||||
| xs_pad, olens = self.overlap_chunk_cls.remove_chunk( | |||||
| xs_pad, ilens, chunk_outs) | |||||
| if len(intermediate_outs) > 0: | |||||
| return (xs_pad, intermediate_outs), olens, None | |||||
| return xs_pad, olens, None | |||||
| @@ -0,0 +1,321 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| import logging | |||||
| import math | |||||
| import numpy as np | |||||
| import torch | |||||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
| from ...nets.pytorch_backend.cif_utils.cif import \ | |||||
| cif_predictor as cif_predictor | |||||
| np.set_printoptions(threshold=np.inf) | |||||
| torch.set_printoptions(profile='full', precision=100000, linewidth=None) | |||||
| def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device='cpu'): | |||||
| if maxlen is None: | |||||
| maxlen = lengths.max() | |||||
| row_vector = torch.arange(0, maxlen, 1) | |||||
| matrix = torch.unsqueeze(lengths, dim=-1) | |||||
| mask = row_vector < matrix | |||||
| return mask.type(dtype).to(device) | |||||
| class overlap_chunk(): | |||||
| def __init__( | |||||
| self, | |||||
| chunk_size: tuple = (16, ), | |||||
| stride: tuple = (10, ), | |||||
| pad_left: tuple = (0, ), | |||||
| encoder_att_look_back_factor: tuple = (1, ), | |||||
| shfit_fsmn: int = 0, | |||||
| ): | |||||
| self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor \ | |||||
| = chunk_size, stride, pad_left, encoder_att_look_back_factor | |||||
| self.shfit_fsmn = shfit_fsmn | |||||
| self.x_add_mask = None | |||||
| self.x_rm_mask = None | |||||
| self.x_len = None | |||||
| self.mask_shfit_chunk = None | |||||
| self.mask_chunk_predictor = None | |||||
| self.mask_att_chunk_encoder = None | |||||
| self.mask_shift_att_chunk_decoder = None | |||||
| self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur \ | |||||
| = None, None, None, None | |||||
| def get_chunk_size(self, ind: int = 0): | |||||
| # with torch.no_grad: | |||||
| chunk_size, stride, pad_left, encoder_att_look_back_factor = self.chunk_size[ | |||||
| ind], self.stride[ind], self.pad_left[ | |||||
| ind], self.encoder_att_look_back_factor[ind] | |||||
| self.chunk_size_cur, self.stride_cur, self.pad_left_cur, | |||||
| self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \ | |||||
| = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn | |||||
| return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur | |||||
| def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1): | |||||
| with torch.no_grad(): | |||||
| x_len = x_len.cpu().numpy() | |||||
| x_len_max = x_len.max() | |||||
| chunk_size, stride, pad_left, encoder_att_look_back_factor = self.get_chunk_size( | |||||
| ind) | |||||
| shfit_fsmn = self.shfit_fsmn | |||||
| chunk_size_pad_shift = chunk_size + shfit_fsmn | |||||
| chunk_num_batch = np.ceil(x_len / stride).astype(np.int32) | |||||
| x_len_chunk = ( | |||||
| chunk_num_batch - 1 | |||||
| ) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - ( | |||||
| chunk_num_batch - 1) * stride | |||||
| x_len_chunk = x_len_chunk.astype(x_len.dtype) | |||||
| x_len_chunk_max = x_len_chunk.max() | |||||
| chunk_num = int(math.ceil(x_len_max / stride)) | |||||
| dtype = np.int32 | |||||
| max_len_for_x_mask_tmp = max(chunk_size, x_len_max) | |||||
| x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) | |||||
| x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) | |||||
| mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) | |||||
| mask_chunk_predictor = np.zeros([0, num_units_predictor], | |||||
| dtype=dtype) | |||||
| mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) | |||||
| mask_att_chunk_encoder = np.zeros( | |||||
| [0, chunk_num * chunk_size_pad_shift], dtype=dtype) | |||||
| for chunk_ids in range(chunk_num): | |||||
| # x_mask add | |||||
| fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), | |||||
| dtype=dtype) | |||||
| x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) | |||||
| x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), | |||||
| dtype=dtype) | |||||
| x_mask_pad_right = np.zeros( | |||||
| (chunk_size, max_len_for_x_mask_tmp), dtype=dtype) | |||||
| x_cur_pad = np.concatenate( | |||||
| [x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1) | |||||
| x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp] | |||||
| x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], | |||||
| axis=0) | |||||
| x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], | |||||
| axis=0) | |||||
| # x_mask rm | |||||
| fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), | |||||
| dtype=dtype) | |||||
| x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) | |||||
| x_mask_right = np.zeros((stride, chunk_size - stride), | |||||
| dtype=dtype) | |||||
| x_mask_cur = np.concatenate([x_mask_cur, x_mask_right], axis=1) | |||||
| x_mask_cur_pad_top = np.zeros((chunk_ids * stride, chunk_size), | |||||
| dtype=dtype) | |||||
| x_mask_cur_pad_bottom = np.zeros( | |||||
| (max_len_for_x_mask_tmp, chunk_size), dtype=dtype) | |||||
| x_rm_mask_cur = np.concatenate( | |||||
| [x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], | |||||
| axis=0) | |||||
| x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, : | |||||
| chunk_size] | |||||
| x_rm_mask_cur_fsmn = np.concatenate( | |||||
| [fsmn_padding, x_rm_mask_cur], axis=1) | |||||
| x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], | |||||
| axis=1) | |||||
| # fsmn_padding_mask | |||||
| pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) | |||||
| ones_1 = np.ones([chunk_size, num_units], dtype=dtype) | |||||
| mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], | |||||
| axis=0) | |||||
| mask_shfit_chunk = np.concatenate( | |||||
| [mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) | |||||
| # predictor mask | |||||
| zeros_1 = np.zeros( | |||||
| [shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) | |||||
| ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) | |||||
| zeros_3 = np.zeros( | |||||
| [chunk_size - stride - pad_left, num_units_predictor], | |||||
| dtype=dtype) | |||||
| ones_zeros = np.concatenate([ones_2, zeros_3], axis=0) | |||||
| mask_chunk_predictor_cur = np.concatenate( | |||||
| [zeros_1, ones_zeros], axis=0) | |||||
| mask_chunk_predictor = np.concatenate( | |||||
| [mask_chunk_predictor, mask_chunk_predictor_cur], axis=0) | |||||
| # encoder att mask | |||||
| zeros_1_top = np.zeros( | |||||
| [shfit_fsmn, chunk_num * chunk_size_pad_shift], | |||||
| dtype=dtype) | |||||
| zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) | |||||
| zeros_2 = np.zeros( | |||||
| [chunk_size, zeros_2_num * chunk_size_pad_shift], | |||||
| dtype=dtype) | |||||
| encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) | |||||
| zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||||
| ones_2_mid = np.ones([stride, stride], dtype=dtype) | |||||
| zeros_2_bottom = np.zeros([chunk_size - stride, stride], | |||||
| dtype=dtype) | |||||
| zeros_2_right = np.zeros([chunk_size, chunk_size - stride], | |||||
| dtype=dtype) | |||||
| ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0) | |||||
| ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], | |||||
| axis=1) | |||||
| ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) | |||||
| zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||||
| ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) | |||||
| ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) | |||||
| zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0) | |||||
| zeros_remain = np.zeros( | |||||
| [chunk_size, zeros_remain_num * chunk_size_pad_shift], | |||||
| dtype=dtype) | |||||
| ones2_bottom = np.concatenate( | |||||
| [zeros_2, ones_2, ones_3, zeros_remain], axis=1) | |||||
| mask_att_chunk_encoder_cur = np.concatenate( | |||||
| [zeros_1_top, ones2_bottom], axis=0) | |||||
| mask_att_chunk_encoder = np.concatenate( | |||||
| [mask_att_chunk_encoder, mask_att_chunk_encoder_cur], | |||||
| axis=0) | |||||
| # decoder fsmn_shift_att_mask | |||||
| zeros_1 = np.zeros([shfit_fsmn, 1]) | |||||
| ones_1 = np.ones([chunk_size, 1]) | |||||
| mask_shift_att_chunk_decoder_cur = np.concatenate( | |||||
| [zeros_1, ones_1], axis=0) | |||||
| mask_shift_att_chunk_decoder = np.concatenate( | |||||
| [ | |||||
| mask_shift_att_chunk_decoder, | |||||
| mask_shift_att_chunk_decoder_cur | |||||
| ], | |||||
| vaxis=0) # noqa: * | |||||
| self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max] | |||||
| self.x_len_chunk = x_len_chunk | |||||
| self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] | |||||
| self.x_len = x_len | |||||
| self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] | |||||
| self.mask_chunk_predictor = mask_chunk_predictor[: | |||||
| x_len_chunk_max, :] | |||||
| self.mask_att_chunk_encoder = mask_att_chunk_encoder[: | |||||
| x_len_chunk_max, : | |||||
| x_len_chunk_max] | |||||
| self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[: | |||||
| x_len_chunk_max, :] | |||||
| return (self.x_add_mask, self.x_len_chunk, self.x_rm_mask, self.x_len, | |||||
| self.mask_shfit_chunk, self.mask_chunk_predictor, | |||||
| self.mask_att_chunk_encoder, self.mask_shift_att_chunk_decoder) | |||||
| def split_chunk(self, x, x_len, chunk_outs): | |||||
| """ | |||||
| :param x: (b, t, d) | |||||
| :param x_length: (b) | |||||
| :param ind: int | |||||
| :return: | |||||
| """ | |||||
| x = x[:, :x_len.max(), :] | |||||
| b, t, d = x.size() | |||||
| x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(x.device) | |||||
| x *= x_len_mask[:, :, None] | |||||
| x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype) | |||||
| x_len_chunk = self.get_x_len_chunk( | |||||
| chunk_outs, x_len.device, dtype=x_len.dtype) | |||||
| x = torch.transpose(x, 1, 0) | |||||
| x = torch.reshape(x, [t, -1]) | |||||
| x_chunk = torch.mm(x_add_mask, x) | |||||
| x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0) | |||||
| return x_chunk, x_len_chunk | |||||
| def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs): | |||||
| x_chunk = x_chunk[:, :x_len_chunk.max(), :] | |||||
| b, t, d = x_chunk.size() | |||||
| x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to( | |||||
| x_chunk.device) | |||||
| x_chunk *= x_len_chunk_mask[:, :, None] | |||||
| x_rm_mask = self.get_x_rm_mask( | |||||
| chunk_outs, x_chunk.device, dtype=x_chunk.dtype) | |||||
| x_len = self.get_x_len( | |||||
| chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype) | |||||
| x_chunk = torch.transpose(x_chunk, 1, 0) | |||||
| x_chunk = torch.reshape(x_chunk, [t, -1]) | |||||
| x = torch.mm(x_rm_mask, x_chunk) | |||||
| x = torch.reshape(x, [-1, b, d]).transpose(1, 0) | |||||
| return x, x_len | |||||
| def get_x_add_mask(self, chunk_outs, device, idx=0, dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| def get_x_len_chunk(self, chunk_outs, device, idx=1, dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| def get_x_rm_mask(self, chunk_outs, device, idx=2, dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| def get_x_len(self, chunk_outs, device, idx=3, dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| def get_mask_shfit_chunk(self, | |||||
| chunk_outs, | |||||
| device, | |||||
| batch_size=1, | |||||
| num_units=1, | |||||
| idx=4, | |||||
| dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| def get_mask_chunk_predictor(self, | |||||
| chunk_outs, | |||||
| device, | |||||
| batch_size=1, | |||||
| num_units=1, | |||||
| idx=5, | |||||
| dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| def get_mask_att_chunk_encoder(self, | |||||
| chunk_outs, | |||||
| device, | |||||
| batch_size=1, | |||||
| idx=6, | |||||
| dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = np.tile(x[None, :, :, ], [batch_size, 1, 1]) | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| def get_mask_shift_att_chunk_decoder(self, | |||||
| chunk_outs, | |||||
| device, | |||||
| batch_size=1, | |||||
| idx=7, | |||||
| dtype=torch.float32): | |||||
| x = chunk_outs[idx] | |||||
| x = np.tile(x[None, None, :, 0], [batch_size, 1, 1]) | |||||
| x = torch.from_numpy(x).type(dtype).to(device) | |||||
| return x.detach() | |||||
| @@ -0,0 +1,250 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| import logging | |||||
| import numpy as np | |||||
| import torch | |||||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||||
| from torch import nn | |||||
| class CIF_Model(nn.Module): | |||||
| def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||||
| super(CIF_Model, self).__init__() | |||||
| self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||||
| self.cif_conv1d = nn.Conv1d( | |||||
| idim, idim, l_order + r_order + 1, groups=idim) | |||||
| self.cif_output = nn.Linear(idim, 1) | |||||
| self.dropout = torch.nn.Dropout(p=dropout) | |||||
| self.threshold = threshold | |||||
| def forward(self, hidden, target_label=None, mask=None, ignore_id=-1): | |||||
| h = hidden | |||||
| context = h.transpose(1, 2) | |||||
| queries = self.pad(context) | |||||
| memory = self.cif_conv1d(queries) | |||||
| output = memory + context | |||||
| output = self.dropout(output) | |||||
| output = output.transpose(1, 2) | |||||
| output = torch.relu(output) | |||||
| output = self.cif_output(output) | |||||
| alphas = torch.sigmoid(output) | |||||
| if mask is not None: | |||||
| alphas = alphas * mask.transpose(-1, -2).float() | |||||
| alphas = alphas.squeeze(-1) | |||||
| if target_label is not None: | |||||
| target_length = (target_label != ignore_id).float().sum(-1) | |||||
| else: | |||||
| target_length = None | |||||
| cif_length = alphas.sum(-1) | |||||
| if target_label is not None: | |||||
| alphas *= (target_length / cif_length)[:, None].repeat( | |||||
| 1, alphas.size(1)) | |||||
| cif_output, cif_peak = cif(hidden, alphas, self.threshold) | |||||
| return cif_output, cif_length, target_length, cif_peak | |||||
| def gen_frame_alignments(self, | |||||
| alphas: torch.Tensor = None, | |||||
| memory_sequence_length: torch.Tensor = None, | |||||
| is_training: bool = True, | |||||
| dtype: torch.dtype = torch.float32): | |||||
| batch_size, maximum_length = alphas.size() | |||||
| int_type = torch.int32 | |||||
| token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||||
| max_token_num = torch.max(token_num).item() | |||||
| alphas_cumsum = torch.cumsum(alphas, dim=1) | |||||
| alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||||
| alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||||
| [1, max_token_num, 1]) | |||||
| index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||||
| index = torch.cumsum(index, dim=1) | |||||
| index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||||
| index_div = torch.floor(torch.divide(alphas_cumsum, | |||||
| index)).type(int_type) | |||||
| index_div_bool_zeros = index_div.eq(0) | |||||
| index_div_bool_zeros_count = torch.sum( | |||||
| index_div_bool_zeros, dim=-1) + 1 | |||||
| index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||||
| memory_sequence_length.max()) | |||||
| token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||||
| token_num.device) | |||||
| index_div_bool_zeros_count *= token_num_mask | |||||
| index_div_bool_zeros_count_tile = torch.tile( | |||||
| index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||||
| ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||||
| zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||||
| ones = torch.cumsum(ones, dim=2) | |||||
| cond = index_div_bool_zeros_count_tile == ones | |||||
| index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||||
| index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||||
| torch.bool) | |||||
| index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||||
| int_type) | |||||
| index_div_bool_zeros_count_tile_out = torch.sum( | |||||
| index_div_bool_zeros_count_tile, dim=1) | |||||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||||
| int_type) | |||||
| predictor_mask = (~make_pad_mask( | |||||
| memory_sequence_length, | |||||
| maxlen=memory_sequence_length.max())).type(int_type).to( | |||||
| memory_sequence_length.device) # noqa: * | |||||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||||
| return index_div_bool_zeros_count_tile_out.detach( | |||||
| ), index_div_bool_zeros_count.detach() | |||||
| class cif_predictor(nn.Module): | |||||
| def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||||
| super(cif_predictor, self).__init__() | |||||
| self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||||
| self.cif_conv1d = nn.Conv1d( | |||||
| idim, idim, l_order + r_order + 1, groups=idim) | |||||
| self.cif_output = nn.Linear(idim, 1) | |||||
| self.dropout = torch.nn.Dropout(p=dropout) | |||||
| self.threshold = threshold | |||||
| def forward(self, | |||||
| hidden, | |||||
| target_label=None, | |||||
| mask=None, | |||||
| ignore_id=-1, | |||||
| mask_chunk_predictor=None, | |||||
| target_label_length=None): | |||||
| h = hidden | |||||
| context = h.transpose(1, 2) | |||||
| queries = self.pad(context) | |||||
| memory = self.cif_conv1d(queries) | |||||
| output = memory + context | |||||
| output = self.dropout(output) | |||||
| output = output.transpose(1, 2) | |||||
| output = torch.relu(output) | |||||
| output = self.cif_output(output) | |||||
| alphas = torch.sigmoid(output) | |||||
| if mask is not None: | |||||
| alphas = alphas * mask.transpose(-1, -2).float() | |||||
| if mask_chunk_predictor is not None: | |||||
| alphas = alphas * mask_chunk_predictor | |||||
| alphas = alphas.squeeze(-1) | |||||
| if target_label_length is not None: | |||||
| target_length = target_label_length | |||||
| elif target_label is not None: | |||||
| target_length = (target_label != ignore_id).float().sum(-1) | |||||
| else: | |||||
| target_length = None | |||||
| token_num = alphas.sum(-1) | |||||
| if target_length is not None: | |||||
| alphas *= (target_length / token_num)[:, None].repeat( | |||||
| 1, alphas.size(1)) | |||||
| acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) | |||||
| return acoustic_embeds, token_num, alphas, cif_peak | |||||
| def gen_frame_alignments(self, | |||||
| alphas: torch.Tensor = None, | |||||
| memory_sequence_length: torch.Tensor = None, | |||||
| is_training: bool = True, | |||||
| dtype: torch.dtype = torch.float32): | |||||
| batch_size, maximum_length = alphas.size() | |||||
| int_type = torch.int32 | |||||
| token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||||
| max_token_num = torch.max(token_num).item() | |||||
| alphas_cumsum = torch.cumsum(alphas, dim=1) | |||||
| alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||||
| alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||||
| [1, max_token_num, 1]) | |||||
| index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||||
| index = torch.cumsum(index, dim=1) | |||||
| index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||||
| index_div = torch.floor(torch.divide(alphas_cumsum, | |||||
| index)).type(int_type) | |||||
| index_div_bool_zeros = index_div.eq(0) | |||||
| index_div_bool_zeros_count = torch.sum( | |||||
| index_div_bool_zeros, dim=-1) + 1 | |||||
| index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||||
| memory_sequence_length.max()) | |||||
| token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||||
| token_num.device) | |||||
| index_div_bool_zeros_count *= token_num_mask | |||||
| index_div_bool_zeros_count_tile = torch.tile( | |||||
| index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||||
| ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||||
| zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||||
| ones = torch.cumsum(ones, dim=2) | |||||
| cond = index_div_bool_zeros_count_tile == ones | |||||
| index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||||
| index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||||
| torch.bool) | |||||
| index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||||
| int_type) | |||||
| index_div_bool_zeros_count_tile_out = torch.sum( | |||||
| index_div_bool_zeros_count_tile, dim=1) | |||||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||||
| int_type) | |||||
| predictor_mask = (~make_pad_mask( | |||||
| memory_sequence_length, | |||||
| maxlen=memory_sequence_length.max())).type(int_type).to( | |||||
| memory_sequence_length.device) # noqa: * | |||||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||||
| return index_div_bool_zeros_count_tile_out.detach( | |||||
| ), index_div_bool_zeros_count.detach() | |||||
| def cif(hidden, alphas, threshold): | |||||
| batch_size, len_time, hidden_size = hidden.size() | |||||
| # loop varss | |||||
| integrate = torch.zeros([batch_size], device=hidden.device) | |||||
| frame = torch.zeros([batch_size, hidden_size], device=hidden.device) | |||||
| # intermediate vars along time | |||||
| list_fires = [] | |||||
| list_frames = [] | |||||
| for t in range(len_time): | |||||
| alpha = alphas[:, t] | |||||
| distribution_completion = torch.ones([batch_size], | |||||
| device=hidden.device) - integrate | |||||
| integrate += alpha | |||||
| list_fires.append(integrate) | |||||
| fire_place = integrate >= threshold | |||||
| integrate = torch.where( | |||||
| fire_place, | |||||
| integrate - torch.ones([batch_size], device=hidden.device), | |||||
| integrate) | |||||
| cur = torch.where(fire_place, distribution_completion, alpha) | |||||
| remainds = alpha - cur | |||||
| frame += cur[:, None] * hidden[:, t, :] | |||||
| list_frames.append(frame) | |||||
| frame = torch.where(fire_place[:, None].repeat(1, hidden_size), | |||||
| remainds[:, None] * hidden[:, t, :], frame) | |||||
| fires = torch.stack(list_fires, 1) | |||||
| frames = torch.stack(list_frames, 1) | |||||
| list_ls = [] | |||||
| len_labels = torch.round(alphas.sum(-1)).int() | |||||
| max_label_len = len_labels.max() | |||||
| for b in range(batch_size): | |||||
| fire = fires[b, :] | |||||
| ls = torch.index_select(frames[b, :, :], 0, | |||||
| torch.nonzero(fire >= threshold).squeeze()) | |||||
| pad_l = torch.zeros([max_label_len - ls.size(0), hidden_size], | |||||
| device=hidden.device) | |||||
| list_ls.append(torch.cat([ls, pad_l], 0)) | |||||
| return torch.stack(list_ls, 0), fires | |||||
| @@ -0,0 +1,680 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| """Multi-Head Attention layer definition.""" | |||||
| import logging | |||||
| import math | |||||
| import numpy | |||||
| import torch | |||||
| from torch import nn | |||||
| torch.set_printoptions(profile='full', precision=1) | |||||
| class MultiHeadedAttention(nn.Module): | |||||
| """Multi-Head Attention layer. | |||||
| Args: | |||||
| n_head (int): The number of heads. | |||||
| n_feat (int): The number of features. | |||||
| dropout_rate (float): Dropout rate. | |||||
| """ | |||||
| def __init__(self, n_head, n_feat, dropout_rate): | |||||
| """Construct an MultiHeadedAttention object.""" | |||||
| super(MultiHeadedAttention, self).__init__() | |||||
| assert n_feat % n_head == 0 | |||||
| # We assume d_v always equals d_k | |||||
| self.d_k = n_feat // n_head | |||||
| self.h = n_head | |||||
| self.linear_q = nn.Linear(n_feat, n_feat) | |||||
| self.linear_k = nn.Linear(n_feat, n_feat) | |||||
| self.linear_v = nn.Linear(n_feat, n_feat) | |||||
| self.linear_out = nn.Linear(n_feat, n_feat) | |||||
| self.attn = None | |||||
| self.dropout = nn.Dropout(p=dropout_rate) | |||||
| def forward_qkv(self, query, key, value): | |||||
| """Transform query, key and value. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| Returns: | |||||
| torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||||
| torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||||
| torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||||
| """ | |||||
| n_batch = query.size(0) | |||||
| q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||||
| k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||||
| v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||||
| q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||||
| k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||||
| v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||||
| return q, k, v | |||||
| def forward_attention(self, value, scores, mask): | |||||
| """Compute attention context vector. | |||||
| Args: | |||||
| value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||||
| scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||||
| mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Transformed value (#batch, time1, d_model) | |||||
| weighted by the attention score (#batch, time1, time2). | |||||
| """ | |||||
| n_batch = value.size(0) | |||||
| if mask is not None: | |||||
| mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||||
| min_value = float( | |||||
| numpy.finfo(torch.tensor( | |||||
| 0, dtype=scores.dtype).numpy().dtype).min) | |||||
| scores = scores.masked_fill(mask, min_value) | |||||
| self.attn = torch.softmax( | |||||
| scores, dim=-1).masked_fill(mask, | |||||
| 0.0) # (batch, head, time1, time2) | |||||
| else: | |||||
| self.attn = torch.softmax( | |||||
| scores, dim=-1) # (batch, head, time1, time2) | |||||
| p_attn = self.dropout(self.attn) | |||||
| x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||||
| x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||||
| self.h * self.d_k) | |||||
| ) # (batch, time1, d_model) | |||||
| return self.linear_out(x) # (batch, time1, d_model) | |||||
| def forward(self, query, key, value, mask): | |||||
| """Compute scaled dot product attention. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
| (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
| """ | |||||
| q, k, v = self.forward_qkv(query, key, value) | |||||
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||||
| return self.forward_attention(v, scores, mask) | |||||
| class MultiHeadedAttentionSANM(nn.Module): | |||||
| """Multi-Head Attention layer. | |||||
| Args: | |||||
| n_head (int): The number of heads. | |||||
| n_feat (int): The number of features. | |||||
| dropout_rate (float): Dropout rate. | |||||
| """ | |||||
| def __init__(self, | |||||
| n_head, | |||||
| n_feat, | |||||
| dropout_rate, | |||||
| kernel_size, | |||||
| sanm_shfit=0): | |||||
| """Construct an MultiHeadedAttention object.""" | |||||
| super(MultiHeadedAttentionSANM, self).__init__() | |||||
| assert n_feat % n_head == 0 | |||||
| # We assume d_v always equals d_k | |||||
| self.d_k = n_feat // n_head | |||||
| self.h = n_head | |||||
| self.linear_q = nn.Linear(n_feat, n_feat) | |||||
| self.linear_k = nn.Linear(n_feat, n_feat) | |||||
| self.linear_v = nn.Linear(n_feat, n_feat) | |||||
| self.linear_out = nn.Linear(n_feat, n_feat) | |||||
| self.attn = None | |||||
| self.dropout = nn.Dropout(p=dropout_rate) | |||||
| self.fsmn_block = nn.Conv1d( | |||||
| n_feat, | |||||
| n_feat, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| padding=0, | |||||
| groups=n_feat, | |||||
| bias=False) | |||||
| # padding | |||||
| left_padding = (kernel_size - 1) // 2 | |||||
| if sanm_shfit > 0: | |||||
| left_padding = left_padding + sanm_shfit | |||||
| right_padding = kernel_size - 1 - left_padding | |||||
| self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) | |||||
| def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): | |||||
| ''' | |||||
| :param x: (#batch, time1, size). | |||||
| :param mask: Mask tensor (#batch, 1, time) | |||||
| :return: | |||||
| ''' | |||||
| # b, t, d = inputs.size() | |||||
| mask = mask[:, 0, :, None] | |||||
| if mask_shfit_chunk is not None: | |||||
| mask = mask * mask_shfit_chunk | |||||
| inputs *= mask | |||||
| x = inputs.transpose(1, 2) | |||||
| x = self.pad_fn(x) | |||||
| x = self.fsmn_block(x) | |||||
| x = x.transpose(1, 2) | |||||
| x += inputs | |||||
| x = self.dropout(x) | |||||
| return x * mask | |||||
| def forward_qkv(self, query, key, value): | |||||
| """Transform query, key and value. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| Returns: | |||||
| torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||||
| torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||||
| torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||||
| """ | |||||
| n_batch = query.size(0) | |||||
| q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||||
| k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||||
| v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||||
| q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||||
| k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||||
| v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||||
| return q, k, v | |||||
| def forward_attention(self, | |||||
| value, | |||||
| scores, | |||||
| mask, | |||||
| mask_att_chunk_encoder=None): | |||||
| """Compute attention context vector. | |||||
| Args: | |||||
| value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||||
| scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||||
| mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Transformed value (#batch, time1, d_model) | |||||
| weighted by the attention score (#batch, time1, time2). | |||||
| """ | |||||
| n_batch = value.size(0) | |||||
| if mask is not None: | |||||
| if mask_att_chunk_encoder is not None: | |||||
| mask = mask * mask_att_chunk_encoder | |||||
| mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||||
| min_value = float( | |||||
| numpy.finfo(torch.tensor( | |||||
| 0, dtype=scores.dtype).numpy().dtype).min) | |||||
| scores = scores.masked_fill(mask, min_value) | |||||
| self.attn = torch.softmax( | |||||
| scores, dim=-1).masked_fill(mask, | |||||
| 0.0) # (batch, head, time1, time2) | |||||
| else: | |||||
| self.attn = torch.softmax( | |||||
| scores, dim=-1) # (batch, head, time1, time2) | |||||
| p_attn = self.dropout(self.attn) | |||||
| x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||||
| x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||||
| self.h * self.d_k) | |||||
| ) # (batch, time1, d_model) | |||||
| return self.linear_out(x) # (batch, time1, d_model) | |||||
| def forward(self, | |||||
| query, | |||||
| key, | |||||
| value, | |||||
| mask, | |||||
| mask_shfit_chunk=None, | |||||
| mask_att_chunk_encoder=None): | |||||
| """Compute scaled dot product attention. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
| (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
| """ | |||||
| fsmn_memory = self.forward_fsmn(value, mask, mask_shfit_chunk) | |||||
| q, k, v = self.forward_qkv(query, key, value) | |||||
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||||
| att_outs = self.forward_attention(v, scores, mask, | |||||
| mask_att_chunk_encoder) | |||||
| return att_outs + fsmn_memory | |||||
| class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): | |||||
| """Multi-Head Attention layer with relative position encoding (old version). | |||||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
| Paper: https://arxiv.org/abs/1901.02860 | |||||
| Args: | |||||
| n_head (int): The number of heads. | |||||
| n_feat (int): The number of features. | |||||
| dropout_rate (float): Dropout rate. | |||||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
| """ | |||||
| def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||||
| super().__init__(n_head, n_feat, dropout_rate) | |||||
| self.zero_triu = zero_triu | |||||
| # linear transformation for positional encoding | |||||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
| # these two learnable bias are used in matrix c and matrix d | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
| def rel_shift(self, x): | |||||
| """Compute relative positional encoding. | |||||
| Args: | |||||
| x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor. | |||||
| """ | |||||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
| device=x.device, | |||||
| dtype=x.dtype) | |||||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
| x = x_padded[:, :, 1:].view_as(x) | |||||
| if self.zero_triu: | |||||
| ones = torch.ones((x.size(2), x.size(3))) | |||||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
| return x | |||||
| def forward(self, query, key, value, pos_emb, mask): | |||||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
| (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
| """ | |||||
| q, k, v = self.forward_qkv(query, key, value) | |||||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
| n_batch_pos = pos_emb.size(0) | |||||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
| p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
| # compute attention score | |||||
| # first compute matrix a and matrix c | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| # (batch, head, time1, time2) | |||||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
| # compute matrix b and matrix d | |||||
| # (batch, head, time1, time1) | |||||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
| matrix_bd = self.rel_shift(matrix_bd) | |||||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
| self.d_k) # (batch, head, time1, time2) | |||||
| return self.forward_attention(v, scores, mask) | |||||
| class LegacyRelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||||
| """Multi-Head Attention layer with relative position encoding (old version). | |||||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
| Paper: https://arxiv.org/abs/1901.02860 | |||||
| Args: | |||||
| n_head (int): The number of heads. | |||||
| n_feat (int): The number of features. | |||||
| dropout_rate (float): Dropout rate. | |||||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
| """ | |||||
| def __init__(self, | |||||
| n_head, | |||||
| n_feat, | |||||
| dropout_rate, | |||||
| zero_triu=False, | |||||
| kernel_size=15, | |||||
| sanm_shfit=0): | |||||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||||
| super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||||
| self.zero_triu = zero_triu | |||||
| # linear transformation for positional encoding | |||||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
| # these two learnable bias are used in matrix c and matrix d | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
| def rel_shift(self, x): | |||||
| """Compute relative positional encoding. | |||||
| Args: | |||||
| x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor. | |||||
| """ | |||||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
| device=x.device, | |||||
| dtype=x.dtype) | |||||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
| x = x_padded[:, :, 1:].view_as(x) | |||||
| if self.zero_triu: | |||||
| ones = torch.ones((x.size(2), x.size(3))) | |||||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
| return x | |||||
| def forward(self, query, key, value, pos_emb, mask): | |||||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
| (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
| """ | |||||
| fsmn_memory = self.forward_fsmn(value, mask) | |||||
| q, k, v = self.forward_qkv(query, key, value) | |||||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
| n_batch_pos = pos_emb.size(0) | |||||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
| p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
| # compute attention score | |||||
| # first compute matrix a and matrix c | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| # (batch, head, time1, time2) | |||||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
| # compute matrix b and matrix d | |||||
| # (batch, head, time1, time1) | |||||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
| matrix_bd = self.rel_shift(matrix_bd) | |||||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
| self.d_k) # (batch, head, time1, time2) | |||||
| att_outs = self.forward_attention(v, scores, mask) | |||||
| return att_outs + fsmn_memory | |||||
| class RelPositionMultiHeadedAttention(MultiHeadedAttention): | |||||
| """Multi-Head Attention layer with relative position encoding (new implementation). | |||||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
| Paper: https://arxiv.org/abs/1901.02860 | |||||
| Args: | |||||
| n_head (int): The number of heads. | |||||
| n_feat (int): The number of features. | |||||
| dropout_rate (float): Dropout rate. | |||||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
| """ | |||||
| def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||||
| super().__init__(n_head, n_feat, dropout_rate) | |||||
| self.zero_triu = zero_triu | |||||
| # linear transformation for positional encoding | |||||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
| # these two learnable bias are used in matrix c and matrix d | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
| def rel_shift(self, x): | |||||
| """Compute relative positional encoding. | |||||
| Args: | |||||
| x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||||
| time1 means the length of query vector. | |||||
| Returns: | |||||
| torch.Tensor: Output tensor. | |||||
| """ | |||||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
| device=x.device, | |||||
| dtype=x.dtype) | |||||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
| x = x_padded[:, :, 1:].view_as( | |||||
| x)[:, :, :, :x.size(-1) // 2 | |||||
| + 1] # only keep the positions from 0 to time2 | |||||
| if self.zero_triu: | |||||
| ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
| return x | |||||
| def forward(self, query, key, value, pos_emb, mask): | |||||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| pos_emb (torch.Tensor): Positional embedding tensor | |||||
| (#batch, 2*time1-1, size). | |||||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
| (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
| """ | |||||
| q, k, v = self.forward_qkv(query, key, value) | |||||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
| n_batch_pos = pos_emb.size(0) | |||||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
| p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
| # compute attention score | |||||
| # first compute matrix a and matrix c | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| # (batch, head, time1, time2) | |||||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
| # compute matrix b and matrix d | |||||
| # (batch, head, time1, 2*time1-1) | |||||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
| matrix_bd = self.rel_shift(matrix_bd) | |||||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
| self.d_k) # (batch, head, time1, time2) | |||||
| return self.forward_attention(v, scores, mask) | |||||
| class RelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||||
| """Multi-Head Attention layer with relative position encoding (new implementation). | |||||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||||
| Paper: https://arxiv.org/abs/1901.02860 | |||||
| Args: | |||||
| n_head (int): The number of heads. | |||||
| n_feat (int): The number of features. | |||||
| dropout_rate (float): Dropout rate. | |||||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||||
| """ | |||||
| def __init__(self, | |||||
| n_head, | |||||
| n_feat, | |||||
| dropout_rate, | |||||
| zero_triu=False, | |||||
| kernel_size=15, | |||||
| sanm_shfit=0): | |||||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||||
| super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||||
| self.zero_triu = zero_triu | |||||
| # linear transformation for positional encoding | |||||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||||
| # these two learnable bias are used in matrix c and matrix d | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||||
| def rel_shift(self, x): | |||||
| """Compute relative positional encoding. | |||||
| Args: | |||||
| x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||||
| time1 means the length of query vector. | |||||
| Returns: | |||||
| torch.Tensor: Output tensor. | |||||
| """ | |||||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||||
| device=x.device, | |||||
| dtype=x.dtype) | |||||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||||
| x = x_padded[:, :, 1:].view_as( | |||||
| x)[:, :, :, :x.size(-1) // 2 | |||||
| + 1] # only keep the positions from 0 to time2 | |||||
| if self.zero_triu: | |||||
| ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||||
| return x | |||||
| def forward(self, query, key, value, pos_emb, mask): | |||||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||||
| Args: | |||||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||||
| pos_emb (torch.Tensor): Positional embedding tensor | |||||
| (#batch, 2*time1-1, size). | |||||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||||
| (#batch, time1, time2). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||||
| """ | |||||
| fsmn_memory = self.forward_fsmn(value, mask) | |||||
| q, k, v = self.forward_qkv(query, key, value) | |||||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||||
| n_batch_pos = pos_emb.size(0) | |||||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||||
| p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||||
| # (batch, head, time1, d_k) | |||||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||||
| # compute attention score | |||||
| # first compute matrix a and matrix c | |||||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||||
| # (batch, head, time1, time2) | |||||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||||
| # compute matrix b and matrix d | |||||
| # (batch, head, time1, 2*time1-1) | |||||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||||
| matrix_bd = self.rel_shift(matrix_bd) | |||||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||||
| self.d_k) # (batch, head, time1, time2) | |||||
| att_outs = self.forward_attention(v, scores, mask) | |||||
| return att_outs + fsmn_memory | |||||
| @@ -0,0 +1,239 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| """Encoder self-attention layer definition.""" | |||||
| import torch | |||||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||||
| from torch import nn | |||||
| class EncoderLayer(nn.Module): | |||||
| """Encoder layer module. | |||||
| Args: | |||||
| size (int): Input dimension. | |||||
| self_attn (torch.nn.Module): Self-attention module instance. | |||||
| `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||||
| can be used as the argument. | |||||
| feed_forward (torch.nn.Module): Feed-forward module instance. | |||||
| `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||||
| can be used as the argument. | |||||
| dropout_rate (float): Dropout rate. | |||||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||||
| concat_after (bool): Whether to concat attention layer's input and output. | |||||
| if True, additional linear will be applied. | |||||
| i.e. x -> x + linear(concat(x, att(x))) | |||||
| if False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
| stochastic_depth_rate (float): Proability to skip this layer. | |||||
| During training, the layer may skip residual computation and return input | |||||
| as-is with given probability. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| size, | |||||
| self_attn, | |||||
| feed_forward, | |||||
| dropout_rate, | |||||
| normalize_before=True, | |||||
| concat_after=False, | |||||
| stochastic_depth_rate=0.0, | |||||
| ): | |||||
| """Construct an EncoderLayer object.""" | |||||
| super(EncoderLayer, self).__init__() | |||||
| self.self_attn = self_attn | |||||
| self.feed_forward = feed_forward | |||||
| self.norm1 = LayerNorm(size) | |||||
| self.norm2 = LayerNorm(size) | |||||
| self.dropout = nn.Dropout(dropout_rate) | |||||
| self.size = size | |||||
| self.normalize_before = normalize_before | |||||
| self.concat_after = concat_after | |||||
| if self.concat_after: | |||||
| self.concat_linear = nn.Linear(size + size, size) | |||||
| self.stochastic_depth_rate = stochastic_depth_rate | |||||
| def forward(self, x, mask, cache=None): | |||||
| """Compute encoded features. | |||||
| Args: | |||||
| x_input (torch.Tensor): Input tensor (#batch, time, size). | |||||
| mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||||
| cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time, size). | |||||
| torch.Tensor: Mask tensor (#batch, time). | |||||
| """ | |||||
| skip_layer = False | |||||
| # with stochastic depth, residual connection `x + f(x)` becomes | |||||
| # `x <- x + 1 / (1 - p) * f(x)` at training time. | |||||
| stoch_layer_coeff = 1.0 | |||||
| if self.training and self.stochastic_depth_rate > 0: | |||||
| skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||||
| stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||||
| if skip_layer: | |||||
| if cache is not None: | |||||
| x = torch.cat([cache, x], dim=1) | |||||
| return x, mask | |||||
| residual = x | |||||
| if self.normalize_before: | |||||
| x = self.norm1(x) | |||||
| if cache is None: | |||||
| x_q = x | |||||
| else: | |||||
| assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||||
| x_q = x[:, -1:, :] | |||||
| residual = residual[:, -1:, :] | |||||
| mask = None if mask is None else mask[:, -1:, :] | |||||
| if self.concat_after: | |||||
| x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) | |||||
| x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||||
| else: | |||||
| x = residual + stoch_layer_coeff * self.dropout( | |||||
| self.self_attn(x_q, x, x, mask)) | |||||
| if not self.normalize_before: | |||||
| x = self.norm1(x) | |||||
| residual = x | |||||
| if self.normalize_before: | |||||
| x = self.norm2(x) | |||||
| x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||||
| if not self.normalize_before: | |||||
| x = self.norm2(x) | |||||
| if cache is not None: | |||||
| x = torch.cat([cache, x], dim=1) | |||||
| return x, mask | |||||
| class EncoderLayerChunk(nn.Module): | |||||
| """Encoder layer module. | |||||
| Args: | |||||
| size (int): Input dimension. | |||||
| self_attn (torch.nn.Module): Self-attention module instance. | |||||
| `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||||
| can be used as the argument. | |||||
| feed_forward (torch.nn.Module): Feed-forward module instance. | |||||
| `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||||
| can be used as the argument. | |||||
| dropout_rate (float): Dropout rate. | |||||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||||
| concat_after (bool): Whether to concat attention layer's input and output. | |||||
| if True, additional linear will be applied. | |||||
| i.e. x -> x + linear(concat(x, att(x))) | |||||
| if False, no additional linear will be applied. i.e. x -> x + att(x) | |||||
| stochastic_depth_rate (float): Proability to skip this layer. | |||||
| During training, the layer may skip residual computation and return input | |||||
| as-is with given probability. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| size, | |||||
| self_attn, | |||||
| feed_forward, | |||||
| dropout_rate, | |||||
| normalize_before=True, | |||||
| concat_after=False, | |||||
| stochastic_depth_rate=0.0, | |||||
| ): | |||||
| """Construct an EncoderLayer object.""" | |||||
| super(EncoderLayerChunk, self).__init__() | |||||
| self.self_attn = self_attn | |||||
| self.feed_forward = feed_forward | |||||
| self.norm1 = LayerNorm(size) | |||||
| self.norm2 = LayerNorm(size) | |||||
| self.dropout = nn.Dropout(dropout_rate) | |||||
| self.size = size | |||||
| self.normalize_before = normalize_before | |||||
| self.concat_after = concat_after | |||||
| if self.concat_after: | |||||
| self.concat_linear = nn.Linear(size + size, size) | |||||
| self.stochastic_depth_rate = stochastic_depth_rate | |||||
| def forward(self, | |||||
| x, | |||||
| mask, | |||||
| cache=None, | |||||
| mask_shfit_chunk=None, | |||||
| mask_att_chunk_encoder=None): | |||||
| """Compute encoded features. | |||||
| Args: | |||||
| x_input (torch.Tensor): Input tensor (#batch, time, size). | |||||
| mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||||
| cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||||
| Returns: | |||||
| torch.Tensor: Output tensor (#batch, time, size). | |||||
| torch.Tensor: Mask tensor (#batch, time). | |||||
| """ | |||||
| skip_layer = False | |||||
| # with stochastic depth, residual connection `x + f(x)` becomes | |||||
| # `x <- x + 1 / (1 - p) * f(x)` at training time. | |||||
| stoch_layer_coeff = 1.0 | |||||
| if self.training and self.stochastic_depth_rate > 0: | |||||
| skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||||
| stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||||
| if skip_layer: | |||||
| if cache is not None: | |||||
| x = torch.cat([cache, x], dim=1) | |||||
| return x, mask | |||||
| residual = x | |||||
| if self.normalize_before: | |||||
| x = self.norm1(x) | |||||
| if cache is None: | |||||
| x_q = x | |||||
| else: | |||||
| assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||||
| x_q = x[:, -1:, :] | |||||
| residual = residual[:, -1:, :] | |||||
| mask = None if mask is None else mask[:, -1:, :] | |||||
| if self.concat_after: | |||||
| x_concat = torch.cat( | |||||
| (x, | |||||
| self.self_attn( | |||||
| x_q, | |||||
| x, | |||||
| x, | |||||
| mask, | |||||
| mask_shfit_chunk=mask_shfit_chunk, | |||||
| mask_att_chunk_encoder=mask_att_chunk_encoder)), | |||||
| dim=-1) | |||||
| x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||||
| else: | |||||
| x = residual + stoch_layer_coeff * self.dropout( | |||||
| self.self_attn( | |||||
| x_q, | |||||
| x, | |||||
| x, | |||||
| mask, | |||||
| mask_shfit_chunk=mask_shfit_chunk, | |||||
| mask_att_chunk_encoder=mask_att_chunk_encoder)) | |||||
| if not self.normalize_before: | |||||
| x = self.norm1(x) | |||||
| residual = x | |||||
| if self.normalize_before: | |||||
| x = self.norm2(x) | |||||
| x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||||
| if not self.normalize_before: | |||||
| x = self.norm2(x) | |||||
| if cache is not None: | |||||
| x = torch.cat([cache, x], dim=1) | |||||
| return x, mask, None, mask_shfit_chunk, mask_att_chunk_encoder | |||||
| @@ -0,0 +1,890 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed from espnet/espnet. | |||||
| import argparse | |||||
| import logging | |||||
| import os | |||||
| from pathlib import Path | |||||
| from typing import Callable, Collection, Dict, List, Optional, Tuple, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| import yaml | |||||
| from espnet2.asr.ctc import CTC | |||||
| from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||||
| from espnet2.asr.decoder.mlm_decoder import MLMDecoder | |||||
| from espnet2.asr.decoder.rnn_decoder import RNNDecoder | |||||
| from espnet2.asr.decoder.transformer_decoder import \ | |||||
| DynamicConvolution2DTransformerDecoder # noqa: H301 | |||||
| from espnet2.asr.decoder.transformer_decoder import \ | |||||
| LightweightConvolution2DTransformerDecoder # noqa: H301 | |||||
| from espnet2.asr.decoder.transformer_decoder import \ | |||||
| LightweightConvolutionTransformerDecoder # noqa: H301 | |||||
| from espnet2.asr.decoder.transformer_decoder import ( | |||||
| DynamicConvolutionTransformerDecoder, TransformerDecoder) | |||||
| from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||||
| from espnet2.asr.encoder.contextual_block_conformer_encoder import \ | |||||
| ContextualBlockConformerEncoder # noqa: H301 | |||||
| from espnet2.asr.encoder.contextual_block_transformer_encoder import \ | |||||
| ContextualBlockTransformerEncoder # noqa: H301 | |||||
| from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder, | |||||
| FairseqHubertPretrainEncoder) | |||||
| from espnet2.asr.encoder.longformer_encoder import LongformerEncoder | |||||
| from espnet2.asr.encoder.rnn_encoder import RNNEncoder | |||||
| from espnet2.asr.encoder.transformer_encoder import TransformerEncoder | |||||
| from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder | |||||
| from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder | |||||
| from espnet2.asr.espnet_model import ESPnetASRModel | |||||
| from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||||
| from espnet2.asr.frontend.default import DefaultFrontend | |||||
| from espnet2.asr.frontend.fused import FusedFrontends | |||||
| from espnet2.asr.frontend.s3prl import S3prlFrontend | |||||
| from espnet2.asr.frontend.windowing import SlidingWindow | |||||
| from espnet2.asr.maskctc_model import MaskCTCModel | |||||
| from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder | |||||
| from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ | |||||
| HuggingFaceTransformersPostEncoder # noqa: H301 | |||||
| from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder | |||||
| from espnet2.asr.preencoder.linear import LinearProjection | |||||
| from espnet2.asr.preencoder.sinc import LightweightSincConvs | |||||
| from espnet2.asr.specaug.abs_specaug import AbsSpecAug | |||||
| from espnet2.asr.specaug.specaug import SpecAug | |||||
| from espnet2.asr.transducer.joint_network import JointNetwork | |||||
| from espnet2.asr.transducer.transducer_decoder import TransducerDecoder | |||||
| from espnet2.layers.abs_normalize import AbsNormalize | |||||
| from espnet2.layers.global_mvn import GlobalMVN | |||||
| from espnet2.layers.utterance_mvn import UtteranceMVN | |||||
| from espnet2.tasks.abs_task import AbsTask | |||||
| from espnet2.text.phoneme_tokenizer import g2p_choices | |||||
| from espnet2.torch_utils.initialize import initialize | |||||
| from espnet2.train.abs_espnet_model import AbsESPnetModel | |||||
| from espnet2.train.class_choices import ClassChoices | |||||
| from espnet2.train.collate_fn import CommonCollateFn | |||||
| from espnet2.train.preprocessor import CommonPreprocessor | |||||
| from espnet2.train.trainer import Trainer | |||||
| from espnet2.utils.get_default_kwargs import get_default_kwargs | |||||
| from espnet2.utils.nested_dict_action import NestedDictAction | |||||
| from espnet2.utils.types import (float_or_none, int_or_none, str2bool, | |||||
| str_or_none) | |||||
| from typeguard import check_argument_types, check_return_type | |||||
| from ..asr.decoder.transformer_decoder import (ParaformerDecoder, | |||||
| ParaformerDecoderBertEmbed) | |||||
| from ..asr.encoder.conformer_encoder import ConformerEncoder, SANMEncoder_v2 | |||||
| from ..asr.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunk | |||||
| from ..asr.espnet_model import AEDStreaming | |||||
| from ..asr.espnet_model_paraformer import Paraformer, ParaformerBertEmbed | |||||
| from ..nets.pytorch_backend.cif_utils.cif import cif_predictor | |||||
| # FIXME(wjm): suggested by fairseq, We need to setup root logger before importing any fairseq libraries. | |||||
| logging.basicConfig( | |||||
| level='INFO', | |||||
| format=f"[{os.uname()[1].split('.')[0]}]" | |||||
| f' %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||||
| ) | |||||
| # FIXME(wjm): create logger to set level, unset __name__ for different files to share the same logger | |||||
| logger = logging.getLogger() | |||||
| frontend_choices = ClassChoices( | |||||
| name='frontend', | |||||
| classes=dict( | |||||
| default=DefaultFrontend, | |||||
| sliding_window=SlidingWindow, | |||||
| s3prl=S3prlFrontend, | |||||
| fused=FusedFrontends, | |||||
| ), | |||||
| type_check=AbsFrontend, | |||||
| default='default', | |||||
| ) | |||||
| specaug_choices = ClassChoices( | |||||
| name='specaug', | |||||
| classes=dict(specaug=SpecAug, ), | |||||
| type_check=AbsSpecAug, | |||||
| default=None, | |||||
| optional=True, | |||||
| ) | |||||
| normalize_choices = ClassChoices( | |||||
| 'normalize', | |||||
| classes=dict( | |||||
| global_mvn=GlobalMVN, | |||||
| utterance_mvn=UtteranceMVN, | |||||
| ), | |||||
| type_check=AbsNormalize, | |||||
| default='utterance_mvn', | |||||
| optional=True, | |||||
| ) | |||||
| model_choices = ClassChoices( | |||||
| 'model', | |||||
| classes=dict( | |||||
| espnet=ESPnetASRModel, | |||||
| maskctc=MaskCTCModel, | |||||
| paraformer=Paraformer, | |||||
| paraformer_bert_embed=ParaformerBertEmbed, | |||||
| aedstreaming=AEDStreaming, | |||||
| ), | |||||
| type_check=AbsESPnetModel, | |||||
| default='espnet', | |||||
| ) | |||||
| preencoder_choices = ClassChoices( | |||||
| name='preencoder', | |||||
| classes=dict( | |||||
| sinc=LightweightSincConvs, | |||||
| linear=LinearProjection, | |||||
| ), | |||||
| type_check=AbsPreEncoder, | |||||
| default=None, | |||||
| optional=True, | |||||
| ) | |||||
| encoder_choices = ClassChoices( | |||||
| 'encoder', | |||||
| classes=dict( | |||||
| conformer=ConformerEncoder, | |||||
| transformer=TransformerEncoder, | |||||
| contextual_block_transformer=ContextualBlockTransformerEncoder, | |||||
| contextual_block_conformer=ContextualBlockConformerEncoder, | |||||
| vgg_rnn=VGGRNNEncoder, | |||||
| rnn=RNNEncoder, | |||||
| wav2vec2=FairSeqWav2Vec2Encoder, | |||||
| hubert=FairseqHubertEncoder, | |||||
| hubert_pretrain=FairseqHubertPretrainEncoder, | |||||
| longformer=LongformerEncoder, | |||||
| sanm=SANMEncoder, | |||||
| sanm_v2=SANMEncoder_v2, | |||||
| sanm_chunk=SANMEncoderChunk, | |||||
| ), | |||||
| type_check=AbsEncoder, | |||||
| default='rnn', | |||||
| ) | |||||
| postencoder_choices = ClassChoices( | |||||
| name='postencoder', | |||||
| classes=dict( | |||||
| hugging_face_transformers=HuggingFaceTransformersPostEncoder, ), | |||||
| type_check=AbsPostEncoder, | |||||
| default=None, | |||||
| optional=True, | |||||
| ) | |||||
| decoder_choices = ClassChoices( | |||||
| 'decoder', | |||||
| classes=dict( | |||||
| transformer=TransformerDecoder, | |||||
| lightweight_conv=LightweightConvolutionTransformerDecoder, | |||||
| lightweight_conv2d=LightweightConvolution2DTransformerDecoder, | |||||
| dynamic_conv=DynamicConvolutionTransformerDecoder, | |||||
| dynamic_conv2d=DynamicConvolution2DTransformerDecoder, | |||||
| rnn=RNNDecoder, | |||||
| transducer=TransducerDecoder, | |||||
| mlm=MLMDecoder, | |||||
| paraformer_decoder=ParaformerDecoder, | |||||
| paraformer_decoder_bert_embed=ParaformerDecoderBertEmbed, | |||||
| ), | |||||
| type_check=AbsDecoder, | |||||
| default='rnn', | |||||
| ) | |||||
| predictor_choices = ClassChoices( | |||||
| name='predictor', | |||||
| classes=dict( | |||||
| cif_predictor=cif_predictor, | |||||
| ctc_predictor=None, | |||||
| ), | |||||
| type_check=None, | |||||
| default='cif_predictor', | |||||
| optional=True, | |||||
| ) | |||||
| class ASRTask(AbsTask): | |||||
| # If you need more than one optimizers, change this value | |||||
| num_optimizers: int = 1 | |||||
| # Add variable objects configurations | |||||
| class_choices_list = [ | |||||
| # --frontend and --frontend_conf | |||||
| frontend_choices, | |||||
| # --specaug and --specaug_conf | |||||
| specaug_choices, | |||||
| # --normalize and --normalize_conf | |||||
| normalize_choices, | |||||
| # --model and --model_conf | |||||
| model_choices, | |||||
| # --preencoder and --preencoder_conf | |||||
| preencoder_choices, | |||||
| # --encoder and --encoder_conf | |||||
| encoder_choices, | |||||
| # --postencoder and --postencoder_conf | |||||
| postencoder_choices, | |||||
| # --decoder and --decoder_conf | |||||
| decoder_choices, | |||||
| ] | |||||
| # If you need to modify train() or eval() procedures, change Trainer class here | |||||
| trainer = Trainer | |||||
| @classmethod | |||||
| def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||||
| group = parser.add_argument_group(description='Task related') | |||||
| # NOTE(kamo): add_arguments(..., required=True) can't be used | |||||
| # to provide --print_config mode. Instead of it, do as | |||||
| required = parser.get_default('required') | |||||
| required += ['token_list'] | |||||
| group.add_argument( | |||||
| '--token_list', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='A text mapping int-id to token', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--init', | |||||
| type=lambda x: str_or_none(x.lower()), | |||||
| default=None, | |||||
| help='The initialization method', | |||||
| choices=[ | |||||
| 'chainer', | |||||
| 'xavier_uniform', | |||||
| 'xavier_normal', | |||||
| 'kaiming_uniform', | |||||
| 'kaiming_normal', | |||||
| None, | |||||
| ], | |||||
| ) | |||||
| group.add_argument( | |||||
| '--input_size', | |||||
| type=int_or_none, | |||||
| default=None, | |||||
| help='The number of input dimension of the feature', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--ctc_conf', | |||||
| action=NestedDictAction, | |||||
| default=get_default_kwargs(CTC), | |||||
| help='The keyword arguments for CTC class.', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--joint_net_conf', | |||||
| action=NestedDictAction, | |||||
| default=None, | |||||
| help='The keyword arguments for joint network class.', | |||||
| ) | |||||
| group = parser.add_argument_group(description='Preprocess related') | |||||
| group.add_argument( | |||||
| '--use_preprocessor', | |||||
| type=str2bool, | |||||
| default=True, | |||||
| help='Apply preprocessing to data or not', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--token_type', | |||||
| type=str, | |||||
| default='bpe', | |||||
| choices=['bpe', 'char', 'word', 'phn'], | |||||
| help='The text will be tokenized ' | |||||
| 'in the specified level token', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--bpemodel', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='The model file of sentencepiece', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--non_linguistic_symbols', | |||||
| type=str_or_none, | |||||
| help='non_linguistic_symbols file path', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--cleaner', | |||||
| type=str_or_none, | |||||
| choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||||
| default=None, | |||||
| help='Apply text cleaning', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--g2p', | |||||
| type=str_or_none, | |||||
| choices=g2p_choices, | |||||
| default=None, | |||||
| help='Specify g2p method if --token_type=phn', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--speech_volume_normalize', | |||||
| type=float_or_none, | |||||
| default=None, | |||||
| help='Scale the maximum amplitude to the given value.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--rir_scp', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='The file path of rir scp file.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--rir_apply_prob', | |||||
| type=float, | |||||
| default=1.0, | |||||
| help='THe probability for applying RIR convolution.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--noise_scp', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='The file path of noise scp file.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--noise_apply_prob', | |||||
| type=float, | |||||
| default=1.0, | |||||
| help='The probability applying Noise adding.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--noise_db_range', | |||||
| type=str, | |||||
| default='13_15', | |||||
| help='The range of noise decibel level.', | |||||
| ) | |||||
| for class_choices in cls.class_choices_list: | |||||
| # Append --<name> and --<name>_conf. | |||||
| # e.g. --encoder and --encoder_conf | |||||
| class_choices.add_arguments(group) | |||||
| @classmethod | |||||
| def build_collate_fn( | |||||
| cls, args: argparse.Namespace, train: bool | |||||
| ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||||
| List[str], Dict[str, torch.Tensor]], ]: | |||||
| assert check_argument_types() | |||||
| # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||||
| return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||||
| @classmethod | |||||
| def build_preprocess_fn( | |||||
| cls, args: argparse.Namespace, train: bool | |||||
| ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||||
| assert check_argument_types() | |||||
| if args.use_preprocessor: | |||||
| retval = CommonPreprocessor( | |||||
| train=train, | |||||
| token_type=args.token_type, | |||||
| token_list=args.token_list, | |||||
| bpemodel=args.bpemodel, | |||||
| non_linguistic_symbols=args.non_linguistic_symbols, | |||||
| text_cleaner=args.cleaner, | |||||
| g2p_type=args.g2p, | |||||
| # NOTE(kamo): Check attribute existence for backward compatibility | |||||
| rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||||
| rir_apply_prob=args.rir_apply_prob if hasattr( | |||||
| args, 'rir_apply_prob') else 1.0, | |||||
| noise_scp=args.noise_scp | |||||
| if hasattr(args, 'noise_scp') else None, | |||||
| noise_apply_prob=args.noise_apply_prob if hasattr( | |||||
| args, 'noise_apply_prob') else 1.0, | |||||
| noise_db_range=args.noise_db_range if hasattr( | |||||
| args, 'noise_db_range') else '13_15', | |||||
| speech_volume_normalize=args.speech_volume_normalize | |||||
| if hasattr(args, 'rir_scp') else None, | |||||
| ) | |||||
| else: | |||||
| retval = None | |||||
| assert check_return_type(retval) | |||||
| return retval | |||||
| @classmethod | |||||
| def required_data_names(cls, | |||||
| train: bool = True, | |||||
| inference: bool = False) -> Tuple[str, ...]: | |||||
| if not inference: | |||||
| retval = ('speech', 'text') | |||||
| else: | |||||
| # Recognition mode | |||||
| retval = ('speech', ) | |||||
| return retval | |||||
| @classmethod | |||||
| def optional_data_names(cls, | |||||
| train: bool = True, | |||||
| inference: bool = False) -> Tuple[str, ...]: | |||||
| retval = () | |||||
| assert check_return_type(retval) | |||||
| return retval | |||||
| @classmethod | |||||
| def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: | |||||
| assert check_argument_types() | |||||
| if isinstance(args.token_list, str): | |||||
| with open(args.token_list, encoding='utf-8') as f: | |||||
| token_list = [line.rstrip() for line in f] | |||||
| # Overwriting token_list to keep it as "portable". | |||||
| args.token_list = list(token_list) | |||||
| elif isinstance(args.token_list, (tuple, list)): | |||||
| token_list = list(args.token_list) | |||||
| else: | |||||
| raise RuntimeError('token_list must be str or list') | |||||
| vocab_size = len(token_list) | |||||
| logger.info(f'Vocabulary size: {vocab_size }') | |||||
| # 1. frontend | |||||
| if args.input_size is None: | |||||
| # Extract features in the model | |||||
| frontend_class = frontend_choices.get_class(args.frontend) | |||||
| frontend = frontend_class(**args.frontend_conf) | |||||
| input_size = frontend.output_size() | |||||
| else: | |||||
| # Give features from data-loader | |||||
| args.frontend = None | |||||
| args.frontend_conf = {} | |||||
| frontend = None | |||||
| input_size = args.input_size | |||||
| # 2. Data augmentation for spectrogram | |||||
| if args.specaug is not None: | |||||
| specaug_class = specaug_choices.get_class(args.specaug) | |||||
| specaug = specaug_class(**args.specaug_conf) | |||||
| else: | |||||
| specaug = None | |||||
| # 3. Normalization layer | |||||
| if args.normalize is not None: | |||||
| normalize_class = normalize_choices.get_class(args.normalize) | |||||
| normalize = normalize_class(**args.normalize_conf) | |||||
| else: | |||||
| normalize = None | |||||
| # 4. Pre-encoder input block | |||||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
| if getattr(args, 'preencoder', None) is not None: | |||||
| preencoder_class = preencoder_choices.get_class(args.preencoder) | |||||
| preencoder = preencoder_class(**args.preencoder_conf) | |||||
| input_size = preencoder.output_size() | |||||
| else: | |||||
| preencoder = None | |||||
| # 4. Encoder | |||||
| encoder_class = encoder_choices.get_class(args.encoder) | |||||
| encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||||
| # 5. Post-encoder block | |||||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
| encoder_output_size = encoder.output_size() | |||||
| if getattr(args, 'postencoder', None) is not None: | |||||
| postencoder_class = postencoder_choices.get_class(args.postencoder) | |||||
| postencoder = postencoder_class( | |||||
| input_size=encoder_output_size, **args.postencoder_conf) | |||||
| encoder_output_size = postencoder.output_size() | |||||
| else: | |||||
| postencoder = None | |||||
| # 5. Decoder | |||||
| decoder_class = decoder_choices.get_class(args.decoder) | |||||
| if args.decoder == 'transducer': | |||||
| decoder = decoder_class( | |||||
| vocab_size, | |||||
| embed_pad=0, | |||||
| **args.decoder_conf, | |||||
| ) | |||||
| joint_network = JointNetwork( | |||||
| vocab_size, | |||||
| encoder.output_size(), | |||||
| decoder.dunits, | |||||
| **args.joint_net_conf, | |||||
| ) | |||||
| else: | |||||
| decoder = decoder_class( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| **args.decoder_conf, | |||||
| ) | |||||
| joint_network = None | |||||
| # 6. CTC | |||||
| ctc = CTC( | |||||
| odim=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| **args.ctc_conf) | |||||
| # 7. Build model | |||||
| try: | |||||
| model_class = model_choices.get_class(args.model) | |||||
| except AttributeError: | |||||
| model_class = model_choices.get_class('espnet') | |||||
| model = model_class( | |||||
| vocab_size=vocab_size, | |||||
| frontend=frontend, | |||||
| specaug=specaug, | |||||
| normalize=normalize, | |||||
| preencoder=preencoder, | |||||
| encoder=encoder, | |||||
| postencoder=postencoder, | |||||
| decoder=decoder, | |||||
| ctc=ctc, | |||||
| joint_network=joint_network, | |||||
| token_list=token_list, | |||||
| **args.model_conf, | |||||
| ) | |||||
| # FIXME(kamo): Should be done in model? | |||||
| # 8. Initialize | |||||
| if args.init is not None: | |||||
| initialize(model, args.init) | |||||
| assert check_return_type(model) | |||||
| return model | |||||
| class ASRTaskNAR(AbsTask): | |||||
| # If you need more than one optimizers, change this value | |||||
| num_optimizers: int = 1 | |||||
| # Add variable objects configurations | |||||
| class_choices_list = [ | |||||
| # --frontend and --frontend_conf | |||||
| frontend_choices, | |||||
| # --specaug and --specaug_conf | |||||
| specaug_choices, | |||||
| # --normalize and --normalize_conf | |||||
| normalize_choices, | |||||
| # --model and --model_conf | |||||
| model_choices, | |||||
| # --preencoder and --preencoder_conf | |||||
| preencoder_choices, | |||||
| # --encoder and --encoder_conf | |||||
| encoder_choices, | |||||
| # --postencoder and --postencoder_conf | |||||
| postencoder_choices, | |||||
| # --decoder and --decoder_conf | |||||
| decoder_choices, | |||||
| # --predictor and --predictor_conf | |||||
| predictor_choices, | |||||
| ] | |||||
| # If you need to modify train() or eval() procedures, change Trainer class here | |||||
| trainer = Trainer | |||||
| @classmethod | |||||
| def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||||
| group = parser.add_argument_group(description='Task related') | |||||
| # NOTE(kamo): add_arguments(..., required=True) can't be used | |||||
| # to provide --print_config mode. Instead of it, do as | |||||
| required = parser.get_default('required') | |||||
| required += ['token_list'] | |||||
| group.add_argument( | |||||
| '--token_list', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='A text mapping int-id to token', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--init', | |||||
| type=lambda x: str_or_none(x.lower()), | |||||
| default=None, | |||||
| help='The initialization method', | |||||
| choices=[ | |||||
| 'chainer', | |||||
| 'xavier_uniform', | |||||
| 'xavier_normal', | |||||
| 'kaiming_uniform', | |||||
| 'kaiming_normal', | |||||
| None, | |||||
| ], | |||||
| ) | |||||
| group.add_argument( | |||||
| '--input_size', | |||||
| type=int_or_none, | |||||
| default=None, | |||||
| help='The number of input dimension of the feature', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--ctc_conf', | |||||
| action=NestedDictAction, | |||||
| default=get_default_kwargs(CTC), | |||||
| help='The keyword arguments for CTC class.', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--joint_net_conf', | |||||
| action=NestedDictAction, | |||||
| default=None, | |||||
| help='The keyword arguments for joint network class.', | |||||
| ) | |||||
| group = parser.add_argument_group(description='Preprocess related') | |||||
| group.add_argument( | |||||
| '--use_preprocessor', | |||||
| type=str2bool, | |||||
| default=True, | |||||
| help='Apply preprocessing to data or not', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--token_type', | |||||
| type=str, | |||||
| default='bpe', | |||||
| choices=['bpe', 'char', 'word', 'phn'], | |||||
| help='The text will be tokenized ' | |||||
| 'in the specified level token', | |||||
| ) | |||||
| group.add_argument( | |||||
| '--bpemodel', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='The model file of sentencepiece', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--non_linguistic_symbols', | |||||
| type=str_or_none, | |||||
| help='non_linguistic_symbols file path', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--cleaner', | |||||
| type=str_or_none, | |||||
| choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||||
| default=None, | |||||
| help='Apply text cleaning', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--g2p', | |||||
| type=str_or_none, | |||||
| choices=g2p_choices, | |||||
| default=None, | |||||
| help='Specify g2p method if --token_type=phn', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--speech_volume_normalize', | |||||
| type=float_or_none, | |||||
| default=None, | |||||
| help='Scale the maximum amplitude to the given value.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--rir_scp', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='The file path of rir scp file.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--rir_apply_prob', | |||||
| type=float, | |||||
| default=1.0, | |||||
| help='THe probability for applying RIR convolution.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--noise_scp', | |||||
| type=str_or_none, | |||||
| default=None, | |||||
| help='The file path of noise scp file.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--noise_apply_prob', | |||||
| type=float, | |||||
| default=1.0, | |||||
| help='The probability applying Noise adding.', | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--noise_db_range', | |||||
| type=str, | |||||
| default='13_15', | |||||
| help='The range of noise decibel level.', | |||||
| ) | |||||
| for class_choices in cls.class_choices_list: | |||||
| # Append --<name> and --<name>_conf. | |||||
| # e.g. --encoder and --encoder_conf | |||||
| class_choices.add_arguments(group) | |||||
| @classmethod | |||||
| def build_collate_fn( | |||||
| cls, args: argparse.Namespace, train: bool | |||||
| ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||||
| List[str], Dict[str, torch.Tensor]], ]: | |||||
| assert check_argument_types() | |||||
| # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||||
| return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||||
| @classmethod | |||||
| def build_preprocess_fn( | |||||
| cls, args: argparse.Namespace, train: bool | |||||
| ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||||
| assert check_argument_types() | |||||
| if args.use_preprocessor: | |||||
| retval = CommonPreprocessor( | |||||
| train=train, | |||||
| token_type=args.token_type, | |||||
| token_list=args.token_list, | |||||
| bpemodel=args.bpemodel, | |||||
| non_linguistic_symbols=args.non_linguistic_symbols, | |||||
| text_cleaner=args.cleaner, | |||||
| g2p_type=args.g2p, | |||||
| # NOTE(kamo): Check attribute existence for backward compatibility | |||||
| rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||||
| rir_apply_prob=args.rir_apply_prob if hasattr( | |||||
| args, 'rir_apply_prob') else 1.0, | |||||
| noise_scp=args.noise_scp | |||||
| if hasattr(args, 'noise_scp') else None, | |||||
| noise_apply_prob=args.noise_apply_prob if hasattr( | |||||
| args, 'noise_apply_prob') else 1.0, | |||||
| noise_db_range=args.noise_db_range if hasattr( | |||||
| args, 'noise_db_range') else '13_15', | |||||
| speech_volume_normalize=args.speech_volume_normalize | |||||
| if hasattr(args, 'rir_scp') else None, | |||||
| ) | |||||
| else: | |||||
| retval = None | |||||
| assert check_return_type(retval) | |||||
| return retval | |||||
| @classmethod | |||||
| def required_data_names(cls, | |||||
| train: bool = True, | |||||
| inference: bool = False) -> Tuple[str, ...]: | |||||
| if not inference: | |||||
| retval = ('speech', 'text') | |||||
| else: | |||||
| # Recognition mode | |||||
| retval = ('speech', ) | |||||
| return retval | |||||
| @classmethod | |||||
| def optional_data_names(cls, | |||||
| train: bool = True, | |||||
| inference: bool = False) -> Tuple[str, ...]: | |||||
| retval = () | |||||
| assert check_return_type(retval) | |||||
| return retval | |||||
| @classmethod | |||||
| def build_model(cls, args: argparse.Namespace): | |||||
| assert check_argument_types() | |||||
| if isinstance(args.token_list, str): | |||||
| with open(args.token_list, encoding='utf-8') as f: | |||||
| token_list = [line.rstrip() for line in f] | |||||
| # Overwriting token_list to keep it as "portable". | |||||
| args.token_list = list(token_list) | |||||
| elif isinstance(args.token_list, (tuple, list)): | |||||
| token_list = list(args.token_list) | |||||
| else: | |||||
| raise RuntimeError('token_list must be str or list') | |||||
| vocab_size = len(token_list) | |||||
| # logger.info(f'Vocabulary size: {vocab_size }') | |||||
| # 1. frontend | |||||
| if args.input_size is None: | |||||
| # Extract features in the model | |||||
| frontend_class = frontend_choices.get_class(args.frontend) | |||||
| frontend = frontend_class(**args.frontend_conf) | |||||
| input_size = frontend.output_size() | |||||
| else: | |||||
| # Give features from data-loader | |||||
| args.frontend = None | |||||
| args.frontend_conf = {} | |||||
| frontend = None | |||||
| input_size = args.input_size | |||||
| # 2. Data augmentation for spectrogram | |||||
| if args.specaug is not None: | |||||
| specaug_class = specaug_choices.get_class(args.specaug) | |||||
| specaug = specaug_class(**args.specaug_conf) | |||||
| else: | |||||
| specaug = None | |||||
| # 3. Normalization layer | |||||
| if args.normalize is not None: | |||||
| normalize_class = normalize_choices.get_class(args.normalize) | |||||
| normalize = normalize_class(**args.normalize_conf) | |||||
| else: | |||||
| normalize = None | |||||
| # 4. Pre-encoder input block | |||||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
| if getattr(args, 'preencoder', None) is not None: | |||||
| preencoder_class = preencoder_choices.get_class(args.preencoder) | |||||
| preencoder = preencoder_class(**args.preencoder_conf) | |||||
| input_size = preencoder.output_size() | |||||
| else: | |||||
| preencoder = None | |||||
| # 4. Encoder | |||||
| encoder_class = encoder_choices.get_class(args.encoder) | |||||
| encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||||
| # 5. Post-encoder block | |||||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||||
| encoder_output_size = encoder.output_size() | |||||
| if getattr(args, 'postencoder', None) is not None: | |||||
| postencoder_class = postencoder_choices.get_class(args.postencoder) | |||||
| postencoder = postencoder_class( | |||||
| input_size=encoder_output_size, **args.postencoder_conf) | |||||
| encoder_output_size = postencoder.output_size() | |||||
| else: | |||||
| postencoder = None | |||||
| # 5. Decoder | |||||
| decoder_class = decoder_choices.get_class(args.decoder) | |||||
| if args.decoder == 'transducer': | |||||
| decoder = decoder_class( | |||||
| vocab_size, | |||||
| embed_pad=0, | |||||
| **args.decoder_conf, | |||||
| ) | |||||
| joint_network = JointNetwork( | |||||
| vocab_size, | |||||
| encoder.output_size(), | |||||
| decoder.dunits, | |||||
| **args.joint_net_conf, | |||||
| ) | |||||
| else: | |||||
| decoder = decoder_class( | |||||
| vocab_size=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| **args.decoder_conf, | |||||
| ) | |||||
| joint_network = None | |||||
| # 6. CTC | |||||
| ctc = CTC( | |||||
| odim=vocab_size, | |||||
| encoder_output_size=encoder_output_size, | |||||
| **args.ctc_conf) | |||||
| predictor_class = predictor_choices.get_class(args.predictor) | |||||
| predictor = predictor_class(**args.predictor_conf) | |||||
| # 7. Build model | |||||
| try: | |||||
| model_class = model_choices.get_class(args.model) | |||||
| except AttributeError: | |||||
| model_class = model_choices.get_class('espnet') | |||||
| model = model_class( | |||||
| vocab_size=vocab_size, | |||||
| frontend=frontend, | |||||
| specaug=specaug, | |||||
| normalize=normalize, | |||||
| preencoder=preencoder, | |||||
| encoder=encoder, | |||||
| postencoder=postencoder, | |||||
| decoder=decoder, | |||||
| ctc=ctc, | |||||
| joint_network=joint_network, | |||||
| token_list=token_list, | |||||
| predictor=predictor, | |||||
| **args.model_conf, | |||||
| ) | |||||
| # FIXME(kamo): Should be done in model? | |||||
| # 8. Initialize | |||||
| if args.init is not None: | |||||
| initialize(model, args.init) | |||||
| assert check_return_type(model) | |||||
| return model | |||||
| @@ -0,0 +1,217 @@ | |||||
| import io | |||||
| import os | |||||
| import shutil | |||||
| import threading | |||||
| from typing import Any, Dict, List, Union | |||||
| import yaml | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import WavToScp | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .asr_engine import asr_env_checking, asr_inference_paraformer_espnet | |||||
| from .asr_engine.common import asr_utils | |||||
| __all__ = ['AutomaticSpeechRecognitionPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) | |||||
| class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| """ASR Pipeline | |||||
| """ | |||||
| def __init__(self, | |||||
| model: Union[List[Model], List[str]] = None, | |||||
| preprocessor: WavToScp = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create an asr pipeline for prediction | |||||
| """ | |||||
| assert model is not None, 'asr model should be provided' | |||||
| model_list: List = [] | |||||
| if isinstance(model[0], Model): | |||||
| model_list = model | |||||
| else: | |||||
| model_list.append(Model.from_pretrained(model[0])) | |||||
| if len(model) == 2 and model[1] is not None: | |||||
| model_list.append(Model.from_pretrained(model[1])) | |||||
| super().__init__(model=model_list, preprocessor=preprocessor, **kwargs) | |||||
| self._preprocessor = preprocessor | |||||
| self._am_model = model_list[0] | |||||
| if len(model_list) == 2 and model_list[1] is not None: | |||||
| self._lm_model = model_list[1] | |||||
| def __call__(self, | |||||
| wav_path: str, | |||||
| recog_type: str = None, | |||||
| audio_format: str = None, | |||||
| workspace: str = None) -> Dict[str, Any]: | |||||
| assert len(wav_path) > 0, 'wav_path should be provided' | |||||
| self._recog_type = recog_type | |||||
| self._audio_format = audio_format | |||||
| self._workspace = workspace | |||||
| self._wav_path = wav_path | |||||
| if recog_type is None or audio_format is None or workspace is None: | |||||
| self._recog_type, self._audio_format, self._workspace, self._wav_path = asr_utils.type_checking( | |||||
| wav_path, recog_type, audio_format, workspace) | |||||
| if self._preprocessor is None: | |||||
| self._preprocessor = WavToScp(workspace=self._workspace) | |||||
| output = self._preprocessor.forward(self._am_model.forward(), | |||||
| self._recog_type, | |||||
| self._audio_format, self._wav_path) | |||||
| output = self.forward(output) | |||||
| rst = self.postprocess(output) | |||||
| return rst | |||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """Decoding | |||||
| """ | |||||
| j: int = 0 | |||||
| process = [] | |||||
| while j < inputs['thread_count']: | |||||
| data_cmd: Sequence[Tuple[str, str, str]] | |||||
| if inputs['audio_format'] == 'wav': | |||||
| data_cmd = [(os.path.join(inputs['workspace'], | |||||
| 'data.' + str(j) + '.scp'), 'speech', | |||||
| 'sound')] | |||||
| elif inputs['audio_format'] == 'kaldi_ark': | |||||
| data_cmd = [(os.path.join(inputs['workspace'], | |||||
| 'data.' + str(j) + '.scp'), 'speech', | |||||
| 'kaldi_ark')] | |||||
| output_dir: str = os.path.join(inputs['output'], | |||||
| 'output.' + str(j)) | |||||
| if not os.path.exists(output_dir): | |||||
| os.mkdir(output_dir) | |||||
| config_file = open(inputs['asr_model_config']) | |||||
| root = yaml.full_load(config_file) | |||||
| config_file.close() | |||||
| frontend_conf = None | |||||
| if 'frontend_conf' in root: | |||||
| frontend_conf = root['frontend_conf'] | |||||
| cmd = { | |||||
| 'model_type': inputs['model_type'], | |||||
| 'beam_size': root['beam_size'], | |||||
| 'penalty': root['penalty'], | |||||
| 'maxlenratio': root['maxlenratio'], | |||||
| 'minlenratio': root['minlenratio'], | |||||
| 'ctc_weight': root['ctc_weight'], | |||||
| 'lm_weight': root['lm_weight'], | |||||
| 'output_dir': output_dir, | |||||
| 'ngpu': 0, | |||||
| 'log_level': 'ERROR', | |||||
| 'data_path_and_name_and_type': data_cmd, | |||||
| 'asr_train_config': inputs['am_model_config'], | |||||
| 'asr_model_file': inputs['am_model_path'], | |||||
| 'batch_size': inputs['model_config']['batch_size'], | |||||
| 'frontend_conf': frontend_conf | |||||
| } | |||||
| thread = AsrInferenceThread(j, cmd) | |||||
| thread.start() | |||||
| j += 1 | |||||
| process.append(thread) | |||||
| for p in process: | |||||
| p.join() | |||||
| return inputs | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the asr results | |||||
| """ | |||||
| rst = {'rec_result': 'None'} | |||||
| # single wav task | |||||
| if inputs['recog_type'] == 'wav' and inputs['audio_format'] == 'wav': | |||||
| text_file: str = os.path.join(inputs['output'], 'output.0', | |||||
| '1best_recog', 'text') | |||||
| if os.path.exists(text_file): | |||||
| f = open(text_file, 'r') | |||||
| result_str: str = f.readline() | |||||
| f.close() | |||||
| if len(result_str) > 0: | |||||
| result_list = result_str.split() | |||||
| if len(result_list) >= 2: | |||||
| rst['rec_result'] = result_list[1] | |||||
| # run with datasets, and audio format is waveform or kaldi_ark | |||||
| elif inputs['recog_type'] != 'wav': | |||||
| inputs['reference_text'] = self._ref_text_tidy(inputs) | |||||
| inputs['datasets_result'] = asr_utils.compute_wer( | |||||
| inputs['hypothesis_text'], inputs['reference_text']) | |||||
| else: | |||||
| raise ValueError('recog_type and audio_format are mismatching') | |||||
| if 'datasets_result' in inputs: | |||||
| rst['datasets_result'] = inputs['datasets_result'] | |||||
| # remove workspace dir (.tmp) | |||||
| if os.path.exists(self._workspace): | |||||
| shutil.rmtree(self._workspace) | |||||
| return rst | |||||
| def _ref_text_tidy(self, inputs: Dict[str, Any]) -> str: | |||||
| ref_text: str = os.path.join(inputs['output'], 'text.ref') | |||||
| k: int = 0 | |||||
| while k < inputs['thread_count']: | |||||
| output_text = os.path.join(inputs['output'], 'output.' + str(k), | |||||
| '1best_recog', 'text') | |||||
| if os.path.exists(output_text): | |||||
| with open(output_text, 'r', encoding='utf-8') as i: | |||||
| lines = i.readlines() | |||||
| with open(ref_text, 'a', encoding='utf-8') as o: | |||||
| for line in lines: | |||||
| o.write(line) | |||||
| k += 1 | |||||
| return ref_text | |||||
| class AsrInferenceThread(threading.Thread): | |||||
| def __init__(self, threadID, cmd): | |||||
| threading.Thread.__init__(self) | |||||
| self._threadID = threadID | |||||
| self._cmd = cmd | |||||
| def run(self): | |||||
| if self._cmd['model_type'] == 'pytorch': | |||||
| asr_inference_paraformer_espnet.asr_inference( | |||||
| batch_size=self._cmd['batch_size'], | |||||
| output_dir=self._cmd['output_dir'], | |||||
| maxlenratio=self._cmd['maxlenratio'], | |||||
| minlenratio=self._cmd['minlenratio'], | |||||
| beam_size=self._cmd['beam_size'], | |||||
| ngpu=self._cmd['ngpu'], | |||||
| ctc_weight=self._cmd['ctc_weight'], | |||||
| lm_weight=self._cmd['lm_weight'], | |||||
| penalty=self._cmd['penalty'], | |||||
| log_level=self._cmd['log_level'], | |||||
| data_path_and_name_and_type=self. | |||||
| _cmd['data_path_and_name_and_type'], | |||||
| asr_train_config=self._cmd['asr_train_config'], | |||||
| asr_model_file=self._cmd['asr_model_file'], | |||||
| frontend_conf=self._cmd['frontend_conf']) | |||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from modelscope.utils.error import AUDIO_IMPORT_ERROR, TENSORFLOW_IMPORT_ERROR | from modelscope.utils.error import AUDIO_IMPORT_ERROR, TENSORFLOW_IMPORT_ERROR | ||||
| from .asr import WavToScp | |||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| from .builder import PREPROCESSORS, build_preprocessor | from .builder import PREPROCESSORS, build_preprocessor | ||||
| from .common import Compose | from .common import Compose | ||||
| @@ -0,0 +1,254 @@ | |||||
| import io | |||||
| import os | |||||
| import shutil | |||||
| from pathlib import Path | |||||
| from typing import Any, Dict, List | |||||
| import yaml | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.utils.constant import Fields | |||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | |||||
| __all__ = ['WavToScp'] | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.audio, module_name=Preprocessors.wav_to_scp) | |||||
| class WavToScp(Preprocessor): | |||||
| """generate audio scp from wave or ark | |||||
| Args: | |||||
| workspace (str): | |||||
| """ | |||||
| def __init__(self, workspace: str = None): | |||||
| # the workspace path | |||||
| if workspace is None or len(workspace) == 0: | |||||
| self._workspace = os.path.join(os.getcwd(), '.tmp') | |||||
| else: | |||||
| self._workspace = workspace | |||||
| if not os.path.exists(self._workspace): | |||||
| os.mkdir(self._workspace) | |||||
| def __call__(self, | |||||
| model: List[Model] = None, | |||||
| recog_type: str = None, | |||||
| audio_format: str = None, | |||||
| wav_path: str = None) -> Dict[str, Any]: | |||||
| assert len(model) > 0, 'preprocess model is invalid' | |||||
| assert len(recog_type) > 0, 'preprocess recog_type is empty' | |||||
| assert len(audio_format) > 0, 'preprocess audio_format is empty' | |||||
| assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||||
| self._am_model = model[0] | |||||
| if len(model) == 2 and model[1] is not None: | |||||
| self._lm_model = model[1] | |||||
| out = self.forward(self._am_model.forward(), recog_type, audio_format, | |||||
| wav_path) | |||||
| return out | |||||
| def forward(self, model: Dict[str, Any], recog_type: str, | |||||
| audio_format: str, wav_path: str) -> Dict[str, Any]: | |||||
| assert len(recog_type) > 0, 'preprocess recog_type is empty' | |||||
| assert len(audio_format) > 0, 'preprocess audio_format is empty' | |||||
| assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||||
| assert os.path.exists(wav_path), 'preprocess wav_path does not exist' | |||||
| assert len( | |||||
| model['am_model']) > 0, 'preprocess model[am_model] is empty' | |||||
| assert len(model['am_model_path'] | |||||
| ) > 0, 'preprocess model[am_model_path] is empty' | |||||
| assert os.path.exists( | |||||
| model['am_model_path']), 'preprocess am_model_path does not exist' | |||||
| assert len(model['model_workspace'] | |||||
| ) > 0, 'preprocess model[model_workspace] is empty' | |||||
| assert os.path.exists(model['model_workspace'] | |||||
| ), 'preprocess model_workspace does not exist' | |||||
| assert len(model['model_config'] | |||||
| ) > 0, 'preprocess model[model_config] is empty' | |||||
| # the am model name | |||||
| am_model: str = model['am_model'] | |||||
| # the am model file path | |||||
| am_model_path: str = model['am_model_path'] | |||||
| # the recognition model dir path | |||||
| model_workspace: str = model['model_workspace'] | |||||
| # the recognition model config dict | |||||
| global_model_config_dict: str = model['model_config'] | |||||
| rst = { | |||||
| 'workspace': os.path.join(self._workspace, recog_type), | |||||
| 'am_model': am_model, | |||||
| 'am_model_path': am_model_path, | |||||
| 'model_workspace': model_workspace, | |||||
| # the asr type setting, eg: test dev train wav | |||||
| 'recog_type': recog_type, | |||||
| # the asr audio format setting, eg: wav, kaldi_ark | |||||
| 'audio_format': audio_format, | |||||
| # the test wav file path or the dataset path | |||||
| 'wav_path': wav_path, | |||||
| 'model_config': global_model_config_dict | |||||
| } | |||||
| out = self._config_checking(rst) | |||||
| out = self._env_setting(out) | |||||
| if audio_format == 'wav': | |||||
| out = self._scp_generation_from_wav(out) | |||||
| elif audio_format == 'kaldi_ark': | |||||
| out = self._scp_generation_from_ark(out) | |||||
| return out | |||||
| def _config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """config checking | |||||
| """ | |||||
| assert inputs['model_config'].__contains__( | |||||
| 'type'), 'model type does not exist' | |||||
| assert inputs['model_config'].__contains__( | |||||
| 'batch_size'), 'batch_size does not exist' | |||||
| assert inputs['model_config'].__contains__( | |||||
| 'am_model_config'), 'am_model_config does not exist' | |||||
| assert inputs['model_config'].__contains__( | |||||
| 'asr_model_config'), 'asr_model_config does not exist' | |||||
| assert inputs['model_config'].__contains__( | |||||
| 'asr_model_wav_config'), 'asr_model_wav_config does not exist' | |||||
| am_model_config: str = os.path.join( | |||||
| inputs['model_workspace'], | |||||
| inputs['model_config']['am_model_config']) | |||||
| assert os.path.exists( | |||||
| am_model_config), 'am_model_config does not exist' | |||||
| inputs['am_model_config'] = am_model_config | |||||
| asr_model_config: str = os.path.join( | |||||
| inputs['model_workspace'], | |||||
| inputs['model_config']['asr_model_config']) | |||||
| assert os.path.exists( | |||||
| asr_model_config), 'asr_model_config does not exist' | |||||
| asr_model_wav_config: str = os.path.join( | |||||
| inputs['model_workspace'], | |||||
| inputs['model_config']['asr_model_wav_config']) | |||||
| assert os.path.exists( | |||||
| asr_model_wav_config), 'asr_model_wav_config does not exist' | |||||
| inputs['model_type'] = inputs['model_config']['type'] | |||||
| if inputs['audio_format'] == 'wav': | |||||
| inputs['asr_model_config'] = asr_model_wav_config | |||||
| else: | |||||
| inputs['asr_model_config'] = asr_model_config | |||||
| return inputs | |||||
| def _env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| if not os.path.exists(inputs['workspace']): | |||||
| os.mkdir(inputs['workspace']) | |||||
| inputs['output'] = os.path.join(inputs['workspace'], 'logdir') | |||||
| if not os.path.exists(inputs['output']): | |||||
| os.mkdir(inputs['output']) | |||||
| # run with datasets, should set datasets_path and text_path | |||||
| if inputs['recog_type'] != 'wav': | |||||
| inputs['datasets_path'] = inputs['wav_path'] | |||||
| # run with datasets, and audio format is waveform | |||||
| if inputs['audio_format'] == 'wav': | |||||
| inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||||
| 'wav', inputs['recog_type']) | |||||
| inputs['hypothesis_text'] = os.path.join( | |||||
| inputs['datasets_path'], 'transcript', 'data.text') | |||||
| assert os.path.exists(inputs['hypothesis_text'] | |||||
| ), 'hypothesis text does not exist' | |||||
| elif inputs['audio_format'] == 'kaldi_ark': | |||||
| inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||||
| inputs['recog_type']) | |||||
| inputs['hypothesis_text'] = os.path.join( | |||||
| inputs['wav_path'], 'data.text') | |||||
| assert os.path.exists(inputs['hypothesis_text'] | |||||
| ), 'hypothesis text does not exist' | |||||
| return inputs | |||||
| def _scp_generation_from_wav(self, inputs: Dict[str, | |||||
| Any]) -> Dict[str, Any]: | |||||
| """scp generation from waveform files | |||||
| """ | |||||
| # find all waveform files | |||||
| wav_list = [] | |||||
| if inputs['recog_type'] == 'wav': | |||||
| file_path = inputs['wav_path'] | |||||
| if os.path.isfile(file_path): | |||||
| if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||||
| wav_list.append(file_path) | |||||
| else: | |||||
| wav_dir: str = inputs['wav_path'] | |||||
| wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||||
| list_count: int = len(wav_list) | |||||
| inputs['wav_count'] = list_count | |||||
| # store all wav into data.0.scp | |||||
| inputs['thread_count'] = 1 | |||||
| j: int = 0 | |||||
| wav_list_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||||
| with open(wav_list_path, 'a') as f: | |||||
| while j < list_count: | |||||
| wav_file = wav_list[j] | |||||
| wave_scp_content: str = os.path.splitext( | |||||
| os.path.basename(wav_file))[0] | |||||
| wave_scp_content += ' ' + wav_file + '\n' | |||||
| f.write(wave_scp_content) | |||||
| j += 1 | |||||
| return inputs | |||||
| def _scp_generation_from_ark(self, inputs: Dict[str, | |||||
| Any]) -> Dict[str, Any]: | |||||
| """scp generation from kaldi ark file | |||||
| """ | |||||
| inputs['thread_count'] = 1 | |||||
| ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp') | |||||
| ark_file_path = os.path.join(inputs['wav_path'], 'data.ark') | |||||
| assert os.path.exists(ark_scp_path), 'data.scp does not exist' | |||||
| assert os.path.exists(ark_file_path), 'data.ark does not exist' | |||||
| new_ark_scp_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||||
| with open(ark_scp_path, 'r', encoding='utf-8') as f: | |||||
| lines = f.readlines() | |||||
| with open(new_ark_scp_path, 'w', encoding='utf-8') as n: | |||||
| for line in lines: | |||||
| outs = line.strip().split(' ') | |||||
| if len(outs) == 2: | |||||
| key = outs[0] | |||||
| sub = outs[1].split(':') | |||||
| if len(sub) == 2: | |||||
| nums = sub[1] | |||||
| content = key + ' ' + ark_file_path + ':' + nums + '\n' | |||||
| n.write(content) | |||||
| return inputs | |||||
| def _recursion_dir_all_wave(self, wav_list, | |||||
| dir_path: str) -> Dict[str, Any]: | |||||
| dir_files = os.listdir(dir_path) | |||||
| for file in dir_files: | |||||
| file_path = os.path.join(dir_path, file) | |||||
| if os.path.isfile(file_path): | |||||
| if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||||
| wav_list.append(file_path) | |||||
| elif os.path.isdir(file_path): | |||||
| self._recursion_dir_all_wave(wav_list, file_path) | |||||
| return wav_list | |||||
| @@ -1,3 +1,4 @@ | |||||
| espnet==202204 | |||||
| #tts | #tts | ||||
| h5py | h5py | ||||
| inflect | inflect | ||||
| @@ -0,0 +1,199 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tarfile | |||||
| import unittest | |||||
| import requests | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| WAV_FILE = 'data/test/audios/asr_example.wav' | |||||
| LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' | |||||
| LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz' | |||||
| AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz' | |||||
| AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz' | |||||
| def un_tar_gz(fname, dirs): | |||||
| t = tarfile.open(fname) | |||||
| t.extractall(path=dirs) | |||||
| class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self._am_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||||
| # this temporary workspace dir will store waveform files | |||||
| self._workspace = os.path.join(os.getcwd(), '.tmp') | |||||
| if not os.path.exists(self._workspace): | |||||
| os.mkdir(self._workspace) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_wav(self): | |||||
| '''run with single waveform file | |||||
| ''' | |||||
| wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||||
| inference_16k_pipline = pipeline( | |||||
| task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||||
| self.assertTrue(inference_16k_pipline is not None) | |||||
| rec_result = inference_16k_pipline(wav_file_path) | |||||
| self.assertTrue(len(rec_result['rec_result']) > 0) | |||||
| self.assertTrue(rec_result['rec_result'] != 'None') | |||||
| ''' | |||||
| result structure: | |||||
| { | |||||
| 'rec_result': '每一天都要快乐喔' | |||||
| } | |||||
| or | |||||
| { | |||||
| 'rec_result': 'None' | |||||
| } | |||||
| ''' | |||||
| print('test_run_with_wav rec result: ' + rec_result['rec_result']) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_wav_dataset(self): | |||||
| '''run with datasets, and audio format is waveform | |||||
| datasets directory: | |||||
| <dataset_path> | |||||
| wav | |||||
| test # testsets | |||||
| xx.wav | |||||
| ... | |||||
| dev # devsets | |||||
| yy.wav | |||||
| ... | |||||
| train # trainsets | |||||
| zz.wav | |||||
| ... | |||||
| transcript | |||||
| data.text # hypothesis text | |||||
| ''' | |||||
| # downloading pos_testsets file | |||||
| testsets_file_path = os.path.join(self._workspace, | |||||
| LITTLE_TESTSETS_FILE) | |||||
| if not os.path.exists(testsets_file_path): | |||||
| r = requests.get(LITTLE_TESTSETS_URL) | |||||
| with open(testsets_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename( | |||||
| os.path.splitext( | |||||
| os.path.basename(LITTLE_TESTSETS_FILE))[0]))[0] | |||||
| # dataset_path = <cwd>/.tmp/data_aishell/wav/test | |||||
| dataset_path = os.path.join(self._workspace, testsets_dir_name, 'wav', | |||||
| 'test') | |||||
| # untar the dataset_path file | |||||
| if not os.path.exists(dataset_path): | |||||
| un_tar_gz(testsets_file_path, self._workspace) | |||||
| inference_16k_pipline = pipeline( | |||||
| task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||||
| self.assertTrue(inference_16k_pipline is not None) | |||||
| rec_result = inference_16k_pipline(wav_path=dataset_path) | |||||
| self.assertTrue(len(rec_result['datasets_result']) > 0) | |||||
| self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||||
| ''' | |||||
| result structure: | |||||
| { | |||||
| 'rec_result': 'None', | |||||
| 'datasets_result': | |||||
| { | |||||
| 'Wrd': 1654, # the number of words | |||||
| 'Snt': 128, # the number of sentences | |||||
| 'Corr': 1573, # the number of correct words | |||||
| 'Ins': 1, # the number of insert words | |||||
| 'Del': 1, # the number of delete words | |||||
| 'Sub': 80, # the number of substitution words | |||||
| 'wrong_words': 82, # the number of wrong words | |||||
| 'wrong_sentences': 47, # the number of wrong sentences | |||||
| 'Err': 4.96, # WER/CER | |||||
| 'S.Err': 36.72 # SER | |||||
| } | |||||
| } | |||||
| ''' | |||||
| print('test_run_with_wav_dataset datasets result: ') | |||||
| print(rec_result['datasets_result']) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_ark_dataset(self): | |||||
| '''run with datasets, and audio format is kaldi_ark | |||||
| datasets directory: | |||||
| <dataset_path> | |||||
| test # testsets | |||||
| data.ark | |||||
| data.scp | |||||
| data.text | |||||
| dev # devsets | |||||
| data.ark | |||||
| data.scp | |||||
| data.text | |||||
| train # trainsets | |||||
| data.ark | |||||
| data.scp | |||||
| data.text | |||||
| ''' | |||||
| # downloading pos_testsets file | |||||
| testsets_file_path = os.path.join(self._workspace, | |||||
| AISHELL1_TESTSETS_FILE) | |||||
| if not os.path.exists(testsets_file_path): | |||||
| r = requests.get(AISHELL1_TESTSETS_URL) | |||||
| with open(testsets_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename( | |||||
| os.path.splitext( | |||||
| os.path.basename(AISHELL1_TESTSETS_FILE))[0]))[0] | |||||
| # dataset_path = <cwd>/.tmp/aishell1/test | |||||
| dataset_path = os.path.join(self._workspace, testsets_dir_name, 'test') | |||||
| # untar the dataset_path file | |||||
| if not os.path.exists(dataset_path): | |||||
| un_tar_gz(testsets_file_path, self._workspace) | |||||
| inference_16k_pipline = pipeline( | |||||
| task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||||
| self.assertTrue(inference_16k_pipline is not None) | |||||
| rec_result = inference_16k_pipline(wav_path=dataset_path) | |||||
| self.assertTrue(len(rec_result['datasets_result']) > 0) | |||||
| self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||||
| ''' | |||||
| result structure: | |||||
| { | |||||
| 'rec_result': 'None', | |||||
| 'datasets_result': | |||||
| { | |||||
| 'Wrd': 104816, # the number of words | |||||
| 'Snt': 7176, # the number of sentences | |||||
| 'Corr': 99327, # the number of correct words | |||||
| 'Ins': 104, # the number of insert words | |||||
| 'Del': 155, # the number of delete words | |||||
| 'Sub': 5334, # the number of substitution words | |||||
| 'wrong_words': 5593, # the number of wrong words | |||||
| 'wrong_sentences': 2898, # the number of wrong sentences | |||||
| 'Err': 5.34, # WER/CER | |||||
| 'S.Err': 40.38 # SER | |||||
| } | |||||
| } | |||||
| ''' | |||||
| print('test_run_with_ark_dataset datasets result: ') | |||||
| print(rec_result['datasets_result']) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||