|
|
|
@@ -1,20 +1,31 @@ |
|
|
|
from __future__ import (absolute_import, division, print_function, |
|
|
|
unicode_literals) |
|
|
|
import io |
|
|
|
import os |
|
|
|
import time |
|
|
|
import zipfile |
|
|
|
from typing import Any, Dict, Optional, Union |
|
|
|
|
|
|
|
import json |
|
|
|
import numpy as np |
|
|
|
import tensorflow as tf |
|
|
|
import torch |
|
|
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
|
|
|
|
|
|
from modelscope.metainfo import Models |
|
|
|
from modelscope.models.base import Model |
|
|
|
from modelscope.models.builder import MODELS |
|
|
|
from modelscope.utils.audio.tts_exceptions import ( |
|
|
|
TtsFrontendInitializeFailedException, |
|
|
|
TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationExcetion, |
|
|
|
TtsVocoderMelspecShapeMismatchException) |
|
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
|
from .models import create_model |
|
|
|
from .models import Generator, create_am_model |
|
|
|
from .text.symbols import load_symbols |
|
|
|
from .text.symbols_dict import SymbolsDict |
|
|
|
|
|
|
|
__all__ = ['SambertNetHifi16k'] |
|
|
|
__all__ = ['SambertHifigan'] |
|
|
|
MAX_WAV_VALUE = 32768.0 |
|
|
|
|
|
|
|
|
|
|
|
def multi_label_symbol_to_sequence(my_classes, my_symbol): |
|
|
|
@@ -23,13 +34,25 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol): |
|
|
|
sequences = [] |
|
|
|
for token in tokens: |
|
|
|
sequences.append(tuple(token.split('&'))) |
|
|
|
# sequences.append(tuple(['~'])) # sequence length minus 1 to ignore EOS ~ |
|
|
|
return one_hot.fit_transform(sequences) |
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(filepath, device): |
|
|
|
assert os.path.isfile(filepath) |
|
|
|
checkpoint_dict = torch.load(filepath, map_location=device) |
|
|
|
return checkpoint_dict |
|
|
|
|
|
|
|
|
|
|
|
class AttrDict(dict): |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
|
|
self.__dict__ = self |
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module( |
|
|
|
Tasks.text_to_speech, module_name=Models.sambert_hifi_16k) |
|
|
|
class SambertNetHifi16k(Model): |
|
|
|
Tasks.text_to_speech, module_name=Models.sambert_hifigan) |
|
|
|
class SambertHifigan(Model): |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
model_dir, |
|
|
|
@@ -38,20 +61,50 @@ class SambertNetHifi16k(Model): |
|
|
|
energy_control_str='', |
|
|
|
*args, |
|
|
|
**kwargs): |
|
|
|
super().__init__(model_dir, *args, **kwargs) |
|
|
|
if 'am' not in kwargs: |
|
|
|
raise TtsModelConfigurationExcetion( |
|
|
|
'configuration model field missing am!') |
|
|
|
if 'vocoder' not in kwargs: |
|
|
|
raise TtsModelConfigurationExcetion( |
|
|
|
'configuration model field missing vocoder!') |
|
|
|
if 'lang_type' not in kwargs: |
|
|
|
raise TtsModelConfigurationExcetion( |
|
|
|
'configuration model field missing lang_type!') |
|
|
|
# initialize frontend |
|
|
|
import ttsfrd |
|
|
|
frontend = ttsfrd.TtsFrontendEngine() |
|
|
|
zip_file = os.path.join(model_dir, 'resource.zip') |
|
|
|
self._res_path = os.path.join(model_dir, 'resource') |
|
|
|
with zipfile.ZipFile(zip_file, 'r') as zip_ref: |
|
|
|
zip_ref.extractall(model_dir) |
|
|
|
if not frontend.initialize(self._res_path): |
|
|
|
raise TtsFrontendInitializeFailedException( |
|
|
|
'resource invalid: {}'.format(self._res_path)) |
|
|
|
if not frontend.set_lang_type(kwargs['lang_type']): |
|
|
|
raise TtsFrontendLanguageTypeInvalidException( |
|
|
|
'language type invalid: {}'.format(kwargs['lang_type'])) |
|
|
|
self._frontend = frontend |
|
|
|
|
|
|
|
# initialize am |
|
|
|
tf.reset_default_graph() |
|
|
|
local_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, 'ckpt') |
|
|
|
self._ckpt_path = os.path.join(model_dir, local_ckpt_path) |
|
|
|
local_am_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, |
|
|
|
'ckpt') |
|
|
|
self._am_ckpt_path = os.path.join(model_dir, local_am_ckpt_path) |
|
|
|
self._dict_path = os.path.join(model_dir, 'dicts') |
|
|
|
self._hparams = tf.contrib.training.HParams(**kwargs) |
|
|
|
values = self._hparams.values() |
|
|
|
self._am_hparams = tf.contrib.training.HParams(**kwargs['am']) |
|
|
|
has_mask = True |
|
|
|
if self._am_hparams.get('has_mask') is not None: |
|
|
|
has_mask = self._am_hparams.has_mask |
|
|
|
print('set has_mask to {}'.format(has_mask)) |
|
|
|
values = self._am_hparams.values() |
|
|
|
hp = [' {}:{}'.format(name, values[name]) for name in sorted(values)] |
|
|
|
print('Hyperparameters:\n' + '\n'.join(hp)) |
|
|
|
super().__init__(self._ckpt_path, *args, **kwargs) |
|
|
|
model_name = 'robutrans' |
|
|
|
self._lfeat_type_list = self._hparams.lfeat_type_list.strip().split( |
|
|
|
self._lfeat_type_list = self._am_hparams.lfeat_type_list.strip().split( |
|
|
|
',') |
|
|
|
sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( |
|
|
|
self._dict_path) |
|
|
|
self._dict_path, has_mask) |
|
|
|
self._sy = sy |
|
|
|
self._tone = tone |
|
|
|
self._syllable_flag = syllable_flag |
|
|
|
@@ -86,7 +139,6 @@ class SambertNetHifi16k(Model): |
|
|
|
inputs_speaker = tf.placeholder(tf.float32, |
|
|
|
[1, None, self._inputs_dim['speaker']], |
|
|
|
'inputs_speaker') |
|
|
|
|
|
|
|
input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') |
|
|
|
pitch_contours_scale = tf.placeholder(tf.float32, [1, None], |
|
|
|
'pitch_contours_scale') |
|
|
|
@@ -94,9 +146,8 @@ class SambertNetHifi16k(Model): |
|
|
|
'energy_contours_scale') |
|
|
|
duration_scale = tf.placeholder(tf.float32, [1, None], |
|
|
|
'duration_scale') |
|
|
|
|
|
|
|
with tf.variable_scope('model') as _: |
|
|
|
self._model = create_model(model_name, self._hparams) |
|
|
|
self._model = create_am_model(model_name, self._am_hparams) |
|
|
|
self._model.initialize( |
|
|
|
inputs, |
|
|
|
inputs_emotion, |
|
|
|
@@ -123,14 +174,14 @@ class SambertNetHifi16k(Model): |
|
|
|
self._attention_h = self._model.attention_h |
|
|
|
self._attention_x = self._model.attention_x |
|
|
|
|
|
|
|
print('Loading checkpoint: %s' % self._ckpt_path) |
|
|
|
print('Loading checkpoint: %s' % self._am_ckpt_path) |
|
|
|
config = tf.ConfigProto() |
|
|
|
config.gpu_options.allow_growth = True |
|
|
|
self._session = tf.Session(config=config) |
|
|
|
self._session.run(tf.global_variables_initializer()) |
|
|
|
|
|
|
|
saver = tf.train.Saver() |
|
|
|
saver.restore(self._session, self._ckpt_path) |
|
|
|
saver.restore(self._session, self._am_ckpt_path) |
|
|
|
|
|
|
|
duration_cfg_lst = [] |
|
|
|
if len(duration_control_str) != 0: |
|
|
|
@@ -158,8 +209,26 @@ class SambertNetHifi16k(Model): |
|
|
|
|
|
|
|
self._energy_contours_cfg_lst = energy_contours_cfg_lst |
|
|
|
|
|
|
|
def forward(self, text): |
|
|
|
cleaner_names = [x.strip() for x in self._hparams.cleaners.split(',')] |
|
|
|
# initialize vocoder |
|
|
|
self._voc_ckpt_path = os.path.join(model_dir, |
|
|
|
ModelFile.TORCH_MODEL_BIN_FILE) |
|
|
|
self._voc_config = AttrDict(**kwargs['vocoder']) |
|
|
|
print(self._voc_config) |
|
|
|
if torch.cuda.is_available(): |
|
|
|
torch.manual_seed(self._voc_config.seed) |
|
|
|
self._device = torch.device('cuda') |
|
|
|
else: |
|
|
|
self._device = torch.device('cpu') |
|
|
|
self._generator = Generator(self._voc_config).to(self._device) |
|
|
|
state_dict_g = load_checkpoint(self._voc_ckpt_path, self._device) |
|
|
|
self._generator.load_state_dict(state_dict_g['generator']) |
|
|
|
self._generator.eval() |
|
|
|
self._generator.remove_weight_norm() |
|
|
|
|
|
|
|
def am_synthesis_one_sentences(self, text): |
|
|
|
cleaner_names = [ |
|
|
|
x.strip() for x in self._am_hparams.cleaners.split(',') |
|
|
|
] |
|
|
|
|
|
|
|
lfeat_symbol = text.strip().split(' ') |
|
|
|
lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) |
|
|
|
@@ -255,3 +324,31 @@ class SambertNetHifi16k(Model): |
|
|
|
self._energy_embeddings, self._attention_x, self._attention_h |
|
|
|
], feed_dict=feed_dict) # yapf:disable |
|
|
|
return result[0] |
|
|
|
|
|
|
|
def vocoder_process(self, melspec): |
|
|
|
dim0 = list(melspec.shape)[-1] |
|
|
|
if dim0 != self._voc_config.num_mels: |
|
|
|
raise TtsVocoderMelspecShapeMismatchException( |
|
|
|
'input melspec mismatch require {} but {}'.format( |
|
|
|
self._voc_config.num_mels, dim0)) |
|
|
|
with torch.no_grad(): |
|
|
|
x = melspec.T |
|
|
|
x = torch.FloatTensor(x).to(self._device) |
|
|
|
if len(x.shape) == 2: |
|
|
|
x = x.unsqueeze(0) |
|
|
|
y_g_hat = self._generator(x) |
|
|
|
audio = y_g_hat.squeeze() |
|
|
|
audio = audio * MAX_WAV_VALUE |
|
|
|
audio = audio.cpu().numpy().astype('int16') |
|
|
|
return audio |
|
|
|
|
|
|
|
def forward(self, text): |
|
|
|
result = self._frontend.gen_tacotron_symbols(text) |
|
|
|
texts = [s for s in result.splitlines() if s != ''] |
|
|
|
audio_total = np.empty((0), dtype='int16') |
|
|
|
for line in texts: |
|
|
|
line = line.strip().split('\t') |
|
|
|
audio = self.vocoder_process( |
|
|
|
self.am_synthesis_one_sentences(line[1])) |
|
|
|
audio_total = np.append(audio_total, audio, axis=0) |
|
|
|
return audio_total |