Merge frontend, am and vocoder model card into one model card. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9303837master
| @@ -20,9 +20,7 @@ class Models(object): | |||||
| space = 'space' | space = 'space' | ||||
| # audio models | # audio models | ||||
| sambert_hifi_16k = 'sambert-hifi-16k' | |||||
| generic_tts_frontend = 'generic-tts-frontend' | |||||
| hifigan16k = 'hifigan16k' | |||||
| sambert_hifigan = 'sambert-hifigan' | |||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| @@ -66,7 +64,7 @@ class Pipelines(object): | |||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | |||||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | |||||
| speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| @@ -5,8 +5,7 @@ from .base import Model | |||||
| from .builder import MODELS, build_model | from .builder import MODELS, build_model | ||||
| try: | try: | ||||
| from .audio.tts.am import SambertNetHifi16k | |||||
| from .audio.tts.vocoder import Hifigan16k | |||||
| from .audio.tts import SambertHifigan | |||||
| from .audio.kws import GenericKeyWordSpotting | from .audio.kws import GenericKeyWordSpotting | ||||
| from .audio.ans.frcrn import FRCRNModel | from .audio.ans.frcrn import FRCRNModel | ||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| @@ -0,0 +1 @@ | |||||
| from .sambert_hifi import * # noqa F403 | |||||
| @@ -1 +0,0 @@ | |||||
| from .sambert_hifi_16k import * # noqa F403 | |||||
| @@ -1 +0,0 @@ | |||||
| from .generic_text_to_speech_frontend import * # noqa F403 | |||||
| @@ -1,39 +0,0 @@ | |||||
| import os | |||||
| import zipfile | |||||
| from typing import Any, Dict, List | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.audio.tts_exceptions import ( | |||||
| TtsFrontendInitializeFailedException, | |||||
| TtsFrontendLanguageTypeInvalidException) | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['GenericTtsFrontend'] | |||||
| @MODELS.register_module( | |||||
| Tasks.text_to_speech, module_name=Models.generic_tts_frontend) | |||||
| class GenericTtsFrontend(Model): | |||||
| def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs): | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| import ttsfrd | |||||
| frontend = ttsfrd.TtsFrontendEngine() | |||||
| zip_file = os.path.join(model_dir, 'resource.zip') | |||||
| self._res_path = os.path.join(model_dir, 'resource') | |||||
| with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |||||
| zip_ref.extractall(model_dir) | |||||
| if not frontend.initialize(self._res_path): | |||||
| raise TtsFrontendInitializeFailedException( | |||||
| 'resource invalid: {}'.format(self._res_path)) | |||||
| if not frontend.set_lang_type(lang_type): | |||||
| raise TtsFrontendLanguageTypeInvalidException( | |||||
| 'language type invalid: {}, valid is pinyin and chenmix'. | |||||
| format(lang_type)) | |||||
| self._frontend = frontend | |||||
| def forward(self, data: str) -> Dict[str, List]: | |||||
| result = self._frontend.gen_tacotron_symbols(data) | |||||
| return {'texts': [s for s in result.splitlines() if s != '']} | |||||
| @@ -1,7 +1,8 @@ | |||||
| from .robutrans import RobuTrans | from .robutrans import RobuTrans | ||||
| from .vocoder_models import Generator | |||||
| def create_model(name, hparams): | |||||
| def create_am_model(name, hparams): | |||||
| if name == 'robutrans': | if name == 'robutrans': | ||||
| return RobuTrans(hparams) | return RobuTrans(hparams) | ||||
| else: | else: | ||||
| @@ -4,7 +4,7 @@ from tensorflow.contrib.rnn import RNNCell | |||||
| from tensorflow.contrib.seq2seq import AttentionWrapperState | from tensorflow.contrib.seq2seq import AttentionWrapperState | ||||
| from tensorflow.python.ops import rnn_cell_impl | from tensorflow.python.ops import rnn_cell_impl | ||||
| from .modules import prenet | |||||
| from .am_models import prenet | |||||
| class VarPredictorCell(RNNCell): | class VarPredictorCell(RNNCell): | ||||
| @@ -3,9 +3,9 @@ from tensorflow.contrib.rnn import LSTMBlockCell, MultiRNNCell | |||||
| from tensorflow.contrib.seq2seq import BasicDecoder | from tensorflow.contrib.seq2seq import BasicDecoder | ||||
| from tensorflow.python.ops.ragged.ragged_util import repeat | from tensorflow.python.ops.ragged.ragged_util import repeat | ||||
| from .am_models import conv_prenet, decoder_prenet, encoder_prenet | |||||
| from .fsmn_encoder import FsmnEncoderV2 | from .fsmn_encoder import FsmnEncoderV2 | ||||
| from .helpers import VarTestHelper, VarTrainingHelper | from .helpers import VarTestHelper, VarTrainingHelper | ||||
| from .modules import conv_prenet, decoder_prenet, encoder_prenet | |||||
| from .position import (BatchSinusodalPositionalEncoding, | from .position import (BatchSinusodalPositionalEncoding, | ||||
| SinusodalPositionalEncoding) | SinusodalPositionalEncoding) | ||||
| from .rnn_wrappers import DurPredictorCell, VarPredictorCell | from .rnn_wrappers import DurPredictorCell, VarPredictorCell | ||||
| @@ -5,7 +5,7 @@ import sys | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| from . import compat, transformer | from . import compat, transformer | ||||
| from .modules import decoder_prenet | |||||
| from .am_models import decoder_prenet | |||||
| from .position import SinusoidalPositionEncoder | from .position import SinusoidalPositionEncoder | ||||
| @@ -1,20 +1,31 @@ | |||||
| from __future__ import (absolute_import, division, print_function, | |||||
| unicode_literals) | |||||
| import io | import io | ||||
| import os | import os | ||||
| import time | |||||
| import zipfile | |||||
| from typing import Any, Dict, Optional, Union | from typing import Any, Dict, Optional, Union | ||||
| import json | |||||
| import numpy as np | import numpy as np | ||||
| import tensorflow as tf | import tensorflow as tf | ||||
| import torch | |||||
| from sklearn.preprocessing import MultiLabelBinarizer | from sklearn.preprocessing import MultiLabelBinarizer | ||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.utils.audio.tts_exceptions import ( | |||||
| TtsFrontendInitializeFailedException, | |||||
| TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationExcetion, | |||||
| TtsVocoderMelspecShapeMismatchException) | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from .models import create_model | |||||
| from .models import Generator, create_am_model | |||||
| from .text.symbols import load_symbols | from .text.symbols import load_symbols | ||||
| from .text.symbols_dict import SymbolsDict | from .text.symbols_dict import SymbolsDict | ||||
| __all__ = ['SambertNetHifi16k'] | |||||
| __all__ = ['SambertHifigan'] | |||||
| MAX_WAV_VALUE = 32768.0 | |||||
| def multi_label_symbol_to_sequence(my_classes, my_symbol): | def multi_label_symbol_to_sequence(my_classes, my_symbol): | ||||
| @@ -23,13 +34,25 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol): | |||||
| sequences = [] | sequences = [] | ||||
| for token in tokens: | for token in tokens: | ||||
| sequences.append(tuple(token.split('&'))) | sequences.append(tuple(token.split('&'))) | ||||
| # sequences.append(tuple(['~'])) # sequence length minus 1 to ignore EOS ~ | |||||
| return one_hot.fit_transform(sequences) | return one_hot.fit_transform(sequences) | ||||
| def load_checkpoint(filepath, device): | |||||
| assert os.path.isfile(filepath) | |||||
| checkpoint_dict = torch.load(filepath, map_location=device) | |||||
| return checkpoint_dict | |||||
| class AttrDict(dict): | |||||
| def __init__(self, *args, **kwargs): | |||||
| super(AttrDict, self).__init__(*args, **kwargs) | |||||
| self.__dict__ = self | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.text_to_speech, module_name=Models.sambert_hifi_16k) | |||||
| class SambertNetHifi16k(Model): | |||||
| Tasks.text_to_speech, module_name=Models.sambert_hifigan) | |||||
| class SambertHifigan(Model): | |||||
| def __init__(self, | def __init__(self, | ||||
| model_dir, | model_dir, | ||||
| @@ -38,20 +61,50 @@ class SambertNetHifi16k(Model): | |||||
| energy_control_str='', | energy_control_str='', | ||||
| *args, | *args, | ||||
| **kwargs): | **kwargs): | ||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| if 'am' not in kwargs: | |||||
| raise TtsModelConfigurationExcetion( | |||||
| 'configuration model field missing am!') | |||||
| if 'vocoder' not in kwargs: | |||||
| raise TtsModelConfigurationExcetion( | |||||
| 'configuration model field missing vocoder!') | |||||
| if 'lang_type' not in kwargs: | |||||
| raise TtsModelConfigurationExcetion( | |||||
| 'configuration model field missing lang_type!') | |||||
| # initialize frontend | |||||
| import ttsfrd | |||||
| frontend = ttsfrd.TtsFrontendEngine() | |||||
| zip_file = os.path.join(model_dir, 'resource.zip') | |||||
| self._res_path = os.path.join(model_dir, 'resource') | |||||
| with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |||||
| zip_ref.extractall(model_dir) | |||||
| if not frontend.initialize(self._res_path): | |||||
| raise TtsFrontendInitializeFailedException( | |||||
| 'resource invalid: {}'.format(self._res_path)) | |||||
| if not frontend.set_lang_type(kwargs['lang_type']): | |||||
| raise TtsFrontendLanguageTypeInvalidException( | |||||
| 'language type invalid: {}'.format(kwargs['lang_type'])) | |||||
| self._frontend = frontend | |||||
| # initialize am | |||||
| tf.reset_default_graph() | tf.reset_default_graph() | ||||
| local_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, 'ckpt') | |||||
| self._ckpt_path = os.path.join(model_dir, local_ckpt_path) | |||||
| local_am_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, | |||||
| 'ckpt') | |||||
| self._am_ckpt_path = os.path.join(model_dir, local_am_ckpt_path) | |||||
| self._dict_path = os.path.join(model_dir, 'dicts') | self._dict_path = os.path.join(model_dir, 'dicts') | ||||
| self._hparams = tf.contrib.training.HParams(**kwargs) | |||||
| values = self._hparams.values() | |||||
| self._am_hparams = tf.contrib.training.HParams(**kwargs['am']) | |||||
| has_mask = True | |||||
| if self._am_hparams.get('has_mask') is not None: | |||||
| has_mask = self._am_hparams.has_mask | |||||
| print('set has_mask to {}'.format(has_mask)) | |||||
| values = self._am_hparams.values() | |||||
| hp = [' {}:{}'.format(name, values[name]) for name in sorted(values)] | hp = [' {}:{}'.format(name, values[name]) for name in sorted(values)] | ||||
| print('Hyperparameters:\n' + '\n'.join(hp)) | print('Hyperparameters:\n' + '\n'.join(hp)) | ||||
| super().__init__(self._ckpt_path, *args, **kwargs) | |||||
| model_name = 'robutrans' | model_name = 'robutrans' | ||||
| self._lfeat_type_list = self._hparams.lfeat_type_list.strip().split( | |||||
| self._lfeat_type_list = self._am_hparams.lfeat_type_list.strip().split( | |||||
| ',') | ',') | ||||
| sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( | sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( | ||||
| self._dict_path) | |||||
| self._dict_path, has_mask) | |||||
| self._sy = sy | self._sy = sy | ||||
| self._tone = tone | self._tone = tone | ||||
| self._syllable_flag = syllable_flag | self._syllable_flag = syllable_flag | ||||
| @@ -86,7 +139,6 @@ class SambertNetHifi16k(Model): | |||||
| inputs_speaker = tf.placeholder(tf.float32, | inputs_speaker = tf.placeholder(tf.float32, | ||||
| [1, None, self._inputs_dim['speaker']], | [1, None, self._inputs_dim['speaker']], | ||||
| 'inputs_speaker') | 'inputs_speaker') | ||||
| input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') | input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') | ||||
| pitch_contours_scale = tf.placeholder(tf.float32, [1, None], | pitch_contours_scale = tf.placeholder(tf.float32, [1, None], | ||||
| 'pitch_contours_scale') | 'pitch_contours_scale') | ||||
| @@ -94,9 +146,8 @@ class SambertNetHifi16k(Model): | |||||
| 'energy_contours_scale') | 'energy_contours_scale') | ||||
| duration_scale = tf.placeholder(tf.float32, [1, None], | duration_scale = tf.placeholder(tf.float32, [1, None], | ||||
| 'duration_scale') | 'duration_scale') | ||||
| with tf.variable_scope('model') as _: | with tf.variable_scope('model') as _: | ||||
| self._model = create_model(model_name, self._hparams) | |||||
| self._model = create_am_model(model_name, self._am_hparams) | |||||
| self._model.initialize( | self._model.initialize( | ||||
| inputs, | inputs, | ||||
| inputs_emotion, | inputs_emotion, | ||||
| @@ -123,14 +174,14 @@ class SambertNetHifi16k(Model): | |||||
| self._attention_h = self._model.attention_h | self._attention_h = self._model.attention_h | ||||
| self._attention_x = self._model.attention_x | self._attention_x = self._model.attention_x | ||||
| print('Loading checkpoint: %s' % self._ckpt_path) | |||||
| print('Loading checkpoint: %s' % self._am_ckpt_path) | |||||
| config = tf.ConfigProto() | config = tf.ConfigProto() | ||||
| config.gpu_options.allow_growth = True | config.gpu_options.allow_growth = True | ||||
| self._session = tf.Session(config=config) | self._session = tf.Session(config=config) | ||||
| self._session.run(tf.global_variables_initializer()) | self._session.run(tf.global_variables_initializer()) | ||||
| saver = tf.train.Saver() | saver = tf.train.Saver() | ||||
| saver.restore(self._session, self._ckpt_path) | |||||
| saver.restore(self._session, self._am_ckpt_path) | |||||
| duration_cfg_lst = [] | duration_cfg_lst = [] | ||||
| if len(duration_control_str) != 0: | if len(duration_control_str) != 0: | ||||
| @@ -158,8 +209,26 @@ class SambertNetHifi16k(Model): | |||||
| self._energy_contours_cfg_lst = energy_contours_cfg_lst | self._energy_contours_cfg_lst = energy_contours_cfg_lst | ||||
| def forward(self, text): | |||||
| cleaner_names = [x.strip() for x in self._hparams.cleaners.split(',')] | |||||
| # initialize vocoder | |||||
| self._voc_ckpt_path = os.path.join(model_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| self._voc_config = AttrDict(**kwargs['vocoder']) | |||||
| print(self._voc_config) | |||||
| if torch.cuda.is_available(): | |||||
| torch.manual_seed(self._voc_config.seed) | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self._generator = Generator(self._voc_config).to(self._device) | |||||
| state_dict_g = load_checkpoint(self._voc_ckpt_path, self._device) | |||||
| self._generator.load_state_dict(state_dict_g['generator']) | |||||
| self._generator.eval() | |||||
| self._generator.remove_weight_norm() | |||||
| def am_synthesis_one_sentences(self, text): | |||||
| cleaner_names = [ | |||||
| x.strip() for x in self._am_hparams.cleaners.split(',') | |||||
| ] | |||||
| lfeat_symbol = text.strip().split(' ') | lfeat_symbol = text.strip().split(' ') | ||||
| lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) | lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) | ||||
| @@ -255,3 +324,31 @@ class SambertNetHifi16k(Model): | |||||
| self._energy_embeddings, self._attention_x, self._attention_h | self._energy_embeddings, self._attention_x, self._attention_h | ||||
| ], feed_dict=feed_dict) # yapf:disable | ], feed_dict=feed_dict) # yapf:disable | ||||
| return result[0] | return result[0] | ||||
| def vocoder_process(self, melspec): | |||||
| dim0 = list(melspec.shape)[-1] | |||||
| if dim0 != self._voc_config.num_mels: | |||||
| raise TtsVocoderMelspecShapeMismatchException( | |||||
| 'input melspec mismatch require {} but {}'.format( | |||||
| self._voc_config.num_mels, dim0)) | |||||
| with torch.no_grad(): | |||||
| x = melspec.T | |||||
| x = torch.FloatTensor(x).to(self._device) | |||||
| if len(x.shape) == 2: | |||||
| x = x.unsqueeze(0) | |||||
| y_g_hat = self._generator(x) | |||||
| audio = y_g_hat.squeeze() | |||||
| audio = audio * MAX_WAV_VALUE | |||||
| audio = audio.cpu().numpy().astype('int16') | |||||
| return audio | |||||
| def forward(self, text): | |||||
| result = self._frontend.gen_tacotron_symbols(text) | |||||
| texts = [s for s in result.splitlines() if s != ''] | |||||
| audio_total = np.empty((0), dtype='int16') | |||||
| for line in texts: | |||||
| line = line.strip().split('\t') | |||||
| audio = self.vocoder_process( | |||||
| self.am_synthesis_one_sentences(line[1])) | |||||
| audio_total = np.append(audio_total, audio, axis=0) | |||||
| return audio_total | |||||
| @@ -12,7 +12,7 @@ _eos = '~' | |||||
| _mask = '@[MASK]' | _mask = '@[MASK]' | ||||
| def load_symbols(dict_path): | |||||
| def load_symbols(dict_path, has_mask=True): | |||||
| _characters = '' | _characters = '' | ||||
| _ch_symbols = [] | _ch_symbols = [] | ||||
| sy_dict_name = 'sy_dict.txt' | sy_dict_name = 'sy_dict.txt' | ||||
| @@ -25,7 +25,9 @@ def load_symbols(dict_path): | |||||
| _arpabet = ['@' + s for s in _ch_symbols] | _arpabet = ['@' + s for s in _ch_symbols] | ||||
| # Export all symbols: | # Export all symbols: | ||||
| sy = list(_characters) + _arpabet + [_pad, _eos, _mask] | |||||
| sy = list(_characters) + _arpabet + [_pad, _eos] | |||||
| if has_mask: | |||||
| sy.append(_mask) | |||||
| _characters = '' | _characters = '' | ||||
| @@ -38,7 +40,9 @@ def load_symbols(dict_path): | |||||
| _ch_tones.append(line) | _ch_tones.append(line) | ||||
| # Export all tones: | # Export all tones: | ||||
| tone = list(_characters) + _ch_tones + [_pad, _eos, _mask] | |||||
| tone = list(_characters) + _ch_tones + [_pad, _eos] | |||||
| if has_mask: | |||||
| tone.append(_mask) | |||||
| _characters = '' | _characters = '' | ||||
| @@ -51,9 +55,9 @@ def load_symbols(dict_path): | |||||
| _ch_syllable_flags.append(line) | _ch_syllable_flags.append(line) | ||||
| # Export all syllable_flags: | # Export all syllable_flags: | ||||
| syllable_flag = list(_characters) + _ch_syllable_flags + [ | |||||
| _pad, _eos, _mask | |||||
| ] | |||||
| syllable_flag = list(_characters) + _ch_syllable_flags + [_pad, _eos] | |||||
| if has_mask: | |||||
| syllable_flag.append(_mask) | |||||
| _characters = '' | _characters = '' | ||||
| @@ -66,7 +70,9 @@ def load_symbols(dict_path): | |||||
| _ch_word_segments.append(line) | _ch_word_segments.append(line) | ||||
| # Export all syllable_flags: | # Export all syllable_flags: | ||||
| word_segment = list(_characters) + _ch_word_segments + [_pad, _eos, _mask] | |||||
| word_segment = list(_characters) + _ch_word_segments + [_pad, _eos] | |||||
| if has_mask: | |||||
| word_segment.append(_mask) | |||||
| _characters = '' | _characters = '' | ||||
| @@ -78,7 +84,9 @@ def load_symbols(dict_path): | |||||
| line = line.strip('\r\n') | line = line.strip('\r\n') | ||||
| _ch_emo_types.append(line) | _ch_emo_types.append(line) | ||||
| emo_category = list(_characters) + _ch_emo_types + [_pad, _eos, _mask] | |||||
| emo_category = list(_characters) + _ch_emo_types + [_pad, _eos] | |||||
| if has_mask: | |||||
| emo_category.append(_mask) | |||||
| _characters = '' | _characters = '' | ||||
| @@ -91,5 +99,7 @@ def load_symbols(dict_path): | |||||
| _ch_speakers.append(line) | _ch_speakers.append(line) | ||||
| # Export all syllable_flags: | # Export all syllable_flags: | ||||
| speaker = list(_characters) + _ch_speakers + [_pad, _eos, _mask] | |||||
| speaker = list(_characters) + _ch_speakers + [_pad, _eos] | |||||
| if has_mask: | |||||
| speaker.append(_mask) | |||||
| return sy, tone, syllable_flag, word_segment, emo_category, speaker | return sy, tone, syllable_flag, word_segment, emo_category, speaker | ||||
| @@ -1 +0,0 @@ | |||||
| from .hifigan16k import * # noqa F403 | |||||
| @@ -1,74 +0,0 @@ | |||||
| from __future__ import (absolute_import, division, print_function, | |||||
| unicode_literals) | |||||
| import argparse | |||||
| import glob | |||||
| import os | |||||
| import time | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from scipy.io.wavfile import write | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.audio.tts_exceptions import \ | |||||
| TtsVocoderMelspecShapeMismatchException | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from .models import Generator | |||||
| __all__ = ['Hifigan16k', 'AttrDict'] | |||||
| MAX_WAV_VALUE = 32768.0 | |||||
| def load_checkpoint(filepath, device): | |||||
| assert os.path.isfile(filepath) | |||||
| print("Loading '{}'".format(filepath)) | |||||
| checkpoint_dict = torch.load(filepath, map_location=device) | |||||
| print('Complete.') | |||||
| return checkpoint_dict | |||||
| class AttrDict(dict): | |||||
| def __init__(self, *args, **kwargs): | |||||
| super(AttrDict, self).__init__(*args, **kwargs) | |||||
| self.__dict__ = self | |||||
| @MODELS.register_module(Tasks.text_to_speech, module_name=Models.hifigan16k) | |||||
| class Hifigan16k(Model): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| self._ckpt_path = os.path.join(model_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| self._config = AttrDict(**kwargs) | |||||
| super().__init__(self._ckpt_path, *args, **kwargs) | |||||
| if torch.cuda.is_available(): | |||||
| torch.manual_seed(self._config.seed) | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self._generator = Generator(self._config).to(self._device) | |||||
| state_dict_g = load_checkpoint(self._ckpt_path, self._device) | |||||
| self._generator.load_state_dict(state_dict_g['generator']) | |||||
| self._generator.eval() | |||||
| self._generator.remove_weight_norm() | |||||
| def forward(self, melspec): | |||||
| dim0 = list(melspec.shape)[-1] | |||||
| if dim0 != 80: | |||||
| raise TtsVocoderMelspecShapeMismatchException( | |||||
| 'input melspec mismatch 0 dim require 80 but {}'.format(dim0)) | |||||
| with torch.no_grad(): | |||||
| x = melspec.T | |||||
| x = torch.FloatTensor(x).to(self._device) | |||||
| if len(x.shape) == 2: | |||||
| x = x.unsqueeze(0) | |||||
| y_g_hat = self._generator(x) | |||||
| audio = y_g_hat.squeeze() | |||||
| audio = audio * MAX_WAV_VALUE | |||||
| audio = audio.cpu().numpy().astype('int16') | |||||
| return audio | |||||
| @@ -1 +0,0 @@ | |||||
| from .models import Generator | |||||
| @@ -3,46 +3,45 @@ from typing import Any, Dict, List | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.audio.tts import SambertHifigan | |||||
| from modelscope.pipelines.base import Input, InputModel, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import Preprocessor, TextToTacotronSymbols | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.pipelines.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Fields, Tasks | |||||
| __all__ = ['TextToSpeechSambertHifigan16kPipeline'] | |||||
| __all__ = ['TextToSpeechSambertHifiganPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_16k_tts) | |||||
| class TextToSpeechSambertHifigan16kPipeline(Pipeline): | |||||
| Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_tts) | |||||
| class TextToSpeechSambertHifiganPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: List[str] = None, | |||||
| preprocessor: Preprocessor = None, | |||||
| **kwargs): | |||||
| def __init__(self, model: InputModel, **kwargs): | |||||
| """use `model` to create a text-to-speech pipeline for prediction | |||||
| Args: | |||||
| model (SambertHifigan or str): a model instance or valid offical model id | |||||
| """ | """ | ||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| super().__init__(model=model, **kwargs) | |||||
| def forward(self, inputs: Dict[str, str]) -> Dict[str, np.ndarray]: | |||||
| """synthesis text from inputs with pipeline | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | |||||
| inputs (Dict[str, str]): a dictionary that key is the name of | |||||
| certain testcase and value is the text to synthesis. | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: a dictionary with key and value. The key | |||||
| is the same as inputs' key which is the label of the testcase | |||||
| and the value is the pcm audio data. | |||||
| """ | """ | ||||
| assert len(model) == 3, 'model number should be 3' | |||||
| if preprocessor is None: | |||||
| lang_type = 'pinyin' | |||||
| if 'lang_type' in kwargs: | |||||
| lang_type = kwargs.lang_type | |||||
| preprocessor = TextToTacotronSymbols(model[0], lang_type=lang_type) | |||||
| models = [model[1], model[2]] | |||||
| super().__init__(model=models, preprocessor=preprocessor, **kwargs) | |||||
| self._am = self.models[0] | |||||
| self._vocoder = self.models[1] | |||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| texts = inputs['texts'] | |||||
| audio_total = np.empty((0), dtype='int16') | |||||
| for line in texts: | |||||
| line = line.strip().split('\t') | |||||
| audio = self._vocoder.forward(self._am.forward(line[1])) | |||||
| audio_total = np.append(audio_total, audio, axis=0) | |||||
| return {'output': audio_total} | |||||
| output_wav = {} | |||||
| for label, text in inputs.items(): | |||||
| output_wav[label] = self.model.forward(text) | |||||
| return {OutputKeys.OUTPUT_PCM: output_wav} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -263,5 +263,11 @@ TASK_OUTPUTS = { | |||||
| # { | # { | ||||
| # "output_img": np.ndarray with shape [height, width, 3] | # "output_img": np.ndarray with shape [height, width, 3] | ||||
| # } | # } | ||||
| Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG] | |||||
| Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG], | |||||
| # text_to_speech result for a single sample | |||||
| # { | |||||
| # "output_pcm": {"input_label" : np.ndarray with shape [D]} | |||||
| # } | |||||
| Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM] | |||||
| } | } | ||||
| @@ -6,7 +6,6 @@ from .builder import PREPROCESSORS, build_preprocessor | |||||
| from .common import Compose | from .common import Compose | ||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .kws import WavToLists | from .kws import WavToLists | ||||
| from .text_to_speech import * # noqa F403 | |||||
| try: | try: | ||||
| from .audio import LinearAECAndFbank | from .audio import LinearAECAndFbank | ||||
| @@ -1,52 +0,0 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import io | |||||
| from typing import Any, Dict, Union | |||||
| from modelscope.fileio import File | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.models.audio.tts.frontend import GenericTtsFrontend | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.utils.audio.tts_exceptions import * # noqa F403 | |||||
| from modelscope.utils.constant import Fields | |||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | |||||
| __all__ = ['TextToTacotronSymbols'] | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.audio, module_name=Preprocessors.text_to_tacotron_symbols) | |||||
| class TextToTacotronSymbols(Preprocessor): | |||||
| """extract tacotron symbols from text. | |||||
| Args: | |||||
| res_path (str): TTS frontend resource url | |||||
| lang_type (str): language type, valid values are "pinyin" and "chenmix" | |||||
| """ | |||||
| def __init__(self, model_name, lang_type='pinyin'): | |||||
| self._frontend_model = Model.from_pretrained( | |||||
| model_name, lang_type=lang_type) | |||||
| assert self._frontend_model is not None, 'load model from pretained failed' | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """Call functions to load text and get tacotron symbols. | |||||
| Args: | |||||
| input (str): text with utf-8 | |||||
| Returns: | |||||
| symbos (list[str]): texts in tacotron symbols format. | |||||
| """ | |||||
| return self._frontend_model.forward(data) | |||||
| def text_to_tacotron_symbols(text='', path='./', lang='pinyin'): | |||||
| """ simple interface to transform text to tacotron symbols | |||||
| Args: | |||||
| text (str): input text | |||||
| path (str): resource path | |||||
| lang (str): language type from one of "pinyin" and "chenmix" | |||||
| """ | |||||
| transform = TextToTacotronSymbols(path, lang) | |||||
| return transform(text) | |||||
| @@ -10,6 +10,13 @@ class TtsException(Exception): | |||||
| pass | pass | ||||
| class TtsModelConfigurationExcetion(TtsException): | |||||
| """ | |||||
| TTS model configuration exceptions. | |||||
| """ | |||||
| pass | |||||
| class TtsFrontendException(TtsException): | class TtsFrontendException(TtsException): | ||||
| """ | """ | ||||
| TTS frontend module level exceptions. | TTS frontend module level exceptions. | ||||
| @@ -177,7 +177,6 @@ def build_from_cfg(cfg, | |||||
| f'but got {type(default_args)}') | f'but got {type(default_args)}') | ||||
| args = cfg.copy() | args = cfg.copy() | ||||
| if default_args is not None: | if default_args is not None: | ||||
| for name, value in default_args.items(): | for name, value in default_args.items(): | ||||
| args.setdefault(name, value) | args.setdefault(name, value) | ||||
| @@ -20,5 +20,5 @@ torch | |||||
| torchaudio | torchaudio | ||||
| torchvision | torchvision | ||||
| tqdm | tqdm | ||||
| ttsfrd==0.0.2 | |||||
| ttsfrd==0.0.3 | |||||
| unidecode | unidecode | ||||
| @@ -7,10 +7,10 @@ import unittest | |||||
| import torch | import torch | ||||
| from scipy.io.wavfile import write | from scipy.io.wavfile import write | ||||
| from modelscope.metainfo import Pipelines, Preprocessors | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.preprocessors import build_preprocessor | |||||
| from modelscope.pipelines.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Fields, Tasks | from modelscope.utils.constant import Fields, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -24,17 +24,18 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_pipeline(self): | def test_pipeline(self): | ||||
| text = '明天天气怎么样' | |||||
| preprocessor_model_id = 'damo/speech_binary_tts_frontend_resource' | |||||
| am_model_id = 'damo/speech_sambert16k_tts_zhitian_emo' | |||||
| voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo' | |||||
| sambert_tts = pipeline( | |||||
| task=Tasks.text_to_speech, | |||||
| model=[preprocessor_model_id, am_model_id, voc_model_id]) | |||||
| self.assertTrue(sambert_tts is not None) | |||||
| output = sambert_tts(text) | |||||
| self.assertTrue(len(output['output']) > 0) | |||||
| write('output.wav', 16000, output['output']) | |||||
| single_test_case_label = 'test_case_label_0' | |||||
| text = '今天北京天气怎么样?' | |||||
| model_id = 'damo/speech_sambert-hifigan_tts_zhitian_emo_zhcn_16k' | |||||
| sambert_hifigan_tts = pipeline( | |||||
| task=Tasks.text_to_speech, model=model_id) | |||||
| self.assertTrue(sambert_hifigan_tts is not None) | |||||
| test_cases = {single_test_case_label: text} | |||||
| output = sambert_hifigan_tts(test_cases) | |||||
| self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) | |||||
| pcm = output[OutputKeys.OUTPUT_PCM][single_test_case_label] | |||||
| write('output.wav', 16000, pcm) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -1,29 +0,0 @@ | |||||
| import shutil | |||||
| import unittest | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.preprocessors import build_preprocessor | |||||
| from modelscope.utils.constant import Fields, InputFields | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| class TtsPreprocessorTest(unittest.TestCase): | |||||
| def test_preprocess(self): | |||||
| lang_type = 'pinyin' | |||||
| text = '今天天气不错,我们去散步吧。' | |||||
| cfg = dict( | |||||
| type=Preprocessors.text_to_tacotron_symbols, | |||||
| model_name='damo/speech_binary_tts_frontend_resource', | |||||
| lang_type=lang_type) | |||||
| preprocessor = build_preprocessor(cfg, Fields.audio) | |||||
| output = preprocessor(text) | |||||
| self.assertTrue(output) | |||||
| for line in output['texts']: | |||||
| print(line) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||