Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9410174master
| @@ -20,8 +20,10 @@ class GenericAutomaticSpeechRecognition(Model): | |||
| Args: | |||
| model_dir (str): the model path. | |||
| am_model_name (str): the am model name from configuration.json | |||
| model_config (Dict[str, Any]): the detail config about model from configuration.json | |||
| """ | |||
| super().__init__(model_dir, am_model_name, model_config, *args, | |||
| **kwargs) | |||
| self.model_cfg = { | |||
| # the recognition model dir path | |||
| 'model_workspace': model_dir, | |||
| @@ -312,5 +312,11 @@ TASK_OUTPUTS = { | |||
| # { | |||
| # "text": "this is the text generated by a model." | |||
| # } | |||
| Tasks.visual_question_answering: [OutputKeys.TEXT] | |||
| Tasks.visual_question_answering: [OutputKeys.TEXT], | |||
| # auto_speech_recognition result for a single sample | |||
| # { | |||
| # "text": "每天都要快乐喔" | |||
| # } | |||
| Tasks.auto_speech_recognition: [OutputKeys.TEXT] | |||
| } | |||
| @@ -3,7 +3,7 @@ | |||
| from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR | |||
| try: | |||
| from .asr.asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||
| from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | |||
| from .kws_kwsbp_pipeline import * # noqa F403 | |||
| from .linear_aec_pipeline import LinearAECPipeline | |||
| except ModuleNotFoundError as e: | |||
| @@ -1,21 +0,0 @@ | |||
| import ssl | |||
| import nltk | |||
| try: | |||
| _create_unverified_https_context = ssl._create_unverified_context | |||
| except AttributeError: | |||
| pass | |||
| else: | |||
| ssl._create_default_https_context = _create_unverified_https_context | |||
| try: | |||
| nltk.data.find('taggers/averaged_perceptron_tagger') | |||
| except LookupError: | |||
| nltk.download( | |||
| 'averaged_perceptron_tagger', halt_on_error=False, raise_on_error=True) | |||
| try: | |||
| nltk.data.find('corpora/cmudict') | |||
| except LookupError: | |||
| nltk.download('cmudict', halt_on_error=False, raise_on_error=True) | |||
| @@ -1,690 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| import argparse | |||
| import logging | |||
| import sys | |||
| import time | |||
| from pathlib import Path | |||
| from typing import Any, Optional, Sequence, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer | |||
| from espnet2.asr.transducer.beam_search_transducer import \ | |||
| ExtendedHypothesis as ExtTransHypothesis # noqa: H301 | |||
| from espnet2.asr.transducer.beam_search_transducer import \ | |||
| Hypothesis as TransHypothesis | |||
| from espnet2.fileio.datadir_writer import DatadirWriter | |||
| from espnet2.tasks.lm import LMTask | |||
| from espnet2.text.build_tokenizer import build_tokenizer | |||
| from espnet2.text.token_id_converter import TokenIDConverter | |||
| from espnet2.torch_utils.device_funcs import to_device | |||
| from espnet2.torch_utils.set_all_random_seed import set_all_random_seed | |||
| from espnet2.utils import config_argparse | |||
| from espnet2.utils.types import str2bool, str2triple_str, str_or_none | |||
| from espnet.nets.batch_beam_search import BatchBeamSearch | |||
| from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim | |||
| from espnet.nets.beam_search import BeamSearch, Hypothesis | |||
| from espnet.nets.pytorch_backend.transformer.subsampling import \ | |||
| TooShortUttError | |||
| from espnet.nets.scorer_interface import BatchScorerInterface | |||
| from espnet.nets.scorers.ctc import CTCPrefixScorer | |||
| from espnet.nets.scorers.length_bonus import LengthBonus | |||
| from espnet.utils.cli_utils import get_commandline_args | |||
| from typeguard import check_argument_types | |||
| from .espnet.asr.frontend.wav_frontend import WavFrontend | |||
| from .espnet.tasks.asr import ASRTaskNAR as ASRTask | |||
| class Speech2Text: | |||
| def __init__(self, | |||
| asr_train_config: Union[Path, str] = None, | |||
| asr_model_file: Union[Path, str] = None, | |||
| transducer_conf: dict = None, | |||
| lm_train_config: Union[Path, str] = None, | |||
| lm_file: Union[Path, str] = None, | |||
| ngram_scorer: str = 'full', | |||
| ngram_file: Union[Path, str] = None, | |||
| token_type: str = None, | |||
| bpemodel: str = None, | |||
| device: str = 'cpu', | |||
| maxlenratio: float = 0.0, | |||
| minlenratio: float = 0.0, | |||
| batch_size: int = 1, | |||
| dtype: str = 'float32', | |||
| beam_size: int = 20, | |||
| ctc_weight: float = 0.5, | |||
| lm_weight: float = 1.0, | |||
| ngram_weight: float = 0.9, | |||
| penalty: float = 0.0, | |||
| nbest: int = 1, | |||
| streaming: bool = False, | |||
| frontend_conf: dict = None): | |||
| assert check_argument_types() | |||
| # 1. Build ASR model | |||
| scorers = {} | |||
| asr_model, asr_train_args = ASRTask.build_model_from_file( | |||
| asr_train_config, asr_model_file, device) | |||
| if asr_model.frontend is None and frontend_conf is not None: | |||
| frontend = WavFrontend(**frontend_conf) | |||
| asr_model.frontend = frontend | |||
| asr_model.to(dtype=getattr(torch, dtype)).eval() | |||
| decoder = asr_model.decoder | |||
| ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) | |||
| token_list = asr_model.token_list | |||
| scorers.update( | |||
| decoder=decoder, | |||
| ctc=ctc, | |||
| length_bonus=LengthBonus(len(token_list)), | |||
| ) | |||
| # 2. Build Language model | |||
| if lm_train_config is not None: | |||
| lm, lm_train_args = LMTask.build_model_from_file( | |||
| lm_train_config, lm_file, device) | |||
| scorers['lm'] = lm.lm | |||
| # 3. Build ngram model | |||
| if ngram_file is not None: | |||
| if ngram_scorer == 'full': | |||
| from espnet.nets.scorers.ngram import NgramFullScorer | |||
| ngram = NgramFullScorer(ngram_file, token_list) | |||
| else: | |||
| from espnet.nets.scorers.ngram import NgramPartScorer | |||
| ngram = NgramPartScorer(ngram_file, token_list) | |||
| else: | |||
| ngram = None | |||
| scorers['ngram'] = ngram | |||
| # 4. Build BeamSearch object | |||
| if asr_model.use_transducer_decoder: | |||
| beam_search_transducer = BeamSearchTransducer( | |||
| decoder=asr_model.decoder, | |||
| joint_network=asr_model.joint_network, | |||
| beam_size=beam_size, | |||
| lm=scorers['lm'] if 'lm' in scorers else None, | |||
| lm_weight=lm_weight, | |||
| **transducer_conf, | |||
| ) | |||
| beam_search = None | |||
| else: | |||
| beam_search_transducer = None | |||
| weights = dict( | |||
| decoder=1.0 - ctc_weight, | |||
| ctc=ctc_weight, | |||
| lm=lm_weight, | |||
| ngram=ngram_weight, | |||
| length_bonus=penalty, | |||
| ) | |||
| beam_search = BeamSearch( | |||
| beam_size=beam_size, | |||
| weights=weights, | |||
| scorers=scorers, | |||
| sos=asr_model.sos, | |||
| eos=asr_model.eos, | |||
| vocab_size=len(token_list), | |||
| token_list=token_list, | |||
| pre_beam_score_key=None if ctc_weight == 1.0 else 'full', | |||
| ) | |||
| # TODO(karita): make all scorers batchfied | |||
| if batch_size == 1: | |||
| non_batch = [ | |||
| k for k, v in beam_search.full_scorers.items() | |||
| if not isinstance(v, BatchScorerInterface) | |||
| ] | |||
| if len(non_batch) == 0: | |||
| if streaming: | |||
| beam_search.__class__ = BatchBeamSearchOnlineSim | |||
| beam_search.set_streaming_config(asr_train_config) | |||
| logging.info( | |||
| 'BatchBeamSearchOnlineSim implementation is selected.' | |||
| ) | |||
| else: | |||
| beam_search.__class__ = BatchBeamSearch | |||
| else: | |||
| logging.warning( | |||
| f'As non-batch scorers {non_batch} are found, ' | |||
| f'fall back to non-batch implementation.') | |||
| beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() | |||
| for scorer in scorers.values(): | |||
| if isinstance(scorer, torch.nn.Module): | |||
| scorer.to( | |||
| device=device, dtype=getattr(torch, dtype)).eval() | |||
| # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text | |||
| if token_type is None: | |||
| token_type = asr_train_args.token_type | |||
| if bpemodel is None: | |||
| bpemodel = asr_train_args.bpemodel | |||
| if token_type is None: | |||
| tokenizer = None | |||
| elif token_type == 'bpe': | |||
| if bpemodel is not None: | |||
| tokenizer = build_tokenizer( | |||
| token_type=token_type, bpemodel=bpemodel) | |||
| else: | |||
| tokenizer = None | |||
| else: | |||
| tokenizer = build_tokenizer(token_type=token_type) | |||
| converter = TokenIDConverter(token_list=token_list) | |||
| self.asr_model = asr_model | |||
| self.asr_train_args = asr_train_args | |||
| self.converter = converter | |||
| self.tokenizer = tokenizer | |||
| self.beam_search = beam_search | |||
| self.beam_search_transducer = beam_search_transducer | |||
| self.maxlenratio = maxlenratio | |||
| self.minlenratio = minlenratio | |||
| self.device = device | |||
| self.dtype = dtype | |||
| self.nbest = nbest | |||
| @torch.no_grad() | |||
| def __call__(self, speech: Union[torch.Tensor, np.ndarray]): | |||
| """Inference | |||
| Args: | |||
| data: Input speech data | |||
| Returns: | |||
| text, token, token_int, hyp | |||
| """ | |||
| assert check_argument_types() | |||
| # Input as audio signal | |||
| if isinstance(speech, np.ndarray): | |||
| speech = torch.tensor(speech) | |||
| # data: (Nsamples,) -> (1, Nsamples) | |||
| speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) | |||
| # lengths: (1,) | |||
| lengths = speech.new_full([1], | |||
| dtype=torch.long, | |||
| fill_value=speech.size(1)) | |||
| batch = {'speech': speech, 'speech_lengths': lengths} | |||
| # a. To device | |||
| batch = to_device(batch, device=self.device) | |||
| # b. Forward Encoder | |||
| enc, enc_len = self.asr_model.encode(**batch) | |||
| if isinstance(enc, tuple): | |||
| enc = enc[0] | |||
| assert len(enc) == 1, len(enc) | |||
| predictor_outs = self.asr_model.calc_predictor(enc, enc_len) | |||
| pre_acoustic_embeds, pre_token_length = predictor_outs[ | |||
| 0], predictor_outs[1] | |||
| pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], | |||
| device=pre_acoustic_embeds.device) | |||
| decoder_outs = self.asr_model.cal_decoder_with_predictor( | |||
| enc, enc_len, pre_acoustic_embeds, pre_token_length) | |||
| decoder_out = decoder_outs[0] | |||
| yseq = decoder_out.argmax(dim=-1) | |||
| score = decoder_out.max(dim=-1)[0] | |||
| score = torch.sum(score, dim=-1) | |||
| # pad with mask tokens to ensure compatibility with sos/eos tokens | |||
| yseq = torch.tensor( | |||
| [self.asr_model.sos] + yseq.tolist()[0] + [self.asr_model.eos], | |||
| device=yseq.device) | |||
| nbest_hyps = [Hypothesis(yseq=yseq, score=score)] | |||
| results = [] | |||
| for hyp in nbest_hyps: | |||
| assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp) | |||
| # remove sos/eos and get results | |||
| last_pos = None if self.asr_model.use_transducer_decoder else -1 | |||
| if isinstance(hyp.yseq, list): | |||
| token_int = hyp.yseq[1:last_pos] | |||
| else: | |||
| token_int = hyp.yseq[1:last_pos].tolist() | |||
| # remove blank symbol id, which is assumed to be 0 | |||
| token_int = list(filter(lambda x: x != 0, token_int)) | |||
| # Change integer-ids to tokens | |||
| token = self.converter.ids2tokens(token_int) | |||
| if self.tokenizer is not None: | |||
| text = self.tokenizer.tokens2text(token) | |||
| else: | |||
| text = None | |||
| results.append((text, token, token_int, hyp, speech.size(1))) | |||
| return results | |||
| @staticmethod | |||
| def from_pretrained( | |||
| model_tag: Optional[str] = None, | |||
| **kwargs: Optional[Any], | |||
| ): | |||
| """Build Speech2Text instance from the pretrained model. | |||
| Args: | |||
| model_tag (Optional[str]): Model tag of the pretrained models. | |||
| Currently, the tags of espnet_model_zoo are supported. | |||
| Returns: | |||
| Speech2Text: Speech2Text instance. | |||
| """ | |||
| if model_tag is not None: | |||
| try: | |||
| from espnet_model_zoo.downloader import ModelDownloader | |||
| except ImportError: | |||
| logging.error( | |||
| '`espnet_model_zoo` is not installed. ' | |||
| 'Please install via `pip install -U espnet_model_zoo`.') | |||
| raise | |||
| d = ModelDownloader() | |||
| kwargs.update(**d.download_and_unpack(model_tag)) | |||
| return Speech2Text(**kwargs) | |||
| def inference( | |||
| output_dir: str, | |||
| maxlenratio: float, | |||
| minlenratio: float, | |||
| batch_size: int, | |||
| dtype: str, | |||
| beam_size: int, | |||
| ngpu: int, | |||
| seed: int, | |||
| ctc_weight: float, | |||
| lm_weight: float, | |||
| ngram_weight: float, | |||
| penalty: float, | |||
| nbest: int, | |||
| num_workers: int, | |||
| log_level: Union[int, str], | |||
| data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||
| key_file: Optional[str], | |||
| asr_train_config: Optional[str], | |||
| asr_model_file: Optional[str], | |||
| lm_train_config: Optional[str], | |||
| lm_file: Optional[str], | |||
| word_lm_train_config: Optional[str], | |||
| word_lm_file: Optional[str], | |||
| ngram_file: Optional[str], | |||
| model_tag: Optional[str], | |||
| token_type: Optional[str], | |||
| bpemodel: Optional[str], | |||
| allow_variable_data_keys: bool, | |||
| transducer_conf: Optional[dict], | |||
| streaming: bool, | |||
| frontend_conf: dict = None, | |||
| ): | |||
| assert check_argument_types() | |||
| if batch_size > 1: | |||
| raise NotImplementedError('batch decoding is not implemented') | |||
| if word_lm_train_config is not None: | |||
| raise NotImplementedError('Word LM is not implemented') | |||
| if ngpu > 1: | |||
| raise NotImplementedError('only single GPU decoding is supported') | |||
| logging.basicConfig( | |||
| level=log_level, | |||
| format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||
| ) | |||
| if ngpu >= 1: | |||
| device = 'cuda' | |||
| else: | |||
| device = 'cpu' | |||
| # 1. Set random-seed | |||
| set_all_random_seed(seed) | |||
| # 2. Build speech2text | |||
| speech2text_kwargs = dict( | |||
| asr_train_config=asr_train_config, | |||
| asr_model_file=asr_model_file, | |||
| transducer_conf=transducer_conf, | |||
| lm_train_config=lm_train_config, | |||
| lm_file=lm_file, | |||
| ngram_file=ngram_file, | |||
| token_type=token_type, | |||
| bpemodel=bpemodel, | |||
| device=device, | |||
| maxlenratio=maxlenratio, | |||
| minlenratio=minlenratio, | |||
| dtype=dtype, | |||
| beam_size=beam_size, | |||
| ctc_weight=ctc_weight, | |||
| lm_weight=lm_weight, | |||
| ngram_weight=ngram_weight, | |||
| penalty=penalty, | |||
| nbest=nbest, | |||
| streaming=streaming, | |||
| frontend_conf=frontend_conf, | |||
| ) | |||
| speech2text = Speech2Text.from_pretrained( | |||
| model_tag=model_tag, | |||
| **speech2text_kwargs, | |||
| ) | |||
| # 3. Build data-iterator | |||
| loader = ASRTask.build_streaming_iterator( | |||
| data_path_and_name_and_type, | |||
| dtype=dtype, | |||
| batch_size=batch_size, | |||
| key_file=key_file, | |||
| num_workers=num_workers, | |||
| preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, | |||
| False), | |||
| collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), | |||
| allow_variable_data_keys=allow_variable_data_keys, | |||
| inference=True, | |||
| ) | |||
| forward_time_total = 0.0 | |||
| length_total = 0.0 | |||
| # 7 .Start for-loop | |||
| # FIXME(kamo): The output format should be discussed about | |||
| with DatadirWriter(output_dir) as writer: | |||
| for keys, batch in loader: | |||
| assert isinstance(batch, dict), type(batch) | |||
| assert all(isinstance(s, str) for s in keys), keys | |||
| _bs = len(next(iter(batch.values()))) | |||
| assert len(keys) == _bs, f'{len(keys)} != {_bs}' | |||
| batch = { | |||
| k: v[0] | |||
| for k, v in batch.items() if not k.endswith('_lengths') | |||
| } | |||
| # N-best list of (text, token, token_int, hyp_object) | |||
| try: | |||
| time_beg = time.time() | |||
| results = speech2text(**batch) | |||
| time_end = time.time() | |||
| forward_time = time_end - time_beg | |||
| length = results[0][-1] | |||
| results = [results[0][:-1]] | |||
| forward_time_total += forward_time | |||
| length_total += length | |||
| except TooShortUttError as e: | |||
| logging.warning(f'Utterance {keys} {e}') | |||
| hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) | |||
| results = [[' ', ['<space>'], [2], hyp]] * nbest | |||
| # Only supporting batch_size==1 | |||
| key = keys[0] | |||
| for n, (text, token, token_int, | |||
| hyp) in zip(range(1, nbest + 1), results): | |||
| # Create a directory: outdir/{n}best_recog | |||
| ibest_writer = writer[f'{n}best_recog'] | |||
| # Write the result to each file | |||
| ibest_writer['token'][key] = ' '.join(token) | |||
| ibest_writer['token_int'][key] = ' '.join(map(str, token_int)) | |||
| ibest_writer['score'][key] = str(hyp.score) | |||
| if text is not None: | |||
| ibest_writer['text'][key] = text | |||
| logging.info( | |||
| 'decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}' | |||
| .format(length_total, forward_time_total, | |||
| 100 * forward_time_total / length_total)) | |||
| def get_parser(): | |||
| parser = config_argparse.ArgumentParser( | |||
| description='ASR Decoding', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| # Note(kamo): Use '_' instead of '-' as separator. | |||
| # '-' is confusing if written in yaml. | |||
| parser.add_argument( | |||
| '--log_level', | |||
| type=lambda x: x.upper(), | |||
| default='INFO', | |||
| choices=('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'), | |||
| help='The verbose level of logging', | |||
| ) | |||
| parser.add_argument('--output_dir', type=str, required=True) | |||
| parser.add_argument( | |||
| '--ngpu', | |||
| type=int, | |||
| default=0, | |||
| help='The number of gpus. 0 indicates CPU mode', | |||
| ) | |||
| parser.add_argument('--seed', type=int, default=0, help='Random seed') | |||
| parser.add_argument( | |||
| '--dtype', | |||
| default='float32', | |||
| choices=['float16', 'float32', 'float64'], | |||
| help='Data type', | |||
| ) | |||
| parser.add_argument( | |||
| '--num_workers', | |||
| type=int, | |||
| default=1, | |||
| help='The number of workers used for DataLoader', | |||
| ) | |||
| group = parser.add_argument_group('Input data related') | |||
| group.add_argument( | |||
| '--data_path_and_name_and_type', | |||
| type=str2triple_str, | |||
| required=True, | |||
| action='append', | |||
| ) | |||
| group.add_argument('--key_file', type=str_or_none) | |||
| group.add_argument( | |||
| '--allow_variable_data_keys', type=str2bool, default=False) | |||
| group = parser.add_argument_group('The model configuration related') | |||
| group.add_argument( | |||
| '--asr_train_config', | |||
| type=str, | |||
| help='ASR training configuration', | |||
| ) | |||
| group.add_argument( | |||
| '--asr_model_file', | |||
| type=str, | |||
| help='ASR model parameter file', | |||
| ) | |||
| group.add_argument( | |||
| '--lm_train_config', | |||
| type=str, | |||
| help='LM training configuration', | |||
| ) | |||
| group.add_argument( | |||
| '--lm_file', | |||
| type=str, | |||
| help='LM parameter file', | |||
| ) | |||
| group.add_argument( | |||
| '--word_lm_train_config', | |||
| type=str, | |||
| help='Word LM training configuration', | |||
| ) | |||
| group.add_argument( | |||
| '--word_lm_file', | |||
| type=str, | |||
| help='Word LM parameter file', | |||
| ) | |||
| group.add_argument( | |||
| '--ngram_file', | |||
| type=str, | |||
| help='N-gram parameter file', | |||
| ) | |||
| group.add_argument( | |||
| '--model_tag', | |||
| type=str, | |||
| help='Pretrained model tag. If specify this option, *_train_config and ' | |||
| '*_file will be overwritten', | |||
| ) | |||
| group = parser.add_argument_group('Beam-search related') | |||
| group.add_argument( | |||
| '--batch_size', | |||
| type=int, | |||
| default=1, | |||
| help='The batch size for inference', | |||
| ) | |||
| group.add_argument( | |||
| '--nbest', type=int, default=1, help='Output N-best hypotheses') | |||
| group.add_argument('--beam_size', type=int, default=20, help='Beam size') | |||
| group.add_argument( | |||
| '--penalty', type=float, default=0.0, help='Insertion penalty') | |||
| group.add_argument( | |||
| '--maxlenratio', | |||
| type=float, | |||
| default=0.0, | |||
| help='Input length ratio to obtain max output length. ' | |||
| 'If maxlenratio=0.0 (default), it uses a end-detect ' | |||
| 'function ' | |||
| 'to automatically find maximum hypothesis lengths.' | |||
| 'If maxlenratio<0.0, its absolute value is interpreted' | |||
| 'as a constant max output length', | |||
| ) | |||
| group.add_argument( | |||
| '--minlenratio', | |||
| type=float, | |||
| default=0.0, | |||
| help='Input length ratio to obtain min output length', | |||
| ) | |||
| group.add_argument( | |||
| '--ctc_weight', | |||
| type=float, | |||
| default=0.5, | |||
| help='CTC weight in joint decoding', | |||
| ) | |||
| group.add_argument( | |||
| '--lm_weight', type=float, default=1.0, help='RNNLM weight') | |||
| group.add_argument( | |||
| '--ngram_weight', type=float, default=0.9, help='ngram weight') | |||
| group.add_argument('--streaming', type=str2bool, default=False) | |||
| group.add_argument( | |||
| '--frontend_conf', | |||
| default=None, | |||
| help='', | |||
| ) | |||
| group = parser.add_argument_group('Text converter related') | |||
| group.add_argument( | |||
| '--token_type', | |||
| type=str_or_none, | |||
| default=None, | |||
| choices=['char', 'bpe', None], | |||
| help='The token type for ASR model. ' | |||
| 'If not given, refers from the training args', | |||
| ) | |||
| group.add_argument( | |||
| '--bpemodel', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='The model path of sentencepiece. ' | |||
| 'If not given, refers from the training args', | |||
| ) | |||
| group.add_argument( | |||
| '--transducer_conf', | |||
| default=None, | |||
| help='The keyword arguments for transducer beam search.', | |||
| ) | |||
| return parser | |||
| def asr_inference( | |||
| output_dir: str, | |||
| maxlenratio: float, | |||
| minlenratio: float, | |||
| beam_size: int, | |||
| ngpu: int, | |||
| ctc_weight: float, | |||
| lm_weight: float, | |||
| penalty: float, | |||
| data_path_and_name_and_type: Sequence[Tuple[str, str, str]], | |||
| asr_train_config: Optional[str], | |||
| asr_model_file: Optional[str], | |||
| nbest: int = 1, | |||
| num_workers: int = 1, | |||
| log_level: Union[int, str] = 'INFO', | |||
| batch_size: int = 1, | |||
| dtype: str = 'float32', | |||
| seed: int = 0, | |||
| key_file: Optional[str] = None, | |||
| lm_train_config: Optional[str] = None, | |||
| lm_file: Optional[str] = None, | |||
| word_lm_train_config: Optional[str] = None, | |||
| word_lm_file: Optional[str] = None, | |||
| ngram_file: Optional[str] = None, | |||
| ngram_weight: float = 0.9, | |||
| model_tag: Optional[str] = None, | |||
| token_type: Optional[str] = None, | |||
| bpemodel: Optional[str] = None, | |||
| allow_variable_data_keys: bool = False, | |||
| transducer_conf: Optional[dict] = None, | |||
| streaming: bool = False, | |||
| frontend_conf: dict = None, | |||
| ): | |||
| inference( | |||
| output_dir=output_dir, | |||
| maxlenratio=maxlenratio, | |||
| minlenratio=minlenratio, | |||
| batch_size=batch_size, | |||
| dtype=dtype, | |||
| beam_size=beam_size, | |||
| ngpu=ngpu, | |||
| seed=seed, | |||
| ctc_weight=ctc_weight, | |||
| lm_weight=lm_weight, | |||
| ngram_weight=ngram_weight, | |||
| penalty=penalty, | |||
| nbest=nbest, | |||
| num_workers=num_workers, | |||
| log_level=log_level, | |||
| data_path_and_name_and_type=data_path_and_name_and_type, | |||
| key_file=key_file, | |||
| asr_train_config=asr_train_config, | |||
| asr_model_file=asr_model_file, | |||
| lm_train_config=lm_train_config, | |||
| lm_file=lm_file, | |||
| word_lm_train_config=word_lm_train_config, | |||
| word_lm_file=word_lm_file, | |||
| ngram_file=ngram_file, | |||
| model_tag=model_tag, | |||
| token_type=token_type, | |||
| bpemodel=bpemodel, | |||
| allow_variable_data_keys=allow_variable_data_keys, | |||
| transducer_conf=transducer_conf, | |||
| streaming=streaming, | |||
| frontend_conf=frontend_conf) | |||
| def main(cmd=None): | |||
| print(get_commandline_args(), file=sys.stderr) | |||
| parser = get_parser() | |||
| args = parser.parse_args(cmd) | |||
| kwargs = vars(args) | |||
| kwargs.pop('config', None) | |||
| inference(**kwargs) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -1,193 +0,0 @@ | |||
| import os | |||
| from typing import Any, Dict, List | |||
| import numpy as np | |||
| def type_checking(wav_path: str, | |||
| recog_type: str = None, | |||
| audio_format: str = None, | |||
| workspace: str = None): | |||
| assert os.path.exists(wav_path), f'wav_path:{wav_path} does not exist' | |||
| r_recog_type = recog_type | |||
| r_audio_format = audio_format | |||
| r_workspace = workspace | |||
| r_wav_path = wav_path | |||
| if r_workspace is None or len(r_workspace) == 0: | |||
| r_workspace = os.path.join(os.getcwd(), '.tmp') | |||
| if r_recog_type is None: | |||
| if os.path.isfile(wav_path): | |||
| if wav_path.endswith('.wav') or wav_path.endswith('.WAV'): | |||
| r_recog_type = 'wav' | |||
| r_audio_format = 'wav' | |||
| elif os.path.isdir(wav_path): | |||
| dir_name = os.path.basename(wav_path) | |||
| if 'test' in dir_name: | |||
| r_recog_type = 'test' | |||
| elif 'dev' in dir_name: | |||
| r_recog_type = 'dev' | |||
| elif 'train' in dir_name: | |||
| r_recog_type = 'train' | |||
| if r_audio_format is None: | |||
| if find_file_by_ends(wav_path, '.ark'): | |||
| r_audio_format = 'kaldi_ark' | |||
| elif find_file_by_ends(wav_path, '.wav') or find_file_by_ends( | |||
| wav_path, '.WAV'): | |||
| r_audio_format = 'wav' | |||
| if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': | |||
| # datasets with kaldi_ark file | |||
| r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) | |||
| elif r_audio_format == 'wav' and r_recog_type != 'wav': | |||
| # datasets with waveform files | |||
| r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) | |||
| return r_recog_type, r_audio_format, r_workspace, r_wav_path | |||
| def find_file_by_ends(dir_path: str, ends: str): | |||
| dir_files = os.listdir(dir_path) | |||
| for file in dir_files: | |||
| file_path = os.path.join(dir_path, file) | |||
| if os.path.isfile(file_path): | |||
| if file_path.endswith(ends): | |||
| return True | |||
| elif os.path.isdir(file_path): | |||
| if find_file_by_ends(file_path, ends): | |||
| return True | |||
| return False | |||
| def compute_wer(hyp_text_path: str, ref_text_path: str) -> Dict[str, Any]: | |||
| assert os.path.exists(hyp_text_path), 'hyp_text does not exist' | |||
| assert os.path.exists(ref_text_path), 'ref_text does not exist' | |||
| rst = { | |||
| 'Wrd': 0, | |||
| 'Corr': 0, | |||
| 'Ins': 0, | |||
| 'Del': 0, | |||
| 'Sub': 0, | |||
| 'Snt': 0, | |||
| 'Err': 0.0, | |||
| 'S.Err': 0.0, | |||
| 'wrong_words': 0, | |||
| 'wrong_sentences': 0 | |||
| } | |||
| with open(ref_text_path, 'r', encoding='utf-8') as r: | |||
| r_lines = r.readlines() | |||
| with open(hyp_text_path, 'r', encoding='utf-8') as h: | |||
| h_lines = h.readlines() | |||
| for r_line in r_lines: | |||
| r_line_item = r_line.split() | |||
| r_key = r_line_item[0] | |||
| r_sentence = r_line_item[1] | |||
| for h_line in h_lines: | |||
| # find sentence from hyp text | |||
| if r_key in h_line: | |||
| h_line_item = h_line.split() | |||
| h_sentence = h_line_item[1] | |||
| out_item = compute_wer_by_line(h_sentence, r_sentence) | |||
| rst['Wrd'] += out_item['nwords'] | |||
| rst['Corr'] += out_item['cor'] | |||
| rst['wrong_words'] += out_item['wrong'] | |||
| rst['Ins'] += out_item['ins'] | |||
| rst['Del'] += out_item['del'] | |||
| rst['Sub'] += out_item['sub'] | |||
| rst['Snt'] += 1 | |||
| if out_item['wrong'] > 0: | |||
| rst['wrong_sentences'] += 1 | |||
| break | |||
| if rst['Wrd'] > 0: | |||
| rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) | |||
| if rst['Snt'] > 0: | |||
| rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) | |||
| return rst | |||
| def compute_wer_by_line(hyp: list, ref: list) -> Dict[str, Any]: | |||
| len_hyp = len(hyp) | |||
| len_ref = len(ref) | |||
| cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) | |||
| ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) | |||
| for i in range(len_hyp + 1): | |||
| cost_matrix[i][0] = i | |||
| for j in range(len_ref + 1): | |||
| cost_matrix[0][j] = j | |||
| for i in range(1, len_hyp + 1): | |||
| for j in range(1, len_ref + 1): | |||
| if hyp[i - 1] == ref[j - 1]: | |||
| cost_matrix[i][j] = cost_matrix[i - 1][j - 1] | |||
| else: | |||
| substitution = cost_matrix[i - 1][j - 1] + 1 | |||
| insertion = cost_matrix[i - 1][j] + 1 | |||
| deletion = cost_matrix[i][j - 1] + 1 | |||
| compare_val = [substitution, insertion, deletion] | |||
| min_val = min(compare_val) | |||
| operation_idx = compare_val.index(min_val) + 1 | |||
| cost_matrix[i][j] = min_val | |||
| ops_matrix[i][j] = operation_idx | |||
| match_idx = [] | |||
| i = len_hyp | |||
| j = len_ref | |||
| rst = { | |||
| 'nwords': len_hyp, | |||
| 'cor': 0, | |||
| 'wrong': 0, | |||
| 'ins': 0, | |||
| 'del': 0, | |||
| 'sub': 0 | |||
| } | |||
| while i >= 0 or j >= 0: | |||
| i_idx = max(0, i) | |||
| j_idx = max(0, j) | |||
| if ops_matrix[i_idx][j_idx] == 0: # correct | |||
| if i - 1 >= 0 and j - 1 >= 0: | |||
| match_idx.append((j - 1, i - 1)) | |||
| rst['cor'] += 1 | |||
| i -= 1 | |||
| j -= 1 | |||
| elif ops_matrix[i_idx][j_idx] == 2: # insert | |||
| i -= 1 | |||
| rst['ins'] += 1 | |||
| elif ops_matrix[i_idx][j_idx] == 3: # delete | |||
| j -= 1 | |||
| rst['del'] += 1 | |||
| elif ops_matrix[i_idx][j_idx] == 1: # substitute | |||
| i -= 1 | |||
| j -= 1 | |||
| rst['sub'] += 1 | |||
| if i < 0 and j >= 0: | |||
| rst['del'] += 1 | |||
| elif j < 0 and i >= 0: | |||
| rst['ins'] += 1 | |||
| match_idx.reverse() | |||
| wrong_cnt = cost_matrix[len_hyp][len_ref] | |||
| rst['wrong'] = wrong_cnt | |||
| return rst | |||
| @@ -1,757 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| """Decoder definition.""" | |||
| from typing import Any, List, Sequence, Tuple | |||
| import torch | |||
| from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
| from espnet.nets.pytorch_backend.transformer.attention import \ | |||
| MultiHeadedAttention | |||
| from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer | |||
| from espnet.nets.pytorch_backend.transformer.dynamic_conv import \ | |||
| DynamicConvolution | |||
| from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \ | |||
| DynamicConvolution2D | |||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
| PositionalEncoding | |||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
| from espnet.nets.pytorch_backend.transformer.lightconv import \ | |||
| LightweightConvolution | |||
| from espnet.nets.pytorch_backend.transformer.lightconv2d import \ | |||
| LightweightConvolution2D | |||
| from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask | |||
| from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||
| PositionwiseFeedForward # noqa: H301 | |||
| from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||
| from espnet.nets.scorer_interface import BatchScorerInterface | |||
| from typeguard import check_argument_types | |||
| class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): | |||
| """Base class of Transfomer decoder module. | |||
| Args: | |||
| vocab_size: output dim | |||
| encoder_output_size: dimension of attention | |||
| attention_heads: the number of heads of multi head attention | |||
| linear_units: the number of units of position-wise feed forward | |||
| num_blocks: the number of decoder blocks | |||
| dropout_rate: dropout rate | |||
| self_attention_dropout_rate: dropout rate for attention | |||
| input_layer: input layer type | |||
| use_output_layer: whether to use output layer | |||
| pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||
| normalize_before: whether to use layer_norm before the first block | |||
| concat_after: whether to concat attention layer's input and output | |||
| if True, additional linear will be applied. | |||
| i.e. x -> x + linear(concat(x, att(x))) | |||
| if False, no additional linear will be applied. | |||
| i.e. x -> x + att(x) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__() | |||
| attention_dim = encoder_output_size | |||
| if input_layer == 'embed': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Embedding(vocab_size, attention_dim), | |||
| pos_enc_class(attention_dim, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'linear': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Linear(vocab_size, attention_dim), | |||
| torch.nn.LayerNorm(attention_dim), | |||
| torch.nn.Dropout(dropout_rate), | |||
| torch.nn.ReLU(), | |||
| pos_enc_class(attention_dim, positional_dropout_rate), | |||
| ) | |||
| else: | |||
| raise ValueError( | |||
| f"only 'embed' or 'linear' is supported: {input_layer}") | |||
| self.normalize_before = normalize_before | |||
| if self.normalize_before: | |||
| self.after_norm = LayerNorm(attention_dim) | |||
| if use_output_layer: | |||
| self.output_layer = torch.nn.Linear(attention_dim, vocab_size) | |||
| else: | |||
| self.output_layer = None | |||
| # Must set by the inheritance | |||
| self.decoders = None | |||
| def forward( | |||
| self, | |||
| hs_pad: torch.Tensor, | |||
| hlens: torch.Tensor, | |||
| ys_in_pad: torch.Tensor, | |||
| ys_in_lens: torch.Tensor, | |||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| """Forward decoder. | |||
| Args: | |||
| hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||
| hlens: (batch) | |||
| ys_in_pad: | |||
| input token ids, int64 (batch, maxlen_out) | |||
| if input_layer == "embed" | |||
| input tensor (batch, maxlen_out, #mels) in the other cases | |||
| ys_in_lens: (batch) | |||
| Returns: | |||
| (tuple): tuple containing: | |||
| x: decoded token score before softmax (batch, maxlen_out, token) | |||
| if use_output_layer is True, | |||
| olens: (batch, ) | |||
| """ | |||
| tgt = ys_in_pad | |||
| # tgt_mask: (B, 1, L) | |||
| tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||
| # m: (1, L, L) | |||
| m = subsequent_mask( | |||
| tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||
| # tgt_mask: (B, L, L) | |||
| tgt_mask = tgt_mask & m | |||
| memory = hs_pad | |||
| memory_mask = ( | |||
| ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||
| memory.device) | |||
| # Padding for Longformer | |||
| if memory_mask.shape[-1] != memory.shape[1]: | |||
| padlen = memory.shape[1] - memory_mask.shape[-1] | |||
| memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||
| 'constant', False) | |||
| x = self.embed(tgt) | |||
| x, tgt_mask, memory, memory_mask = self.decoders( | |||
| x, tgt_mask, memory, memory_mask) | |||
| if self.normalize_before: | |||
| x = self.after_norm(x) | |||
| if self.output_layer is not None: | |||
| x = self.output_layer(x) | |||
| olens = tgt_mask.sum(1) | |||
| return x, olens | |||
| def forward_one_step( | |||
| self, | |||
| tgt: torch.Tensor, | |||
| tgt_mask: torch.Tensor, | |||
| memory: torch.Tensor, | |||
| cache: List[torch.Tensor] = None, | |||
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| """Forward one step. | |||
| Args: | |||
| tgt: input token ids, int64 (batch, maxlen_out) | |||
| tgt_mask: input token mask, (batch, maxlen_out) | |||
| dtype=torch.uint8 in PyTorch 1.2- | |||
| dtype=torch.bool in PyTorch 1.2+ (include 1.2) | |||
| memory: encoded memory, float32 (batch, maxlen_in, feat) | |||
| cache: cached output list of (batch, max_time_out-1, size) | |||
| Returns: | |||
| y, cache: NN output value and cache per `self.decoders`. | |||
| y.shape` is (batch, maxlen_out, token) | |||
| """ | |||
| x = self.embed(tgt) | |||
| if cache is None: | |||
| cache = [None] * len(self.decoders) | |||
| new_cache = [] | |||
| for c, decoder in zip(cache, self.decoders): | |||
| x, tgt_mask, memory, memory_mask = decoder( | |||
| x, tgt_mask, memory, None, cache=c) | |||
| new_cache.append(x) | |||
| if self.normalize_before: | |||
| y = self.after_norm(x[:, -1]) | |||
| else: | |||
| y = x[:, -1] | |||
| if self.output_layer is not None: | |||
| y = torch.log_softmax(self.output_layer(y), dim=-1) | |||
| return y, new_cache | |||
| def score(self, ys, state, x): | |||
| """Score.""" | |||
| ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) | |||
| logp, state = self.forward_one_step( | |||
| ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) | |||
| return logp.squeeze(0), state | |||
| def batch_score(self, ys: torch.Tensor, states: List[Any], | |||
| xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: | |||
| """Score new token batch. | |||
| Args: | |||
| ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |||
| states (List[Any]): Scorer states for prefix tokens. | |||
| xs (torch.Tensor): | |||
| The encoder feature that generates ys (n_batch, xlen, n_feat). | |||
| Returns: | |||
| tuple[torch.Tensor, List[Any]]: Tuple of | |||
| batchfied scores for next token with shape of `(n_batch, n_vocab)` | |||
| and next state list for ys. | |||
| """ | |||
| # merge states | |||
| n_batch = len(ys) | |||
| n_layers = len(self.decoders) | |||
| if states[0] is None: | |||
| batch_state = None | |||
| else: | |||
| # transpose state of [batch, layer] into [layer, batch] | |||
| batch_state = [ | |||
| torch.stack([states[b][i] for b in range(n_batch)]) | |||
| for i in range(n_layers) | |||
| ] | |||
| # batch decoding | |||
| ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) | |||
| logp, states = self.forward_one_step( | |||
| ys, ys_mask, xs, cache=batch_state) | |||
| # transpose state of [layer, batch] into [batch, layer] | |||
| state_list = [[states[i][b] for i in range(n_layers)] | |||
| for b in range(n_batch)] | |||
| return logp, state_list | |||
| class TransformerDecoder(BaseTransformerDecoder): | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| self_attention_dropout_rate: float = 0.0, | |||
| src_attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| dropout_rate=dropout_rate, | |||
| positional_dropout_rate=positional_dropout_rate, | |||
| input_layer=input_layer, | |||
| use_output_layer=use_output_layer, | |||
| pos_enc_class=pos_enc_class, | |||
| normalize_before=normalize_before, | |||
| ) | |||
| attention_dim = encoder_output_size | |||
| self.decoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| self_attention_dropout_rate), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| class ParaformerDecoder(TransformerDecoder): | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| self_attention_dropout_rate: float = 0.0, | |||
| src_attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| dropout_rate=dropout_rate, | |||
| positional_dropout_rate=positional_dropout_rate, | |||
| input_layer=input_layer, | |||
| use_output_layer=use_output_layer, | |||
| pos_enc_class=pos_enc_class, | |||
| normalize_before=normalize_before, | |||
| ) | |||
| attention_dim = encoder_output_size | |||
| self.decoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| self_attention_dropout_rate), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| def forward( | |||
| self, | |||
| hs_pad: torch.Tensor, | |||
| hlens: torch.Tensor, | |||
| ys_in_pad: torch.Tensor, | |||
| ys_in_lens: torch.Tensor, | |||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| """Forward decoder. | |||
| Args: | |||
| hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||
| hlens: (batch) | |||
| ys_in_pad: | |||
| input token ids, int64 (batch, maxlen_out) | |||
| if input_layer == "embed" | |||
| input tensor (batch, maxlen_out, #mels) in the other cases | |||
| ys_in_lens: (batch) | |||
| Returns: | |||
| (tuple): tuple containing: | |||
| x: decoded token score before softmax (batch, maxlen_out, token) | |||
| if use_output_layer is True, | |||
| olens: (batch, ) | |||
| """ | |||
| tgt = ys_in_pad | |||
| # tgt_mask: (B, 1, L) | |||
| tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||
| # m: (1, L, L) | |||
| # m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||
| # tgt_mask: (B, L, L) | |||
| # tgt_mask = tgt_mask & m | |||
| memory = hs_pad | |||
| memory_mask = ( | |||
| ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||
| memory.device) | |||
| # Padding for Longformer | |||
| if memory_mask.shape[-1] != memory.shape[1]: | |||
| padlen = memory.shape[1] - memory_mask.shape[-1] | |||
| memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||
| 'constant', False) | |||
| # x = self.embed(tgt) | |||
| x = tgt | |||
| x, tgt_mask, memory, memory_mask = self.decoders( | |||
| x, tgt_mask, memory, memory_mask) | |||
| if self.normalize_before: | |||
| x = self.after_norm(x) | |||
| if self.output_layer is not None: | |||
| x = self.output_layer(x) | |||
| olens = tgt_mask.sum(1) | |||
| return x, olens | |||
| class ParaformerDecoderBertEmbed(TransformerDecoder): | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| self_attention_dropout_rate: float = 0.0, | |||
| src_attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| embeds_id: int = 2, | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| dropout_rate=dropout_rate, | |||
| positional_dropout_rate=positional_dropout_rate, | |||
| input_layer=input_layer, | |||
| use_output_layer=use_output_layer, | |||
| pos_enc_class=pos_enc_class, | |||
| normalize_before=normalize_before, | |||
| ) | |||
| attention_dim = encoder_output_size | |||
| self.decoders = repeat( | |||
| embeds_id, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| self_attention_dropout_rate), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| if embeds_id == num_blocks: | |||
| self.decoders2 = None | |||
| else: | |||
| self.decoders2 = repeat( | |||
| num_blocks - embeds_id, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| self_attention_dropout_rate), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| def forward( | |||
| self, | |||
| hs_pad: torch.Tensor, | |||
| hlens: torch.Tensor, | |||
| ys_in_pad: torch.Tensor, | |||
| ys_in_lens: torch.Tensor, | |||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| """Forward decoder. | |||
| Args: | |||
| hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | |||
| hlens: (batch) | |||
| ys_in_pad: | |||
| input token ids, int64 (batch, maxlen_out) | |||
| if input_layer == "embed" | |||
| input tensor (batch, maxlen_out, #mels) in the other cases | |||
| ys_in_lens: (batch) | |||
| Returns: | |||
| (tuple): tuple containing: | |||
| x: decoded token score before softmax (batch, maxlen_out, token) | |||
| if use_output_layer is True, | |||
| olens: (batch, ) | |||
| """ | |||
| tgt = ys_in_pad | |||
| # tgt_mask: (B, 1, L) | |||
| tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) | |||
| # m: (1, L, L) | |||
| # m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) | |||
| # tgt_mask: (B, L, L) | |||
| # tgt_mask = tgt_mask & m | |||
| memory = hs_pad | |||
| memory_mask = ( | |||
| ~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | |||
| memory.device) | |||
| # Padding for Longformer | |||
| if memory_mask.shape[-1] != memory.shape[1]: | |||
| padlen = memory.shape[1] - memory_mask.shape[-1] | |||
| memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), | |||
| 'constant', False) | |||
| # x = self.embed(tgt) | |||
| x = tgt | |||
| x, tgt_mask, memory, memory_mask = self.decoders( | |||
| x, tgt_mask, memory, memory_mask) | |||
| embeds_outputs = x | |||
| if self.decoders2 is not None: | |||
| x, tgt_mask, memory, memory_mask = self.decoders2( | |||
| x, tgt_mask, memory, memory_mask) | |||
| if self.normalize_before: | |||
| x = self.after_norm(x) | |||
| if self.output_layer is not None: | |||
| x = self.output_layer(x) | |||
| olens = tgt_mask.sum(1) | |||
| return x, olens, embeds_outputs | |||
| class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| self_attention_dropout_rate: float = 0.0, | |||
| src_attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| conv_wshare: int = 4, | |||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
| conv_usebias: int = False, | |||
| ): | |||
| assert check_argument_types() | |||
| if len(conv_kernel_length) != num_blocks: | |||
| raise ValueError( | |||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||
| super().__init__( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| dropout_rate=dropout_rate, | |||
| positional_dropout_rate=positional_dropout_rate, | |||
| input_layer=input_layer, | |||
| use_output_layer=use_output_layer, | |||
| pos_enc_class=pos_enc_class, | |||
| normalize_before=normalize_before, | |||
| ) | |||
| attention_dim = encoder_output_size | |||
| self.decoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| LightweightConvolution( | |||
| wshare=conv_wshare, | |||
| n_feat=attention_dim, | |||
| dropout_rate=self_attention_dropout_rate, | |||
| kernel_size=conv_kernel_length[lnum], | |||
| use_kernel_mask=True, | |||
| use_bias=conv_usebias, | |||
| ), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| self_attention_dropout_rate: float = 0.0, | |||
| src_attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| conv_wshare: int = 4, | |||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
| conv_usebias: int = False, | |||
| ): | |||
| assert check_argument_types() | |||
| if len(conv_kernel_length) != num_blocks: | |||
| raise ValueError( | |||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||
| super().__init__( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| dropout_rate=dropout_rate, | |||
| positional_dropout_rate=positional_dropout_rate, | |||
| input_layer=input_layer, | |||
| use_output_layer=use_output_layer, | |||
| pos_enc_class=pos_enc_class, | |||
| normalize_before=normalize_before, | |||
| ) | |||
| attention_dim = encoder_output_size | |||
| self.decoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| LightweightConvolution2D( | |||
| wshare=conv_wshare, | |||
| n_feat=attention_dim, | |||
| dropout_rate=self_attention_dropout_rate, | |||
| kernel_size=conv_kernel_length[lnum], | |||
| use_kernel_mask=True, | |||
| use_bias=conv_usebias, | |||
| ), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| self_attention_dropout_rate: float = 0.0, | |||
| src_attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| conv_wshare: int = 4, | |||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
| conv_usebias: int = False, | |||
| ): | |||
| assert check_argument_types() | |||
| if len(conv_kernel_length) != num_blocks: | |||
| raise ValueError( | |||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||
| super().__init__( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| dropout_rate=dropout_rate, | |||
| positional_dropout_rate=positional_dropout_rate, | |||
| input_layer=input_layer, | |||
| use_output_layer=use_output_layer, | |||
| pos_enc_class=pos_enc_class, | |||
| normalize_before=normalize_before, | |||
| ) | |||
| attention_dim = encoder_output_size | |||
| self.decoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| DynamicConvolution( | |||
| wshare=conv_wshare, | |||
| n_feat=attention_dim, | |||
| dropout_rate=self_attention_dropout_rate, | |||
| kernel_size=conv_kernel_length[lnum], | |||
| use_kernel_mask=True, | |||
| use_bias=conv_usebias, | |||
| ), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): | |||
| def __init__( | |||
| self, | |||
| vocab_size: int, | |||
| encoder_output_size: int, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| self_attention_dropout_rate: float = 0.0, | |||
| src_attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'embed', | |||
| use_output_layer: bool = True, | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| conv_wshare: int = 4, | |||
| conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), | |||
| conv_usebias: int = False, | |||
| ): | |||
| assert check_argument_types() | |||
| if len(conv_kernel_length) != num_blocks: | |||
| raise ValueError( | |||
| 'conv_kernel_length must have equal number of values to num_blocks: ' | |||
| f'{len(conv_kernel_length)} != {num_blocks}') | |||
| super().__init__( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| dropout_rate=dropout_rate, | |||
| positional_dropout_rate=positional_dropout_rate, | |||
| input_layer=input_layer, | |||
| use_output_layer=use_output_layer, | |||
| pos_enc_class=pos_enc_class, | |||
| normalize_before=normalize_before, | |||
| ) | |||
| attention_dim = encoder_output_size | |||
| self.decoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: DecoderLayer( | |||
| attention_dim, | |||
| DynamicConvolution2D( | |||
| wshare=conv_wshare, | |||
| n_feat=attention_dim, | |||
| dropout_rate=self_attention_dropout_rate, | |||
| kernel_size=conv_kernel_length[lnum], | |||
| use_kernel_mask=True, | |||
| use_bias=conv_usebias, | |||
| ), | |||
| MultiHeadedAttention(attention_heads, attention_dim, | |||
| src_attention_dropout_rate), | |||
| PositionwiseFeedForward(attention_dim, linear_units, | |||
| dropout_rate), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| @@ -1,710 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| """Conformer encoder definition.""" | |||
| import logging | |||
| from typing import List, Optional, Tuple, Union | |||
| import torch | |||
| from espnet2.asr.ctc import CTC | |||
| from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||
| from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule | |||
| from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer | |||
| from espnet.nets.pytorch_backend.nets_utils import (get_activation, | |||
| make_pad_mask) | |||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
| LegacyRelPositionalEncoding # noqa: H301 | |||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
| PositionalEncoding # noqa: H301 | |||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
| RelPositionalEncoding # noqa: H301 | |||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
| ScaledPositionalEncoding # noqa: H301 | |||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
| from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||
| Conv1dLinear, MultiLayeredConv1d) | |||
| from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||
| PositionwiseFeedForward # noqa: H301 | |||
| from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||
| from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||
| Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||
| Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||
| from typeguard import check_argument_types | |||
| from ...nets.pytorch_backend.transformer.attention import \ | |||
| LegacyRelPositionMultiHeadedAttention # noqa: H301 | |||
| from ...nets.pytorch_backend.transformer.attention import \ | |||
| MultiHeadedAttention # noqa: H301 | |||
| from ...nets.pytorch_backend.transformer.attention import \ | |||
| RelPositionMultiHeadedAttention # noqa: H301 | |||
| from ...nets.pytorch_backend.transformer.attention import ( | |||
| LegacyRelPositionMultiHeadedAttentionSANM, | |||
| RelPositionMultiHeadedAttentionSANM) | |||
| class ConformerEncoder(AbsEncoder): | |||
| """Conformer encoder module. | |||
| Args: | |||
| input_size (int): Input dimension. | |||
| output_size (int): Dimension of attention. | |||
| attention_heads (int): The number of heads of multi head attention. | |||
| linear_units (int): The number of units of position-wise feed forward. | |||
| num_blocks (int): The number of decoder blocks. | |||
| dropout_rate (float): Dropout rate. | |||
| attention_dropout_rate (float): Dropout rate in attention. | |||
| positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||
| input_layer (Union[str, torch.nn.Module]): Input layer type. | |||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||
| concat_after (bool): Whether to concat attention layer's input and output. | |||
| If True, additional linear will be applied. | |||
| i.e. x -> x + linear(concat(x, att(x))) | |||
| If False, no additional linear will be applied. i.e. x -> x + att(x) | |||
| positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||
| positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||
| rel_pos_type (str): Whether to use the latest relative positional encoding or | |||
| the legacy one. The legacy relative positional encoding will be deprecated | |||
| in the future. More Details can be found in | |||
| https://github.com/espnet/espnet/pull/2816. | |||
| encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||
| encoder_attn_layer_type (str): Encoder attention layer type. | |||
| activation_type (str): Encoder activation function type. | |||
| macaron_style (bool): Whether to use macaron style for positionwise layer. | |||
| use_cnn_module (bool): Whether to use convolution module. | |||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
| cnn_module_kernel (int): Kernerl size of convolution module. | |||
| padding_idx (int): Padding idx for input_layer=embed. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| input_size: int, | |||
| output_size: int = 256, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'conv2d', | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| positionwise_layer_type: str = 'linear', | |||
| positionwise_conv_kernel_size: int = 3, | |||
| macaron_style: bool = False, | |||
| rel_pos_type: str = 'legacy', | |||
| pos_enc_layer_type: str = 'rel_pos', | |||
| selfattention_layer_type: str = 'rel_selfattn', | |||
| activation_type: str = 'swish', | |||
| use_cnn_module: bool = True, | |||
| zero_triu: bool = False, | |||
| cnn_module_kernel: int = 31, | |||
| padding_idx: int = -1, | |||
| interctc_layer_idx: List[int] = [], | |||
| interctc_use_conditioning: bool = False, | |||
| stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__() | |||
| self._output_size = output_size | |||
| if rel_pos_type == 'legacy': | |||
| if pos_enc_layer_type == 'rel_pos': | |||
| pos_enc_layer_type = 'legacy_rel_pos' | |||
| if selfattention_layer_type == 'rel_selfattn': | |||
| selfattention_layer_type = 'legacy_rel_selfattn' | |||
| elif rel_pos_type == 'latest': | |||
| assert selfattention_layer_type != 'legacy_rel_selfattn' | |||
| assert pos_enc_layer_type != 'legacy_rel_pos' | |||
| else: | |||
| raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||
| activation = get_activation(activation_type) | |||
| if pos_enc_layer_type == 'abs_pos': | |||
| pos_enc_class = PositionalEncoding | |||
| elif pos_enc_layer_type == 'scaled_abs_pos': | |||
| pos_enc_class = ScaledPositionalEncoding | |||
| elif pos_enc_layer_type == 'rel_pos': | |||
| assert selfattention_layer_type == 'rel_selfattn' | |||
| pos_enc_class = RelPositionalEncoding | |||
| elif pos_enc_layer_type == 'legacy_rel_pos': | |||
| assert selfattention_layer_type == 'legacy_rel_selfattn' | |||
| pos_enc_class = LegacyRelPositionalEncoding | |||
| else: | |||
| raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||
| if input_layer == 'linear': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Linear(input_size, output_size), | |||
| torch.nn.LayerNorm(output_size), | |||
| torch.nn.Dropout(dropout_rate), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d': | |||
| self.embed = Conv2dSubsampling( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d2': | |||
| self.embed = Conv2dSubsampling2( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d6': | |||
| self.embed = Conv2dSubsampling6( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d8': | |||
| self.embed = Conv2dSubsampling8( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'embed': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Embedding( | |||
| input_size, output_size, padding_idx=padding_idx), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif isinstance(input_layer, torch.nn.Module): | |||
| self.embed = torch.nn.Sequential( | |||
| input_layer, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer is None: | |||
| self.embed = torch.nn.Sequential( | |||
| pos_enc_class(output_size, positional_dropout_rate)) | |||
| else: | |||
| raise ValueError('unknown input_layer: ' + input_layer) | |||
| self.normalize_before = normalize_before | |||
| if positionwise_layer_type == 'linear': | |||
| positionwise_layer = PositionwiseFeedForward | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| dropout_rate, | |||
| activation, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d': | |||
| positionwise_layer = MultiLayeredConv1d | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d-linear': | |||
| positionwise_layer = Conv1dLinear | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| else: | |||
| raise NotImplementedError('Support only linear or conv1d.') | |||
| if selfattention_layer_type == 'selfattn': | |||
| encoder_selfattn_layer = MultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| ) | |||
| elif selfattention_layer_type == 'legacy_rel_selfattn': | |||
| assert pos_enc_layer_type == 'legacy_rel_pos' | |||
| encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| ) | |||
| elif selfattention_layer_type == 'rel_selfattn': | |||
| assert pos_enc_layer_type == 'rel_pos' | |||
| encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| zero_triu, | |||
| ) | |||
| else: | |||
| raise ValueError('unknown encoder_attn_layer: ' | |||
| + selfattention_layer_type) | |||
| convolution_layer = ConvolutionModule | |||
| convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||
| if isinstance(stochastic_depth_rate, float): | |||
| stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||
| if len(stochastic_depth_rate) != num_blocks: | |||
| raise ValueError( | |||
| f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||
| f'should be equal to num_blocks ({num_blocks})') | |||
| self.encoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: EncoderLayer( | |||
| output_size, | |||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
| positionwise_layer(*positionwise_layer_args), | |||
| positionwise_layer(*positionwise_layer_args) | |||
| if macaron_style else None, | |||
| convolution_layer(*convolution_layer_args) | |||
| if use_cnn_module else None, | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| stochastic_depth_rate[lnum], | |||
| ), | |||
| ) | |||
| if self.normalize_before: | |||
| self.after_norm = LayerNorm(output_size) | |||
| self.interctc_layer_idx = interctc_layer_idx | |||
| if len(interctc_layer_idx) > 0: | |||
| assert 0 < min(interctc_layer_idx) and max( | |||
| interctc_layer_idx) < num_blocks | |||
| self.interctc_use_conditioning = interctc_use_conditioning | |||
| self.conditioning_layer = None | |||
| def output_size(self) -> int: | |||
| return self._output_size | |||
| def forward( | |||
| self, | |||
| xs_pad: torch.Tensor, | |||
| ilens: torch.Tensor, | |||
| prev_states: torch.Tensor = None, | |||
| ctc: CTC = None, | |||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
| """Calculate forward propagation. | |||
| Args: | |||
| xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||
| ilens (torch.Tensor): Input length (#batch). | |||
| prev_states (torch.Tensor): Not to be used now. | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, L, output_size). | |||
| torch.Tensor: Output length (#batch). | |||
| torch.Tensor: Not to be used now. | |||
| """ | |||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
| if (isinstance(self.embed, Conv2dSubsampling) | |||
| or isinstance(self.embed, Conv2dSubsampling2) | |||
| or isinstance(self.embed, Conv2dSubsampling6) | |||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||
| short_status, limit_size = check_short_utt(self.embed, | |||
| xs_pad.size(1)) | |||
| if short_status: | |||
| raise TooShortUttError( | |||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
| + # noqa: * | |||
| f'(it needs more than {limit_size} frames), return empty results', # noqa: * | |||
| xs_pad.size(1), | |||
| limit_size) # noqa: * | |||
| xs_pad, masks = self.embed(xs_pad, masks) | |||
| else: | |||
| xs_pad = self.embed(xs_pad) | |||
| intermediate_outs = [] | |||
| if len(self.interctc_layer_idx) == 0: | |||
| xs_pad, masks = self.encoders(xs_pad, masks) | |||
| else: | |||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||
| xs_pad, masks = encoder_layer(xs_pad, masks) | |||
| if layer_idx + 1 in self.interctc_layer_idx: | |||
| encoder_out = xs_pad | |||
| if isinstance(encoder_out, tuple): | |||
| encoder_out = encoder_out[0] | |||
| # intermediate outputs are also normalized | |||
| if self.normalize_before: | |||
| encoder_out = self.after_norm(encoder_out) | |||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
| if self.interctc_use_conditioning: | |||
| ctc_out = ctc.softmax(encoder_out) | |||
| if isinstance(xs_pad, tuple): | |||
| x, pos_emb = xs_pad | |||
| x = x + self.conditioning_layer(ctc_out) | |||
| xs_pad = (x, pos_emb) | |||
| else: | |||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
| if isinstance(xs_pad, tuple): | |||
| xs_pad = xs_pad[0] | |||
| if self.normalize_before: | |||
| xs_pad = self.after_norm(xs_pad) | |||
| olens = masks.squeeze(1).sum(1) | |||
| if len(intermediate_outs) > 0: | |||
| return (xs_pad, intermediate_outs), olens, None | |||
| return xs_pad, olens, None | |||
| class SANMEncoder_v2(AbsEncoder): | |||
| """Conformer encoder module. | |||
| Args: | |||
| input_size (int): Input dimension. | |||
| output_size (int): Dimension of attention. | |||
| attention_heads (int): The number of heads of multi head attention. | |||
| linear_units (int): The number of units of position-wise feed forward. | |||
| num_blocks (int): The number of decoder blocks. | |||
| dropout_rate (float): Dropout rate. | |||
| attention_dropout_rate (float): Dropout rate in attention. | |||
| positional_dropout_rate (float): Dropout rate after adding positional encoding. | |||
| input_layer (Union[str, torch.nn.Module]): Input layer type. | |||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||
| concat_after (bool): Whether to concat attention layer's input and output. | |||
| If True, additional linear will be applied. | |||
| i.e. x -> x + linear(concat(x, att(x))) | |||
| If False, no additional linear will be applied. i.e. x -> x + att(x) | |||
| positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". | |||
| positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. | |||
| rel_pos_type (str): Whether to use the latest relative positional encoding or | |||
| the legacy one. The legacy relative positional encoding will be deprecated | |||
| in the future. More Details can be found in | |||
| https://github.com/espnet/espnet/pull/2816. | |||
| encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. | |||
| encoder_attn_layer_type (str): Encoder attention layer type. | |||
| activation_type (str): Encoder activation function type. | |||
| macaron_style (bool): Whether to use macaron style for positionwise layer. | |||
| use_cnn_module (bool): Whether to use convolution module. | |||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
| cnn_module_kernel (int): Kernerl size of convolution module. | |||
| padding_idx (int): Padding idx for input_layer=embed. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| input_size: int, | |||
| output_size: int = 256, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| attention_dropout_rate: float = 0.0, | |||
| input_layer: str = 'conv2d', | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| positionwise_layer_type: str = 'linear', | |||
| positionwise_conv_kernel_size: int = 3, | |||
| macaron_style: bool = False, | |||
| rel_pos_type: str = 'legacy', | |||
| pos_enc_layer_type: str = 'rel_pos', | |||
| selfattention_layer_type: str = 'rel_selfattn', | |||
| activation_type: str = 'swish', | |||
| use_cnn_module: bool = False, | |||
| sanm_shfit: int = 0, | |||
| zero_triu: bool = False, | |||
| cnn_module_kernel: int = 31, | |||
| padding_idx: int = -1, | |||
| interctc_layer_idx: List[int] = [], | |||
| interctc_use_conditioning: bool = False, | |||
| stochastic_depth_rate: Union[float, List[float]] = 0.0, | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__() | |||
| self._output_size = output_size | |||
| if rel_pos_type == 'legacy': | |||
| if pos_enc_layer_type == 'rel_pos': | |||
| pos_enc_layer_type = 'legacy_rel_pos' | |||
| if selfattention_layer_type == 'rel_selfattn': | |||
| selfattention_layer_type = 'legacy_rel_selfattn' | |||
| if selfattention_layer_type == 'rel_selfattnsanm': | |||
| selfattention_layer_type = 'legacy_rel_selfattnsanm' | |||
| elif rel_pos_type == 'latest': | |||
| assert selfattention_layer_type != 'legacy_rel_selfattn' | |||
| assert pos_enc_layer_type != 'legacy_rel_pos' | |||
| else: | |||
| raise ValueError('unknown rel_pos_type: ' + rel_pos_type) | |||
| activation = get_activation(activation_type) | |||
| if pos_enc_layer_type == 'abs_pos': | |||
| pos_enc_class = PositionalEncoding | |||
| elif pos_enc_layer_type == 'scaled_abs_pos': | |||
| pos_enc_class = ScaledPositionalEncoding | |||
| elif pos_enc_layer_type == 'rel_pos': | |||
| # assert selfattention_layer_type == "rel_selfattn" | |||
| pos_enc_class = RelPositionalEncoding | |||
| elif pos_enc_layer_type == 'legacy_rel_pos': | |||
| # assert selfattention_layer_type == "legacy_rel_selfattn" | |||
| pos_enc_class = LegacyRelPositionalEncoding | |||
| logging.warning( | |||
| 'Using legacy_rel_pos and it will be deprecated in the future.' | |||
| ) | |||
| else: | |||
| raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type) | |||
| if input_layer == 'linear': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Linear(input_size, output_size), | |||
| torch.nn.LayerNorm(output_size), | |||
| torch.nn.Dropout(dropout_rate), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d': | |||
| self.embed = Conv2dSubsampling( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d2': | |||
| self.embed = Conv2dSubsampling2( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d6': | |||
| self.embed = Conv2dSubsampling6( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d8': | |||
| self.embed = Conv2dSubsampling8( | |||
| input_size, | |||
| output_size, | |||
| dropout_rate, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'embed': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Embedding( | |||
| input_size, output_size, padding_idx=padding_idx), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif isinstance(input_layer, torch.nn.Module): | |||
| self.embed = torch.nn.Sequential( | |||
| input_layer, | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer is None: | |||
| self.embed = torch.nn.Sequential( | |||
| pos_enc_class(output_size, positional_dropout_rate)) | |||
| else: | |||
| raise ValueError('unknown input_layer: ' + input_layer) | |||
| self.normalize_before = normalize_before | |||
| if positionwise_layer_type == 'linear': | |||
| positionwise_layer = PositionwiseFeedForward | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| dropout_rate, | |||
| activation, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d': | |||
| positionwise_layer = MultiLayeredConv1d | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d-linear': | |||
| positionwise_layer = Conv1dLinear | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| else: | |||
| raise NotImplementedError('Support only linear or conv1d.') | |||
| if selfattention_layer_type == 'selfattn': | |||
| encoder_selfattn_layer = MultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| ) | |||
| elif selfattention_layer_type == 'legacy_rel_selfattn': | |||
| assert pos_enc_layer_type == 'legacy_rel_pos' | |||
| encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| ) | |||
| logging.warning( | |||
| 'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||
| ) | |||
| elif selfattention_layer_type == 'legacy_rel_selfattnsanm': | |||
| assert pos_enc_layer_type == 'legacy_rel_pos' | |||
| encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttentionSANM | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| ) | |||
| logging.warning( | |||
| 'Using legacy_rel_selfattn and it will be deprecated in the future.' | |||
| ) | |||
| elif selfattention_layer_type == 'rel_selfattn': | |||
| assert pos_enc_layer_type == 'rel_pos' | |||
| encoder_selfattn_layer = RelPositionMultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| zero_triu, | |||
| ) | |||
| elif selfattention_layer_type == 'rel_selfattnsanm': | |||
| assert pos_enc_layer_type == 'rel_pos' | |||
| encoder_selfattn_layer = RelPositionMultiHeadedAttentionSANM | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| zero_triu, | |||
| cnn_module_kernel, | |||
| sanm_shfit, | |||
| ) | |||
| else: | |||
| raise ValueError('unknown encoder_attn_layer: ' | |||
| + selfattention_layer_type) | |||
| convolution_layer = ConvolutionModule | |||
| convolution_layer_args = (output_size, cnn_module_kernel, activation) | |||
| if isinstance(stochastic_depth_rate, float): | |||
| stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |||
| if len(stochastic_depth_rate) != num_blocks: | |||
| raise ValueError( | |||
| f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) ' | |||
| f'should be equal to num_blocks ({num_blocks})') | |||
| self.encoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: EncoderLayer( | |||
| output_size, | |||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
| positionwise_layer(*positionwise_layer_args), | |||
| positionwise_layer(*positionwise_layer_args) | |||
| if macaron_style else None, | |||
| convolution_layer(*convolution_layer_args) | |||
| if use_cnn_module else None, | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| stochastic_depth_rate[lnum], | |||
| ), | |||
| ) | |||
| if self.normalize_before: | |||
| self.after_norm = LayerNorm(output_size) | |||
| self.interctc_layer_idx = interctc_layer_idx | |||
| if len(interctc_layer_idx) > 0: | |||
| assert 0 < min(interctc_layer_idx) and max( | |||
| interctc_layer_idx) < num_blocks | |||
| self.interctc_use_conditioning = interctc_use_conditioning | |||
| self.conditioning_layer = None | |||
| def output_size(self) -> int: | |||
| return self._output_size | |||
| def forward( | |||
| self, | |||
| xs_pad: torch.Tensor, | |||
| ilens: torch.Tensor, | |||
| prev_states: torch.Tensor = None, | |||
| ctc: CTC = None, | |||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
| """Calculate forward propagation. | |||
| Args: | |||
| xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). | |||
| ilens (torch.Tensor): Input length (#batch). | |||
| prev_states (torch.Tensor): Not to be used now. | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, L, output_size). | |||
| torch.Tensor: Output length (#batch). | |||
| torch.Tensor: Not to be used now. | |||
| """ | |||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
| if (isinstance(self.embed, Conv2dSubsampling) | |||
| or isinstance(self.embed, Conv2dSubsampling2) | |||
| or isinstance(self.embed, Conv2dSubsampling6) | |||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||
| short_status, limit_size = check_short_utt(self.embed, | |||
| xs_pad.size(1)) | |||
| if short_status: | |||
| raise TooShortUttError( | |||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
| + # noqa: * | |||
| f'(it needs more than {limit_size} frames), return empty results', | |||
| xs_pad.size(1), | |||
| limit_size) # noqa: * | |||
| xs_pad, masks = self.embed(xs_pad, masks) | |||
| else: | |||
| xs_pad = self.embed(xs_pad) | |||
| intermediate_outs = [] | |||
| if len(self.interctc_layer_idx) == 0: | |||
| xs_pad, masks = self.encoders(xs_pad, masks) | |||
| else: | |||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||
| xs_pad, masks = encoder_layer(xs_pad, masks) | |||
| if layer_idx + 1 in self.interctc_layer_idx: | |||
| encoder_out = xs_pad | |||
| if isinstance(encoder_out, tuple): | |||
| encoder_out = encoder_out[0] | |||
| # intermediate outputs are also normalized | |||
| if self.normalize_before: | |||
| encoder_out = self.after_norm(encoder_out) | |||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
| if self.interctc_use_conditioning: | |||
| ctc_out = ctc.softmax(encoder_out) | |||
| if isinstance(xs_pad, tuple): | |||
| x, pos_emb = xs_pad | |||
| x = x + self.conditioning_layer(ctc_out) | |||
| xs_pad = (x, pos_emb) | |||
| else: | |||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
| if isinstance(xs_pad, tuple): | |||
| xs_pad = xs_pad[0] | |||
| if self.normalize_before: | |||
| xs_pad = self.after_norm(xs_pad) | |||
| olens = masks.squeeze(1).sum(1) | |||
| if len(intermediate_outs) > 0: | |||
| return (xs_pad, intermediate_outs), olens, None | |||
| return xs_pad, olens, None | |||
| @@ -1,500 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| """Transformer encoder definition.""" | |||
| import logging | |||
| from typing import List, Optional, Sequence, Tuple, Union | |||
| import torch | |||
| from espnet2.asr.ctc import CTC | |||
| from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
| from espnet.nets.pytorch_backend.transformer.embedding import \ | |||
| PositionalEncoding | |||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
| from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( | |||
| Conv1dLinear, MultiLayeredConv1d) | |||
| from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \ | |||
| PositionwiseFeedForward # noqa: H301 | |||
| from espnet.nets.pytorch_backend.transformer.repeat import repeat | |||
| from espnet.nets.pytorch_backend.transformer.subsampling import ( | |||
| Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, | |||
| Conv2dSubsampling8, TooShortUttError, check_short_utt) | |||
| from typeguard import check_argument_types | |||
| from ...asr.streaming_utilis.chunk_utilis import overlap_chunk | |||
| from ...nets.pytorch_backend.transformer.attention import ( | |||
| MultiHeadedAttention, MultiHeadedAttentionSANM) | |||
| from ...nets.pytorch_backend.transformer.encoder_layer import ( | |||
| EncoderLayer, EncoderLayerChunk) | |||
| class SANMEncoder(AbsEncoder): | |||
| """Transformer encoder module. | |||
| Args: | |||
| input_size: input dim | |||
| output_size: dimension of attention | |||
| attention_heads: the number of heads of multi head attention | |||
| linear_units: the number of units of position-wise feed forward | |||
| num_blocks: the number of decoder blocks | |||
| dropout_rate: dropout rate | |||
| attention_dropout_rate: dropout rate in attention | |||
| positional_dropout_rate: dropout rate after adding positional encoding | |||
| input_layer: input layer type | |||
| pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||
| normalize_before: whether to use layer_norm before the first block | |||
| concat_after: whether to concat attention layer's input and output | |||
| if True, additional linear will be applied. | |||
| i.e. x -> x + linear(concat(x, att(x))) | |||
| if False, no additional linear will be applied. | |||
| i.e. x -> x + att(x) | |||
| positionwise_layer_type: linear of conv1d | |||
| positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||
| padding_idx: padding_idx for input_layer=embed | |||
| """ | |||
| def __init__( | |||
| self, | |||
| input_size: int, | |||
| output_size: int = 256, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| attention_dropout_rate: float = 0.0, | |||
| input_layer: Optional[str] = 'conv2d', | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| positionwise_layer_type: str = 'linear', | |||
| positionwise_conv_kernel_size: int = 1, | |||
| padding_idx: int = -1, | |||
| interctc_layer_idx: List[int] = [], | |||
| interctc_use_conditioning: bool = False, | |||
| kernel_size: int = 11, | |||
| sanm_shfit: int = 0, | |||
| selfattention_layer_type: str = 'sanm', | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__() | |||
| self._output_size = output_size | |||
| if input_layer == 'linear': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Linear(input_size, output_size), | |||
| torch.nn.LayerNorm(output_size), | |||
| torch.nn.Dropout(dropout_rate), | |||
| torch.nn.ReLU(), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d': | |||
| self.embed = Conv2dSubsampling(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'conv2d2': | |||
| self.embed = Conv2dSubsampling2(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'conv2d6': | |||
| self.embed = Conv2dSubsampling6(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'conv2d8': | |||
| self.embed = Conv2dSubsampling8(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'embed': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Embedding( | |||
| input_size, output_size, padding_idx=padding_idx), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer is None: | |||
| if input_size == output_size: | |||
| self.embed = None | |||
| else: | |||
| self.embed = torch.nn.Linear(input_size, output_size) | |||
| else: | |||
| raise ValueError('unknown input_layer: ' + input_layer) | |||
| self.normalize_before = normalize_before | |||
| if positionwise_layer_type == 'linear': | |||
| positionwise_layer = PositionwiseFeedForward | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| dropout_rate, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d': | |||
| positionwise_layer = MultiLayeredConv1d | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d-linear': | |||
| positionwise_layer = Conv1dLinear | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| else: | |||
| raise NotImplementedError('Support only linear or conv1d.') | |||
| if selfattention_layer_type == 'selfattn': | |||
| encoder_selfattn_layer = MultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| ) | |||
| elif selfattention_layer_type == 'sanm': | |||
| encoder_selfattn_layer = MultiHeadedAttentionSANM | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| kernel_size, | |||
| sanm_shfit, | |||
| ) | |||
| self.encoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: EncoderLayer( | |||
| output_size, | |||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
| positionwise_layer(*positionwise_layer_args), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| if self.normalize_before: | |||
| self.after_norm = LayerNorm(output_size) | |||
| self.interctc_layer_idx = interctc_layer_idx | |||
| if len(interctc_layer_idx) > 0: | |||
| assert 0 < min(interctc_layer_idx) and max( | |||
| interctc_layer_idx) < num_blocks | |||
| self.interctc_use_conditioning = interctc_use_conditioning | |||
| self.conditioning_layer = None | |||
| def output_size(self) -> int: | |||
| return self._output_size | |||
| def forward( | |||
| self, | |||
| xs_pad: torch.Tensor, | |||
| ilens: torch.Tensor, | |||
| prev_states: torch.Tensor = None, | |||
| ctc: CTC = None, | |||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
| """Embed positions in tensor. | |||
| Args: | |||
| xs_pad: input tensor (B, L, D) | |||
| ilens: input length (B) | |||
| prev_states: Not to be used now. | |||
| Returns: | |||
| position embedded tensor and mask | |||
| """ | |||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
| if self.embed is None: | |||
| xs_pad = xs_pad | |||
| elif (isinstance(self.embed, Conv2dSubsampling) | |||
| or isinstance(self.embed, Conv2dSubsampling2) | |||
| or isinstance(self.embed, Conv2dSubsampling6) | |||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||
| short_status, limit_size = check_short_utt(self.embed, | |||
| xs_pad.size(1)) | |||
| if short_status: | |||
| raise TooShortUttError( | |||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
| + # noqa: * | |||
| f'(it needs more than {limit_size} frames), return empty results', | |||
| xs_pad.size(1), | |||
| limit_size, | |||
| ) | |||
| xs_pad, masks = self.embed(xs_pad, masks) | |||
| else: | |||
| xs_pad = self.embed(xs_pad) | |||
| intermediate_outs = [] | |||
| if len(self.interctc_layer_idx) == 0: | |||
| xs_pad, masks = self.encoders(xs_pad, masks) | |||
| else: | |||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||
| xs_pad, masks = encoder_layer(xs_pad, masks) | |||
| if layer_idx + 1 in self.interctc_layer_idx: | |||
| encoder_out = xs_pad | |||
| # intermediate outputs are also normalized | |||
| if self.normalize_before: | |||
| encoder_out = self.after_norm(encoder_out) | |||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
| if self.interctc_use_conditioning: | |||
| ctc_out = ctc.softmax(encoder_out) | |||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
| if self.normalize_before: | |||
| xs_pad = self.after_norm(xs_pad) | |||
| olens = masks.squeeze(1).sum(1) | |||
| if len(intermediate_outs) > 0: | |||
| return (xs_pad, intermediate_outs), olens, None | |||
| return xs_pad, olens, None | |||
| class SANMEncoderChunk(AbsEncoder): | |||
| """Transformer encoder module. | |||
| Args: | |||
| input_size: input dim | |||
| output_size: dimension of attention | |||
| attention_heads: the number of heads of multi head attention | |||
| linear_units: the number of units of position-wise feed forward | |||
| num_blocks: the number of decoder blocks | |||
| dropout_rate: dropout rate | |||
| attention_dropout_rate: dropout rate in attention | |||
| positional_dropout_rate: dropout rate after adding positional encoding | |||
| input_layer: input layer type | |||
| pos_enc_class: PositionalEncoding or ScaledPositionalEncoding | |||
| normalize_before: whether to use layer_norm before the first block | |||
| concat_after: whether to concat attention layer's input and output | |||
| if True, additional linear will be applied. | |||
| i.e. x -> x + linear(concat(x, att(x))) | |||
| if False, no additional linear will be applied. | |||
| i.e. x -> x + att(x) | |||
| positionwise_layer_type: linear of conv1d | |||
| positionwise_conv_kernel_size: kernel size of positionwise conv1d layer | |||
| padding_idx: padding_idx for input_layer=embed | |||
| """ | |||
| def __init__( | |||
| self, | |||
| input_size: int, | |||
| output_size: int = 256, | |||
| attention_heads: int = 4, | |||
| linear_units: int = 2048, | |||
| num_blocks: int = 6, | |||
| dropout_rate: float = 0.1, | |||
| positional_dropout_rate: float = 0.1, | |||
| attention_dropout_rate: float = 0.0, | |||
| input_layer: Optional[str] = 'conv2d', | |||
| pos_enc_class=PositionalEncoding, | |||
| normalize_before: bool = True, | |||
| concat_after: bool = False, | |||
| positionwise_layer_type: str = 'linear', | |||
| positionwise_conv_kernel_size: int = 1, | |||
| padding_idx: int = -1, | |||
| interctc_layer_idx: List[int] = [], | |||
| interctc_use_conditioning: bool = False, | |||
| kernel_size: int = 11, | |||
| sanm_shfit: int = 0, | |||
| selfattention_layer_type: str = 'sanm', | |||
| chunk_size: Union[int, Sequence[int]] = (16, ), | |||
| stride: Union[int, Sequence[int]] = (10, ), | |||
| pad_left: Union[int, Sequence[int]] = (0, ), | |||
| encoder_att_look_back_factor: Union[int, Sequence[int]] = (1, ), | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__() | |||
| self._output_size = output_size | |||
| if input_layer == 'linear': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Linear(input_size, output_size), | |||
| torch.nn.LayerNorm(output_size), | |||
| torch.nn.Dropout(dropout_rate), | |||
| torch.nn.ReLU(), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer == 'conv2d': | |||
| self.embed = Conv2dSubsampling(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'conv2d2': | |||
| self.embed = Conv2dSubsampling2(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'conv2d6': | |||
| self.embed = Conv2dSubsampling6(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'conv2d8': | |||
| self.embed = Conv2dSubsampling8(input_size, output_size, | |||
| dropout_rate) | |||
| elif input_layer == 'embed': | |||
| self.embed = torch.nn.Sequential( | |||
| torch.nn.Embedding( | |||
| input_size, output_size, padding_idx=padding_idx), | |||
| pos_enc_class(output_size, positional_dropout_rate), | |||
| ) | |||
| elif input_layer is None: | |||
| if input_size == output_size: | |||
| self.embed = None | |||
| else: | |||
| self.embed = torch.nn.Linear(input_size, output_size) | |||
| else: | |||
| raise ValueError('unknown input_layer: ' + input_layer) | |||
| self.normalize_before = normalize_before | |||
| if positionwise_layer_type == 'linear': | |||
| positionwise_layer = PositionwiseFeedForward | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| dropout_rate, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d': | |||
| positionwise_layer = MultiLayeredConv1d | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| elif positionwise_layer_type == 'conv1d-linear': | |||
| positionwise_layer = Conv1dLinear | |||
| positionwise_layer_args = ( | |||
| output_size, | |||
| linear_units, | |||
| positionwise_conv_kernel_size, | |||
| dropout_rate, | |||
| ) | |||
| else: | |||
| raise NotImplementedError('Support only linear or conv1d.') | |||
| if selfattention_layer_type == 'selfattn': | |||
| encoder_selfattn_layer = MultiHeadedAttention | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| ) | |||
| elif selfattention_layer_type == 'sanm': | |||
| encoder_selfattn_layer = MultiHeadedAttentionSANM | |||
| encoder_selfattn_layer_args = ( | |||
| attention_heads, | |||
| output_size, | |||
| attention_dropout_rate, | |||
| kernel_size, | |||
| sanm_shfit, | |||
| ) | |||
| self.encoders = repeat( | |||
| num_blocks, | |||
| lambda lnum: EncoderLayerChunk( | |||
| output_size, | |||
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |||
| positionwise_layer(*positionwise_layer_args), | |||
| dropout_rate, | |||
| normalize_before, | |||
| concat_after, | |||
| ), | |||
| ) | |||
| if self.normalize_before: | |||
| self.after_norm = LayerNorm(output_size) | |||
| self.interctc_layer_idx = interctc_layer_idx | |||
| if len(interctc_layer_idx) > 0: | |||
| assert 0 < min(interctc_layer_idx) and max( | |||
| interctc_layer_idx) < num_blocks | |||
| self.interctc_use_conditioning = interctc_use_conditioning | |||
| self.conditioning_layer = None | |||
| shfit_fsmn = (kernel_size - 1) // 2 | |||
| self.overlap_chunk_cls = overlap_chunk( | |||
| chunk_size=chunk_size, | |||
| stride=stride, | |||
| pad_left=pad_left, | |||
| shfit_fsmn=shfit_fsmn, | |||
| encoder_att_look_back_factor=encoder_att_look_back_factor, | |||
| ) | |||
| def output_size(self) -> int: | |||
| return self._output_size | |||
| def forward( | |||
| self, | |||
| xs_pad: torch.Tensor, | |||
| ilens: torch.Tensor, | |||
| prev_states: torch.Tensor = None, | |||
| ctc: CTC = None, | |||
| ind: int = 0, | |||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
| """Embed positions in tensor. | |||
| Args: | |||
| xs_pad: input tensor (B, L, D) | |||
| ilens: input length (B) | |||
| prev_states: Not to be used now. | |||
| Returns: | |||
| position embedded tensor and mask | |||
| """ | |||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
| if self.embed is None: | |||
| xs_pad = xs_pad | |||
| elif (isinstance(self.embed, Conv2dSubsampling) | |||
| or isinstance(self.embed, Conv2dSubsampling2) | |||
| or isinstance(self.embed, Conv2dSubsampling6) | |||
| or isinstance(self.embed, Conv2dSubsampling8)): | |||
| short_status, limit_size = check_short_utt(self.embed, | |||
| xs_pad.size(1)) | |||
| if short_status: | |||
| raise TooShortUttError( | |||
| f'has {xs_pad.size(1)} frames and is too short for subsampling ' | |||
| + # noqa: * | |||
| f'(it needs more than {limit_size} frames), return empty results', | |||
| xs_pad.size(1), | |||
| limit_size, | |||
| ) | |||
| xs_pad, masks = self.embed(xs_pad, masks) | |||
| else: | |||
| xs_pad = self.embed(xs_pad) | |||
| mask_shfit_chunk, mask_att_chunk_encoder = None, None | |||
| if self.overlap_chunk_cls is not None: | |||
| ilens = masks.squeeze(1).sum(1) | |||
| chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind) | |||
| xs_pad, ilens = self.overlap_chunk_cls.split_chunk( | |||
| xs_pad, ilens, chunk_outs=chunk_outs) | |||
| masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) | |||
| mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk( | |||
| chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||
| mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder( | |||
| chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype) | |||
| intermediate_outs = [] | |||
| if len(self.interctc_layer_idx) == 0: | |||
| xs_pad, masks, _, _, _ = self.encoders(xs_pad, masks, None, | |||
| mask_shfit_chunk, | |||
| mask_att_chunk_encoder) | |||
| else: | |||
| for layer_idx, encoder_layer in enumerate(self.encoders): | |||
| xs_pad, masks, _, _, _ = encoder_layer(xs_pad, masks, None, | |||
| mask_shfit_chunk, | |||
| mask_att_chunk_encoder) | |||
| if layer_idx + 1 in self.interctc_layer_idx: | |||
| encoder_out = xs_pad | |||
| # intermediate outputs are also normalized | |||
| if self.normalize_before: | |||
| encoder_out = self.after_norm(encoder_out) | |||
| intermediate_outs.append((layer_idx + 1, encoder_out)) | |||
| if self.interctc_use_conditioning: | |||
| ctc_out = ctc.softmax(encoder_out) | |||
| xs_pad = xs_pad + self.conditioning_layer(ctc_out) | |||
| if self.normalize_before: | |||
| xs_pad = self.after_norm(xs_pad) | |||
| if self.overlap_chunk_cls is not None: | |||
| xs_pad, olens = self.overlap_chunk_cls.remove_chunk( | |||
| xs_pad, ilens, chunk_outs) | |||
| if len(intermediate_outs) > 0: | |||
| return (xs_pad, intermediate_outs), olens, None | |||
| return xs_pad, olens, None | |||
| @@ -1,113 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| import copy | |||
| from typing import Optional, Tuple, Union | |||
| import humanfriendly | |||
| import numpy as np | |||
| import torch | |||
| import torchaudio | |||
| import torchaudio.compliance.kaldi as kaldi | |||
| from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||
| from espnet2.layers.log_mel import LogMel | |||
| from espnet2.layers.stft import Stft | |||
| from espnet2.utils.get_default_kwargs import get_default_kwargs | |||
| from espnet.nets.pytorch_backend.frontends.frontend import Frontend | |||
| from typeguard import check_argument_types | |||
| class WavFrontend(AbsFrontend): | |||
| """Conventional frontend structure for ASR. | |||
| Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN | |||
| """ | |||
| def __init__( | |||
| self, | |||
| fs: Union[int, str] = 16000, | |||
| n_fft: int = 512, | |||
| win_length: int = 400, | |||
| hop_length: int = 160, | |||
| window: Optional[str] = 'hamming', | |||
| center: bool = True, | |||
| normalized: bool = False, | |||
| onesided: bool = True, | |||
| n_mels: int = 80, | |||
| fmin: int = None, | |||
| fmax: int = None, | |||
| htk: bool = False, | |||
| frontend_conf: Optional[dict] = get_default_kwargs(Frontend), | |||
| apply_stft: bool = True, | |||
| ): | |||
| assert check_argument_types() | |||
| super().__init__() | |||
| if isinstance(fs, str): | |||
| fs = humanfriendly.parse_size(fs) | |||
| # Deepcopy (In general, dict shouldn't be used as default arg) | |||
| frontend_conf = copy.deepcopy(frontend_conf) | |||
| self.hop_length = hop_length | |||
| self.win_length = win_length | |||
| self.window = window | |||
| self.fs = fs | |||
| if apply_stft: | |||
| self.stft = Stft( | |||
| n_fft=n_fft, | |||
| win_length=win_length, | |||
| hop_length=hop_length, | |||
| center=center, | |||
| window=window, | |||
| normalized=normalized, | |||
| onesided=onesided, | |||
| ) | |||
| else: | |||
| self.stft = None | |||
| self.apply_stft = apply_stft | |||
| if frontend_conf is not None: | |||
| self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) | |||
| else: | |||
| self.frontend = None | |||
| self.logmel = LogMel( | |||
| fs=fs, | |||
| n_fft=n_fft, | |||
| n_mels=n_mels, | |||
| fmin=fmin, | |||
| fmax=fmax, | |||
| htk=htk, | |||
| ) | |||
| self.n_mels = n_mels | |||
| self.frontend_type = 'default' | |||
| def output_size(self) -> int: | |||
| return self.n_mels | |||
| def forward( | |||
| self, input: torch.Tensor, | |||
| input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| sample_frequency = self.fs | |||
| num_mel_bins = self.n_mels | |||
| frame_length = self.win_length * 1000 / sample_frequency | |||
| frame_shift = self.hop_length * 1000 / sample_frequency | |||
| waveform = input * (1 << 15) | |||
| mat = kaldi.fbank( | |||
| waveform, | |||
| num_mel_bins=num_mel_bins, | |||
| frame_length=frame_length, | |||
| frame_shift=frame_shift, | |||
| dither=1.0, | |||
| energy_floor=0.0, | |||
| window_type=self.window, | |||
| sample_frequency=sample_frequency) | |||
| input_feats = mat[None, :] | |||
| feats_lens = torch.randn(1) | |||
| feats_lens.fill_(input_feats.shape[1]) | |||
| return input_feats, feats_lens | |||
| @@ -1,321 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| import logging | |||
| import math | |||
| import numpy as np | |||
| import torch | |||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
| from ...nets.pytorch_backend.cif_utils.cif import \ | |||
| cif_predictor as cif_predictor | |||
| np.set_printoptions(threshold=np.inf) | |||
| torch.set_printoptions(profile='full', precision=100000, linewidth=None) | |||
| def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device='cpu'): | |||
| if maxlen is None: | |||
| maxlen = lengths.max() | |||
| row_vector = torch.arange(0, maxlen, 1) | |||
| matrix = torch.unsqueeze(lengths, dim=-1) | |||
| mask = row_vector < matrix | |||
| return mask.type(dtype).to(device) | |||
| class overlap_chunk(): | |||
| def __init__( | |||
| self, | |||
| chunk_size: tuple = (16, ), | |||
| stride: tuple = (10, ), | |||
| pad_left: tuple = (0, ), | |||
| encoder_att_look_back_factor: tuple = (1, ), | |||
| shfit_fsmn: int = 0, | |||
| ): | |||
| self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor \ | |||
| = chunk_size, stride, pad_left, encoder_att_look_back_factor | |||
| self.shfit_fsmn = shfit_fsmn | |||
| self.x_add_mask = None | |||
| self.x_rm_mask = None | |||
| self.x_len = None | |||
| self.mask_shfit_chunk = None | |||
| self.mask_chunk_predictor = None | |||
| self.mask_att_chunk_encoder = None | |||
| self.mask_shift_att_chunk_decoder = None | |||
| self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur \ | |||
| = None, None, None, None | |||
| def get_chunk_size(self, ind: int = 0): | |||
| # with torch.no_grad: | |||
| chunk_size, stride, pad_left, encoder_att_look_back_factor = self.chunk_size[ | |||
| ind], self.stride[ind], self.pad_left[ | |||
| ind], self.encoder_att_look_back_factor[ind] | |||
| self.chunk_size_cur, self.stride_cur, self.pad_left_cur, | |||
| self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \ | |||
| = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn | |||
| return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur | |||
| def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1): | |||
| with torch.no_grad(): | |||
| x_len = x_len.cpu().numpy() | |||
| x_len_max = x_len.max() | |||
| chunk_size, stride, pad_left, encoder_att_look_back_factor = self.get_chunk_size( | |||
| ind) | |||
| shfit_fsmn = self.shfit_fsmn | |||
| chunk_size_pad_shift = chunk_size + shfit_fsmn | |||
| chunk_num_batch = np.ceil(x_len / stride).astype(np.int32) | |||
| x_len_chunk = ( | |||
| chunk_num_batch - 1 | |||
| ) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - ( | |||
| chunk_num_batch - 1) * stride | |||
| x_len_chunk = x_len_chunk.astype(x_len.dtype) | |||
| x_len_chunk_max = x_len_chunk.max() | |||
| chunk_num = int(math.ceil(x_len_max / stride)) | |||
| dtype = np.int32 | |||
| max_len_for_x_mask_tmp = max(chunk_size, x_len_max) | |||
| x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) | |||
| x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) | |||
| mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) | |||
| mask_chunk_predictor = np.zeros([0, num_units_predictor], | |||
| dtype=dtype) | |||
| mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) | |||
| mask_att_chunk_encoder = np.zeros( | |||
| [0, chunk_num * chunk_size_pad_shift], dtype=dtype) | |||
| for chunk_ids in range(chunk_num): | |||
| # x_mask add | |||
| fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), | |||
| dtype=dtype) | |||
| x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) | |||
| x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), | |||
| dtype=dtype) | |||
| x_mask_pad_right = np.zeros( | |||
| (chunk_size, max_len_for_x_mask_tmp), dtype=dtype) | |||
| x_cur_pad = np.concatenate( | |||
| [x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1) | |||
| x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp] | |||
| x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], | |||
| axis=0) | |||
| x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], | |||
| axis=0) | |||
| # x_mask rm | |||
| fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), | |||
| dtype=dtype) | |||
| x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) | |||
| x_mask_right = np.zeros((stride, chunk_size - stride), | |||
| dtype=dtype) | |||
| x_mask_cur = np.concatenate([x_mask_cur, x_mask_right], axis=1) | |||
| x_mask_cur_pad_top = np.zeros((chunk_ids * stride, chunk_size), | |||
| dtype=dtype) | |||
| x_mask_cur_pad_bottom = np.zeros( | |||
| (max_len_for_x_mask_tmp, chunk_size), dtype=dtype) | |||
| x_rm_mask_cur = np.concatenate( | |||
| [x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], | |||
| axis=0) | |||
| x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, : | |||
| chunk_size] | |||
| x_rm_mask_cur_fsmn = np.concatenate( | |||
| [fsmn_padding, x_rm_mask_cur], axis=1) | |||
| x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], | |||
| axis=1) | |||
| # fsmn_padding_mask | |||
| pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) | |||
| ones_1 = np.ones([chunk_size, num_units], dtype=dtype) | |||
| mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], | |||
| axis=0) | |||
| mask_shfit_chunk = np.concatenate( | |||
| [mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) | |||
| # predictor mask | |||
| zeros_1 = np.zeros( | |||
| [shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) | |||
| ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) | |||
| zeros_3 = np.zeros( | |||
| [chunk_size - stride - pad_left, num_units_predictor], | |||
| dtype=dtype) | |||
| ones_zeros = np.concatenate([ones_2, zeros_3], axis=0) | |||
| mask_chunk_predictor_cur = np.concatenate( | |||
| [zeros_1, ones_zeros], axis=0) | |||
| mask_chunk_predictor = np.concatenate( | |||
| [mask_chunk_predictor, mask_chunk_predictor_cur], axis=0) | |||
| # encoder att mask | |||
| zeros_1_top = np.zeros( | |||
| [shfit_fsmn, chunk_num * chunk_size_pad_shift], | |||
| dtype=dtype) | |||
| zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) | |||
| zeros_2 = np.zeros( | |||
| [chunk_size, zeros_2_num * chunk_size_pad_shift], | |||
| dtype=dtype) | |||
| encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) | |||
| zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||
| ones_2_mid = np.ones([stride, stride], dtype=dtype) | |||
| zeros_2_bottom = np.zeros([chunk_size - stride, stride], | |||
| dtype=dtype) | |||
| zeros_2_right = np.zeros([chunk_size, chunk_size - stride], | |||
| dtype=dtype) | |||
| ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0) | |||
| ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], | |||
| axis=1) | |||
| ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) | |||
| zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) | |||
| ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) | |||
| ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) | |||
| zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0) | |||
| zeros_remain = np.zeros( | |||
| [chunk_size, zeros_remain_num * chunk_size_pad_shift], | |||
| dtype=dtype) | |||
| ones2_bottom = np.concatenate( | |||
| [zeros_2, ones_2, ones_3, zeros_remain], axis=1) | |||
| mask_att_chunk_encoder_cur = np.concatenate( | |||
| [zeros_1_top, ones2_bottom], axis=0) | |||
| mask_att_chunk_encoder = np.concatenate( | |||
| [mask_att_chunk_encoder, mask_att_chunk_encoder_cur], | |||
| axis=0) | |||
| # decoder fsmn_shift_att_mask | |||
| zeros_1 = np.zeros([shfit_fsmn, 1]) | |||
| ones_1 = np.ones([chunk_size, 1]) | |||
| mask_shift_att_chunk_decoder_cur = np.concatenate( | |||
| [zeros_1, ones_1], axis=0) | |||
| mask_shift_att_chunk_decoder = np.concatenate( | |||
| [ | |||
| mask_shift_att_chunk_decoder, | |||
| mask_shift_att_chunk_decoder_cur | |||
| ], | |||
| vaxis=0) # noqa: * | |||
| self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max] | |||
| self.x_len_chunk = x_len_chunk | |||
| self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] | |||
| self.x_len = x_len | |||
| self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] | |||
| self.mask_chunk_predictor = mask_chunk_predictor[: | |||
| x_len_chunk_max, :] | |||
| self.mask_att_chunk_encoder = mask_att_chunk_encoder[: | |||
| x_len_chunk_max, : | |||
| x_len_chunk_max] | |||
| self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[: | |||
| x_len_chunk_max, :] | |||
| return (self.x_add_mask, self.x_len_chunk, self.x_rm_mask, self.x_len, | |||
| self.mask_shfit_chunk, self.mask_chunk_predictor, | |||
| self.mask_att_chunk_encoder, self.mask_shift_att_chunk_decoder) | |||
| def split_chunk(self, x, x_len, chunk_outs): | |||
| """ | |||
| :param x: (b, t, d) | |||
| :param x_length: (b) | |||
| :param ind: int | |||
| :return: | |||
| """ | |||
| x = x[:, :x_len.max(), :] | |||
| b, t, d = x.size() | |||
| x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(x.device) | |||
| x *= x_len_mask[:, :, None] | |||
| x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype) | |||
| x_len_chunk = self.get_x_len_chunk( | |||
| chunk_outs, x_len.device, dtype=x_len.dtype) | |||
| x = torch.transpose(x, 1, 0) | |||
| x = torch.reshape(x, [t, -1]) | |||
| x_chunk = torch.mm(x_add_mask, x) | |||
| x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0) | |||
| return x_chunk, x_len_chunk | |||
| def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs): | |||
| x_chunk = x_chunk[:, :x_len_chunk.max(), :] | |||
| b, t, d = x_chunk.size() | |||
| x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to( | |||
| x_chunk.device) | |||
| x_chunk *= x_len_chunk_mask[:, :, None] | |||
| x_rm_mask = self.get_x_rm_mask( | |||
| chunk_outs, x_chunk.device, dtype=x_chunk.dtype) | |||
| x_len = self.get_x_len( | |||
| chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype) | |||
| x_chunk = torch.transpose(x_chunk, 1, 0) | |||
| x_chunk = torch.reshape(x_chunk, [t, -1]) | |||
| x = torch.mm(x_rm_mask, x_chunk) | |||
| x = torch.reshape(x, [-1, b, d]).transpose(1, 0) | |||
| return x, x_len | |||
| def get_x_add_mask(self, chunk_outs, device, idx=0, dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| def get_x_len_chunk(self, chunk_outs, device, idx=1, dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| def get_x_rm_mask(self, chunk_outs, device, idx=2, dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| def get_x_len(self, chunk_outs, device, idx=3, dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| def get_mask_shfit_chunk(self, | |||
| chunk_outs, | |||
| device, | |||
| batch_size=1, | |||
| num_units=1, | |||
| idx=4, | |||
| dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| def get_mask_chunk_predictor(self, | |||
| chunk_outs, | |||
| device, | |||
| batch_size=1, | |||
| num_units=1, | |||
| idx=5, | |||
| dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| def get_mask_att_chunk_encoder(self, | |||
| chunk_outs, | |||
| device, | |||
| batch_size=1, | |||
| idx=6, | |||
| dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = np.tile(x[None, :, :, ], [batch_size, 1, 1]) | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| def get_mask_shift_att_chunk_decoder(self, | |||
| chunk_outs, | |||
| device, | |||
| batch_size=1, | |||
| idx=7, | |||
| dtype=torch.float32): | |||
| x = chunk_outs[idx] | |||
| x = np.tile(x[None, None, :, 0], [batch_size, 1, 1]) | |||
| x = torch.from_numpy(x).type(dtype).to(device) | |||
| return x.detach() | |||
| @@ -1,250 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| import logging | |||
| import numpy as np | |||
| import torch | |||
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | |||
| from torch import nn | |||
| class CIF_Model(nn.Module): | |||
| def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||
| super(CIF_Model, self).__init__() | |||
| self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||
| self.cif_conv1d = nn.Conv1d( | |||
| idim, idim, l_order + r_order + 1, groups=idim) | |||
| self.cif_output = nn.Linear(idim, 1) | |||
| self.dropout = torch.nn.Dropout(p=dropout) | |||
| self.threshold = threshold | |||
| def forward(self, hidden, target_label=None, mask=None, ignore_id=-1): | |||
| h = hidden | |||
| context = h.transpose(1, 2) | |||
| queries = self.pad(context) | |||
| memory = self.cif_conv1d(queries) | |||
| output = memory + context | |||
| output = self.dropout(output) | |||
| output = output.transpose(1, 2) | |||
| output = torch.relu(output) | |||
| output = self.cif_output(output) | |||
| alphas = torch.sigmoid(output) | |||
| if mask is not None: | |||
| alphas = alphas * mask.transpose(-1, -2).float() | |||
| alphas = alphas.squeeze(-1) | |||
| if target_label is not None: | |||
| target_length = (target_label != ignore_id).float().sum(-1) | |||
| else: | |||
| target_length = None | |||
| cif_length = alphas.sum(-1) | |||
| if target_label is not None: | |||
| alphas *= (target_length / cif_length)[:, None].repeat( | |||
| 1, alphas.size(1)) | |||
| cif_output, cif_peak = cif(hidden, alphas, self.threshold) | |||
| return cif_output, cif_length, target_length, cif_peak | |||
| def gen_frame_alignments(self, | |||
| alphas: torch.Tensor = None, | |||
| memory_sequence_length: torch.Tensor = None, | |||
| is_training: bool = True, | |||
| dtype: torch.dtype = torch.float32): | |||
| batch_size, maximum_length = alphas.size() | |||
| int_type = torch.int32 | |||
| token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||
| max_token_num = torch.max(token_num).item() | |||
| alphas_cumsum = torch.cumsum(alphas, dim=1) | |||
| alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||
| alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||
| [1, max_token_num, 1]) | |||
| index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||
| index = torch.cumsum(index, dim=1) | |||
| index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||
| index_div = torch.floor(torch.divide(alphas_cumsum, | |||
| index)).type(int_type) | |||
| index_div_bool_zeros = index_div.eq(0) | |||
| index_div_bool_zeros_count = torch.sum( | |||
| index_div_bool_zeros, dim=-1) + 1 | |||
| index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||
| memory_sequence_length.max()) | |||
| token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||
| token_num.device) | |||
| index_div_bool_zeros_count *= token_num_mask | |||
| index_div_bool_zeros_count_tile = torch.tile( | |||
| index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||
| ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||
| zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||
| ones = torch.cumsum(ones, dim=2) | |||
| cond = index_div_bool_zeros_count_tile == ones | |||
| index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||
| index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||
| torch.bool) | |||
| index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||
| int_type) | |||
| index_div_bool_zeros_count_tile_out = torch.sum( | |||
| index_div_bool_zeros_count_tile, dim=1) | |||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||
| int_type) | |||
| predictor_mask = (~make_pad_mask( | |||
| memory_sequence_length, | |||
| maxlen=memory_sequence_length.max())).type(int_type).to( | |||
| memory_sequence_length.device) # noqa: * | |||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||
| return index_div_bool_zeros_count_tile_out.detach( | |||
| ), index_div_bool_zeros_count.detach() | |||
| class cif_predictor(nn.Module): | |||
| def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1): | |||
| super(cif_predictor, self).__init__() | |||
| self.pad = nn.ConstantPad1d((l_order, r_order), 0) | |||
| self.cif_conv1d = nn.Conv1d( | |||
| idim, idim, l_order + r_order + 1, groups=idim) | |||
| self.cif_output = nn.Linear(idim, 1) | |||
| self.dropout = torch.nn.Dropout(p=dropout) | |||
| self.threshold = threshold | |||
| def forward(self, | |||
| hidden, | |||
| target_label=None, | |||
| mask=None, | |||
| ignore_id=-1, | |||
| mask_chunk_predictor=None, | |||
| target_label_length=None): | |||
| h = hidden | |||
| context = h.transpose(1, 2) | |||
| queries = self.pad(context) | |||
| memory = self.cif_conv1d(queries) | |||
| output = memory + context | |||
| output = self.dropout(output) | |||
| output = output.transpose(1, 2) | |||
| output = torch.relu(output) | |||
| output = self.cif_output(output) | |||
| alphas = torch.sigmoid(output) | |||
| if mask is not None: | |||
| alphas = alphas * mask.transpose(-1, -2).float() | |||
| if mask_chunk_predictor is not None: | |||
| alphas = alphas * mask_chunk_predictor | |||
| alphas = alphas.squeeze(-1) | |||
| if target_label_length is not None: | |||
| target_length = target_label_length | |||
| elif target_label is not None: | |||
| target_length = (target_label != ignore_id).float().sum(-1) | |||
| else: | |||
| target_length = None | |||
| token_num = alphas.sum(-1) | |||
| if target_length is not None: | |||
| alphas *= (target_length / token_num)[:, None].repeat( | |||
| 1, alphas.size(1)) | |||
| acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) | |||
| return acoustic_embeds, token_num, alphas, cif_peak | |||
| def gen_frame_alignments(self, | |||
| alphas: torch.Tensor = None, | |||
| memory_sequence_length: torch.Tensor = None, | |||
| is_training: bool = True, | |||
| dtype: torch.dtype = torch.float32): | |||
| batch_size, maximum_length = alphas.size() | |||
| int_type = torch.int32 | |||
| token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type) | |||
| max_token_num = torch.max(token_num).item() | |||
| alphas_cumsum = torch.cumsum(alphas, dim=1) | |||
| alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) | |||
| alphas_cumsum = torch.tile(alphas_cumsum[:, None, :], | |||
| [1, max_token_num, 1]) | |||
| index = torch.ones([batch_size, max_token_num], dtype=int_type) | |||
| index = torch.cumsum(index, dim=1) | |||
| index = torch.tile(index[:, :, None], [1, 1, maximum_length]) | |||
| index_div = torch.floor(torch.divide(alphas_cumsum, | |||
| index)).type(int_type) | |||
| index_div_bool_zeros = index_div.eq(0) | |||
| index_div_bool_zeros_count = torch.sum( | |||
| index_div_bool_zeros, dim=-1) + 1 | |||
| index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0, | |||
| memory_sequence_length.max()) | |||
| token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to( | |||
| token_num.device) | |||
| index_div_bool_zeros_count *= token_num_mask | |||
| index_div_bool_zeros_count_tile = torch.tile( | |||
| index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length]) | |||
| ones = torch.ones_like(index_div_bool_zeros_count_tile) | |||
| zeros = torch.zeros_like(index_div_bool_zeros_count_tile) | |||
| ones = torch.cumsum(ones, dim=2) | |||
| cond = index_div_bool_zeros_count_tile == ones | |||
| index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones) | |||
| index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type( | |||
| torch.bool) | |||
| index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type( | |||
| int_type) | |||
| index_div_bool_zeros_count_tile_out = torch.sum( | |||
| index_div_bool_zeros_count_tile, dim=1) | |||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type( | |||
| int_type) | |||
| predictor_mask = (~make_pad_mask( | |||
| memory_sequence_length, | |||
| maxlen=memory_sequence_length.max())).type(int_type).to( | |||
| memory_sequence_length.device) # noqa: * | |||
| index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask | |||
| return index_div_bool_zeros_count_tile_out.detach( | |||
| ), index_div_bool_zeros_count.detach() | |||
| def cif(hidden, alphas, threshold): | |||
| batch_size, len_time, hidden_size = hidden.size() | |||
| # loop varss | |||
| integrate = torch.zeros([batch_size], device=hidden.device) | |||
| frame = torch.zeros([batch_size, hidden_size], device=hidden.device) | |||
| # intermediate vars along time | |||
| list_fires = [] | |||
| list_frames = [] | |||
| for t in range(len_time): | |||
| alpha = alphas[:, t] | |||
| distribution_completion = torch.ones([batch_size], | |||
| device=hidden.device) - integrate | |||
| integrate += alpha | |||
| list_fires.append(integrate) | |||
| fire_place = integrate >= threshold | |||
| integrate = torch.where( | |||
| fire_place, | |||
| integrate - torch.ones([batch_size], device=hidden.device), | |||
| integrate) | |||
| cur = torch.where(fire_place, distribution_completion, alpha) | |||
| remainds = alpha - cur | |||
| frame += cur[:, None] * hidden[:, t, :] | |||
| list_frames.append(frame) | |||
| frame = torch.where(fire_place[:, None].repeat(1, hidden_size), | |||
| remainds[:, None] * hidden[:, t, :], frame) | |||
| fires = torch.stack(list_fires, 1) | |||
| frames = torch.stack(list_frames, 1) | |||
| list_ls = [] | |||
| len_labels = torch.round(alphas.sum(-1)).int() | |||
| max_label_len = len_labels.max() | |||
| for b in range(batch_size): | |||
| fire = fires[b, :] | |||
| ls = torch.index_select(frames[b, :, :], 0, | |||
| torch.nonzero(fire >= threshold).squeeze()) | |||
| pad_l = torch.zeros([max_label_len - ls.size(0), hidden_size], | |||
| device=hidden.device) | |||
| list_ls.append(torch.cat([ls, pad_l], 0)) | |||
| return torch.stack(list_ls, 0), fires | |||
| @@ -1,680 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| """Multi-Head Attention layer definition.""" | |||
| import logging | |||
| import math | |||
| import numpy | |||
| import torch | |||
| from torch import nn | |||
| torch.set_printoptions(profile='full', precision=1) | |||
| class MultiHeadedAttention(nn.Module): | |||
| """Multi-Head Attention layer. | |||
| Args: | |||
| n_head (int): The number of heads. | |||
| n_feat (int): The number of features. | |||
| dropout_rate (float): Dropout rate. | |||
| """ | |||
| def __init__(self, n_head, n_feat, dropout_rate): | |||
| """Construct an MultiHeadedAttention object.""" | |||
| super(MultiHeadedAttention, self).__init__() | |||
| assert n_feat % n_head == 0 | |||
| # We assume d_v always equals d_k | |||
| self.d_k = n_feat // n_head | |||
| self.h = n_head | |||
| self.linear_q = nn.Linear(n_feat, n_feat) | |||
| self.linear_k = nn.Linear(n_feat, n_feat) | |||
| self.linear_v = nn.Linear(n_feat, n_feat) | |||
| self.linear_out = nn.Linear(n_feat, n_feat) | |||
| self.attn = None | |||
| self.dropout = nn.Dropout(p=dropout_rate) | |||
| def forward_qkv(self, query, key, value): | |||
| """Transform query, key and value. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| Returns: | |||
| torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||
| torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||
| torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||
| """ | |||
| n_batch = query.size(0) | |||
| q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||
| k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||
| v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||
| q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||
| k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||
| v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||
| return q, k, v | |||
| def forward_attention(self, value, scores, mask): | |||
| """Compute attention context vector. | |||
| Args: | |||
| value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||
| scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||
| mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Transformed value (#batch, time1, d_model) | |||
| weighted by the attention score (#batch, time1, time2). | |||
| """ | |||
| n_batch = value.size(0) | |||
| if mask is not None: | |||
| mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||
| min_value = float( | |||
| numpy.finfo(torch.tensor( | |||
| 0, dtype=scores.dtype).numpy().dtype).min) | |||
| scores = scores.masked_fill(mask, min_value) | |||
| self.attn = torch.softmax( | |||
| scores, dim=-1).masked_fill(mask, | |||
| 0.0) # (batch, head, time1, time2) | |||
| else: | |||
| self.attn = torch.softmax( | |||
| scores, dim=-1) # (batch, head, time1, time2) | |||
| p_attn = self.dropout(self.attn) | |||
| x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||
| x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||
| self.h * self.d_k) | |||
| ) # (batch, time1, d_model) | |||
| return self.linear_out(x) # (batch, time1, d_model) | |||
| def forward(self, query, key, value, mask): | |||
| """Compute scaled dot product attention. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
| (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||
| """ | |||
| q, k, v = self.forward_qkv(query, key, value) | |||
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||
| return self.forward_attention(v, scores, mask) | |||
| class MultiHeadedAttentionSANM(nn.Module): | |||
| """Multi-Head Attention layer. | |||
| Args: | |||
| n_head (int): The number of heads. | |||
| n_feat (int): The number of features. | |||
| dropout_rate (float): Dropout rate. | |||
| """ | |||
| def __init__(self, | |||
| n_head, | |||
| n_feat, | |||
| dropout_rate, | |||
| kernel_size, | |||
| sanm_shfit=0): | |||
| """Construct an MultiHeadedAttention object.""" | |||
| super(MultiHeadedAttentionSANM, self).__init__() | |||
| assert n_feat % n_head == 0 | |||
| # We assume d_v always equals d_k | |||
| self.d_k = n_feat // n_head | |||
| self.h = n_head | |||
| self.linear_q = nn.Linear(n_feat, n_feat) | |||
| self.linear_k = nn.Linear(n_feat, n_feat) | |||
| self.linear_v = nn.Linear(n_feat, n_feat) | |||
| self.linear_out = nn.Linear(n_feat, n_feat) | |||
| self.attn = None | |||
| self.dropout = nn.Dropout(p=dropout_rate) | |||
| self.fsmn_block = nn.Conv1d( | |||
| n_feat, | |||
| n_feat, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| groups=n_feat, | |||
| bias=False) | |||
| # padding | |||
| left_padding = (kernel_size - 1) // 2 | |||
| if sanm_shfit > 0: | |||
| left_padding = left_padding + sanm_shfit | |||
| right_padding = kernel_size - 1 - left_padding | |||
| self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) | |||
| def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): | |||
| ''' | |||
| :param x: (#batch, time1, size). | |||
| :param mask: Mask tensor (#batch, 1, time) | |||
| :return: | |||
| ''' | |||
| # b, t, d = inputs.size() | |||
| mask = mask[:, 0, :, None] | |||
| if mask_shfit_chunk is not None: | |||
| mask = mask * mask_shfit_chunk | |||
| inputs *= mask | |||
| x = inputs.transpose(1, 2) | |||
| x = self.pad_fn(x) | |||
| x = self.fsmn_block(x) | |||
| x = x.transpose(1, 2) | |||
| x += inputs | |||
| x = self.dropout(x) | |||
| return x * mask | |||
| def forward_qkv(self, query, key, value): | |||
| """Transform query, key and value. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| Returns: | |||
| torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | |||
| torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | |||
| torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | |||
| """ | |||
| n_batch = query.size(0) | |||
| q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | |||
| k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | |||
| v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | |||
| q = q.transpose(1, 2) # (batch, head, time1, d_k) | |||
| k = k.transpose(1, 2) # (batch, head, time2, d_k) | |||
| v = v.transpose(1, 2) # (batch, head, time2, d_k) | |||
| return q, k, v | |||
| def forward_attention(self, | |||
| value, | |||
| scores, | |||
| mask, | |||
| mask_att_chunk_encoder=None): | |||
| """Compute attention context vector. | |||
| Args: | |||
| value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | |||
| scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | |||
| mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Transformed value (#batch, time1, d_model) | |||
| weighted by the attention score (#batch, time1, time2). | |||
| """ | |||
| n_batch = value.size(0) | |||
| if mask is not None: | |||
| if mask_att_chunk_encoder is not None: | |||
| mask = mask * mask_att_chunk_encoder | |||
| mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | |||
| min_value = float( | |||
| numpy.finfo(torch.tensor( | |||
| 0, dtype=scores.dtype).numpy().dtype).min) | |||
| scores = scores.masked_fill(mask, min_value) | |||
| self.attn = torch.softmax( | |||
| scores, dim=-1).masked_fill(mask, | |||
| 0.0) # (batch, head, time1, time2) | |||
| else: | |||
| self.attn = torch.softmax( | |||
| scores, dim=-1) # (batch, head, time1, time2) | |||
| p_attn = self.dropout(self.attn) | |||
| x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | |||
| x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | |||
| self.h * self.d_k) | |||
| ) # (batch, time1, d_model) | |||
| return self.linear_out(x) # (batch, time1, d_model) | |||
| def forward(self, | |||
| query, | |||
| key, | |||
| value, | |||
| mask, | |||
| mask_shfit_chunk=None, | |||
| mask_att_chunk_encoder=None): | |||
| """Compute scaled dot product attention. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
| (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||
| """ | |||
| fsmn_memory = self.forward_fsmn(value, mask, mask_shfit_chunk) | |||
| q, k, v = self.forward_qkv(query, key, value) | |||
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |||
| att_outs = self.forward_attention(v, scores, mask, | |||
| mask_att_chunk_encoder) | |||
| return att_outs + fsmn_memory | |||
| class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): | |||
| """Multi-Head Attention layer with relative position encoding (old version). | |||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
| Paper: https://arxiv.org/abs/1901.02860 | |||
| Args: | |||
| n_head (int): The number of heads. | |||
| n_feat (int): The number of features. | |||
| dropout_rate (float): Dropout rate. | |||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
| """ | |||
| def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||
| super().__init__(n_head, n_feat, dropout_rate) | |||
| self.zero_triu = zero_triu | |||
| # linear transformation for positional encoding | |||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
| # these two learnable bias are used in matrix c and matrix d | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
| def rel_shift(self, x): | |||
| """Compute relative positional encoding. | |||
| Args: | |||
| x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor. | |||
| """ | |||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||
| device=x.device, | |||
| dtype=x.dtype) | |||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
| x = x_padded[:, :, 1:].view_as(x) | |||
| if self.zero_triu: | |||
| ones = torch.ones((x.size(2), x.size(3))) | |||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
| return x | |||
| def forward(self, query, key, value, pos_emb, mask): | |||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
| (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||
| """ | |||
| q, k, v = self.forward_qkv(query, key, value) | |||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
| n_batch_pos = pos_emb.size(0) | |||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
| p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
| # compute attention score | |||
| # first compute matrix a and matrix c | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| # (batch, head, time1, time2) | |||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
| # compute matrix b and matrix d | |||
| # (batch, head, time1, time1) | |||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
| matrix_bd = self.rel_shift(matrix_bd) | |||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
| self.d_k) # (batch, head, time1, time2) | |||
| return self.forward_attention(v, scores, mask) | |||
| class LegacyRelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||
| """Multi-Head Attention layer with relative position encoding (old version). | |||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
| Paper: https://arxiv.org/abs/1901.02860 | |||
| Args: | |||
| n_head (int): The number of heads. | |||
| n_feat (int): The number of features. | |||
| dropout_rate (float): Dropout rate. | |||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
| """ | |||
| def __init__(self, | |||
| n_head, | |||
| n_feat, | |||
| dropout_rate, | |||
| zero_triu=False, | |||
| kernel_size=15, | |||
| sanm_shfit=0): | |||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||
| super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||
| self.zero_triu = zero_triu | |||
| # linear transformation for positional encoding | |||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
| # these two learnable bias are used in matrix c and matrix d | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
| def rel_shift(self, x): | |||
| """Compute relative positional encoding. | |||
| Args: | |||
| x (torch.Tensor): Input tensor (batch, head, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor. | |||
| """ | |||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||
| device=x.device, | |||
| dtype=x.dtype) | |||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
| x = x_padded[:, :, 1:].view_as(x) | |||
| if self.zero_triu: | |||
| ones = torch.ones((x.size(2), x.size(3))) | |||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
| return x | |||
| def forward(self, query, key, value, pos_emb, mask): | |||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). | |||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
| (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||
| """ | |||
| fsmn_memory = self.forward_fsmn(value, mask) | |||
| q, k, v = self.forward_qkv(query, key, value) | |||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
| n_batch_pos = pos_emb.size(0) | |||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
| p = p.transpose(1, 2) # (batch, head, time1, d_k) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
| # compute attention score | |||
| # first compute matrix a and matrix c | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| # (batch, head, time1, time2) | |||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
| # compute matrix b and matrix d | |||
| # (batch, head, time1, time1) | |||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
| matrix_bd = self.rel_shift(matrix_bd) | |||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
| self.d_k) # (batch, head, time1, time2) | |||
| att_outs = self.forward_attention(v, scores, mask) | |||
| return att_outs + fsmn_memory | |||
| class RelPositionMultiHeadedAttention(MultiHeadedAttention): | |||
| """Multi-Head Attention layer with relative position encoding (new implementation). | |||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
| Paper: https://arxiv.org/abs/1901.02860 | |||
| Args: | |||
| n_head (int): The number of heads. | |||
| n_feat (int): The number of features. | |||
| dropout_rate (float): Dropout rate. | |||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
| """ | |||
| def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): | |||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||
| super().__init__(n_head, n_feat, dropout_rate) | |||
| self.zero_triu = zero_triu | |||
| # linear transformation for positional encoding | |||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
| # these two learnable bias are used in matrix c and matrix d | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
| def rel_shift(self, x): | |||
| """Compute relative positional encoding. | |||
| Args: | |||
| x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||
| time1 means the length of query vector. | |||
| Returns: | |||
| torch.Tensor: Output tensor. | |||
| """ | |||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||
| device=x.device, | |||
| dtype=x.dtype) | |||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
| x = x_padded[:, :, 1:].view_as( | |||
| x)[:, :, :, :x.size(-1) // 2 | |||
| + 1] # only keep the positions from 0 to time2 | |||
| if self.zero_triu: | |||
| ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
| return x | |||
| def forward(self, query, key, value, pos_emb, mask): | |||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| pos_emb (torch.Tensor): Positional embedding tensor | |||
| (#batch, 2*time1-1, size). | |||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
| (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||
| """ | |||
| q, k, v = self.forward_qkv(query, key, value) | |||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
| n_batch_pos = pos_emb.size(0) | |||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
| p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
| # compute attention score | |||
| # first compute matrix a and matrix c | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| # (batch, head, time1, time2) | |||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
| # compute matrix b and matrix d | |||
| # (batch, head, time1, 2*time1-1) | |||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
| matrix_bd = self.rel_shift(matrix_bd) | |||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
| self.d_k) # (batch, head, time1, time2) | |||
| return self.forward_attention(v, scores, mask) | |||
| class RelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM): | |||
| """Multi-Head Attention layer with relative position encoding (new implementation). | |||
| Details can be found in https://github.com/espnet/espnet/pull/2816. | |||
| Paper: https://arxiv.org/abs/1901.02860 | |||
| Args: | |||
| n_head (int): The number of heads. | |||
| n_feat (int): The number of features. | |||
| dropout_rate (float): Dropout rate. | |||
| zero_triu (bool): Whether to zero the upper triangular part of attention matrix. | |||
| """ | |||
| def __init__(self, | |||
| n_head, | |||
| n_feat, | |||
| dropout_rate, | |||
| zero_triu=False, | |||
| kernel_size=15, | |||
| sanm_shfit=0): | |||
| """Construct an RelPositionMultiHeadedAttention object.""" | |||
| super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit) | |||
| self.zero_triu = zero_triu | |||
| # linear transformation for positional encoding | |||
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | |||
| # these two learnable bias are used in matrix c and matrix d | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_u) | |||
| torch.nn.init.xavier_uniform_(self.pos_bias_v) | |||
| def rel_shift(self, x): | |||
| """Compute relative positional encoding. | |||
| Args: | |||
| x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). | |||
| time1 means the length of query vector. | |||
| Returns: | |||
| torch.Tensor: Output tensor. | |||
| """ | |||
| zero_pad = torch.zeros((*x.size()[:3], 1), | |||
| device=x.device, | |||
| dtype=x.dtype) | |||
| x_padded = torch.cat([zero_pad, x], dim=-1) | |||
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) | |||
| x = x_padded[:, :, 1:].view_as( | |||
| x)[:, :, :, :x.size(-1) // 2 | |||
| + 1] # only keep the positions from 0 to time2 | |||
| if self.zero_triu: | |||
| ones = torch.ones((x.size(2), x.size(3)), device=x.device) | |||
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | |||
| return x | |||
| def forward(self, query, key, value, pos_emb, mask): | |||
| """Compute 'Scaled Dot Product Attention' with rel. positional encoding. | |||
| Args: | |||
| query (torch.Tensor): Query tensor (#batch, time1, size). | |||
| key (torch.Tensor): Key tensor (#batch, time2, size). | |||
| value (torch.Tensor): Value tensor (#batch, time2, size). | |||
| pos_emb (torch.Tensor): Positional embedding tensor | |||
| (#batch, 2*time1-1, size). | |||
| mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | |||
| (#batch, time1, time2). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time1, d_model). | |||
| """ | |||
| fsmn_memory = self.forward_fsmn(value, mask) | |||
| q, k, v = self.forward_qkv(query, key, value) | |||
| q = q.transpose(1, 2) # (batch, time1, head, d_k) | |||
| n_batch_pos = pos_emb.size(0) | |||
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | |||
| p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | |||
| # (batch, head, time1, d_k) | |||
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | |||
| # compute attention score | |||
| # first compute matrix a and matrix c | |||
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |||
| # (batch, head, time1, time2) | |||
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | |||
| # compute matrix b and matrix d | |||
| # (batch, head, time1, 2*time1-1) | |||
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | |||
| matrix_bd = self.rel_shift(matrix_bd) | |||
| scores = (matrix_ac + matrix_bd) / math.sqrt( | |||
| self.d_k) # (batch, head, time1, time2) | |||
| att_outs = self.forward_attention(v, scores, mask) | |||
| return att_outs + fsmn_memory | |||
| @@ -1,239 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| """Encoder self-attention layer definition.""" | |||
| import torch | |||
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |||
| from torch import nn | |||
| class EncoderLayer(nn.Module): | |||
| """Encoder layer module. | |||
| Args: | |||
| size (int): Input dimension. | |||
| self_attn (torch.nn.Module): Self-attention module instance. | |||
| `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||
| can be used as the argument. | |||
| feed_forward (torch.nn.Module): Feed-forward module instance. | |||
| `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||
| can be used as the argument. | |||
| dropout_rate (float): Dropout rate. | |||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||
| concat_after (bool): Whether to concat attention layer's input and output. | |||
| if True, additional linear will be applied. | |||
| i.e. x -> x + linear(concat(x, att(x))) | |||
| if False, no additional linear will be applied. i.e. x -> x + att(x) | |||
| stochastic_depth_rate (float): Proability to skip this layer. | |||
| During training, the layer may skip residual computation and return input | |||
| as-is with given probability. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| size, | |||
| self_attn, | |||
| feed_forward, | |||
| dropout_rate, | |||
| normalize_before=True, | |||
| concat_after=False, | |||
| stochastic_depth_rate=0.0, | |||
| ): | |||
| """Construct an EncoderLayer object.""" | |||
| super(EncoderLayer, self).__init__() | |||
| self.self_attn = self_attn | |||
| self.feed_forward = feed_forward | |||
| self.norm1 = LayerNorm(size) | |||
| self.norm2 = LayerNorm(size) | |||
| self.dropout = nn.Dropout(dropout_rate) | |||
| self.size = size | |||
| self.normalize_before = normalize_before | |||
| self.concat_after = concat_after | |||
| if self.concat_after: | |||
| self.concat_linear = nn.Linear(size + size, size) | |||
| self.stochastic_depth_rate = stochastic_depth_rate | |||
| def forward(self, x, mask, cache=None): | |||
| """Compute encoded features. | |||
| Args: | |||
| x_input (torch.Tensor): Input tensor (#batch, time, size). | |||
| mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||
| cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time, size). | |||
| torch.Tensor: Mask tensor (#batch, time). | |||
| """ | |||
| skip_layer = False | |||
| # with stochastic depth, residual connection `x + f(x)` becomes | |||
| # `x <- x + 1 / (1 - p) * f(x)` at training time. | |||
| stoch_layer_coeff = 1.0 | |||
| if self.training and self.stochastic_depth_rate > 0: | |||
| skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||
| stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||
| if skip_layer: | |||
| if cache is not None: | |||
| x = torch.cat([cache, x], dim=1) | |||
| return x, mask | |||
| residual = x | |||
| if self.normalize_before: | |||
| x = self.norm1(x) | |||
| if cache is None: | |||
| x_q = x | |||
| else: | |||
| assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||
| x_q = x[:, -1:, :] | |||
| residual = residual[:, -1:, :] | |||
| mask = None if mask is None else mask[:, -1:, :] | |||
| if self.concat_after: | |||
| x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) | |||
| x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||
| else: | |||
| x = residual + stoch_layer_coeff * self.dropout( | |||
| self.self_attn(x_q, x, x, mask)) | |||
| if not self.normalize_before: | |||
| x = self.norm1(x) | |||
| residual = x | |||
| if self.normalize_before: | |||
| x = self.norm2(x) | |||
| x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||
| if not self.normalize_before: | |||
| x = self.norm2(x) | |||
| if cache is not None: | |||
| x = torch.cat([cache, x], dim=1) | |||
| return x, mask | |||
| class EncoderLayerChunk(nn.Module): | |||
| """Encoder layer module. | |||
| Args: | |||
| size (int): Input dimension. | |||
| self_attn (torch.nn.Module): Self-attention module instance. | |||
| `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance | |||
| can be used as the argument. | |||
| feed_forward (torch.nn.Module): Feed-forward module instance. | |||
| `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance | |||
| can be used as the argument. | |||
| dropout_rate (float): Dropout rate. | |||
| normalize_before (bool): Whether to use layer_norm before the first block. | |||
| concat_after (bool): Whether to concat attention layer's input and output. | |||
| if True, additional linear will be applied. | |||
| i.e. x -> x + linear(concat(x, att(x))) | |||
| if False, no additional linear will be applied. i.e. x -> x + att(x) | |||
| stochastic_depth_rate (float): Proability to skip this layer. | |||
| During training, the layer may skip residual computation and return input | |||
| as-is with given probability. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| size, | |||
| self_attn, | |||
| feed_forward, | |||
| dropout_rate, | |||
| normalize_before=True, | |||
| concat_after=False, | |||
| stochastic_depth_rate=0.0, | |||
| ): | |||
| """Construct an EncoderLayer object.""" | |||
| super(EncoderLayerChunk, self).__init__() | |||
| self.self_attn = self_attn | |||
| self.feed_forward = feed_forward | |||
| self.norm1 = LayerNorm(size) | |||
| self.norm2 = LayerNorm(size) | |||
| self.dropout = nn.Dropout(dropout_rate) | |||
| self.size = size | |||
| self.normalize_before = normalize_before | |||
| self.concat_after = concat_after | |||
| if self.concat_after: | |||
| self.concat_linear = nn.Linear(size + size, size) | |||
| self.stochastic_depth_rate = stochastic_depth_rate | |||
| def forward(self, | |||
| x, | |||
| mask, | |||
| cache=None, | |||
| mask_shfit_chunk=None, | |||
| mask_att_chunk_encoder=None): | |||
| """Compute encoded features. | |||
| Args: | |||
| x_input (torch.Tensor): Input tensor (#batch, time, size). | |||
| mask (torch.Tensor): Mask tensor for the input (#batch, time). | |||
| cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). | |||
| Returns: | |||
| torch.Tensor: Output tensor (#batch, time, size). | |||
| torch.Tensor: Mask tensor (#batch, time). | |||
| """ | |||
| skip_layer = False | |||
| # with stochastic depth, residual connection `x + f(x)` becomes | |||
| # `x <- x + 1 / (1 - p) * f(x)` at training time. | |||
| stoch_layer_coeff = 1.0 | |||
| if self.training and self.stochastic_depth_rate > 0: | |||
| skip_layer = torch.rand(1).item() < self.stochastic_depth_rate | |||
| stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) | |||
| if skip_layer: | |||
| if cache is not None: | |||
| x = torch.cat([cache, x], dim=1) | |||
| return x, mask | |||
| residual = x | |||
| if self.normalize_before: | |||
| x = self.norm1(x) | |||
| if cache is None: | |||
| x_q = x | |||
| else: | |||
| assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) | |||
| x_q = x[:, -1:, :] | |||
| residual = residual[:, -1:, :] | |||
| mask = None if mask is None else mask[:, -1:, :] | |||
| if self.concat_after: | |||
| x_concat = torch.cat( | |||
| (x, | |||
| self.self_attn( | |||
| x_q, | |||
| x, | |||
| x, | |||
| mask, | |||
| mask_shfit_chunk=mask_shfit_chunk, | |||
| mask_att_chunk_encoder=mask_att_chunk_encoder)), | |||
| dim=-1) | |||
| x = residual + stoch_layer_coeff * self.concat_linear(x_concat) | |||
| else: | |||
| x = residual + stoch_layer_coeff * self.dropout( | |||
| self.self_attn( | |||
| x_q, | |||
| x, | |||
| x, | |||
| mask, | |||
| mask_shfit_chunk=mask_shfit_chunk, | |||
| mask_att_chunk_encoder=mask_att_chunk_encoder)) | |||
| if not self.normalize_before: | |||
| x = self.norm1(x) | |||
| residual = x | |||
| if self.normalize_before: | |||
| x = self.norm2(x) | |||
| x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) | |||
| if not self.normalize_before: | |||
| x = self.norm2(x) | |||
| if cache is not None: | |||
| x = torch.cat([cache, x], dim=1) | |||
| return x, mask, None, mask_shfit_chunk, mask_att_chunk_encoder | |||
| @@ -1,890 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from espnet/espnet. | |||
| import argparse | |||
| import logging | |||
| import os | |||
| from pathlib import Path | |||
| from typing import Callable, Collection, Dict, List, Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| import yaml | |||
| from espnet2.asr.ctc import CTC | |||
| from espnet2.asr.decoder.abs_decoder import AbsDecoder | |||
| from espnet2.asr.decoder.mlm_decoder import MLMDecoder | |||
| from espnet2.asr.decoder.rnn_decoder import RNNDecoder | |||
| from espnet2.asr.decoder.transformer_decoder import \ | |||
| DynamicConvolution2DTransformerDecoder # noqa: H301 | |||
| from espnet2.asr.decoder.transformer_decoder import \ | |||
| LightweightConvolution2DTransformerDecoder # noqa: H301 | |||
| from espnet2.asr.decoder.transformer_decoder import \ | |||
| LightweightConvolutionTransformerDecoder # noqa: H301 | |||
| from espnet2.asr.decoder.transformer_decoder import ( | |||
| DynamicConvolutionTransformerDecoder, TransformerDecoder) | |||
| from espnet2.asr.encoder.abs_encoder import AbsEncoder | |||
| from espnet2.asr.encoder.contextual_block_conformer_encoder import \ | |||
| ContextualBlockConformerEncoder # noqa: H301 | |||
| from espnet2.asr.encoder.contextual_block_transformer_encoder import \ | |||
| ContextualBlockTransformerEncoder # noqa: H301 | |||
| from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder, | |||
| FairseqHubertPretrainEncoder) | |||
| from espnet2.asr.encoder.longformer_encoder import LongformerEncoder | |||
| from espnet2.asr.encoder.rnn_encoder import RNNEncoder | |||
| from espnet2.asr.encoder.transformer_encoder import TransformerEncoder | |||
| from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder | |||
| from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder | |||
| from espnet2.asr.espnet_model import ESPnetASRModel | |||
| from espnet2.asr.frontend.abs_frontend import AbsFrontend | |||
| from espnet2.asr.frontend.default import DefaultFrontend | |||
| from espnet2.asr.frontend.fused import FusedFrontends | |||
| from espnet2.asr.frontend.s3prl import S3prlFrontend | |||
| from espnet2.asr.frontend.windowing import SlidingWindow | |||
| from espnet2.asr.maskctc_model import MaskCTCModel | |||
| from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder | |||
| from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \ | |||
| HuggingFaceTransformersPostEncoder # noqa: H301 | |||
| from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder | |||
| from espnet2.asr.preencoder.linear import LinearProjection | |||
| from espnet2.asr.preencoder.sinc import LightweightSincConvs | |||
| from espnet2.asr.specaug.abs_specaug import AbsSpecAug | |||
| from espnet2.asr.specaug.specaug import SpecAug | |||
| from espnet2.asr.transducer.joint_network import JointNetwork | |||
| from espnet2.asr.transducer.transducer_decoder import TransducerDecoder | |||
| from espnet2.layers.abs_normalize import AbsNormalize | |||
| from espnet2.layers.global_mvn import GlobalMVN | |||
| from espnet2.layers.utterance_mvn import UtteranceMVN | |||
| from espnet2.tasks.abs_task import AbsTask | |||
| from espnet2.text.phoneme_tokenizer import g2p_choices | |||
| from espnet2.torch_utils.initialize import initialize | |||
| from espnet2.train.abs_espnet_model import AbsESPnetModel | |||
| from espnet2.train.class_choices import ClassChoices | |||
| from espnet2.train.collate_fn import CommonCollateFn | |||
| from espnet2.train.preprocessor import CommonPreprocessor | |||
| from espnet2.train.trainer import Trainer | |||
| from espnet2.utils.get_default_kwargs import get_default_kwargs | |||
| from espnet2.utils.nested_dict_action import NestedDictAction | |||
| from espnet2.utils.types import (float_or_none, int_or_none, str2bool, | |||
| str_or_none) | |||
| from typeguard import check_argument_types, check_return_type | |||
| from ..asr.decoder.transformer_decoder import (ParaformerDecoder, | |||
| ParaformerDecoderBertEmbed) | |||
| from ..asr.encoder.conformer_encoder import ConformerEncoder, SANMEncoder_v2 | |||
| from ..asr.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunk | |||
| from ..asr.espnet_model import AEDStreaming | |||
| from ..asr.espnet_model_paraformer import Paraformer, ParaformerBertEmbed | |||
| from ..nets.pytorch_backend.cif_utils.cif import cif_predictor | |||
| # FIXME(wjm): suggested by fairseq, We need to setup root logger before importing any fairseq libraries. | |||
| logging.basicConfig( | |||
| level='INFO', | |||
| format=f"[{os.uname()[1].split('.')[0]}]" | |||
| f' %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', | |||
| ) | |||
| # FIXME(wjm): create logger to set level, unset __name__ for different files to share the same logger | |||
| logger = logging.getLogger() | |||
| frontend_choices = ClassChoices( | |||
| name='frontend', | |||
| classes=dict( | |||
| default=DefaultFrontend, | |||
| sliding_window=SlidingWindow, | |||
| s3prl=S3prlFrontend, | |||
| fused=FusedFrontends, | |||
| ), | |||
| type_check=AbsFrontend, | |||
| default='default', | |||
| ) | |||
| specaug_choices = ClassChoices( | |||
| name='specaug', | |||
| classes=dict(specaug=SpecAug, ), | |||
| type_check=AbsSpecAug, | |||
| default=None, | |||
| optional=True, | |||
| ) | |||
| normalize_choices = ClassChoices( | |||
| 'normalize', | |||
| classes=dict( | |||
| global_mvn=GlobalMVN, | |||
| utterance_mvn=UtteranceMVN, | |||
| ), | |||
| type_check=AbsNormalize, | |||
| default='utterance_mvn', | |||
| optional=True, | |||
| ) | |||
| model_choices = ClassChoices( | |||
| 'model', | |||
| classes=dict( | |||
| espnet=ESPnetASRModel, | |||
| maskctc=MaskCTCModel, | |||
| paraformer=Paraformer, | |||
| paraformer_bert_embed=ParaformerBertEmbed, | |||
| aedstreaming=AEDStreaming, | |||
| ), | |||
| type_check=AbsESPnetModel, | |||
| default='espnet', | |||
| ) | |||
| preencoder_choices = ClassChoices( | |||
| name='preencoder', | |||
| classes=dict( | |||
| sinc=LightweightSincConvs, | |||
| linear=LinearProjection, | |||
| ), | |||
| type_check=AbsPreEncoder, | |||
| default=None, | |||
| optional=True, | |||
| ) | |||
| encoder_choices = ClassChoices( | |||
| 'encoder', | |||
| classes=dict( | |||
| conformer=ConformerEncoder, | |||
| transformer=TransformerEncoder, | |||
| contextual_block_transformer=ContextualBlockTransformerEncoder, | |||
| contextual_block_conformer=ContextualBlockConformerEncoder, | |||
| vgg_rnn=VGGRNNEncoder, | |||
| rnn=RNNEncoder, | |||
| wav2vec2=FairSeqWav2Vec2Encoder, | |||
| hubert=FairseqHubertEncoder, | |||
| hubert_pretrain=FairseqHubertPretrainEncoder, | |||
| longformer=LongformerEncoder, | |||
| sanm=SANMEncoder, | |||
| sanm_v2=SANMEncoder_v2, | |||
| sanm_chunk=SANMEncoderChunk, | |||
| ), | |||
| type_check=AbsEncoder, | |||
| default='rnn', | |||
| ) | |||
| postencoder_choices = ClassChoices( | |||
| name='postencoder', | |||
| classes=dict( | |||
| hugging_face_transformers=HuggingFaceTransformersPostEncoder, ), | |||
| type_check=AbsPostEncoder, | |||
| default=None, | |||
| optional=True, | |||
| ) | |||
| decoder_choices = ClassChoices( | |||
| 'decoder', | |||
| classes=dict( | |||
| transformer=TransformerDecoder, | |||
| lightweight_conv=LightweightConvolutionTransformerDecoder, | |||
| lightweight_conv2d=LightweightConvolution2DTransformerDecoder, | |||
| dynamic_conv=DynamicConvolutionTransformerDecoder, | |||
| dynamic_conv2d=DynamicConvolution2DTransformerDecoder, | |||
| rnn=RNNDecoder, | |||
| transducer=TransducerDecoder, | |||
| mlm=MLMDecoder, | |||
| paraformer_decoder=ParaformerDecoder, | |||
| paraformer_decoder_bert_embed=ParaformerDecoderBertEmbed, | |||
| ), | |||
| type_check=AbsDecoder, | |||
| default='rnn', | |||
| ) | |||
| predictor_choices = ClassChoices( | |||
| name='predictor', | |||
| classes=dict( | |||
| cif_predictor=cif_predictor, | |||
| ctc_predictor=None, | |||
| ), | |||
| type_check=None, | |||
| default='cif_predictor', | |||
| optional=True, | |||
| ) | |||
| class ASRTask(AbsTask): | |||
| # If you need more than one optimizers, change this value | |||
| num_optimizers: int = 1 | |||
| # Add variable objects configurations | |||
| class_choices_list = [ | |||
| # --frontend and --frontend_conf | |||
| frontend_choices, | |||
| # --specaug and --specaug_conf | |||
| specaug_choices, | |||
| # --normalize and --normalize_conf | |||
| normalize_choices, | |||
| # --model and --model_conf | |||
| model_choices, | |||
| # --preencoder and --preencoder_conf | |||
| preencoder_choices, | |||
| # --encoder and --encoder_conf | |||
| encoder_choices, | |||
| # --postencoder and --postencoder_conf | |||
| postencoder_choices, | |||
| # --decoder and --decoder_conf | |||
| decoder_choices, | |||
| ] | |||
| # If you need to modify train() or eval() procedures, change Trainer class here | |||
| trainer = Trainer | |||
| @classmethod | |||
| def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||
| group = parser.add_argument_group(description='Task related') | |||
| # NOTE(kamo): add_arguments(..., required=True) can't be used | |||
| # to provide --print_config mode. Instead of it, do as | |||
| required = parser.get_default('required') | |||
| required += ['token_list'] | |||
| group.add_argument( | |||
| '--token_list', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='A text mapping int-id to token', | |||
| ) | |||
| group.add_argument( | |||
| '--init', | |||
| type=lambda x: str_or_none(x.lower()), | |||
| default=None, | |||
| help='The initialization method', | |||
| choices=[ | |||
| 'chainer', | |||
| 'xavier_uniform', | |||
| 'xavier_normal', | |||
| 'kaiming_uniform', | |||
| 'kaiming_normal', | |||
| None, | |||
| ], | |||
| ) | |||
| group.add_argument( | |||
| '--input_size', | |||
| type=int_or_none, | |||
| default=None, | |||
| help='The number of input dimension of the feature', | |||
| ) | |||
| group.add_argument( | |||
| '--ctc_conf', | |||
| action=NestedDictAction, | |||
| default=get_default_kwargs(CTC), | |||
| help='The keyword arguments for CTC class.', | |||
| ) | |||
| group.add_argument( | |||
| '--joint_net_conf', | |||
| action=NestedDictAction, | |||
| default=None, | |||
| help='The keyword arguments for joint network class.', | |||
| ) | |||
| group = parser.add_argument_group(description='Preprocess related') | |||
| group.add_argument( | |||
| '--use_preprocessor', | |||
| type=str2bool, | |||
| default=True, | |||
| help='Apply preprocessing to data or not', | |||
| ) | |||
| group.add_argument( | |||
| '--token_type', | |||
| type=str, | |||
| default='bpe', | |||
| choices=['bpe', 'char', 'word', 'phn'], | |||
| help='The text will be tokenized ' | |||
| 'in the specified level token', | |||
| ) | |||
| group.add_argument( | |||
| '--bpemodel', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='The model file of sentencepiece', | |||
| ) | |||
| parser.add_argument( | |||
| '--non_linguistic_symbols', | |||
| type=str_or_none, | |||
| help='non_linguistic_symbols file path', | |||
| ) | |||
| parser.add_argument( | |||
| '--cleaner', | |||
| type=str_or_none, | |||
| choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||
| default=None, | |||
| help='Apply text cleaning', | |||
| ) | |||
| parser.add_argument( | |||
| '--g2p', | |||
| type=str_or_none, | |||
| choices=g2p_choices, | |||
| default=None, | |||
| help='Specify g2p method if --token_type=phn', | |||
| ) | |||
| parser.add_argument( | |||
| '--speech_volume_normalize', | |||
| type=float_or_none, | |||
| default=None, | |||
| help='Scale the maximum amplitude to the given value.', | |||
| ) | |||
| parser.add_argument( | |||
| '--rir_scp', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='The file path of rir scp file.', | |||
| ) | |||
| parser.add_argument( | |||
| '--rir_apply_prob', | |||
| type=float, | |||
| default=1.0, | |||
| help='THe probability for applying RIR convolution.', | |||
| ) | |||
| parser.add_argument( | |||
| '--noise_scp', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='The file path of noise scp file.', | |||
| ) | |||
| parser.add_argument( | |||
| '--noise_apply_prob', | |||
| type=float, | |||
| default=1.0, | |||
| help='The probability applying Noise adding.', | |||
| ) | |||
| parser.add_argument( | |||
| '--noise_db_range', | |||
| type=str, | |||
| default='13_15', | |||
| help='The range of noise decibel level.', | |||
| ) | |||
| for class_choices in cls.class_choices_list: | |||
| # Append --<name> and --<name>_conf. | |||
| # e.g. --encoder and --encoder_conf | |||
| class_choices.add_arguments(group) | |||
| @classmethod | |||
| def build_collate_fn( | |||
| cls, args: argparse.Namespace, train: bool | |||
| ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||
| List[str], Dict[str, torch.Tensor]], ]: | |||
| assert check_argument_types() | |||
| # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||
| return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||
| @classmethod | |||
| def build_preprocess_fn( | |||
| cls, args: argparse.Namespace, train: bool | |||
| ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||
| assert check_argument_types() | |||
| if args.use_preprocessor: | |||
| retval = CommonPreprocessor( | |||
| train=train, | |||
| token_type=args.token_type, | |||
| token_list=args.token_list, | |||
| bpemodel=args.bpemodel, | |||
| non_linguistic_symbols=args.non_linguistic_symbols, | |||
| text_cleaner=args.cleaner, | |||
| g2p_type=args.g2p, | |||
| # NOTE(kamo): Check attribute existence for backward compatibility | |||
| rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||
| rir_apply_prob=args.rir_apply_prob if hasattr( | |||
| args, 'rir_apply_prob') else 1.0, | |||
| noise_scp=args.noise_scp | |||
| if hasattr(args, 'noise_scp') else None, | |||
| noise_apply_prob=args.noise_apply_prob if hasattr( | |||
| args, 'noise_apply_prob') else 1.0, | |||
| noise_db_range=args.noise_db_range if hasattr( | |||
| args, 'noise_db_range') else '13_15', | |||
| speech_volume_normalize=args.speech_volume_normalize | |||
| if hasattr(args, 'rir_scp') else None, | |||
| ) | |||
| else: | |||
| retval = None | |||
| assert check_return_type(retval) | |||
| return retval | |||
| @classmethod | |||
| def required_data_names(cls, | |||
| train: bool = True, | |||
| inference: bool = False) -> Tuple[str, ...]: | |||
| if not inference: | |||
| retval = ('speech', 'text') | |||
| else: | |||
| # Recognition mode | |||
| retval = ('speech', ) | |||
| return retval | |||
| @classmethod | |||
| def optional_data_names(cls, | |||
| train: bool = True, | |||
| inference: bool = False) -> Tuple[str, ...]: | |||
| retval = () | |||
| assert check_return_type(retval) | |||
| return retval | |||
| @classmethod | |||
| def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: | |||
| assert check_argument_types() | |||
| if isinstance(args.token_list, str): | |||
| with open(args.token_list, encoding='utf-8') as f: | |||
| token_list = [line.rstrip() for line in f] | |||
| # Overwriting token_list to keep it as "portable". | |||
| args.token_list = list(token_list) | |||
| elif isinstance(args.token_list, (tuple, list)): | |||
| token_list = list(args.token_list) | |||
| else: | |||
| raise RuntimeError('token_list must be str or list') | |||
| vocab_size = len(token_list) | |||
| logger.info(f'Vocabulary size: {vocab_size }') | |||
| # 1. frontend | |||
| if args.input_size is None: | |||
| # Extract features in the model | |||
| frontend_class = frontend_choices.get_class(args.frontend) | |||
| frontend = frontend_class(**args.frontend_conf) | |||
| input_size = frontend.output_size() | |||
| else: | |||
| # Give features from data-loader | |||
| args.frontend = None | |||
| args.frontend_conf = {} | |||
| frontend = None | |||
| input_size = args.input_size | |||
| # 2. Data augmentation for spectrogram | |||
| if args.specaug is not None: | |||
| specaug_class = specaug_choices.get_class(args.specaug) | |||
| specaug = specaug_class(**args.specaug_conf) | |||
| else: | |||
| specaug = None | |||
| # 3. Normalization layer | |||
| if args.normalize is not None: | |||
| normalize_class = normalize_choices.get_class(args.normalize) | |||
| normalize = normalize_class(**args.normalize_conf) | |||
| else: | |||
| normalize = None | |||
| # 4. Pre-encoder input block | |||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
| if getattr(args, 'preencoder', None) is not None: | |||
| preencoder_class = preencoder_choices.get_class(args.preencoder) | |||
| preencoder = preencoder_class(**args.preencoder_conf) | |||
| input_size = preencoder.output_size() | |||
| else: | |||
| preencoder = None | |||
| # 4. Encoder | |||
| encoder_class = encoder_choices.get_class(args.encoder) | |||
| encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||
| # 5. Post-encoder block | |||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
| encoder_output_size = encoder.output_size() | |||
| if getattr(args, 'postencoder', None) is not None: | |||
| postencoder_class = postencoder_choices.get_class(args.postencoder) | |||
| postencoder = postencoder_class( | |||
| input_size=encoder_output_size, **args.postencoder_conf) | |||
| encoder_output_size = postencoder.output_size() | |||
| else: | |||
| postencoder = None | |||
| # 5. Decoder | |||
| decoder_class = decoder_choices.get_class(args.decoder) | |||
| if args.decoder == 'transducer': | |||
| decoder = decoder_class( | |||
| vocab_size, | |||
| embed_pad=0, | |||
| **args.decoder_conf, | |||
| ) | |||
| joint_network = JointNetwork( | |||
| vocab_size, | |||
| encoder.output_size(), | |||
| decoder.dunits, | |||
| **args.joint_net_conf, | |||
| ) | |||
| else: | |||
| decoder = decoder_class( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| **args.decoder_conf, | |||
| ) | |||
| joint_network = None | |||
| # 6. CTC | |||
| ctc = CTC( | |||
| odim=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| **args.ctc_conf) | |||
| # 7. Build model | |||
| try: | |||
| model_class = model_choices.get_class(args.model) | |||
| except AttributeError: | |||
| model_class = model_choices.get_class('espnet') | |||
| model = model_class( | |||
| vocab_size=vocab_size, | |||
| frontend=frontend, | |||
| specaug=specaug, | |||
| normalize=normalize, | |||
| preencoder=preencoder, | |||
| encoder=encoder, | |||
| postencoder=postencoder, | |||
| decoder=decoder, | |||
| ctc=ctc, | |||
| joint_network=joint_network, | |||
| token_list=token_list, | |||
| **args.model_conf, | |||
| ) | |||
| # FIXME(kamo): Should be done in model? | |||
| # 8. Initialize | |||
| if args.init is not None: | |||
| initialize(model, args.init) | |||
| assert check_return_type(model) | |||
| return model | |||
| class ASRTaskNAR(AbsTask): | |||
| # If you need more than one optimizers, change this value | |||
| num_optimizers: int = 1 | |||
| # Add variable objects configurations | |||
| class_choices_list = [ | |||
| # --frontend and --frontend_conf | |||
| frontend_choices, | |||
| # --specaug and --specaug_conf | |||
| specaug_choices, | |||
| # --normalize and --normalize_conf | |||
| normalize_choices, | |||
| # --model and --model_conf | |||
| model_choices, | |||
| # --preencoder and --preencoder_conf | |||
| preencoder_choices, | |||
| # --encoder and --encoder_conf | |||
| encoder_choices, | |||
| # --postencoder and --postencoder_conf | |||
| postencoder_choices, | |||
| # --decoder and --decoder_conf | |||
| decoder_choices, | |||
| # --predictor and --predictor_conf | |||
| predictor_choices, | |||
| ] | |||
| # If you need to modify train() or eval() procedures, change Trainer class here | |||
| trainer = Trainer | |||
| @classmethod | |||
| def add_task_arguments(cls, parser: argparse.ArgumentParser): | |||
| group = parser.add_argument_group(description='Task related') | |||
| # NOTE(kamo): add_arguments(..., required=True) can't be used | |||
| # to provide --print_config mode. Instead of it, do as | |||
| required = parser.get_default('required') | |||
| required += ['token_list'] | |||
| group.add_argument( | |||
| '--token_list', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='A text mapping int-id to token', | |||
| ) | |||
| group.add_argument( | |||
| '--init', | |||
| type=lambda x: str_or_none(x.lower()), | |||
| default=None, | |||
| help='The initialization method', | |||
| choices=[ | |||
| 'chainer', | |||
| 'xavier_uniform', | |||
| 'xavier_normal', | |||
| 'kaiming_uniform', | |||
| 'kaiming_normal', | |||
| None, | |||
| ], | |||
| ) | |||
| group.add_argument( | |||
| '--input_size', | |||
| type=int_or_none, | |||
| default=None, | |||
| help='The number of input dimension of the feature', | |||
| ) | |||
| group.add_argument( | |||
| '--ctc_conf', | |||
| action=NestedDictAction, | |||
| default=get_default_kwargs(CTC), | |||
| help='The keyword arguments for CTC class.', | |||
| ) | |||
| group.add_argument( | |||
| '--joint_net_conf', | |||
| action=NestedDictAction, | |||
| default=None, | |||
| help='The keyword arguments for joint network class.', | |||
| ) | |||
| group = parser.add_argument_group(description='Preprocess related') | |||
| group.add_argument( | |||
| '--use_preprocessor', | |||
| type=str2bool, | |||
| default=True, | |||
| help='Apply preprocessing to data or not', | |||
| ) | |||
| group.add_argument( | |||
| '--token_type', | |||
| type=str, | |||
| default='bpe', | |||
| choices=['bpe', 'char', 'word', 'phn'], | |||
| help='The text will be tokenized ' | |||
| 'in the specified level token', | |||
| ) | |||
| group.add_argument( | |||
| '--bpemodel', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='The model file of sentencepiece', | |||
| ) | |||
| parser.add_argument( | |||
| '--non_linguistic_symbols', | |||
| type=str_or_none, | |||
| help='non_linguistic_symbols file path', | |||
| ) | |||
| parser.add_argument( | |||
| '--cleaner', | |||
| type=str_or_none, | |||
| choices=[None, 'tacotron', 'jaconv', 'vietnamese'], | |||
| default=None, | |||
| help='Apply text cleaning', | |||
| ) | |||
| parser.add_argument( | |||
| '--g2p', | |||
| type=str_or_none, | |||
| choices=g2p_choices, | |||
| default=None, | |||
| help='Specify g2p method if --token_type=phn', | |||
| ) | |||
| parser.add_argument( | |||
| '--speech_volume_normalize', | |||
| type=float_or_none, | |||
| default=None, | |||
| help='Scale the maximum amplitude to the given value.', | |||
| ) | |||
| parser.add_argument( | |||
| '--rir_scp', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='The file path of rir scp file.', | |||
| ) | |||
| parser.add_argument( | |||
| '--rir_apply_prob', | |||
| type=float, | |||
| default=1.0, | |||
| help='THe probability for applying RIR convolution.', | |||
| ) | |||
| parser.add_argument( | |||
| '--noise_scp', | |||
| type=str_or_none, | |||
| default=None, | |||
| help='The file path of noise scp file.', | |||
| ) | |||
| parser.add_argument( | |||
| '--noise_apply_prob', | |||
| type=float, | |||
| default=1.0, | |||
| help='The probability applying Noise adding.', | |||
| ) | |||
| parser.add_argument( | |||
| '--noise_db_range', | |||
| type=str, | |||
| default='13_15', | |||
| help='The range of noise decibel level.', | |||
| ) | |||
| for class_choices in cls.class_choices_list: | |||
| # Append --<name> and --<name>_conf. | |||
| # e.g. --encoder and --encoder_conf | |||
| class_choices.add_arguments(group) | |||
| @classmethod | |||
| def build_collate_fn( | |||
| cls, args: argparse.Namespace, train: bool | |||
| ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[ | |||
| List[str], Dict[str, torch.Tensor]], ]: | |||
| assert check_argument_types() | |||
| # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol | |||
| return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) | |||
| @classmethod | |||
| def build_preprocess_fn( | |||
| cls, args: argparse.Namespace, train: bool | |||
| ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: | |||
| assert check_argument_types() | |||
| if args.use_preprocessor: | |||
| retval = CommonPreprocessor( | |||
| train=train, | |||
| token_type=args.token_type, | |||
| token_list=args.token_list, | |||
| bpemodel=args.bpemodel, | |||
| non_linguistic_symbols=args.non_linguistic_symbols, | |||
| text_cleaner=args.cleaner, | |||
| g2p_type=args.g2p, | |||
| # NOTE(kamo): Check attribute existence for backward compatibility | |||
| rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None, | |||
| rir_apply_prob=args.rir_apply_prob if hasattr( | |||
| args, 'rir_apply_prob') else 1.0, | |||
| noise_scp=args.noise_scp | |||
| if hasattr(args, 'noise_scp') else None, | |||
| noise_apply_prob=args.noise_apply_prob if hasattr( | |||
| args, 'noise_apply_prob') else 1.0, | |||
| noise_db_range=args.noise_db_range if hasattr( | |||
| args, 'noise_db_range') else '13_15', | |||
| speech_volume_normalize=args.speech_volume_normalize | |||
| if hasattr(args, 'rir_scp') else None, | |||
| ) | |||
| else: | |||
| retval = None | |||
| assert check_return_type(retval) | |||
| return retval | |||
| @classmethod | |||
| def required_data_names(cls, | |||
| train: bool = True, | |||
| inference: bool = False) -> Tuple[str, ...]: | |||
| if not inference: | |||
| retval = ('speech', 'text') | |||
| else: | |||
| # Recognition mode | |||
| retval = ('speech', ) | |||
| return retval | |||
| @classmethod | |||
| def optional_data_names(cls, | |||
| train: bool = True, | |||
| inference: bool = False) -> Tuple[str, ...]: | |||
| retval = () | |||
| assert check_return_type(retval) | |||
| return retval | |||
| @classmethod | |||
| def build_model(cls, args: argparse.Namespace): | |||
| assert check_argument_types() | |||
| if isinstance(args.token_list, str): | |||
| with open(args.token_list, encoding='utf-8') as f: | |||
| token_list = [line.rstrip() for line in f] | |||
| # Overwriting token_list to keep it as "portable". | |||
| args.token_list = list(token_list) | |||
| elif isinstance(args.token_list, (tuple, list)): | |||
| token_list = list(args.token_list) | |||
| else: | |||
| raise RuntimeError('token_list must be str or list') | |||
| vocab_size = len(token_list) | |||
| # logger.info(f'Vocabulary size: {vocab_size }') | |||
| # 1. frontend | |||
| if args.input_size is None: | |||
| # Extract features in the model | |||
| frontend_class = frontend_choices.get_class(args.frontend) | |||
| frontend = frontend_class(**args.frontend_conf) | |||
| input_size = frontend.output_size() | |||
| else: | |||
| # Give features from data-loader | |||
| args.frontend = None | |||
| args.frontend_conf = {} | |||
| frontend = None | |||
| input_size = args.input_size | |||
| # 2. Data augmentation for spectrogram | |||
| if args.specaug is not None: | |||
| specaug_class = specaug_choices.get_class(args.specaug) | |||
| specaug = specaug_class(**args.specaug_conf) | |||
| else: | |||
| specaug = None | |||
| # 3. Normalization layer | |||
| if args.normalize is not None: | |||
| normalize_class = normalize_choices.get_class(args.normalize) | |||
| normalize = normalize_class(**args.normalize_conf) | |||
| else: | |||
| normalize = None | |||
| # 4. Pre-encoder input block | |||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
| if getattr(args, 'preencoder', None) is not None: | |||
| preencoder_class = preencoder_choices.get_class(args.preencoder) | |||
| preencoder = preencoder_class(**args.preencoder_conf) | |||
| input_size = preencoder.output_size() | |||
| else: | |||
| preencoder = None | |||
| # 4. Encoder | |||
| encoder_class = encoder_choices.get_class(args.encoder) | |||
| encoder = encoder_class(input_size=input_size, **args.encoder_conf) | |||
| # 5. Post-encoder block | |||
| # NOTE(kan-bayashi): Use getattr to keep the compatibility | |||
| encoder_output_size = encoder.output_size() | |||
| if getattr(args, 'postencoder', None) is not None: | |||
| postencoder_class = postencoder_choices.get_class(args.postencoder) | |||
| postencoder = postencoder_class( | |||
| input_size=encoder_output_size, **args.postencoder_conf) | |||
| encoder_output_size = postencoder.output_size() | |||
| else: | |||
| postencoder = None | |||
| # 5. Decoder | |||
| decoder_class = decoder_choices.get_class(args.decoder) | |||
| if args.decoder == 'transducer': | |||
| decoder = decoder_class( | |||
| vocab_size, | |||
| embed_pad=0, | |||
| **args.decoder_conf, | |||
| ) | |||
| joint_network = JointNetwork( | |||
| vocab_size, | |||
| encoder.output_size(), | |||
| decoder.dunits, | |||
| **args.joint_net_conf, | |||
| ) | |||
| else: | |||
| decoder = decoder_class( | |||
| vocab_size=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| **args.decoder_conf, | |||
| ) | |||
| joint_network = None | |||
| # 6. CTC | |||
| ctc = CTC( | |||
| odim=vocab_size, | |||
| encoder_output_size=encoder_output_size, | |||
| **args.ctc_conf) | |||
| predictor_class = predictor_choices.get_class(args.predictor) | |||
| predictor = predictor_class(**args.predictor_conf) | |||
| # 7. Build model | |||
| try: | |||
| model_class = model_choices.get_class(args.model) | |||
| except AttributeError: | |||
| model_class = model_choices.get_class('espnet') | |||
| model = model_class( | |||
| vocab_size=vocab_size, | |||
| frontend=frontend, | |||
| specaug=specaug, | |||
| normalize=normalize, | |||
| preencoder=preencoder, | |||
| encoder=encoder, | |||
| postencoder=postencoder, | |||
| decoder=decoder, | |||
| ctc=ctc, | |||
| joint_network=joint_network, | |||
| token_list=token_list, | |||
| predictor=predictor, | |||
| **args.model_conf, | |||
| ) | |||
| # FIXME(kamo): Should be done in model? | |||
| # 8. Initialize | |||
| if args.init is not None: | |||
| initialize(model, args.init) | |||
| assert check_return_type(model) | |||
| return model | |||
| @@ -1,223 +0,0 @@ | |||
| import os | |||
| import shutil | |||
| import threading | |||
| from typing import Any, Dict, List, Sequence, Tuple, Union | |||
| import yaml | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import WavToScp | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .asr_engine.common import asr_utils | |||
| logger = get_logger() | |||
| __all__ = ['AutomaticSpeechRecognitionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) | |||
| class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
| """ASR Pipeline | |||
| """ | |||
| def __init__(self, | |||
| model: Union[List[Model], List[str]] = None, | |||
| preprocessor: WavToScp = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create an asr pipeline for prediction | |||
| """ | |||
| from .asr_engine import asr_env_checking | |||
| assert model is not None, 'asr model should be provided' | |||
| model_list: List = [] | |||
| if isinstance(model[0], Model): | |||
| model_list = model | |||
| else: | |||
| model_list.append(Model.from_pretrained(model[0])) | |||
| if len(model) == 2 and model[1] is not None: | |||
| model_list.append(Model.from_pretrained(model[1])) | |||
| super().__init__(model=model_list, preprocessor=preprocessor, **kwargs) | |||
| self._preprocessor = preprocessor | |||
| self._am_model = model_list[0] | |||
| if len(model_list) == 2 and model_list[1] is not None: | |||
| self._lm_model = model_list[1] | |||
| def __call__(self, | |||
| wav_path: str, | |||
| recog_type: str = None, | |||
| audio_format: str = None, | |||
| workspace: str = None) -> Dict[str, Any]: | |||
| assert len(wav_path) > 0, 'wav_path should be provided' | |||
| self._recog_type = recog_type | |||
| self._audio_format = audio_format | |||
| self._workspace = workspace | |||
| self._wav_path = wav_path | |||
| if recog_type is None or audio_format is None or workspace is None: | |||
| self._recog_type, self._audio_format, self._workspace, self._wav_path = asr_utils.type_checking( | |||
| wav_path, recog_type, audio_format, workspace) | |||
| if self._preprocessor is None: | |||
| self._preprocessor = WavToScp(workspace=self._workspace) | |||
| output = self._preprocessor.forward(self._am_model.forward(), | |||
| self._recog_type, | |||
| self._audio_format, self._wav_path) | |||
| output = self.forward(output) | |||
| rst = self.postprocess(output) | |||
| return rst | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """Decoding | |||
| """ | |||
| logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||
| j: int = 0 | |||
| process = [] | |||
| while j < inputs['thread_count']: | |||
| data_cmd: Sequence[Tuple[str, str, str]] | |||
| if inputs['audio_format'] == 'wav': | |||
| data_cmd = [(os.path.join(inputs['workspace'], | |||
| 'data.' + str(j) + '.scp'), 'speech', | |||
| 'sound')] | |||
| elif inputs['audio_format'] == 'kaldi_ark': | |||
| data_cmd = [(os.path.join(inputs['workspace'], | |||
| 'data.' + str(j) + '.scp'), 'speech', | |||
| 'kaldi_ark')] | |||
| output_dir: str = os.path.join(inputs['output'], | |||
| 'output.' + str(j)) | |||
| if not os.path.exists(output_dir): | |||
| os.mkdir(output_dir) | |||
| config_file = open(inputs['asr_model_config']) | |||
| root = yaml.full_load(config_file) | |||
| config_file.close() | |||
| frontend_conf = None | |||
| if 'frontend_conf' in root: | |||
| frontend_conf = root['frontend_conf'] | |||
| cmd = { | |||
| 'model_type': inputs['model_type'], | |||
| 'beam_size': root['beam_size'], | |||
| 'penalty': root['penalty'], | |||
| 'maxlenratio': root['maxlenratio'], | |||
| 'minlenratio': root['minlenratio'], | |||
| 'ctc_weight': root['ctc_weight'], | |||
| 'lm_weight': root['lm_weight'], | |||
| 'output_dir': output_dir, | |||
| 'ngpu': 0, | |||
| 'log_level': 'ERROR', | |||
| 'data_path_and_name_and_type': data_cmd, | |||
| 'asr_train_config': inputs['am_model_config'], | |||
| 'asr_model_file': inputs['am_model_path'], | |||
| 'batch_size': inputs['model_config']['batch_size'], | |||
| 'frontend_conf': frontend_conf | |||
| } | |||
| thread = AsrInferenceThread(j, cmd) | |||
| thread.start() | |||
| j += 1 | |||
| process.append(thread) | |||
| for p in process: | |||
| p.join() | |||
| return inputs | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the asr results | |||
| """ | |||
| logger.info('Computing the result of ASR ...') | |||
| rst = {'rec_result': 'None'} | |||
| # single wav task | |||
| if inputs['recog_type'] == 'wav' and inputs['audio_format'] == 'wav': | |||
| text_file: str = os.path.join(inputs['output'], 'output.0', | |||
| '1best_recog', 'text') | |||
| if os.path.exists(text_file): | |||
| f = open(text_file, 'r') | |||
| result_str: str = f.readline() | |||
| f.close() | |||
| if len(result_str) > 0: | |||
| result_list = result_str.split() | |||
| if len(result_list) >= 2: | |||
| rst['rec_result'] = result_list[1] | |||
| # run with datasets, and audio format is waveform or kaldi_ark | |||
| elif inputs['recog_type'] != 'wav': | |||
| inputs['reference_text'] = self._ref_text_tidy(inputs) | |||
| inputs['datasets_result'] = asr_utils.compute_wer( | |||
| inputs['hypothesis_text'], inputs['reference_text']) | |||
| else: | |||
| raise ValueError('recog_type and audio_format are mismatching') | |||
| if 'datasets_result' in inputs: | |||
| rst['datasets_result'] = inputs['datasets_result'] | |||
| # remove workspace dir (.tmp) | |||
| if os.path.exists(self._workspace): | |||
| shutil.rmtree(self._workspace) | |||
| return rst | |||
| def _ref_text_tidy(self, inputs: Dict[str, Any]) -> str: | |||
| ref_text: str = os.path.join(inputs['output'], 'text.ref') | |||
| k: int = 0 | |||
| while k < inputs['thread_count']: | |||
| output_text = os.path.join(inputs['output'], 'output.' + str(k), | |||
| '1best_recog', 'text') | |||
| if os.path.exists(output_text): | |||
| with open(output_text, 'r', encoding='utf-8') as i: | |||
| lines = i.readlines() | |||
| with open(ref_text, 'a', encoding='utf-8') as o: | |||
| for line in lines: | |||
| o.write(line) | |||
| k += 1 | |||
| return ref_text | |||
| class AsrInferenceThread(threading.Thread): | |||
| def __init__(self, threadID, cmd): | |||
| threading.Thread.__init__(self) | |||
| self._threadID = threadID | |||
| self._cmd = cmd | |||
| def run(self): | |||
| if self._cmd['model_type'] == 'pytorch': | |||
| from .asr_engine import asr_inference_paraformer_espnet | |||
| asr_inference_paraformer_espnet.asr_inference( | |||
| batch_size=self._cmd['batch_size'], | |||
| output_dir=self._cmd['output_dir'], | |||
| maxlenratio=self._cmd['maxlenratio'], | |||
| minlenratio=self._cmd['minlenratio'], | |||
| beam_size=self._cmd['beam_size'], | |||
| ngpu=self._cmd['ngpu'], | |||
| ctc_weight=self._cmd['ctc_weight'], | |||
| lm_weight=self._cmd['lm_weight'], | |||
| penalty=self._cmd['penalty'], | |||
| log_level=self._cmd['log_level'], | |||
| data_path_and_name_and_type=self. | |||
| _cmd['data_path_and_name_and_type'], | |||
| asr_train_config=self._cmd['asr_train_config'], | |||
| asr_model_file=self._cmd['asr_model_file'], | |||
| frontend_conf=self._cmd['frontend_conf']) | |||
| @@ -0,0 +1,213 @@ | |||
| import os | |||
| from typing import Any, Dict, List, Sequence, Tuple, Union | |||
| import yaml | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import WavToScp | |||
| from modelscope.utils.constant import Frameworks, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['AutomaticSpeechRecognitionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) | |||
| class AutomaticSpeechRecognitionPipeline(Pipeline): | |||
| """ASR Inference Pipeline | |||
| """ | |||
| def __init__(self, | |||
| model: Union[Model, str] = None, | |||
| preprocessor: WavToScp = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create an asr pipeline for prediction | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def __call__(self, | |||
| audio_in: Union[str, bytes], | |||
| recog_type: str = None, | |||
| audio_format: str = None) -> Dict[str, Any]: | |||
| from easyasr.common import asr_utils | |||
| self.recog_type = recog_type | |||
| self.audio_format = audio_format | |||
| self.audio_in = audio_in | |||
| if recog_type is None or audio_format is None: | |||
| self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | |||
| audio_in, recog_type, audio_format) | |||
| if self.preprocessor is None: | |||
| self.preprocessor = WavToScp() | |||
| output = self.preprocessor.forward(self.model.forward(), | |||
| self.recog_type, self.audio_format, | |||
| self.audio_in) | |||
| output = self.forward(output) | |||
| rst = self.postprocess(output) | |||
| return rst | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """Decoding | |||
| """ | |||
| logger.info(f"Decoding with {inputs['audio_format']} files ...") | |||
| data_cmd: Sequence[Tuple[str, str]] | |||
| if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | |||
| data_cmd = ['speech', 'sound'] | |||
| elif inputs['audio_format'] == 'kaldi_ark': | |||
| data_cmd = ['speech', 'kaldi_ark'] | |||
| elif inputs['audio_format'] == 'tfrecord': | |||
| data_cmd = ['speech', 'tfrecord'] | |||
| # generate asr inference command | |||
| cmd = { | |||
| 'model_type': inputs['model_type'], | |||
| 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available | |||
| 'log_level': 'ERROR', | |||
| 'audio_in': inputs['audio_lists'], | |||
| 'name_and_type': data_cmd, | |||
| 'asr_model_file': inputs['am_model_path'], | |||
| 'idx_text': '' | |||
| } | |||
| if self.framework == Frameworks.torch: | |||
| config_file = open(inputs['asr_model_config']) | |||
| root = yaml.full_load(config_file) | |||
| config_file.close() | |||
| frontend_conf = None | |||
| if 'frontend_conf' in root: | |||
| frontend_conf = root['frontend_conf'] | |||
| cmd['beam_size'] = root['beam_size'] | |||
| cmd['penalty'] = root['penalty'] | |||
| cmd['maxlenratio'] = root['maxlenratio'] | |||
| cmd['minlenratio'] = root['minlenratio'] | |||
| cmd['ctc_weight'] = root['ctc_weight'] | |||
| cmd['lm_weight'] = root['lm_weight'] | |||
| cmd['asr_train_config'] = inputs['am_model_config'] | |||
| cmd['batch_size'] = inputs['model_config']['batch_size'] | |||
| cmd['frontend_conf'] = frontend_conf | |||
| elif self.framework == Frameworks.tf: | |||
| cmd['fs'] = inputs['model_config']['fs'] | |||
| cmd['hop_length'] = inputs['model_config']['hop_length'] | |||
| cmd['feature_dims'] = inputs['model_config']['feature_dims'] | |||
| cmd['predictions_file'] = 'text' | |||
| cmd['mvn_file'] = inputs['am_mvn_file'] | |||
| cmd['vocab_file'] = inputs['vocab_file'] | |||
| if 'idx_text' in inputs: | |||
| cmd['idx_text'] = inputs['idx_text'] | |||
| else: | |||
| raise ValueError('model type is mismatching') | |||
| inputs['asr_result'] = self.run_inference(cmd) | |||
| return inputs | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the asr results | |||
| """ | |||
| from easyasr.common import asr_utils | |||
| logger.info('Computing the result of ASR ...') | |||
| rst = {} | |||
| # single wav or pcm task | |||
| if inputs['recog_type'] == 'wav': | |||
| if 'asr_result' in inputs and len(inputs['asr_result']) > 0: | |||
| text = inputs['asr_result'][0]['value'] | |||
| if len(text) > 0: | |||
| rst[OutputKeys.TEXT] = text | |||
| # run with datasets, and audio format is waveform or kaldi_ark or tfrecord | |||
| elif inputs['recog_type'] != 'wav': | |||
| inputs['reference_list'] = self.ref_list_tidy(inputs) | |||
| inputs['datasets_result'] = asr_utils.compute_wer( | |||
| inputs['asr_result'], inputs['reference_list']) | |||
| else: | |||
| raise ValueError('recog_type and audio_format are mismatching') | |||
| if 'datasets_result' in inputs: | |||
| rst[OutputKeys.TEXT] = inputs['datasets_result'] | |||
| return rst | |||
| def ref_list_tidy(self, inputs: Dict[str, Any]) -> List[Any]: | |||
| ref_list = [] | |||
| if inputs['audio_format'] == 'tfrecord': | |||
| # should assemble idx + txt | |||
| with open(inputs['reference_text'], 'r', encoding='utf-8') as r: | |||
| text_lines = r.readlines() | |||
| with open(inputs['idx_text'], 'r', encoding='utf-8') as i: | |||
| idx_lines = i.readlines() | |||
| j: int = 0 | |||
| while j < min(len(text_lines), len(idx_lines)): | |||
| idx_str = idx_lines[j].strip() | |||
| text_str = text_lines[j].strip().replace(' ', '') | |||
| item = {'key': idx_str, 'value': text_str} | |||
| ref_list.append(item) | |||
| j += 1 | |||
| else: | |||
| # text contain idx + sentence | |||
| with open(inputs['reference_text'], 'r', encoding='utf-8') as f: | |||
| lines = f.readlines() | |||
| for line in lines: | |||
| line_item = line.split() | |||
| item = {'key': line_item[0], 'value': line_item[1]} | |||
| ref_list.append(item) | |||
| return ref_list | |||
| def run_inference(self, cmd): | |||
| asr_result = [] | |||
| if self.framework == Frameworks.torch: | |||
| from easyasr import asr_inference_paraformer_espnet | |||
| asr_result = asr_inference_paraformer_espnet.asr_inference( | |||
| batch_size=cmd['batch_size'], | |||
| maxlenratio=cmd['maxlenratio'], | |||
| minlenratio=cmd['minlenratio'], | |||
| beam_size=cmd['beam_size'], | |||
| ngpu=cmd['ngpu'], | |||
| ctc_weight=cmd['ctc_weight'], | |||
| lm_weight=cmd['lm_weight'], | |||
| penalty=cmd['penalty'], | |||
| log_level=cmd['log_level'], | |||
| name_and_type=cmd['name_and_type'], | |||
| audio_lists=cmd['audio_in'], | |||
| asr_train_config=cmd['asr_train_config'], | |||
| asr_model_file=cmd['asr_model_file'], | |||
| frontend_conf=cmd['frontend_conf']) | |||
| elif self.framework == Frameworks.tf: | |||
| from easyasr import asr_inference_paraformer_tf | |||
| asr_result = asr_inference_paraformer_tf.asr_inference( | |||
| ngpu=cmd['ngpu'], | |||
| name_and_type=cmd['name_and_type'], | |||
| audio_lists=cmd['audio_in'], | |||
| idx_text_file=cmd['idx_text'], | |||
| asr_model_file=cmd['asr_model_file'], | |||
| vocab_file=cmd['vocab_file'], | |||
| am_mvn_file=cmd['mvn_file'], | |||
| predictions_file=cmd['predictions_file'], | |||
| fs=cmd['fs'], | |||
| hop_length=cmd['hop_length'], | |||
| feature_dims=cmd['feature_dims']) | |||
| return asr_result | |||
| @@ -1,14 +1,9 @@ | |||
| import io | |||
| import os | |||
| import shutil | |||
| from pathlib import Path | |||
| from typing import Any, Dict, List | |||
| import yaml | |||
| from typing import Any, Dict, List, Union | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.models.base import Model | |||
| from modelscope.utils.constant import Fields | |||
| from modelscope.utils.constant import Fields, Frameworks | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| @@ -19,44 +14,32 @@ __all__ = ['WavToScp'] | |||
| Fields.audio, module_name=Preprocessors.wav_to_scp) | |||
| class WavToScp(Preprocessor): | |||
| """generate audio scp from wave or ark | |||
| Args: | |||
| workspace (str): | |||
| """ | |||
| def __init__(self, workspace: str = None): | |||
| # the workspace path | |||
| if workspace is None or len(workspace) == 0: | |||
| self._workspace = os.path.join(os.getcwd(), '.tmp') | |||
| else: | |||
| self._workspace = workspace | |||
| if not os.path.exists(self._workspace): | |||
| os.mkdir(self._workspace) | |||
| def __init__(self): | |||
| pass | |||
| def __call__(self, | |||
| model: List[Model] = None, | |||
| model: Model = None, | |||
| recog_type: str = None, | |||
| audio_format: str = None, | |||
| wav_path: str = None) -> Dict[str, Any]: | |||
| assert len(model) > 0, 'preprocess model is invalid' | |||
| assert len(recog_type) > 0, 'preprocess recog_type is empty' | |||
| assert len(audio_format) > 0, 'preprocess audio_format is empty' | |||
| assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||
| self._am_model = model[0] | |||
| if len(model) == 2 and model[1] is not None: | |||
| self._lm_model = model[1] | |||
| out = self.forward(self._am_model.forward(), recog_type, audio_format, | |||
| wav_path) | |||
| audio_in: Union[str, bytes] = None) -> Dict[str, Any]: | |||
| assert model is not None, 'preprocess model is empty' | |||
| assert recog_type is not None and len( | |||
| recog_type) > 0, 'preprocess recog_type is empty' | |||
| assert audio_format is not None, 'preprocess audio_format is empty' | |||
| assert audio_in is not None, 'preprocess audio_in is empty' | |||
| self.am_model = model | |||
| out = self.forward(self.am_model.forward(), recog_type, audio_format, | |||
| audio_in) | |||
| return out | |||
| def forward(self, model: Dict[str, Any], recog_type: str, | |||
| audio_format: str, wav_path: str) -> Dict[str, Any]: | |||
| audio_format: str, audio_in: Union[str, | |||
| bytes]) -> Dict[str, Any]: | |||
| assert len(recog_type) > 0, 'preprocess recog_type is empty' | |||
| assert len(audio_format) > 0, 'preprocess audio_format is empty' | |||
| assert len(wav_path) > 0, 'preprocess wav_path is empty' | |||
| assert os.path.exists(wav_path), 'preprocess wav_path does not exist' | |||
| assert len( | |||
| model['am_model']) > 0, 'preprocess model[am_model] is empty' | |||
| assert len(model['am_model_path'] | |||
| @@ -70,90 +53,104 @@ class WavToScp(Preprocessor): | |||
| assert len(model['model_config'] | |||
| ) > 0, 'preprocess model[model_config] is empty' | |||
| # the am model name | |||
| am_model: str = model['am_model'] | |||
| # the am model file path | |||
| am_model_path: str = model['am_model_path'] | |||
| # the recognition model dir path | |||
| model_workspace: str = model['model_workspace'] | |||
| # the recognition model config dict | |||
| global_model_config_dict: str = model['model_config'] | |||
| rst = { | |||
| 'workspace': os.path.join(self._workspace, recog_type), | |||
| 'am_model': am_model, | |||
| 'am_model_path': am_model_path, | |||
| 'model_workspace': model_workspace, | |||
| # the recognition model dir path | |||
| 'model_workspace': model['model_workspace'], | |||
| # the am model name | |||
| 'am_model': model['am_model'], | |||
| # the am model file path | |||
| 'am_model_path': model['am_model_path'], | |||
| # the asr type setting, eg: test dev train wav | |||
| 'recog_type': recog_type, | |||
| # the asr audio format setting, eg: wav, kaldi_ark | |||
| # the asr audio format setting, eg: wav, pcm, kaldi_ark, tfrecord | |||
| 'audio_format': audio_format, | |||
| # the test wav file path or the dataset path | |||
| 'wav_path': wav_path, | |||
| 'model_config': global_model_config_dict | |||
| # the recognition model config dict | |||
| 'model_config': model['model_config'] | |||
| } | |||
| out = self._config_checking(rst) | |||
| out = self._env_setting(out) | |||
| if isinstance(audio_in, str): | |||
| # wav file path or the dataset path | |||
| rst['wav_path'] = audio_in | |||
| out = self.config_checking(rst) | |||
| out = self.env_setting(out) | |||
| if audio_format == 'wav': | |||
| out = self._scp_generation_from_wav(out) | |||
| out['audio_lists'] = self.scp_generation_from_wav(out) | |||
| elif audio_format == 'kaldi_ark': | |||
| out = self._scp_generation_from_ark(out) | |||
| out['audio_lists'] = self.scp_generation_from_ark(out) | |||
| elif audio_format == 'tfrecord': | |||
| out['audio_lists'] = os.path.join(out['wav_path'], 'data.records') | |||
| elif audio_format == 'pcm': | |||
| out['audio_lists'] = audio_in | |||
| return out | |||
| def _config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| def config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """config checking | |||
| """ | |||
| assert inputs['model_config'].__contains__( | |||
| 'type'), 'model type does not exist' | |||
| assert inputs['model_config'].__contains__( | |||
| 'batch_size'), 'batch_size does not exist' | |||
| assert inputs['model_config'].__contains__( | |||
| 'am_model_config'), 'am_model_config does not exist' | |||
| assert inputs['model_config'].__contains__( | |||
| 'asr_model_config'), 'asr_model_config does not exist' | |||
| assert inputs['model_config'].__contains__( | |||
| 'asr_model_wav_config'), 'asr_model_wav_config does not exist' | |||
| am_model_config: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['am_model_config']) | |||
| assert os.path.exists( | |||
| am_model_config), 'am_model_config does not exist' | |||
| inputs['am_model_config'] = am_model_config | |||
| asr_model_config: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['asr_model_config']) | |||
| assert os.path.exists( | |||
| asr_model_config), 'asr_model_config does not exist' | |||
| asr_model_wav_config: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['asr_model_wav_config']) | |||
| assert os.path.exists( | |||
| asr_model_wav_config), 'asr_model_wav_config does not exist' | |||
| inputs['model_type'] = inputs['model_config']['type'] | |||
| if inputs['audio_format'] == 'wav': | |||
| inputs['asr_model_config'] = asr_model_wav_config | |||
| if inputs['model_type'] == Frameworks.torch: | |||
| assert inputs['model_config'].__contains__( | |||
| 'batch_size'), 'batch_size does not exist' | |||
| assert inputs['model_config'].__contains__( | |||
| 'am_model_config'), 'am_model_config does not exist' | |||
| assert inputs['model_config'].__contains__( | |||
| 'asr_model_config'), 'asr_model_config does not exist' | |||
| assert inputs['model_config'].__contains__( | |||
| 'asr_model_wav_config'), 'asr_model_wav_config does not exist' | |||
| am_model_config: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['am_model_config']) | |||
| assert os.path.exists( | |||
| am_model_config), 'am_model_config does not exist' | |||
| inputs['am_model_config'] = am_model_config | |||
| asr_model_config: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['asr_model_config']) | |||
| assert os.path.exists( | |||
| asr_model_config), 'asr_model_config does not exist' | |||
| asr_model_wav_config: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['asr_model_wav_config']) | |||
| assert os.path.exists( | |||
| asr_model_wav_config), 'asr_model_wav_config does not exist' | |||
| if inputs['audio_format'] == 'wav' or inputs[ | |||
| 'audio_format'] == 'pcm': | |||
| inputs['asr_model_config'] = asr_model_wav_config | |||
| else: | |||
| inputs['asr_model_config'] = asr_model_config | |||
| elif inputs['model_type'] == Frameworks.tf: | |||
| assert inputs['model_config'].__contains__( | |||
| 'vocab_file'), 'vocab_file does not exist' | |||
| vocab_file: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['vocab_file']) | |||
| assert os.path.exists(vocab_file), 'vocab file does not exist' | |||
| inputs['vocab_file'] = vocab_file | |||
| assert inputs['model_config'].__contains__( | |||
| 'am_mvn_file'), 'am_mvn_file does not exist' | |||
| am_mvn_file: str = os.path.join( | |||
| inputs['model_workspace'], | |||
| inputs['model_config']['am_mvn_file']) | |||
| assert os.path.exists(am_mvn_file), 'am mvn file does not exist' | |||
| inputs['am_mvn_file'] = am_mvn_file | |||
| else: | |||
| inputs['asr_model_config'] = asr_model_config | |||
| raise ValueError('model type is mismatched') | |||
| return inputs | |||
| def _env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| if not os.path.exists(inputs['workspace']): | |||
| os.mkdir(inputs['workspace']) | |||
| inputs['output'] = os.path.join(inputs['workspace'], 'logdir') | |||
| if not os.path.exists(inputs['output']): | |||
| os.mkdir(inputs['output']) | |||
| def env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| # run with datasets, should set datasets_path and text_path | |||
| if inputs['recog_type'] != 'wav': | |||
| inputs['datasets_path'] = inputs['wav_path'] | |||
| @@ -162,25 +159,39 @@ class WavToScp(Preprocessor): | |||
| if inputs['audio_format'] == 'wav': | |||
| inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||
| 'wav', inputs['recog_type']) | |||
| inputs['hypothesis_text'] = os.path.join( | |||
| inputs['reference_text'] = os.path.join( | |||
| inputs['datasets_path'], 'transcript', 'data.text') | |||
| assert os.path.exists(inputs['hypothesis_text'] | |||
| ), 'hypothesis text does not exist' | |||
| assert os.path.exists( | |||
| inputs['reference_text']), 'reference text does not exist' | |||
| # run with datasets, and audio format is kaldi_ark | |||
| elif inputs['audio_format'] == 'kaldi_ark': | |||
| inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||
| inputs['recog_type']) | |||
| inputs['hypothesis_text'] = os.path.join( | |||
| inputs['reference_text'] = os.path.join( | |||
| inputs['wav_path'], 'data.text') | |||
| assert os.path.exists(inputs['hypothesis_text'] | |||
| ), 'hypothesis text does not exist' | |||
| assert os.path.exists( | |||
| inputs['reference_text']), 'reference text does not exist' | |||
| # run with datasets, and audio format is tfrecord | |||
| elif inputs['audio_format'] == 'tfrecord': | |||
| inputs['wav_path'] = os.path.join(inputs['datasets_path'], | |||
| inputs['recog_type']) | |||
| inputs['reference_text'] = os.path.join( | |||
| inputs['wav_path'], 'data.txt') | |||
| assert os.path.exists( | |||
| inputs['reference_text']), 'reference text does not exist' | |||
| inputs['idx_text'] = os.path.join(inputs['wav_path'], | |||
| 'data.idx') | |||
| assert os.path.exists( | |||
| inputs['idx_text']), 'idx text does not exist' | |||
| return inputs | |||
| def _scp_generation_from_wav(self, inputs: Dict[str, | |||
| Any]) -> Dict[str, Any]: | |||
| def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]: | |||
| """scp generation from waveform files | |||
| """ | |||
| from easyasr.common import asr_utils | |||
| # find all waveform files | |||
| wav_list = [] | |||
| @@ -191,64 +202,46 @@ class WavToScp(Preprocessor): | |||
| wav_list.append(file_path) | |||
| else: | |||
| wav_dir: str = inputs['wav_path'] | |||
| wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||
| wav_list = asr_utils.recursion_dir_all_wav(wav_list, wav_dir) | |||
| list_count: int = len(wav_list) | |||
| inputs['wav_count'] = list_count | |||
| # store all wav into data.0.scp | |||
| inputs['thread_count'] = 1 | |||
| # store all wav into audio list | |||
| audio_lists = [] | |||
| j: int = 0 | |||
| wav_list_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||
| with open(wav_list_path, 'a') as f: | |||
| while j < list_count: | |||
| wav_file = wav_list[j] | |||
| wave_scp_content: str = os.path.splitext( | |||
| os.path.basename(wav_file))[0] | |||
| wave_scp_content += ' ' + wav_file + '\n' | |||
| f.write(wave_scp_content) | |||
| j += 1 | |||
| while j < list_count: | |||
| wav_file = wav_list[j] | |||
| wave_key: str = os.path.splitext(os.path.basename(wav_file))[0] | |||
| item = {'key': wave_key, 'file': wav_file} | |||
| audio_lists.append(item) | |||
| j += 1 | |||
| return inputs | |||
| return audio_lists | |||
| def _scp_generation_from_ark(self, inputs: Dict[str, | |||
| Any]) -> Dict[str, Any]: | |||
| def scp_generation_from_ark(self, inputs: Dict[str, Any]) -> List[Any]: | |||
| """scp generation from kaldi ark file | |||
| """ | |||
| inputs['thread_count'] = 1 | |||
| ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp') | |||
| ark_file_path = os.path.join(inputs['wav_path'], 'data.ark') | |||
| assert os.path.exists(ark_scp_path), 'data.scp does not exist' | |||
| assert os.path.exists(ark_file_path), 'data.ark does not exist' | |||
| new_ark_scp_path = os.path.join(inputs['workspace'], 'data.0.scp') | |||
| with open(ark_scp_path, 'r', encoding='utf-8') as f: | |||
| lines = f.readlines() | |||
| with open(new_ark_scp_path, 'w', encoding='utf-8') as n: | |||
| for line in lines: | |||
| outs = line.strip().split(' ') | |||
| if len(outs) == 2: | |||
| key = outs[0] | |||
| sub = outs[1].split(':') | |||
| if len(sub) == 2: | |||
| nums = sub[1] | |||
| content = key + ' ' + ark_file_path + ':' + nums + '\n' | |||
| n.write(content) | |||
| return inputs | |||
| def _recursion_dir_all_wave(self, wav_list, | |||
| dir_path: str) -> Dict[str, Any]: | |||
| dir_files = os.listdir(dir_path) | |||
| for file in dir_files: | |||
| file_path = os.path.join(dir_path, file) | |||
| if os.path.isfile(file_path): | |||
| if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||
| wav_list.append(file_path) | |||
| elif os.path.isdir(file_path): | |||
| self._recursion_dir_all_wave(wav_list, file_path) | |||
| return wav_list | |||
| # store all ark item into audio list | |||
| audio_lists = [] | |||
| for line in lines: | |||
| outs = line.strip().split(' ') | |||
| if len(outs) == 2: | |||
| key = outs[0] | |||
| sub = outs[1].split(':') | |||
| if len(sub) == 2: | |||
| nums = sub[1] | |||
| content = ark_file_path + ':' + nums | |||
| item = {'key': key, 'file': content} | |||
| audio_lists.append(item) | |||
| return audio_lists | |||
| @@ -1,3 +1,4 @@ | |||
| easyasr>=0.0.2 | |||
| espnet>=202204 | |||
| #tts | |||
| h5py | |||
| @@ -1,15 +1,20 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import sys | |||
| import tarfile | |||
| import unittest | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import requests | |||
| import soundfile | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.constant import ColorCodes, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| from modelscope.utils.test_utils import download_and_untar, test_level | |||
| logger = get_logger() | |||
| @@ -21,6 +26,9 @@ LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AS | |||
| AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz' | |||
| AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz' | |||
| TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz' | |||
| TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz' | |||
| def un_tar_gz(fname, dirs): | |||
| t = tarfile.open(fname) | |||
| @@ -28,45 +36,168 @@ def un_tar_gz(fname, dirs): | |||
| class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||
| action_info = { | |||
| 'test_run_with_wav_pytorch': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'wav_example' | |||
| }, | |||
| 'test_run_with_pcm_pytorch': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'wav_example' | |||
| }, | |||
| 'test_run_with_wav_tf': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'wav_example' | |||
| }, | |||
| 'test_run_with_pcm_tf': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'wav_example' | |||
| }, | |||
| 'test_run_with_wav_dataset_pytorch': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'dataset_example' | |||
| }, | |||
| 'test_run_with_wav_dataset_tf': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'dataset_example' | |||
| }, | |||
| 'test_run_with_ark_dataset': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'dataset_example' | |||
| }, | |||
| 'test_run_with_tfrecord_dataset': { | |||
| 'checking_item': OutputKeys.TEXT, | |||
| 'example': 'dataset_example' | |||
| }, | |||
| 'dataset_example': { | |||
| 'Wrd': 49532, # the number of words | |||
| 'Snt': 5000, # the number of sentences | |||
| 'Corr': 47276, # the number of correct words | |||
| 'Ins': 49, # the number of insert words | |||
| 'Del': 152, # the number of delete words | |||
| 'Sub': 2207, # the number of substitution words | |||
| 'wrong_words': 2408, # the number of wrong words | |||
| 'wrong_sentences': 1598, # the number of wrong sentences | |||
| 'Err': 4.86, # WER/CER | |||
| 'S.Err': 31.96 # SER | |||
| }, | |||
| 'wav_example': { | |||
| 'text': '每一天都要快乐喔' | |||
| } | |||
| } | |||
| def setUp(self) -> None: | |||
| self._am_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||
| self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | |||
| self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' | |||
| # this temporary workspace dir will store waveform files | |||
| self._workspace = os.path.join(os.getcwd(), '.tmp') | |||
| if not os.path.exists(self._workspace): | |||
| os.mkdir(self._workspace) | |||
| self.workspace = os.path.join(os.getcwd(), '.tmp') | |||
| if not os.path.exists(self.workspace): | |||
| os.mkdir(self.workspace) | |||
| def tearDown(self) -> None: | |||
| # remove workspace dir (.tmp) | |||
| shutil.rmtree(self.workspace, ignore_errors=True) | |||
| def run_pipeline(self, model_id: str, | |||
| audio_in: Union[str, bytes]) -> Dict[str, Any]: | |||
| inference_16k_pipline = pipeline( | |||
| task=Tasks.auto_speech_recognition, model=model_id) | |||
| rec_result = inference_16k_pipline(audio_in) | |||
| return rec_result | |||
| def log_error(self, functions: str, result: Dict[str, Any]) -> None: | |||
| logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' | |||
| + ColorCodes.END) | |||
| logger.error( | |||
| ColorCodes.MAGENTA + functions + ' correct result example:' | |||
| + ColorCodes.YELLOW | |||
| + str(self.action_info[self.action_info[functions]['example']]) | |||
| + ColorCodes.END) | |||
| raise ValueError('asr result is mismatched') | |||
| def check_result(self, functions: str, result: Dict[str, Any]) -> None: | |||
| if result.__contains__(self.action_info[functions]['checking_item']): | |||
| logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' | |||
| + ColorCodes.END) | |||
| logger.info( | |||
| ColorCodes.YELLOW | |||
| + str(result[self.action_info[functions]['checking_item']]) | |||
| + ColorCodes.END) | |||
| else: | |||
| self.log_error(functions, result) | |||
| def wav2bytes(self, wav_file) -> bytes: | |||
| audio, fs = soundfile.read(wav_file) | |||
| # float32 -> int16 | |||
| audio = np.asarray(audio) | |||
| dtype = np.dtype('int16') | |||
| i = np.iinfo(dtype) | |||
| abs_max = 2**(i.bits - 1) | |||
| offset = i.min + abs_max | |||
| audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) | |||
| # int16(PCM_16) -> byte | |||
| audio = audio.tobytes() | |||
| return audio | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_wav(self): | |||
| def test_run_with_wav_pytorch(self): | |||
| '''run with single waveform file | |||
| ''' | |||
| logger.info('Run ASR test with waveform file ...') | |||
| logger.info('Run ASR test with waveform file (pytorch)...') | |||
| wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||
| inference_16k_pipline = pipeline( | |||
| task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||
| self.assertTrue(inference_16k_pipline is not None) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_pytorch_model_id, audio_in=wav_file_path) | |||
| self.check_result('test_run_with_wav_pytorch', rec_result) | |||
| rec_result = inference_16k_pipline(wav_file_path) | |||
| self.assertTrue(len(rec_result['rec_result']) > 0) | |||
| self.assertTrue(rec_result['rec_result'] != 'None') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_pcm_pytorch(self): | |||
| '''run with wav data | |||
| ''' | |||
| result structure: | |||
| { | |||
| 'rec_result': '每一天都要快乐喔' | |||
| } | |||
| or | |||
| { | |||
| 'rec_result': 'None' | |||
| } | |||
| logger.info('Run ASR test with wav data (pytorch)...') | |||
| audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_pytorch_model_id, audio_in=audio) | |||
| self.check_result('test_run_with_pcm_pytorch', rec_result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_wav_tf(self): | |||
| '''run with single waveform file | |||
| ''' | |||
| logger.info('test_run_with_wav rec result: ' | |||
| + rec_result['rec_result']) | |||
| logger.info('Run ASR test with waveform file (tensorflow)...') | |||
| wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_tf_model_id, audio_in=wav_file_path) | |||
| self.check_result('test_run_with_wav_tf', rec_result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_pcm_tf(self): | |||
| '''run with wav data | |||
| ''' | |||
| logger.info('Run ASR test with wav data (tensorflow)...') | |||
| audio = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_tf_model_id, audio_in=audio) | |||
| self.check_result('test_run_with_pcm_tf', rec_result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_wav_dataset(self): | |||
| def test_run_with_wav_dataset_pytorch(self): | |||
| '''run with datasets, and audio format is waveform | |||
| datasets directory: | |||
| <dataset_path> | |||
| @@ -84,57 +215,48 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||
| data.text # hypothesis text | |||
| ''' | |||
| logger.info('Run ASR test with waveform dataset ...') | |||
| logger.info('Run ASR test with waveform dataset (pytorch)...') | |||
| logger.info('Downloading waveform testsets file ...') | |||
| # downloading pos_testsets file | |||
| testsets_file_path = os.path.join(self._workspace, | |||
| LITTLE_TESTSETS_FILE) | |||
| if not os.path.exists(testsets_file_path): | |||
| r = requests.get(LITTLE_TESTSETS_URL) | |||
| with open(testsets_file_path, 'wb') as f: | |||
| f.write(r.content) | |||
| testsets_dir_name = os.path.splitext( | |||
| os.path.basename( | |||
| os.path.splitext( | |||
| os.path.basename(LITTLE_TESTSETS_FILE))[0]))[0] | |||
| # dataset_path = <cwd>/.tmp/data_aishell/wav/test | |||
| dataset_path = os.path.join(self._workspace, testsets_dir_name, 'wav', | |||
| 'test') | |||
| # untar the dataset_path file | |||
| if not os.path.exists(dataset_path): | |||
| un_tar_gz(testsets_file_path, self._workspace) | |||
| dataset_path = download_and_untar( | |||
| os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||
| LITTLE_TESTSETS_URL, self.workspace) | |||
| dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||
| inference_16k_pipline = pipeline( | |||
| task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||
| self.assertTrue(inference_16k_pipline is not None) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_pytorch_model_id, audio_in=dataset_path) | |||
| self.check_result('test_run_with_wav_dataset_pytorch', rec_result) | |||
| rec_result = inference_16k_pipline(wav_path=dataset_path) | |||
| self.assertTrue(len(rec_result['datasets_result']) > 0) | |||
| self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||
| ''' | |||
| result structure: | |||
| { | |||
| 'rec_result': 'None', | |||
| 'datasets_result': | |||
| { | |||
| 'Wrd': 1654, # the number of words | |||
| 'Snt': 128, # the number of sentences | |||
| 'Corr': 1573, # the number of correct words | |||
| 'Ins': 1, # the number of insert words | |||
| 'Del': 1, # the number of delete words | |||
| 'Sub': 80, # the number of substitution words | |||
| 'wrong_words': 82, # the number of wrong words | |||
| 'wrong_sentences': 47, # the number of wrong sentences | |||
| 'Err': 4.96, # WER/CER | |||
| 'S.Err': 36.72 # SER | |||
| } | |||
| } | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_wav_dataset_tf(self): | |||
| '''run with datasets, and audio format is waveform | |||
| datasets directory: | |||
| <dataset_path> | |||
| wav | |||
| test # testsets | |||
| xx.wav | |||
| ... | |||
| dev # devsets | |||
| yy.wav | |||
| ... | |||
| train # trainsets | |||
| zz.wav | |||
| ... | |||
| transcript | |||
| data.text # hypothesis text | |||
| ''' | |||
| logger.info('test_run_with_wav_dataset datasets result: ') | |||
| logger.info(rec_result['datasets_result']) | |||
| logger.info('Run ASR test with waveform dataset (tensorflow)...') | |||
| logger.info('Downloading waveform testsets file ...') | |||
| dataset_path = download_and_untar( | |||
| os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||
| LITTLE_TESTSETS_URL, self.workspace) | |||
| dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_tf_model_id, audio_in=dataset_path) | |||
| self.check_result('test_run_with_wav_dataset_tf', rec_result) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_ark_dataset(self): | |||
| @@ -155,56 +277,40 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): | |||
| data.text | |||
| ''' | |||
| logger.info('Run ASR test with ark dataset ...') | |||
| logger.info('Run ASR test with ark dataset (pytorch)...') | |||
| logger.info('Downloading ark testsets file ...') | |||
| # downloading pos_testsets file | |||
| testsets_file_path = os.path.join(self._workspace, | |||
| AISHELL1_TESTSETS_FILE) | |||
| if not os.path.exists(testsets_file_path): | |||
| r = requests.get(AISHELL1_TESTSETS_URL) | |||
| with open(testsets_file_path, 'wb') as f: | |||
| f.write(r.content) | |||
| testsets_dir_name = os.path.splitext( | |||
| os.path.basename( | |||
| os.path.splitext( | |||
| os.path.basename(AISHELL1_TESTSETS_FILE))[0]))[0] | |||
| # dataset_path = <cwd>/.tmp/aishell1/test | |||
| dataset_path = os.path.join(self._workspace, testsets_dir_name, 'test') | |||
| # untar the dataset_path file | |||
| if not os.path.exists(dataset_path): | |||
| un_tar_gz(testsets_file_path, self._workspace) | |||
| dataset_path = download_and_untar( | |||
| os.path.join(self.workspace, AISHELL1_TESTSETS_FILE), | |||
| AISHELL1_TESTSETS_URL, self.workspace) | |||
| dataset_path = os.path.join(dataset_path, 'test') | |||
| inference_16k_pipline = pipeline( | |||
| task=Tasks.auto_speech_recognition, model=[self._am_model_id]) | |||
| self.assertTrue(inference_16k_pipline is not None) | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_pytorch_model_id, audio_in=dataset_path) | |||
| self.check_result('test_run_with_ark_dataset', rec_result) | |||
| rec_result = inference_16k_pipline(wav_path=dataset_path) | |||
| self.assertTrue(len(rec_result['datasets_result']) > 0) | |||
| self.assertTrue(rec_result['datasets_result']['Wrd'] > 0) | |||
| ''' | |||
| result structure: | |||
| { | |||
| 'rec_result': 'None', | |||
| 'datasets_result': | |||
| { | |||
| 'Wrd': 104816, # the number of words | |||
| 'Snt': 7176, # the number of sentences | |||
| 'Corr': 99327, # the number of correct words | |||
| 'Ins': 104, # the number of insert words | |||
| 'Del': 155, # the number of delete words | |||
| 'Sub': 5334, # the number of substitution words | |||
| 'wrong_words': 5593, # the number of wrong words | |||
| 'wrong_sentences': 2898, # the number of wrong sentences | |||
| 'Err': 5.34, # WER/CER | |||
| 'S.Err': 40.38 # SER | |||
| } | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_tfrecord_dataset(self): | |||
| '''run with datasets, and audio format is tfrecord | |||
| datasets directory: | |||
| <dataset_path> | |||
| test # testsets | |||
| data.records | |||
| data.idx | |||
| data.text | |||
| ''' | |||
| logger.info('test_run_with_ark_dataset datasets result: ') | |||
| logger.info(rec_result['datasets_result']) | |||
| logger.info('Run ASR test with tfrecord dataset (tensorflow)...') | |||
| logger.info('Downloading tfrecord testsets file ...') | |||
| dataset_path = download_and_untar( | |||
| os.path.join(self.workspace, TFRECORD_TESTSETS_FILE), | |||
| TFRECORD_TESTSETS_URL, self.workspace) | |||
| dataset_path = os.path.join(dataset_path, 'test') | |||
| rec_result = self.run_pipeline( | |||
| model_id=self.am_tf_model_id, audio_in=dataset_path) | |||
| self.check_result('test_run_with_tfrecord_dataset', rec_result) | |||
| if __name__ == '__main__': | |||