From d313c440c4e40f638d812fef381f75358217d616 Mon Sep 17 00:00:00 2001 From: "jiaqi.sjq" Date: Fri, 8 Jul 2022 14:26:18 +0800 Subject: [PATCH] [to #9303837] Merge frontend am and vocoder into one model card Merge frontend, am and vocoder model card into one model card. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9303837 --- modelscope/metainfo.py | 6 +- modelscope/models/__init__.py | 3 +- modelscope/models/audio/tts/__init__.py | 1 + modelscope/models/audio/tts/am/__init__.py | 1 - .../models/audio/tts/frontend/__init__.py | 1 - .../generic_text_to_speech_frontend.py | 39 ----- .../audio/tts/{am => }/models/__init__.py | 3 +- .../models/modules.py => models/am_models.py} | 0 .../audio/tts/{am => }/models/compat.py | 0 .../models/audio/tts/{am => }/models/fsmn.py | 0 .../audio/tts/{am => }/models/fsmn_encoder.py | 0 .../audio/tts/{am => }/models/helpers.py | 0 .../audio/tts/{am => }/models/position.py | 0 .../audio/tts/{am => }/models/reducer.py | 0 .../audio/tts/{am => }/models/rnn_wrappers.py | 2 +- .../audio/tts/{am => }/models/robutrans.py | 2 +- .../{am => }/models/self_attention_decoder.py | 2 +- .../{am => }/models/self_attention_encoder.py | 0 .../audio/tts/{am => }/models/transformer.py | 0 .../audio/tts/{vocoder => }/models/utils.py | 0 .../models.py => models/vocoder_models.py} | 0 .../sambert_hifi_16k.py => sambert_hifi.py} | 135 +++++++++++++++--- .../audio/tts/{am => }/text/__init__.py | 0 .../audio/tts/{am => }/text/cleaners.py | 0 .../models/audio/tts/{am => }/text/cmudict.py | 0 .../models/audio/tts/{am => }/text/numbers.py | 0 .../models/audio/tts/{am => }/text/symbols.py | 28 ++-- .../audio/tts/{am => }/text/symbols_dict.py | 0 .../models/audio/tts/vocoder/__init__.py | 1 - .../models/audio/tts/vocoder/hifigan16k.py | 74 ---------- .../audio/tts/vocoder/models/__init__.py | 1 - .../audio/text_to_speech_pipeline.py | 61 ++++---- modelscope/pipelines/outputs.py | 8 +- modelscope/preprocessors/__init__.py | 1 - modelscope/preprocessors/text_to_speech.py | 52 ------- modelscope/utils/audio/tts_exceptions.py | 7 + modelscope/utils/registry.py | 1 - requirements/audio.txt | 2 +- tests/pipelines/test_text_to_speech.py | 27 ++-- tests/preprocessors/test_text_to_speech.py | 29 ---- 40 files changed, 203 insertions(+), 284 deletions(-) delete mode 100644 modelscope/models/audio/tts/am/__init__.py delete mode 100644 modelscope/models/audio/tts/frontend/__init__.py delete mode 100644 modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py rename modelscope/models/audio/tts/{am => }/models/__init__.py (67%) rename modelscope/models/audio/tts/{am/models/modules.py => models/am_models.py} (100%) rename modelscope/models/audio/tts/{am => }/models/compat.py (100%) rename modelscope/models/audio/tts/{am => }/models/fsmn.py (100%) rename modelscope/models/audio/tts/{am => }/models/fsmn_encoder.py (100%) rename modelscope/models/audio/tts/{am => }/models/helpers.py (100%) rename modelscope/models/audio/tts/{am => }/models/position.py (100%) rename modelscope/models/audio/tts/{am => }/models/reducer.py (100%) rename modelscope/models/audio/tts/{am => }/models/rnn_wrappers.py (99%) rename modelscope/models/audio/tts/{am => }/models/robutrans.py (99%) rename modelscope/models/audio/tts/{am => }/models/self_attention_decoder.py (99%) rename modelscope/models/audio/tts/{am => }/models/self_attention_encoder.py (100%) rename modelscope/models/audio/tts/{am => }/models/transformer.py (100%) rename modelscope/models/audio/tts/{vocoder => }/models/utils.py (100%) rename modelscope/models/audio/tts/{vocoder/models/models.py => models/vocoder_models.py} (100%) rename modelscope/models/audio/tts/{am/sambert_hifi_16k.py => sambert_hifi.py} (68%) rename modelscope/models/audio/tts/{am => }/text/__init__.py (100%) rename modelscope/models/audio/tts/{am => }/text/cleaners.py (100%) rename modelscope/models/audio/tts/{am => }/text/cmudict.py (100%) rename modelscope/models/audio/tts/{am => }/text/numbers.py (100%) rename modelscope/models/audio/tts/{am => }/text/symbols.py (80%) rename modelscope/models/audio/tts/{am => }/text/symbols_dict.py (100%) delete mode 100644 modelscope/models/audio/tts/vocoder/__init__.py delete mode 100644 modelscope/models/audio/tts/vocoder/hifigan16k.py delete mode 100644 modelscope/models/audio/tts/vocoder/models/__init__.py delete mode 100644 modelscope/preprocessors/text_to_speech.py delete mode 100644 tests/preprocessors/test_text_to_speech.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 555de643..a1dbc95e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -20,9 +20,7 @@ class Models(object): space = 'space' # audio models - sambert_hifi_16k = 'sambert-hifi-16k' - generic_tts_frontend = 'generic-tts-frontend' - hifigan16k = 'hifigan16k' + sambert_hifigan = 'sambert-hifigan' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' @@ -66,7 +64,7 @@ class Pipelines(object): zero_shot_classification = 'zero-shot-classification' # audio tasks - sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' + sambert_hifigan_tts = 'sambert-hifigan-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 2e12d6ad..b5913d2c 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -5,8 +5,7 @@ from .base import Model from .builder import MODELS, build_model try: - from .audio.tts.am import SambertNetHifi16k - from .audio.tts.vocoder import Hifigan16k + from .audio.tts import SambertHifigan from .audio.kws import GenericKeyWordSpotting from .audio.ans.frcrn import FRCRNModel except ModuleNotFoundError as e: diff --git a/modelscope/models/audio/tts/__init__.py b/modelscope/models/audio/tts/__init__.py index e69de29b..12e5029b 100644 --- a/modelscope/models/audio/tts/__init__.py +++ b/modelscope/models/audio/tts/__init__.py @@ -0,0 +1 @@ +from .sambert_hifi import * # noqa F403 diff --git a/modelscope/models/audio/tts/am/__init__.py b/modelscope/models/audio/tts/am/__init__.py deleted file mode 100644 index 2ebbda1c..00000000 --- a/modelscope/models/audio/tts/am/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sambert_hifi_16k import * # noqa F403 diff --git a/modelscope/models/audio/tts/frontend/__init__.py b/modelscope/models/audio/tts/frontend/__init__.py deleted file mode 100644 index d7b1015d..00000000 --- a/modelscope/models/audio/tts/frontend/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .generic_text_to_speech_frontend import * # noqa F403 diff --git a/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py b/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py deleted file mode 100644 index 757e4db9..00000000 --- a/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import zipfile -from typing import Any, Dict, List - -from modelscope.metainfo import Models -from modelscope.models.base import Model -from modelscope.models.builder import MODELS -from modelscope.utils.audio.tts_exceptions import ( - TtsFrontendInitializeFailedException, - TtsFrontendLanguageTypeInvalidException) -from modelscope.utils.constant import Tasks - -__all__ = ['GenericTtsFrontend'] - - -@MODELS.register_module( - Tasks.text_to_speech, module_name=Models.generic_tts_frontend) -class GenericTtsFrontend(Model): - - def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs): - super().__init__(model_dir, *args, **kwargs) - import ttsfrd - frontend = ttsfrd.TtsFrontendEngine() - zip_file = os.path.join(model_dir, 'resource.zip') - self._res_path = os.path.join(model_dir, 'resource') - with zipfile.ZipFile(zip_file, 'r') as zip_ref: - zip_ref.extractall(model_dir) - if not frontend.initialize(self._res_path): - raise TtsFrontendInitializeFailedException( - 'resource invalid: {}'.format(self._res_path)) - if not frontend.set_lang_type(lang_type): - raise TtsFrontendLanguageTypeInvalidException( - 'language type invalid: {}, valid is pinyin and chenmix'. - format(lang_type)) - self._frontend = frontend - - def forward(self, data: str) -> Dict[str, List]: - result = self._frontend.gen_tacotron_symbols(data) - return {'texts': [s for s in result.splitlines() if s != '']} diff --git a/modelscope/models/audio/tts/am/models/__init__.py b/modelscope/models/audio/tts/models/__init__.py similarity index 67% rename from modelscope/models/audio/tts/am/models/__init__.py rename to modelscope/models/audio/tts/models/__init__.py index 9e198e7a..c260d4fe 100755 --- a/modelscope/models/audio/tts/am/models/__init__.py +++ b/modelscope/models/audio/tts/models/__init__.py @@ -1,7 +1,8 @@ from .robutrans import RobuTrans +from .vocoder_models import Generator -def create_model(name, hparams): +def create_am_model(name, hparams): if name == 'robutrans': return RobuTrans(hparams) else: diff --git a/modelscope/models/audio/tts/am/models/modules.py b/modelscope/models/audio/tts/models/am_models.py similarity index 100% rename from modelscope/models/audio/tts/am/models/modules.py rename to modelscope/models/audio/tts/models/am_models.py diff --git a/modelscope/models/audio/tts/am/models/compat.py b/modelscope/models/audio/tts/models/compat.py similarity index 100% rename from modelscope/models/audio/tts/am/models/compat.py rename to modelscope/models/audio/tts/models/compat.py diff --git a/modelscope/models/audio/tts/am/models/fsmn.py b/modelscope/models/audio/tts/models/fsmn.py similarity index 100% rename from modelscope/models/audio/tts/am/models/fsmn.py rename to modelscope/models/audio/tts/models/fsmn.py diff --git a/modelscope/models/audio/tts/am/models/fsmn_encoder.py b/modelscope/models/audio/tts/models/fsmn_encoder.py similarity index 100% rename from modelscope/models/audio/tts/am/models/fsmn_encoder.py rename to modelscope/models/audio/tts/models/fsmn_encoder.py diff --git a/modelscope/models/audio/tts/am/models/helpers.py b/modelscope/models/audio/tts/models/helpers.py similarity index 100% rename from modelscope/models/audio/tts/am/models/helpers.py rename to modelscope/models/audio/tts/models/helpers.py diff --git a/modelscope/models/audio/tts/am/models/position.py b/modelscope/models/audio/tts/models/position.py similarity index 100% rename from modelscope/models/audio/tts/am/models/position.py rename to modelscope/models/audio/tts/models/position.py diff --git a/modelscope/models/audio/tts/am/models/reducer.py b/modelscope/models/audio/tts/models/reducer.py similarity index 100% rename from modelscope/models/audio/tts/am/models/reducer.py rename to modelscope/models/audio/tts/models/reducer.py diff --git a/modelscope/models/audio/tts/am/models/rnn_wrappers.py b/modelscope/models/audio/tts/models/rnn_wrappers.py similarity index 99% rename from modelscope/models/audio/tts/am/models/rnn_wrappers.py rename to modelscope/models/audio/tts/models/rnn_wrappers.py index 8f0d612b..85a6b335 100755 --- a/modelscope/models/audio/tts/am/models/rnn_wrappers.py +++ b/modelscope/models/audio/tts/models/rnn_wrappers.py @@ -4,7 +4,7 @@ from tensorflow.contrib.rnn import RNNCell from tensorflow.contrib.seq2seq import AttentionWrapperState from tensorflow.python.ops import rnn_cell_impl -from .modules import prenet +from .am_models import prenet class VarPredictorCell(RNNCell): diff --git a/modelscope/models/audio/tts/am/models/robutrans.py b/modelscope/models/audio/tts/models/robutrans.py similarity index 99% rename from modelscope/models/audio/tts/am/models/robutrans.py rename to modelscope/models/audio/tts/models/robutrans.py index 34b4da7a..d5bafcec 100755 --- a/modelscope/models/audio/tts/am/models/robutrans.py +++ b/modelscope/models/audio/tts/models/robutrans.py @@ -3,9 +3,9 @@ from tensorflow.contrib.rnn import LSTMBlockCell, MultiRNNCell from tensorflow.contrib.seq2seq import BasicDecoder from tensorflow.python.ops.ragged.ragged_util import repeat +from .am_models import conv_prenet, decoder_prenet, encoder_prenet from .fsmn_encoder import FsmnEncoderV2 from .helpers import VarTestHelper, VarTrainingHelper -from .modules import conv_prenet, decoder_prenet, encoder_prenet from .position import (BatchSinusodalPositionalEncoding, SinusodalPositionalEncoding) from .rnn_wrappers import DurPredictorCell, VarPredictorCell diff --git a/modelscope/models/audio/tts/am/models/self_attention_decoder.py b/modelscope/models/audio/tts/models/self_attention_decoder.py similarity index 99% rename from modelscope/models/audio/tts/am/models/self_attention_decoder.py rename to modelscope/models/audio/tts/models/self_attention_decoder.py index 4e64342c..9cf3fcaa 100755 --- a/modelscope/models/audio/tts/am/models/self_attention_decoder.py +++ b/modelscope/models/audio/tts/models/self_attention_decoder.py @@ -5,7 +5,7 @@ import sys import tensorflow as tf from . import compat, transformer -from .modules import decoder_prenet +from .am_models import decoder_prenet from .position import SinusoidalPositionEncoder diff --git a/modelscope/models/audio/tts/am/models/self_attention_encoder.py b/modelscope/models/audio/tts/models/self_attention_encoder.py similarity index 100% rename from modelscope/models/audio/tts/am/models/self_attention_encoder.py rename to modelscope/models/audio/tts/models/self_attention_encoder.py diff --git a/modelscope/models/audio/tts/am/models/transformer.py b/modelscope/models/audio/tts/models/transformer.py similarity index 100% rename from modelscope/models/audio/tts/am/models/transformer.py rename to modelscope/models/audio/tts/models/transformer.py diff --git a/modelscope/models/audio/tts/vocoder/models/utils.py b/modelscope/models/audio/tts/models/utils.py similarity index 100% rename from modelscope/models/audio/tts/vocoder/models/utils.py rename to modelscope/models/audio/tts/models/utils.py diff --git a/modelscope/models/audio/tts/vocoder/models/models.py b/modelscope/models/audio/tts/models/vocoder_models.py similarity index 100% rename from modelscope/models/audio/tts/vocoder/models/models.py rename to modelscope/models/audio/tts/models/vocoder_models.py diff --git a/modelscope/models/audio/tts/am/sambert_hifi_16k.py b/modelscope/models/audio/tts/sambert_hifi.py similarity index 68% rename from modelscope/models/audio/tts/am/sambert_hifi_16k.py rename to modelscope/models/audio/tts/sambert_hifi.py index fc6d519a..72c5b80c 100644 --- a/modelscope/models/audio/tts/am/sambert_hifi_16k.py +++ b/modelscope/models/audio/tts/sambert_hifi.py @@ -1,20 +1,31 @@ +from __future__ import (absolute_import, division, print_function, + unicode_literals) import io import os +import time +import zipfile from typing import Any, Dict, Optional, Union +import json import numpy as np import tensorflow as tf +import torch from sklearn.preprocessing import MultiLabelBinarizer from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS +from modelscope.utils.audio.tts_exceptions import ( + TtsFrontendInitializeFailedException, + TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationExcetion, + TtsVocoderMelspecShapeMismatchException) from modelscope.utils.constant import ModelFile, Tasks -from .models import create_model +from .models import Generator, create_am_model from .text.symbols import load_symbols from .text.symbols_dict import SymbolsDict -__all__ = ['SambertNetHifi16k'] +__all__ = ['SambertHifigan'] +MAX_WAV_VALUE = 32768.0 def multi_label_symbol_to_sequence(my_classes, my_symbol): @@ -23,13 +34,25 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol): sequences = [] for token in tokens: sequences.append(tuple(token.split('&'))) - # sequences.append(tuple(['~'])) # sequence length minus 1 to ignore EOS ~ return one_hot.fit_transform(sequences) +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + checkpoint_dict = torch.load(filepath, map_location=device) + return checkpoint_dict + + +class AttrDict(dict): + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + @MODELS.register_module( - Tasks.text_to_speech, module_name=Models.sambert_hifi_16k) -class SambertNetHifi16k(Model): + Tasks.text_to_speech, module_name=Models.sambert_hifigan) +class SambertHifigan(Model): def __init__(self, model_dir, @@ -38,20 +61,50 @@ class SambertNetHifi16k(Model): energy_control_str='', *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + if 'am' not in kwargs: + raise TtsModelConfigurationExcetion( + 'configuration model field missing am!') + if 'vocoder' not in kwargs: + raise TtsModelConfigurationExcetion( + 'configuration model field missing vocoder!') + if 'lang_type' not in kwargs: + raise TtsModelConfigurationExcetion( + 'configuration model field missing lang_type!') + # initialize frontend + import ttsfrd + frontend = ttsfrd.TtsFrontendEngine() + zip_file = os.path.join(model_dir, 'resource.zip') + self._res_path = os.path.join(model_dir, 'resource') + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(model_dir) + if not frontend.initialize(self._res_path): + raise TtsFrontendInitializeFailedException( + 'resource invalid: {}'.format(self._res_path)) + if not frontend.set_lang_type(kwargs['lang_type']): + raise TtsFrontendLanguageTypeInvalidException( + 'language type invalid: {}'.format(kwargs['lang_type'])) + self._frontend = frontend + + # initialize am tf.reset_default_graph() - local_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, 'ckpt') - self._ckpt_path = os.path.join(model_dir, local_ckpt_path) + local_am_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, + 'ckpt') + self._am_ckpt_path = os.path.join(model_dir, local_am_ckpt_path) self._dict_path = os.path.join(model_dir, 'dicts') - self._hparams = tf.contrib.training.HParams(**kwargs) - values = self._hparams.values() + self._am_hparams = tf.contrib.training.HParams(**kwargs['am']) + has_mask = True + if self._am_hparams.get('has_mask') is not None: + has_mask = self._am_hparams.has_mask + print('set has_mask to {}'.format(has_mask)) + values = self._am_hparams.values() hp = [' {}:{}'.format(name, values[name]) for name in sorted(values)] print('Hyperparameters:\n' + '\n'.join(hp)) - super().__init__(self._ckpt_path, *args, **kwargs) model_name = 'robutrans' - self._lfeat_type_list = self._hparams.lfeat_type_list.strip().split( + self._lfeat_type_list = self._am_hparams.lfeat_type_list.strip().split( ',') sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( - self._dict_path) + self._dict_path, has_mask) self._sy = sy self._tone = tone self._syllable_flag = syllable_flag @@ -86,7 +139,6 @@ class SambertNetHifi16k(Model): inputs_speaker = tf.placeholder(tf.float32, [1, None, self._inputs_dim['speaker']], 'inputs_speaker') - input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') pitch_contours_scale = tf.placeholder(tf.float32, [1, None], 'pitch_contours_scale') @@ -94,9 +146,8 @@ class SambertNetHifi16k(Model): 'energy_contours_scale') duration_scale = tf.placeholder(tf.float32, [1, None], 'duration_scale') - with tf.variable_scope('model') as _: - self._model = create_model(model_name, self._hparams) + self._model = create_am_model(model_name, self._am_hparams) self._model.initialize( inputs, inputs_emotion, @@ -123,14 +174,14 @@ class SambertNetHifi16k(Model): self._attention_h = self._model.attention_h self._attention_x = self._model.attention_x - print('Loading checkpoint: %s' % self._ckpt_path) + print('Loading checkpoint: %s' % self._am_ckpt_path) config = tf.ConfigProto() config.gpu_options.allow_growth = True self._session = tf.Session(config=config) self._session.run(tf.global_variables_initializer()) saver = tf.train.Saver() - saver.restore(self._session, self._ckpt_path) + saver.restore(self._session, self._am_ckpt_path) duration_cfg_lst = [] if len(duration_control_str) != 0: @@ -158,8 +209,26 @@ class SambertNetHifi16k(Model): self._energy_contours_cfg_lst = energy_contours_cfg_lst - def forward(self, text): - cleaner_names = [x.strip() for x in self._hparams.cleaners.split(',')] + # initialize vocoder + self._voc_ckpt_path = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + self._voc_config = AttrDict(**kwargs['vocoder']) + print(self._voc_config) + if torch.cuda.is_available(): + torch.manual_seed(self._voc_config.seed) + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self._generator = Generator(self._voc_config).to(self._device) + state_dict_g = load_checkpoint(self._voc_ckpt_path, self._device) + self._generator.load_state_dict(state_dict_g['generator']) + self._generator.eval() + self._generator.remove_weight_norm() + + def am_synthesis_one_sentences(self, text): + cleaner_names = [ + x.strip() for x in self._am_hparams.cleaners.split(',') + ] lfeat_symbol = text.strip().split(' ') lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) @@ -255,3 +324,31 @@ class SambertNetHifi16k(Model): self._energy_embeddings, self._attention_x, self._attention_h ], feed_dict=feed_dict) # yapf:disable return result[0] + + def vocoder_process(self, melspec): + dim0 = list(melspec.shape)[-1] + if dim0 != self._voc_config.num_mels: + raise TtsVocoderMelspecShapeMismatchException( + 'input melspec mismatch require {} but {}'.format( + self._voc_config.num_mels, dim0)) + with torch.no_grad(): + x = melspec.T + x = torch.FloatTensor(x).to(self._device) + if len(x.shape) == 2: + x = x.unsqueeze(0) + y_g_hat = self._generator(x) + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype('int16') + return audio + + def forward(self, text): + result = self._frontend.gen_tacotron_symbols(text) + texts = [s for s in result.splitlines() if s != ''] + audio_total = np.empty((0), dtype='int16') + for line in texts: + line = line.strip().split('\t') + audio = self.vocoder_process( + self.am_synthesis_one_sentences(line[1])) + audio_total = np.append(audio_total, audio, axis=0) + return audio_total diff --git a/modelscope/models/audio/tts/am/text/__init__.py b/modelscope/models/audio/tts/text/__init__.py similarity index 100% rename from modelscope/models/audio/tts/am/text/__init__.py rename to modelscope/models/audio/tts/text/__init__.py diff --git a/modelscope/models/audio/tts/am/text/cleaners.py b/modelscope/models/audio/tts/text/cleaners.py similarity index 100% rename from modelscope/models/audio/tts/am/text/cleaners.py rename to modelscope/models/audio/tts/text/cleaners.py diff --git a/modelscope/models/audio/tts/am/text/cmudict.py b/modelscope/models/audio/tts/text/cmudict.py similarity index 100% rename from modelscope/models/audio/tts/am/text/cmudict.py rename to modelscope/models/audio/tts/text/cmudict.py diff --git a/modelscope/models/audio/tts/am/text/numbers.py b/modelscope/models/audio/tts/text/numbers.py similarity index 100% rename from modelscope/models/audio/tts/am/text/numbers.py rename to modelscope/models/audio/tts/text/numbers.py diff --git a/modelscope/models/audio/tts/am/text/symbols.py b/modelscope/models/audio/tts/text/symbols.py similarity index 80% rename from modelscope/models/audio/tts/am/text/symbols.py rename to modelscope/models/audio/tts/text/symbols.py index a7715cca..63975abb 100644 --- a/modelscope/models/audio/tts/am/text/symbols.py +++ b/modelscope/models/audio/tts/text/symbols.py @@ -12,7 +12,7 @@ _eos = '~' _mask = '@[MASK]' -def load_symbols(dict_path): +def load_symbols(dict_path, has_mask=True): _characters = '' _ch_symbols = [] sy_dict_name = 'sy_dict.txt' @@ -25,7 +25,9 @@ def load_symbols(dict_path): _arpabet = ['@' + s for s in _ch_symbols] # Export all symbols: - sy = list(_characters) + _arpabet + [_pad, _eos, _mask] + sy = list(_characters) + _arpabet + [_pad, _eos] + if has_mask: + sy.append(_mask) _characters = '' @@ -38,7 +40,9 @@ def load_symbols(dict_path): _ch_tones.append(line) # Export all tones: - tone = list(_characters) + _ch_tones + [_pad, _eos, _mask] + tone = list(_characters) + _ch_tones + [_pad, _eos] + if has_mask: + tone.append(_mask) _characters = '' @@ -51,9 +55,9 @@ def load_symbols(dict_path): _ch_syllable_flags.append(line) # Export all syllable_flags: - syllable_flag = list(_characters) + _ch_syllable_flags + [ - _pad, _eos, _mask - ] + syllable_flag = list(_characters) + _ch_syllable_flags + [_pad, _eos] + if has_mask: + syllable_flag.append(_mask) _characters = '' @@ -66,7 +70,9 @@ def load_symbols(dict_path): _ch_word_segments.append(line) # Export all syllable_flags: - word_segment = list(_characters) + _ch_word_segments + [_pad, _eos, _mask] + word_segment = list(_characters) + _ch_word_segments + [_pad, _eos] + if has_mask: + word_segment.append(_mask) _characters = '' @@ -78,7 +84,9 @@ def load_symbols(dict_path): line = line.strip('\r\n') _ch_emo_types.append(line) - emo_category = list(_characters) + _ch_emo_types + [_pad, _eos, _mask] + emo_category = list(_characters) + _ch_emo_types + [_pad, _eos] + if has_mask: + emo_category.append(_mask) _characters = '' @@ -91,5 +99,7 @@ def load_symbols(dict_path): _ch_speakers.append(line) # Export all syllable_flags: - speaker = list(_characters) + _ch_speakers + [_pad, _eos, _mask] + speaker = list(_characters) + _ch_speakers + [_pad, _eos] + if has_mask: + speaker.append(_mask) return sy, tone, syllable_flag, word_segment, emo_category, speaker diff --git a/modelscope/models/audio/tts/am/text/symbols_dict.py b/modelscope/models/audio/tts/text/symbols_dict.py similarity index 100% rename from modelscope/models/audio/tts/am/text/symbols_dict.py rename to modelscope/models/audio/tts/text/symbols_dict.py diff --git a/modelscope/models/audio/tts/vocoder/__init__.py b/modelscope/models/audio/tts/vocoder/__init__.py deleted file mode 100644 index 94f257f8..00000000 --- a/modelscope/models/audio/tts/vocoder/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .hifigan16k import * # noqa F403 diff --git a/modelscope/models/audio/tts/vocoder/hifigan16k.py b/modelscope/models/audio/tts/vocoder/hifigan16k.py deleted file mode 100644 index b3fd9cf6..00000000 --- a/modelscope/models/audio/tts/vocoder/hifigan16k.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import (absolute_import, division, print_function, - unicode_literals) -import argparse -import glob -import os -import time - -import json -import numpy as np -import torch -from scipy.io.wavfile import write - -from modelscope.metainfo import Models -from modelscope.models.base import Model -from modelscope.models.builder import MODELS -from modelscope.utils.audio.tts_exceptions import \ - TtsVocoderMelspecShapeMismatchException -from modelscope.utils.constant import ModelFile, Tasks -from .models import Generator - -__all__ = ['Hifigan16k', 'AttrDict'] -MAX_WAV_VALUE = 32768.0 - - -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - print("Loading '{}'".format(filepath)) - checkpoint_dict = torch.load(filepath, map_location=device) - print('Complete.') - return checkpoint_dict - - -class AttrDict(dict): - - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - -@MODELS.register_module(Tasks.text_to_speech, module_name=Models.hifigan16k) -class Hifigan16k(Model): - - def __init__(self, model_dir, *args, **kwargs): - self._ckpt_path = os.path.join(model_dir, - ModelFile.TORCH_MODEL_BIN_FILE) - self._config = AttrDict(**kwargs) - - super().__init__(self._ckpt_path, *args, **kwargs) - if torch.cuda.is_available(): - torch.manual_seed(self._config.seed) - self._device = torch.device('cuda') - else: - self._device = torch.device('cpu') - self._generator = Generator(self._config).to(self._device) - state_dict_g = load_checkpoint(self._ckpt_path, self._device) - self._generator.load_state_dict(state_dict_g['generator']) - self._generator.eval() - self._generator.remove_weight_norm() - - def forward(self, melspec): - dim0 = list(melspec.shape)[-1] - if dim0 != 80: - raise TtsVocoderMelspecShapeMismatchException( - 'input melspec mismatch 0 dim require 80 but {}'.format(dim0)) - with torch.no_grad(): - x = melspec.T - x = torch.FloatTensor(x).to(self._device) - if len(x.shape) == 2: - x = x.unsqueeze(0) - y_g_hat = self._generator(x) - audio = y_g_hat.squeeze() - audio = audio * MAX_WAV_VALUE - audio = audio.cpu().numpy().astype('int16') - return audio diff --git a/modelscope/models/audio/tts/vocoder/models/__init__.py b/modelscope/models/audio/tts/vocoder/models/__init__.py deleted file mode 100644 index b00eec9b..00000000 --- a/modelscope/models/audio/tts/vocoder/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .models import Generator diff --git a/modelscope/pipelines/audio/text_to_speech_pipeline.py b/modelscope/pipelines/audio/text_to_speech_pipeline.py index 142d697d..d8d7ca02 100644 --- a/modelscope/pipelines/audio/text_to_speech_pipeline.py +++ b/modelscope/pipelines/audio/text_to_speech_pipeline.py @@ -3,46 +3,45 @@ from typing import Any, Dict, List import numpy as np from modelscope.metainfo import Pipelines -from modelscope.pipelines.base import Pipeline +from modelscope.models import Model +from modelscope.models.audio.tts import SambertHifigan +from modelscope.pipelines.base import Input, InputModel, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import Preprocessor, TextToTacotronSymbols -from modelscope.utils.constant import Tasks +from modelscope.pipelines.outputs import OutputKeys +from modelscope.utils.constant import Fields, Tasks -__all__ = ['TextToSpeechSambertHifigan16kPipeline'] +__all__ = ['TextToSpeechSambertHifiganPipeline'] @PIPELINES.register_module( - Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_16k_tts) -class TextToSpeechSambertHifigan16kPipeline(Pipeline): + Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_tts) +class TextToSpeechSambertHifiganPipeline(Pipeline): - def __init__(self, - model: List[str] = None, - preprocessor: Preprocessor = None, - **kwargs): + def __init__(self, model: InputModel, **kwargs): + """use `model` to create a text-to-speech pipeline for prediction + + Args: + model (SambertHifigan or str): a model instance or valid offical model id """ - use `model` and `preprocessor` to create a kws pipeline for prediction + super().__init__(model=model, **kwargs) + + def forward(self, inputs: Dict[str, str]) -> Dict[str, np.ndarray]: + """synthesis text from inputs with pipeline Args: - model: model id on modelscope hub. + inputs (Dict[str, str]): a dictionary that key is the name of + certain testcase and value is the text to synthesis. + Returns: + Dict[str, np.ndarray]: a dictionary with key and value. The key + is the same as inputs' key which is the label of the testcase + and the value is the pcm audio data. """ - assert len(model) == 3, 'model number should be 3' - if preprocessor is None: - lang_type = 'pinyin' - if 'lang_type' in kwargs: - lang_type = kwargs.lang_type - preprocessor = TextToTacotronSymbols(model[0], lang_type=lang_type) - models = [model[1], model[2]] - super().__init__(model=models, preprocessor=preprocessor, **kwargs) - self._am = self.models[0] - self._vocoder = self.models[1] - - def forward(self, inputs: Dict[str, Any]) -> Dict[str, np.ndarray]: - texts = inputs['texts'] - audio_total = np.empty((0), dtype='int16') - for line in texts: - line = line.strip().split('\t') - audio = self._vocoder.forward(self._am.forward(line[1])) - audio_total = np.append(audio_total, audio, axis=0) - return {'output': audio_total} + output_wav = {} + for label, text in inputs.items(): + output_wav[label] = self.model.forward(text) + return {OutputKeys.OUTPUT_PCM: output_wav} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 368586df..b1b0e86c 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -263,5 +263,11 @@ TASK_OUTPUTS = { # { # "output_img": np.ndarray with shape [height, width, 3] # } - Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG] + Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG], + + # text_to_speech result for a single sample + # { + # "output_pcm": {"input_label" : np.ndarray with shape [D]} + # } + Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM] } diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 29839c2b..95d1f3b2 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -6,7 +6,6 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .kws import WavToLists -from .text_to_speech import * # noqa F403 try: from .audio import LinearAECAndFbank diff --git a/modelscope/preprocessors/text_to_speech.py b/modelscope/preprocessors/text_to_speech.py deleted file mode 100644 index 9d8af6fa..00000000 --- a/modelscope/preprocessors/text_to_speech.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import io -from typing import Any, Dict, Union - -from modelscope.fileio import File -from modelscope.metainfo import Preprocessors -from modelscope.models.audio.tts.frontend import GenericTtsFrontend -from modelscope.models.base import Model -from modelscope.utils.audio.tts_exceptions import * # noqa F403 -from modelscope.utils.constant import Fields -from .base import Preprocessor -from .builder import PREPROCESSORS - -__all__ = ['TextToTacotronSymbols'] - - -@PREPROCESSORS.register_module( - Fields.audio, module_name=Preprocessors.text_to_tacotron_symbols) -class TextToTacotronSymbols(Preprocessor): - """extract tacotron symbols from text. - - Args: - res_path (str): TTS frontend resource url - lang_type (str): language type, valid values are "pinyin" and "chenmix" - """ - - def __init__(self, model_name, lang_type='pinyin'): - self._frontend_model = Model.from_pretrained( - model_name, lang_type=lang_type) - assert self._frontend_model is not None, 'load model from pretained failed' - - def __call__(self, data: str) -> Dict[str, Any]: - """Call functions to load text and get tacotron symbols. - - Args: - input (str): text with utf-8 - Returns: - symbos (list[str]): texts in tacotron symbols format. - """ - return self._frontend_model.forward(data) - - -def text_to_tacotron_symbols(text='', path='./', lang='pinyin'): - """ simple interface to transform text to tacotron symbols - - Args: - text (str): input text - path (str): resource path - lang (str): language type from one of "pinyin" and "chenmix" - """ - transform = TextToTacotronSymbols(path, lang) - return transform(text) diff --git a/modelscope/utils/audio/tts_exceptions.py b/modelscope/utils/audio/tts_exceptions.py index 1ca731c3..6204582d 100644 --- a/modelscope/utils/audio/tts_exceptions.py +++ b/modelscope/utils/audio/tts_exceptions.py @@ -10,6 +10,13 @@ class TtsException(Exception): pass +class TtsModelConfigurationExcetion(TtsException): + """ + TTS model configuration exceptions. + """ + pass + + class TtsFrontendException(TtsException): """ TTS frontend module level exceptions. diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 2e1f8672..1ace79ba 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -177,7 +177,6 @@ def build_from_cfg(cfg, f'but got {type(default_args)}') args = cfg.copy() - if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) diff --git a/requirements/audio.txt b/requirements/audio.txt index 1e1e577b..feb4eb82 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -20,5 +20,5 @@ torch torchaudio torchvision tqdm -ttsfrd==0.0.2 +ttsfrd==0.0.3 unidecode diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py index c445f46f..bd9ddb20 100644 --- a/tests/pipelines/test_text_to_speech.py +++ b/tests/pipelines/test_text_to_speech.py @@ -7,10 +7,10 @@ import unittest import torch from scipy.io.wavfile import write -from modelscope.metainfo import Pipelines, Preprocessors +from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.pipelines import pipeline -from modelscope.preprocessors import build_preprocessor +from modelscope.pipelines.outputs import OutputKeys from modelscope.utils.constant import Fields, Tasks from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import test_level @@ -24,17 +24,18 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_pipeline(self): - text = '明天天气怎么样' - preprocessor_model_id = 'damo/speech_binary_tts_frontend_resource' - am_model_id = 'damo/speech_sambert16k_tts_zhitian_emo' - voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo' - sambert_tts = pipeline( - task=Tasks.text_to_speech, - model=[preprocessor_model_id, am_model_id, voc_model_id]) - self.assertTrue(sambert_tts is not None) - output = sambert_tts(text) - self.assertTrue(len(output['output']) > 0) - write('output.wav', 16000, output['output']) + single_test_case_label = 'test_case_label_0' + text = '今天北京天气怎么样?' + model_id = 'damo/speech_sambert-hifigan_tts_zhitian_emo_zhcn_16k' + + sambert_hifigan_tts = pipeline( + task=Tasks.text_to_speech, model=model_id) + self.assertTrue(sambert_hifigan_tts is not None) + test_cases = {single_test_case_label: text} + output = sambert_hifigan_tts(test_cases) + self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) + pcm = output[OutputKeys.OUTPUT_PCM][single_test_case_label] + write('output.wav', 16000, pcm) if __name__ == '__main__': diff --git a/tests/preprocessors/test_text_to_speech.py b/tests/preprocessors/test_text_to_speech.py deleted file mode 100644 index fd2473fd..00000000 --- a/tests/preprocessors/test_text_to_speech.py +++ /dev/null @@ -1,29 +0,0 @@ -import shutil -import unittest - -from modelscope.metainfo import Preprocessors -from modelscope.preprocessors import build_preprocessor -from modelscope.utils.constant import Fields, InputFields -from modelscope.utils.logger import get_logger - -logger = get_logger() - - -class TtsPreprocessorTest(unittest.TestCase): - - def test_preprocess(self): - lang_type = 'pinyin' - text = '今天天气不错,我们去散步吧。' - cfg = dict( - type=Preprocessors.text_to_tacotron_symbols, - model_name='damo/speech_binary_tts_frontend_resource', - lang_type=lang_type) - preprocessor = build_preprocessor(cfg, Fields.audio) - output = preprocessor(text) - self.assertTrue(output) - for line in output['texts']: - print(line) - - -if __name__ == '__main__': - unittest.main()