From e90ff9e4795129eb8d64a2c4b67b3833217c7e1b Mon Sep 17 00:00:00 2001 From: "jiaqi.sjq" Date: Tue, 27 Sep 2022 22:09:30 +0800 Subject: [PATCH] [to #42322933] tts sambert am changs from tensorfow to PyTorch and add licenses * [to #41669377] docs and tools refinement and release 1. add build_doc linter script 2. add sphinx-docs support 3. add development doc and api doc 4. change version to 0.1.0 for the first internal release version Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307 --- .../models/audio/tts/models/__init__.py | 9 - .../models/audio/tts/models/am_models.py | 460 ------- modelscope/models/audio/tts/models/compat.py | 82 -- .../tts/{text => models/datasets}/__init__.py | 0 .../tts/models/datasets/kantts_data4fs.py | 238 ++++ .../audio/tts/models/datasets/samplers.py | 131 ++ .../tts/models/datasets/units/__init__.py | 3 + .../tts/models/datasets/units/cleaners.py | 88 ++ .../tts/models/datasets/units/ling_unit.py | 395 ++++++ .../datasets/units}/numbers.py | 3 + modelscope/models/audio/tts/models/fsmn.py | 273 ---- .../models/audio/tts/models/fsmn_encoder.py | 178 --- modelscope/models/audio/tts/models/helpers.py | 159 --- .../audio/tts/models/models/__init__.py | 0 .../tts/models/models/hifigan/__init__.py | 3 + .../tts/models/models/hifigan/hifigan.py | 238 ++++ .../tts/models/models/sambert/__init__.py | 3 + .../tts/models/models/sambert/adaptors.py | 131 ++ .../audio/tts/models/models/sambert/base.py | 369 ++++++ .../audio/tts/models/models/sambert/fsmn.py | 126 ++ .../models/models/sambert/kantts_sambert.py | 718 ++++++++++ .../tts/models/models/sambert/positions.py | 101 ++ .../models/audio/tts/models/position.py | 174 --- modelscope/models/audio/tts/models/reducer.py | 155 --- .../models/audio/tts/models/rnn_wrappers.py | 237 ---- .../models/audio/tts/models/robutrans.py | 760 ----------- .../tts/models/self_attention_decoder.py | 817 ------------ .../tts/models/self_attention_encoder.py | 182 --- .../models/audio/tts/models/transformer.py | 1157 ----------------- modelscope/models/audio/tts/models/utils.py | 59 - .../models/audio/tts/models/utils/__init__.py | 3 + .../models/audio/tts/models/utils/utils.py | 136 ++ .../models/audio/tts/models/vocoder_models.py | 516 -------- modelscope/models/audio/tts/sambert_hifi.py | 34 +- modelscope/models/audio/tts/text/cleaners.py | 89 -- modelscope/models/audio/tts/text/cmudict.py | 64 - modelscope/models/audio/tts/text/symbols.py | 105 -- .../models/audio/tts/text/symbols_dict.py | 200 --- modelscope/models/audio/tts/voice.py | 333 ++--- .../audio/text_to_speech_pipeline.py | 5 + modelscope/utils/audio/tts_exceptions.py | 3 +- requirements/audio.txt | 5 - tests/pipelines/test_text_to_speech.py | 5 +- 43 files changed, 2799 insertions(+), 5948 deletions(-) mode change 100755 => 100644 modelscope/models/audio/tts/models/__init__.py delete mode 100755 modelscope/models/audio/tts/models/am_models.py delete mode 100755 modelscope/models/audio/tts/models/compat.py rename modelscope/models/audio/tts/{text => models/datasets}/__init__.py (100%) mode change 100755 => 100644 create mode 100644 modelscope/models/audio/tts/models/datasets/kantts_data4fs.py create mode 100644 modelscope/models/audio/tts/models/datasets/samplers.py create mode 100644 modelscope/models/audio/tts/models/datasets/units/__init__.py create mode 100644 modelscope/models/audio/tts/models/datasets/units/cleaners.py create mode 100644 modelscope/models/audio/tts/models/datasets/units/ling_unit.py rename modelscope/models/audio/tts/{text => models/datasets/units}/numbers.py (94%) mode change 100755 => 100644 delete mode 100755 modelscope/models/audio/tts/models/fsmn.py delete mode 100755 modelscope/models/audio/tts/models/fsmn_encoder.py delete mode 100755 modelscope/models/audio/tts/models/helpers.py create mode 100644 modelscope/models/audio/tts/models/models/__init__.py create mode 100644 modelscope/models/audio/tts/models/models/hifigan/__init__.py create mode 100755 modelscope/models/audio/tts/models/models/hifigan/hifigan.py create mode 100644 modelscope/models/audio/tts/models/models/sambert/__init__.py create mode 100644 modelscope/models/audio/tts/models/models/sambert/adaptors.py create mode 100644 modelscope/models/audio/tts/models/models/sambert/base.py create mode 100644 modelscope/models/audio/tts/models/models/sambert/fsmn.py create mode 100644 modelscope/models/audio/tts/models/models/sambert/kantts_sambert.py create mode 100644 modelscope/models/audio/tts/models/models/sambert/positions.py delete mode 100755 modelscope/models/audio/tts/models/position.py delete mode 100755 modelscope/models/audio/tts/models/reducer.py delete mode 100755 modelscope/models/audio/tts/models/rnn_wrappers.py delete mode 100755 modelscope/models/audio/tts/models/robutrans.py delete mode 100755 modelscope/models/audio/tts/models/self_attention_decoder.py delete mode 100755 modelscope/models/audio/tts/models/self_attention_encoder.py delete mode 100755 modelscope/models/audio/tts/models/transformer.py delete mode 100755 modelscope/models/audio/tts/models/utils.py create mode 100644 modelscope/models/audio/tts/models/utils/__init__.py create mode 100755 modelscope/models/audio/tts/models/utils/utils.py delete mode 100755 modelscope/models/audio/tts/models/vocoder_models.py delete mode 100755 modelscope/models/audio/tts/text/cleaners.py delete mode 100755 modelscope/models/audio/tts/text/cmudict.py delete mode 100644 modelscope/models/audio/tts/text/symbols.py delete mode 100644 modelscope/models/audio/tts/text/symbols_dict.py diff --git a/modelscope/models/audio/tts/models/__init__.py b/modelscope/models/audio/tts/models/__init__.py old mode 100755 new mode 100644 index c260d4fe..e69de29b --- a/modelscope/models/audio/tts/models/__init__.py +++ b/modelscope/models/audio/tts/models/__init__.py @@ -1,9 +0,0 @@ -from .robutrans import RobuTrans -from .vocoder_models import Generator - - -def create_am_model(name, hparams): - if name == 'robutrans': - return RobuTrans(hparams) - else: - raise Exception('Unknown model: ' + name) diff --git a/modelscope/models/audio/tts/models/am_models.py b/modelscope/models/audio/tts/models/am_models.py deleted file mode 100755 index cd43ff12..00000000 --- a/modelscope/models/audio/tts/models/am_models.py +++ /dev/null @@ -1,460 +0,0 @@ -import tensorflow as tf - - -def encoder_prenet(inputs, - n_conv_layers, - filters, - kernel_size, - dense_units, - is_training, - mask=None, - scope='encoder_prenet'): - x = inputs - with tf.variable_scope(scope): - for i in range(n_conv_layers): - x = conv1d( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - mask=mask, - scope='conv1d_{}'.format(i)) - x = tf.layers.dense( - x, units=dense_units, activation=None, name='dense') - return x - - -def decoder_prenet(inputs, - prenet_units, - dense_units, - is_training, - scope='decoder_prenet'): - x = inputs - with tf.variable_scope(scope): - for i, units in enumerate(prenet_units): - x = tf.layers.dense( - x, - units=units, - activation=tf.nn.relu, - name='dense_{}'.format(i)) - x = tf.layers.dropout( - x, rate=0.5, training=is_training, name='dropout_{}'.format(i)) - x = tf.layers.dense( - x, units=dense_units, activation=None, name='dense') - return x - - -def encoder(inputs, - input_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - embedded_inputs_speaker, - mask=None, - scope='encoder'): - with tf.variable_scope(scope): - x = conv_and_lstm( - inputs, - input_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - embedded_inputs_speaker, - mask=mask) - return x - - -def prenet(inputs, prenet_units, is_training, scope='prenet'): - x = inputs - with tf.variable_scope(scope): - for i, units in enumerate(prenet_units): - x = tf.layers.dense( - x, - units=units, - activation=tf.nn.relu, - name='dense_{}'.format(i)) - x = tf.layers.dropout( - x, rate=0.5, training=is_training, name='dropout_{}'.format(i)) - return x - - -def postnet_residual_ulstm(inputs, - n_conv_layers, - filters, - kernel_size, - lstm_units, - output_units, - is_training, - scope='postnet_residual_ulstm'): - with tf.variable_scope(scope): - x = conv_and_ulstm(inputs, None, n_conv_layers, filters, kernel_size, - lstm_units, is_training) - x = conv1d( - x, - output_units, - kernel_size, - is_training, - activation=None, - dropout=False, - scope='conv1d_{}'.format(n_conv_layers - 1)) - return x - - -def postnet_residual_lstm(inputs, - n_conv_layers, - filters, - kernel_size, - lstm_units, - output_units, - is_training, - scope='postnet_residual_lstm'): - with tf.variable_scope(scope): - x = conv_and_lstm(inputs, None, n_conv_layers, filters, kernel_size, - lstm_units, is_training) - x = conv1d( - x, - output_units, - kernel_size, - is_training, - activation=None, - dropout=False, - scope='conv1d_{}'.format(n_conv_layers - 1)) - return x - - -def postnet_linear_ulstm(inputs, - n_conv_layers, - filters, - kernel_size, - lstm_units, - output_units, - is_training, - scope='postnet_linear'): - with tf.variable_scope(scope): - x = conv_and_ulstm(inputs, None, n_conv_layers, filters, kernel_size, - lstm_units, is_training) - x = tf.layers.dense(x, units=output_units) - return x - - -def postnet_linear_lstm(inputs, - n_conv_layers, - filters, - kernel_size, - lstm_units, - output_units, - output_lengths, - is_training, - embedded_inputs_speaker2, - mask=None, - scope='postnet_linear'): - with tf.variable_scope(scope): - x = conv_and_lstm_dec( - inputs, - output_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - embedded_inputs_speaker2, - mask=mask) - x = tf.layers.dense(x, units=output_units) - return x - - -def postnet_linear(inputs, - n_conv_layers, - filters, - kernel_size, - lstm_units, - output_units, - output_lengths, - is_training, - embedded_inputs_speaker2, - mask=None, - scope='postnet_linear'): - with tf.variable_scope(scope): - x = conv_dec( - inputs, - output_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - embedded_inputs_speaker2, - mask=mask) - return x - - -def conv_and_lstm(inputs, - sequence_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - embedded_inputs_speaker, - mask=None, - scope='conv_and_lstm'): - from tensorflow.contrib.rnn import LSTMBlockCell - x = inputs - with tf.variable_scope(scope): - for i in range(n_conv_layers): - x = conv1d( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - mask=mask, - scope='conv1d_{}'.format(i)) - - x = tf.concat([x, embedded_inputs_speaker], axis=2) - - outputs, states = tf.nn.bidirectional_dynamic_rnn( - LSTMBlockCell(lstm_units), - LSTMBlockCell(lstm_units), - x, - sequence_length=sequence_lengths, - dtype=tf.float32) - x = tf.concat(outputs, axis=-1) - - return x - - -def conv_and_lstm_dec(inputs, - sequence_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - embedded_inputs_speaker2, - mask=None, - scope='conv_and_lstm'): - x = inputs - from tensorflow.contrib.rnn import LSTMBlockCell - with tf.variable_scope(scope): - for i in range(n_conv_layers): - x = conv1d( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - mask=mask, - scope='conv1d_{}'.format(i)) - - x = tf.concat([x, embedded_inputs_speaker2], axis=2) - - outputs, states = tf.nn.bidirectional_dynamic_rnn( - LSTMBlockCell(lstm_units), - LSTMBlockCell(lstm_units), - x, - sequence_length=sequence_lengths, - dtype=tf.float32) - x = tf.concat(outputs, axis=-1) - return x - - -def conv_dec(inputs, - sequence_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - embedded_inputs_speaker2, - mask=None, - scope='conv_and_lstm'): - x = inputs - with tf.variable_scope(scope): - for i in range(n_conv_layers): - x = conv1d( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - mask=mask, - scope='conv1d_{}'.format(i)) - x = tf.concat([x, embedded_inputs_speaker2], axis=2) - return x - - -def conv_and_ulstm(inputs, - sequence_lengths, - n_conv_layers, - filters, - kernel_size, - lstm_units, - is_training, - scope='conv_and_ulstm'): - x = inputs - with tf.variable_scope(scope): - for i in range(n_conv_layers): - x = conv1d( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - scope='conv1d_{}'.format(i)) - - outputs, states = tf.nn.dynamic_rnn( - LSTMBlockCell(lstm_units), - x, - sequence_length=sequence_lengths, - dtype=tf.float32) - - return outputs - - -def conv1d(inputs, - filters, - kernel_size, - is_training, - activation=None, - dropout=False, - mask=None, - scope='conv1d'): - with tf.variable_scope(scope): - if mask is not None: - inputs = inputs * tf.expand_dims(mask, -1) - x = tf.layers.conv1d( - inputs, filters=filters, kernel_size=kernel_size, padding='same') - if mask is not None: - x = x * tf.expand_dims(mask, -1) - - x = tf.layers.batch_normalization(x, training=is_training) - if activation is not None: - x = activation(x) - if dropout: - x = tf.layers.dropout(x, rate=0.5, training=is_training) - return x - - -def conv1d_dp(inputs, - filters, - kernel_size, - is_training, - activation=None, - dropout=False, - dropoutrate=0.5, - mask=None, - scope='conv1d'): - with tf.variable_scope(scope): - if mask is not None: - inputs = inputs * tf.expand_dims(mask, -1) - x = tf.layers.conv1d( - inputs, filters=filters, kernel_size=kernel_size, padding='same') - if mask is not None: - x = x * tf.expand_dims(mask, -1) - - x = tf.contrib.layers.layer_norm(x) - if activation is not None: - x = activation(x) - if dropout: - x = tf.layers.dropout(x, rate=dropoutrate, training=is_training) - return x - - -def duration_predictor(inputs, - n_conv_layers, - filters, - kernel_size, - lstm_units, - input_lengths, - is_training, - embedded_inputs_speaker, - mask=None, - scope='duration_predictor'): - with tf.variable_scope(scope): - x = inputs - for i in range(n_conv_layers): - x = conv1d_dp( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - dropoutrate=0.1, - mask=mask, - scope='conv1d_{}'.format(i)) - - x = tf.concat([x, embedded_inputs_speaker], axis=2) - - outputs, states = tf.nn.bidirectional_dynamic_rnn( - LSTMBlockCell(lstm_units), - LSTMBlockCell(lstm_units), - x, - sequence_length=input_lengths, - dtype=tf.float32) - x = tf.concat(outputs, axis=-1) - - x = tf.layers.dense(x, units=1) - x = tf.nn.relu(x) - return x - - -def duration_predictor2(inputs, - n_conv_layers, - filters, - kernel_size, - input_lengths, - is_training, - mask=None, - scope='duration_predictor'): - with tf.variable_scope(scope): - x = inputs - for i in range(n_conv_layers): - x = conv1d_dp( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - dropoutrate=0.1, - mask=mask, - scope='conv1d_{}'.format(i)) - - x = tf.layers.dense(x, units=1) - x = tf.nn.relu(x) - return x - - -def conv_prenet(inputs, - n_conv_layers, - filters, - kernel_size, - is_training, - mask=None, - scope='conv_prenet'): - x = inputs - with tf.variable_scope(scope): - for i in range(n_conv_layers): - x = conv1d( - x, - filters, - kernel_size, - is_training, - activation=tf.nn.relu, - dropout=True, - mask=mask, - scope='conv1d_{}'.format(i)) - - return x diff --git a/modelscope/models/audio/tts/models/compat.py b/modelscope/models/audio/tts/models/compat.py deleted file mode 100755 index bb810841..00000000 --- a/modelscope/models/audio/tts/models/compat.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Functions for compatibility with different TensorFlow versions.""" - -import tensorflow as tf - - -def is_tf2(): - """Returns ``True`` if running TensorFlow 2.0.""" - return tf.__version__.startswith('2') - - -def tf_supports(symbol): - """Returns ``True`` if TensorFlow defines :obj:`symbol`.""" - return _string_to_tf_symbol(symbol) is not None - - -def tf_any(*symbols): - """Returns the first supported symbol.""" - for symbol in symbols: - module = _string_to_tf_symbol(symbol) - if module is not None: - return module - return None - - -def tf_compat(v2=None, v1=None): # pylint: disable=invalid-name - """Returns the compatible symbol based on the current TensorFlow version. - - Args: - v2: The candidate v2 symbol name. - v1: The candidate v1 symbol name. - - Returns: - A TensorFlow symbol. - - Raises: - ValueError: if no symbol can be found. - """ - candidates = [] - if v2 is not None: - candidates.append(v2) - if v1 is not None: - candidates.append(v1) - candidates.append('compat.v1.%s' % v1) - symbol = tf_any(*candidates) - if symbol is None: - raise ValueError('Failure to resolve the TensorFlow symbol') - return symbol - - -def name_from_variable_scope(name=''): - """Creates a name prefixed by the current variable scope.""" - var_scope = tf_compat(v1='get_variable_scope')().name - compat_name = '' - if name: - compat_name = '%s/' % name - if var_scope: - compat_name = '%s/%s' % (var_scope, compat_name) - return compat_name - - -def reuse(): - """Returns ``True`` if the current variable scope is marked for reuse.""" - return tf_compat(v1='get_variable_scope')().reuse - - -def _string_to_tf_symbol(symbol): - modules = symbol.split('.') - namespace = tf - for module in modules: - namespace = getattr(namespace, module, None) - if namespace is None: - return None - return namespace - - -# pylint: disable=invalid-name -gfile_copy = tf_compat(v2='io.gfile.copy', v1='gfile.Copy') -gfile_exists = tf_compat(v2='io.gfile.exists', v1='gfile.Exists') -gfile_open = tf_compat(v2='io.gfile.GFile', v1='gfile.GFile') -is_tensor = tf_compat(v2='is_tensor', v1='contrib.framework.is_tensor') -logging = tf_compat(v1='logging') -nest = tf_compat(v2='nest', v1='contrib.framework.nest') diff --git a/modelscope/models/audio/tts/text/__init__.py b/modelscope/models/audio/tts/models/datasets/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from modelscope/models/audio/tts/text/__init__.py rename to modelscope/models/audio/tts/models/datasets/__init__.py diff --git a/modelscope/models/audio/tts/models/datasets/kantts_data4fs.py b/modelscope/models/audio/tts/models/datasets/kantts_data4fs.py new file mode 100644 index 00000000..cc47d0c4 --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/kantts_data4fs.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import json +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + +from modelscope.utils.logger import get_logger +from .units import KanTtsLinguisticUnit + +logger = get_logger() + + +class KanTtsText2MelDataset(Dataset): + + def __init__(self, metadata_filename, config_filename, cache=False): + super(KanTtsText2MelDataset, self).__init__() + + self.cache = cache + + with open(config_filename) as f: + self._config = json.loads(f.read()) + + # Load metadata: + self._datadir = os.path.dirname(metadata_filename) + with open(metadata_filename, encoding='utf-8') as f: + self._metadata = [line.strip().split('|') for line in f] + self._length_lst = [int(x[2]) for x in self._metadata] + hours = sum( + self._length_lst) * self._config['audio']['frame_shift_ms'] / ( + 3600 * 1000) + + logger.info('Loaded metadata for %d examples (%.2f hours)' % + (len(self._metadata), hours)) + logger.info('Minimum length: %d, Maximum length: %d' % + (min(self._length_lst), max(self._length_lst))) + + self.ling_unit = KanTtsLinguisticUnit(config_filename) + self.pad_executor = KanTtsText2MelPad() + + self.r = self._config['am']['outputs_per_step'] + self.num_mels = self._config['am']['num_mels'] + + if 'adv' in self._config: + self.feat_window = self._config['adv']['random_window'] + else: + self.feat_window = None + logger.info(self.feat_window) + + self.data_cache = [ + self.cache_load(i) for i in tqdm(range(self.__len__())) + ] if self.cache else [] + + def get_frames_lst(self): + return self._length_lst + + def __getitem__(self, index): + if self.cache: + sample = self.data_cache[index] + return sample + + return self.cache_load(index) + + def cache_load(self, index): + sample = {} + + meta = self._metadata[index] + + sample['utt_id'] = meta[0] + + sample['mel_target'] = np.load(os.path.join( + self._datadir, meta[1]))[:, :self.num_mels] + sample['output_length'] = len(sample['mel_target']) + + lfeat_symbol = meta[3] + sample['ling'] = self.ling_unit.encode_symbol_sequence(lfeat_symbol) + + sample['duration'] = np.load(os.path.join(self._datadir, meta[4])) + + sample['pitch_contour'] = np.load(os.path.join(self._datadir, meta[5])) + + sample['energy_contour'] = np.load( + os.path.join(self._datadir, meta[6])) + + return sample + + def __len__(self): + return len(self._metadata) + + def collate_fn(self, batch): + data_dict = {} + + max_input_length = max((len(x['ling'][0]) for x in batch)) + + # pure linguistic info: sy|tone|syllable_flag|word_segment + + # sy + lfeat_type = self.ling_unit._lfeat_type_list[0] + inputs_sy = self.pad_executor._prepare_scalar_inputs( + [x['ling'][0] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + # tone + lfeat_type = self.ling_unit._lfeat_type_list[1] + inputs_tone = self.pad_executor._prepare_scalar_inputs( + [x['ling'][1] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # syllable_flag + lfeat_type = self.ling_unit._lfeat_type_list[2] + inputs_syllable_flag = self.pad_executor._prepare_scalar_inputs( + [x['ling'][2] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # word_segment + lfeat_type = self.ling_unit._lfeat_type_list[3] + inputs_ws = self.pad_executor._prepare_scalar_inputs( + [x['ling'][3] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # emotion category + lfeat_type = self.ling_unit._lfeat_type_list[4] + data_dict['input_emotions'] = self.pad_executor._prepare_scalar_inputs( + [x['ling'][4] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # speaker category + lfeat_type = self.ling_unit._lfeat_type_list[5] + data_dict['input_speakers'] = self.pad_executor._prepare_scalar_inputs( + [x['ling'][5] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + data_dict['input_lings'] = torch.stack( + [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2) + + data_dict['valid_input_lengths'] = torch.as_tensor( + [len(x['ling'][0]) - 1 for x in batch], dtype=torch.long + ) # There is one '~' in the last of symbol sequence. We put length-1 for calculation. + + data_dict['valid_output_lengths'] = torch.as_tensor( + [x['output_length'] for x in batch], dtype=torch.long) + max_output_length = torch.max(data_dict['valid_output_lengths']).item() + max_output_round_length = self.pad_executor._round_up( + max_output_length, self.r) + + if self.feat_window is not None: + active_feat_len = np.minimum(max_output_round_length, + self.feat_window) + if active_feat_len < self.feat_window: + max_output_round_length = self.pad_executor._round_up( + self.feat_window, self.r) + active_feat_len = self.feat_window + + max_offsets = [x['output_length'] - active_feat_len for x in batch] + feat_offsets = [ + np.random.randint(0, np.maximum(1, offset)) + for offset in max_offsets + ] + feat_offsets = torch.from_numpy( + np.asarray(feat_offsets, dtype=np.int32)).long() + data_dict['feat_offsets'] = feat_offsets + + data_dict['mel_targets'] = self.pad_executor._prepare_targets( + [x['mel_target'] for x in batch], max_output_round_length, 0.0) + data_dict['durations'] = self.pad_executor._prepare_durations( + [x['duration'] for x in batch], max_input_length, + max_output_round_length) + + data_dict['pitch_contours'] = self.pad_executor._prepare_scalar_inputs( + [x['pitch_contour'] for x in batch], max_input_length, + 0.0).float() + data_dict[ + 'energy_contours'] = self.pad_executor._prepare_scalar_inputs( + [x['energy_contour'] for x in batch], max_input_length, + 0.0).float() + + data_dict['utt_ids'] = [x['utt_id'] for x in batch] + + return data_dict + + +class KanTtsText2MelPad(object): + + def __init__(self): + super(KanTtsText2MelPad, self).__init__() + pass + + def _pad1D(self, x, length, pad): + return np.pad( + x, (0, length - x.shape[0]), mode='constant', constant_values=pad) + + def _pad2D(self, x, length, pad): + return np.pad( + x, [(0, length - x.shape[0]), (0, 0)], + mode='constant', + constant_values=pad) + + def _pad_durations(self, duration, max_in_len, max_out_len): + framenum = np.sum(duration) + symbolnum = duration.shape[0] + if framenum < max_out_len: + padframenum = max_out_len - framenum + duration = np.insert( + duration, symbolnum, values=padframenum, axis=0) + duration = np.insert( + duration, + symbolnum + 1, + values=[0] * (max_in_len - symbolnum - 1), + axis=0) + else: + if symbolnum < max_in_len: + duration = np.insert( + duration, + symbolnum, + values=[0] * (max_in_len - symbolnum), + axis=0) + return duration + + def _round_up(self, x, multiple): + remainder = x % multiple + return x if remainder == 0 else x + multiple - remainder + + def _prepare_scalar_inputs(self, inputs, max_len, pad): + return torch.from_numpy( + np.stack([self._pad1D(x, max_len, pad) for x in inputs])) + + def _prepare_targets(self, targets, max_len, pad): + return torch.from_numpy( + np.stack([self._pad2D(t, max_len, pad) for t in targets])).float() + + def _prepare_durations(self, durations, max_in_len, max_out_len): + return torch.from_numpy( + np.stack([ + self._pad_durations(t, max_in_len, max_out_len) + for t in durations + ])).long() diff --git a/modelscope/models/audio/tts/models/datasets/samplers.py b/modelscope/models/audio/tts/models/datasets/samplers.py new file mode 100644 index 00000000..0657fa8a --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/samplers.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import random + +import torch +from torch import distributed as dist +from torch.utils.data import Sampler + + +class LenSortGroupPoolSampler(Sampler): + + def __init__(self, data_source, length_lst, group_size): + super(LenSortGroupPoolSampler, self).__init__(data_source) + + self.data_source = data_source + self.length_lst = length_lst + self.group_size = group_size + + self.num = len(self.length_lst) + self.buckets = self.num // group_size + + def __iter__(self): + + def getkey(item): + return item[1] + + random_lst = torch.randperm(self.num).tolist() + random_len_lst = [(i, self.length_lst[i]) for i in random_lst] + + # Bucket examples based on similar output sequence length for efficiency: + groups = [ + random_len_lst[i:i + self.group_size] + for i in range(0, self.num, self.group_size) + ] + if (self.num % self.group_size): + groups.append(random_len_lst[self.buckets * self.group_size:-1]) + + indices = [] + + for group in groups: + group.sort(key=getkey, reverse=True) + for item in group: + indices.append(item[0]) + + return iter(indices) + + def __len__(self): + return len(self.data_source) + + +class DistributedLenSortGroupPoolSampler(Sampler): + + def __init__(self, + dataset, + length_lst, + group_size, + num_replicas=None, + rank=None, + shuffle=True): + super(DistributedLenSortGroupPoolSampler, self).__init__(dataset) + + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError( + 'modelscope error: Requires distributed package to be available' + ) + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError( + 'modelscope error: Requires distributed package to be available' + ) + rank = dist.get_rank() + self.dataset = dataset + self.length_lst = length_lst + self.group_size = group_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.buckets = self.num_samples // group_size + self.shuffle = shuffle + + def __iter__(self): + + def getkey(item): + return item[1] + + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + random_len_lst = [(i, self.length_lst[i]) for i in indices] + + # Bucket examples based on similar output sequence length for efficiency: + groups = [ + random_len_lst[i:i + self.group_size] + for i in range(0, self.num_samples, self.group_size) + ] + if (self.num_samples % self.group_size): + groups.append(random_len_lst[self.buckets * self.group_size:-1]) + + new_indices = [] + + for group in groups: + group.sort(key=getkey, reverse=True) + for item in group: + new_indices.append(item[0]) + + return iter(new_indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/modelscope/models/audio/tts/models/datasets/units/__init__.py b/modelscope/models/audio/tts/models/datasets/units/__init__.py new file mode 100644 index 00000000..4d03df04 --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/units/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .ling_unit import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/datasets/units/cleaners.py b/modelscope/models/audio/tts/models/datasets/units/cleaners.py new file mode 100644 index 00000000..07d4fbdb --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/units/cleaners.py @@ -0,0 +1,88 @@ +# from https://github.com/keithito/tacotron +# Cleaners are transformations that run over the input text at both training and eval time. +# +# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +# hyperparameter. Some cleaners are English-specific. You'll typically want to use: +# 1. "english_cleaners" for English text +# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using +# the Unidecode library (https://pypi.python.org/pypi/Unidecode) +# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update +# the symbols in symbols.py to match your data). + +import re + +from unidecode import unidecode + +from .numbers import normalize_numbers + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) + for x in [('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), ]] # yapf:disable + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/modelscope/models/audio/tts/models/datasets/units/ling_unit.py b/modelscope/models/audio/tts/models/datasets/units/ling_unit.py new file mode 100644 index 00000000..3c211cc7 --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/units/ling_unit.py @@ -0,0 +1,395 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import abc +import codecs +import os +import re +import shutil + +import json +import numpy as np + +from . import cleaners as cleaners + +# Regular expression matching text enclosed in curly braces: +_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception( + 'modelscope error: configuration cleaner unknown: %s' % name) + text = cleaner(text) + return text + + +class LinguisticBaseUnit(abc.ABC): + + def set_config_params(self, config_params): + self.config_params = config_params + + def save(self, config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) + + +class KanTtsLinguisticUnit(LinguisticBaseUnit): + + def __init__(self, config, path, has_mask=True): + super(KanTtsLinguisticUnit, self).__init__() + + # special symbol + self._pad = '_' + self._eos = '~' + self._mask = '@[MASK]' + self._has_mask = has_mask + self._unit_config = config + self._path = path + + self._cleaner_names = [ + x.strip() for x in self._unit_config['cleaners'].split(',') + ] + self._lfeat_type_list = self._unit_config['lfeat_type_list'].strip( + ).split(',') + + self.build() + + def get_unit_size(self): + ling_unit_size = {} + ling_unit_size['sy'] = len(self.sy) + ling_unit_size['tone'] = len(self.tone) + ling_unit_size['syllable_flag'] = len(self.syllable_flag) + ling_unit_size['word_segment'] = len(self.word_segment) + + if 'emo_category' in self._lfeat_type_list: + ling_unit_size['emotion'] = len(self.emo_category) + if 'speaker_category' in self._lfeat_type_list: + ling_unit_size['speaker'] = len(self.speaker) + + return ling_unit_size + + def build(self): + + self._sub_unit_dim = {} + self._sub_unit_pad = {} + # sy sub-unit + _characters = '' + + _ch_symbols = [] + + sy_path = os.path.join(self._path, self._unit_config['sy']) + f = codecs.open(sy_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_symbols.append(line) + + _arpabet = ['@' + s for s in _ch_symbols] + + # Export all symbols: + self.sy = list(_characters) + _arpabet + [self._pad, self._eos] + if self._has_mask: + self.sy.append(self._mask) + self._sy_to_id = {s: i for i, s in enumerate(self.sy)} + self._id_to_sy = {i: s for i, s in enumerate(self.sy)} + self._sub_unit_dim['sy'] = len(self.sy) + self._sub_unit_pad['sy'] = self._sy_to_id['_'] + + # tone sub-unit + _characters = '' + + _ch_tones = [] + + tone_path = os.path.join(self._path, self._unit_config['tone']) + f = codecs.open(tone_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_tones.append(line) + + # Export all tones: + self.tone = list(_characters) + _ch_tones + [self._pad, self._eos] + if self._has_mask: + self.tone.append(self._mask) + self._tone_to_id = {s: i for i, s in enumerate(self.tone)} + self._id_to_tone = {i: s for i, s in enumerate(self.tone)} + self._sub_unit_dim['tone'] = len(self.tone) + self._sub_unit_pad['tone'] = self._tone_to_id['_'] + + # syllable flag sub-unit + _characters = '' + + _ch_syllable_flags = [] + + sy_flag_path = os.path.join(self._path, + self._unit_config['syllable_flag']) + f = codecs.open(sy_flag_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_syllable_flags.append(line) + + # Export all syllable_flags: + self.syllable_flag = list(_characters) + _ch_syllable_flags + [ + self._pad, self._eos + ] + if self._has_mask: + self.syllable_flag.append(self._mask) + self._syllable_flag_to_id = { + s: i + for i, s in enumerate(self.syllable_flag) + } + self._id_to_syllable_flag = { + i: s + for i, s in enumerate(self.syllable_flag) + } + self._sub_unit_dim['syllable_flag'] = len(self.syllable_flag) + self._sub_unit_pad['syllable_flag'] = self._syllable_flag_to_id['_'] + + # word segment sub-unit + _characters = '' + + _ch_word_segments = [] + + ws_path = os.path.join(self._path, self._unit_config['word_segment']) + f = codecs.open(ws_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_word_segments.append(line) + + # Export all syllable_flags: + self.word_segment = list(_characters) + _ch_word_segments + [ + self._pad, self._eos + ] + if self._has_mask: + self.word_segment.append(self._mask) + self._word_segment_to_id = { + s: i + for i, s in enumerate(self.word_segment) + } + self._id_to_word_segment = { + i: s + for i, s in enumerate(self.word_segment) + } + self._sub_unit_dim['word_segment'] = len(self.word_segment) + self._sub_unit_pad['word_segment'] = self._word_segment_to_id['_'] + + if 'emo_category' in self._lfeat_type_list: + # emotion category sub-unit + _characters = '' + + _ch_emo_types = [] + + emo_path = os.path.join(self._path, + self._unit_config['emo_category']) + f = codecs.open(emo_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_emo_types.append(line) + + self.emo_category = list(_characters) + _ch_emo_types + [ + self._pad, self._eos + ] + if self._has_mask: + self.emo_category.append(self._mask) + self._emo_category_to_id = { + s: i + for i, s in enumerate(self.emo_category) + } + self._id_to_emo_category = { + i: s + for i, s in enumerate(self.emo_category) + } + self._sub_unit_dim['emo_category'] = len(self.emo_category) + self._sub_unit_pad['emo_category'] = self._emo_category_to_id['_'] + + if 'speaker_category' in self._lfeat_type_list: + # speaker category sub-unit + _characters = '' + + _ch_speakers = [] + + speaker_path = os.path.join(self._path, + self._unit_config['speaker_category']) + f = codecs.open(speaker_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_speakers.append(line) + + # Export all syllable_flags: + self.speaker = list(_characters) + _ch_speakers + [ + self._pad, self._eos + ] + if self._has_mask: + self.speaker.append(self._mask) + self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)} + self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)} + self._sub_unit_dim['speaker_category'] = len(self._speaker_to_id) + self._sub_unit_pad['speaker_category'] = self._speaker_to_id['_'] + + def encode_symbol_sequence(self, lfeat_symbol): + lfeat_symbol = lfeat_symbol.strip().split(' ') + + lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) + for this_lfeat_symbol in lfeat_symbol: + this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split( + '$') + index = 0 + while index < len(lfeat_symbol_separate): + lfeat_symbol_separate[index] = lfeat_symbol_separate[ + index] + this_lfeat_symbol[index] + ' ' + index = index + 1 + + input_and_label_data = [] + index = 0 + while index < len(self._lfeat_type_list): + sequence = self.encode_sub_unit( + lfeat_symbol_separate[index].strip(), + self._lfeat_type_list[index]) + sequence_array = np.asarray(sequence, dtype=np.int32) + input_and_label_data.append(sequence_array) + index = index + 1 + + return input_and_label_data + + def decode_symbol_sequence(self, sequence): + result = [] + for i, lfeat_type in enumerate(self._lfeat_type_list): + s = '' + sequence_item = sequence[i].tolist() + if lfeat_type == 'sy': + s = self.decode_sy(sequence_item) + elif lfeat_type == 'tone': + s = self.decode_tone(sequence_item) + elif lfeat_type == 'syllable_flag': + s = self.decode_syllable_flag(sequence_item) + elif lfeat_type == 'word_segment': + s = self.decode_word_segment(sequence_item) + elif lfeat_type == 'emo_category': + s = self.decode_emo_category(sequence_item) + elif lfeat_type == 'speaker_category': + s = self.decode_speaker_category(sequence_item) + else: + raise Exception( + 'modelscope error: configuration lfeat type(%s) unknown.' + % lfeat_type) + result.append('%s:%s' % (lfeat_type, s)) + + return result + + def encode_sub_unit(self, this_lfeat_symbol, lfeat_type): + sequence = [] + if lfeat_type == 'sy': + this_lfeat_symbol = this_lfeat_symbol.strip().split(' ') + this_lfeat_symbol_format = '' + index = 0 + while index < len(this_lfeat_symbol): + this_lfeat_symbol_format = this_lfeat_symbol_format + '{' + this_lfeat_symbol[ + index] + '}' + ' ' + index = index + 1 + sequence = self.encode_text(this_lfeat_symbol_format, + self._cleaner_names) + elif lfeat_type == 'tone': + sequence = self.encode_tone(this_lfeat_symbol) + elif lfeat_type == 'syllable_flag': + sequence = self.encode_syllable_flag(this_lfeat_symbol) + elif lfeat_type == 'word_segment': + sequence = self.encode_word_segment(this_lfeat_symbol) + elif lfeat_type == 'emo_category': + sequence = self.encode_emo_category(this_lfeat_symbol) + elif lfeat_type == 'speaker_category': + sequence = self.encode_speaker_category(this_lfeat_symbol) + else: + raise Exception( + 'modelscope error: configuration lfeat type(%s) unknown.' + % lfeat_type) + + return sequence + + def encode_text(self, text, cleaner_names): + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += self.encode_sy(_clean_text(text, cleaner_names)) + break + sequence += self.encode_sy(_clean_text(m.group(1), cleaner_names)) + sequence += self.encode_arpanet(m.group(2)) + text = m.group(3) + + # Append EOS token + sequence.append(self._sy_to_id['~']) + return sequence + + def encode_sy(self, sy): + return [self._sy_to_id[s] for s in sy if self.should_keep_sy(s)] + + def decode_sy(self, id): + s = self._id_to_sy[id] + if len(s) > 1 and s[0] == '@': + s = s[1:] + return s + + def should_keep_sy(self, s): + return s in self._sy_to_id and s != '_' and s != '~' + + def encode_arpanet(self, text): + return self.encode_sy(['@' + s for s in text.split()]) + + def encode_tone(self, tone): + tones = tone.strip().split(' ') + sequence = [] + for this_tone in tones: + sequence.append(self._tone_to_id[this_tone]) + sequence.append(self._tone_to_id['~']) + return sequence + + def decode_tone(self, id): + return self._id_to_tone[id] + + def encode_syllable_flag(self, syllable_flag): + syllable_flags = syllable_flag.strip().split(' ') + sequence = [] + for this_syllable_flag in syllable_flags: + sequence.append(self._syllable_flag_to_id[this_syllable_flag]) + sequence.append(self._syllable_flag_to_id['~']) + return sequence + + def decode_syllable_flag(self, id): + return self._id_to_syllable_flag[id] + + def encode_word_segment(self, word_segment): + word_segments = word_segment.strip().split(' ') + sequence = [] + for this_word_segment in word_segments: + sequence.append(self._word_segment_to_id[this_word_segment]) + sequence.append(self._word_segment_to_id['~']) + return sequence + + def decode_word_segment(self, id): + return self._id_to_word_segment[id] + + def encode_emo_category(self, emo_type): + emo_categories = emo_type.strip().split(' ') + sequence = [] + for this_category in emo_categories: + sequence.append(self._emo_category_to_id[this_category]) + sequence.append(self._emo_category_to_id['~']) + return sequence + + def decode_emo_category(self, id): + return self._id_to_emo_category[id] + + def encode_speaker_category(self, speaker): + speakers = speaker.strip().split(' ') + sequence = [] + for this_speaker in speakers: + sequence.append(self._speaker_to_id[this_speaker]) + sequence.append(self._speaker_to_id['~']) + return sequence + + def decode_speaker_category(self, id): + return self._id_to_speaker[id] diff --git a/modelscope/models/audio/tts/text/numbers.py b/modelscope/models/audio/tts/models/datasets/units/numbers.py old mode 100755 new mode 100644 similarity index 94% rename from modelscope/models/audio/tts/text/numbers.py rename to modelscope/models/audio/tts/models/datasets/units/numbers.py index d9453fee..d8835059 --- a/modelscope/models/audio/tts/text/numbers.py +++ b/modelscope/models/audio/tts/models/datasets/units/numbers.py @@ -1,3 +1,6 @@ +# The implementation is adopted from tacotron, +# made publicly available under the MIT License at https://github.com/keithito/tacotron + import re import inflect diff --git a/modelscope/models/audio/tts/models/fsmn.py b/modelscope/models/audio/tts/models/fsmn.py deleted file mode 100755 index 875c27f0..00000000 --- a/modelscope/models/audio/tts/models/fsmn.py +++ /dev/null @@ -1,273 +0,0 @@ -import tensorflow as tf - - -def build_sequence_mask(sequence_length, - maximum_length=None, - dtype=tf.float32): - """Builds the dot product mask. - - Args: - sequence_length: The sequence length. - maximum_length: Optional size of the returned time dimension. Otherwise - it is the maximum of :obj:`sequence_length`. - dtype: The type of the mask tensor. - - Returns: - A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape - ``[batch_size, max_length]``. - """ - mask = tf.sequence_mask( - sequence_length, maxlen=maximum_length, dtype=dtype) - - return mask - - -def norm(inputs): - """Layer normalizes :obj:`inputs`.""" - return tf.contrib.layers.layer_norm(inputs, begin_norm_axis=-1) - - -def pad_in_time(x, padding_shape): - """Helper function to pad a tensor in the time dimension and retain the static depth dimension. - - Agrs: - x: [Batch, Time, Frequency] - padding_length: padding size of constant value (0) before the time dimension - - return: - padded x - """ - - depth = x.get_shape().as_list()[-1] - x = tf.pad(x, [[0, 0], padding_shape, [0, 0]]) - x.set_shape((None, None, depth)) - - return x - - -def pad_in_time_right(x, padding_length): - """Helper function to pad a tensor in the time dimension and retain the static depth dimension. - - Agrs: - x: [Batch, Time, Frequency] - padding_length: padding size of constant value (0) before the time dimension - - return: - padded x - """ - depth = x.get_shape().as_list()[-1] - x = tf.pad(x, [[0, 0], [0, padding_length], [0, 0]]) - x.set_shape((None, None, depth)) - - return x - - -def feed_forward(x, ffn_dim, memory_units, mode, dropout=0.0): - """Implements the Transformer's "Feed Forward" layer. - - .. math:: - - ffn(x) = max(0, x*W_1 + b_1)*W_2 - - Args: - x: The input. - ffn_dim: The number of units of the nonlinear transformation. - memory_units: the number of units of linear transformation - mode: A ``tf.estimator.ModeKeys`` mode. - dropout: The probability to drop units from the inner transformation. - - Returns: - The transformed input. - """ - inner = tf.layers.conv1d(x, ffn_dim, 1, activation=tf.nn.relu) - inner = tf.layers.dropout( - inner, rate=dropout, training=mode == tf.estimator.ModeKeys.TRAIN) - outer = tf.layers.conv1d(inner, memory_units, 1, use_bias=False) - - return outer - - -def drop_and_add(inputs, outputs, mode, dropout=0.0): - """Drops units in the outputs and adds the previous values. - - Args: - inputs: The input of the previous layer. - outputs: The output of the previous layer. - mode: A ``tf.estimator.ModeKeys`` mode. - dropout: The probability to drop units in :obj:`outputs`. - - Returns: - The residual and normalized output. - """ - outputs = tf.layers.dropout(outputs, rate=dropout, training=mode) - - input_dim = inputs.get_shape().as_list()[-1] - output_dim = outputs.get_shape().as_list()[-1] - - if input_dim == output_dim: - outputs += inputs - - return outputs - - -def MemoryBlock( - inputs, - filter_size, - mode, - mask=None, - dropout=0.0, -): - """ - Define the bidirectional memory block in FSMN - - Agrs: - inputs: The output of the previous layer. [Batch, Time, Frequency] - filter_size: memory block filter size - mode: Training or Evaluation - mask: A ``tf.Tensor`` applied to the memory block output - - return: - output: 3-D tensor ([Batch, Time, Frequency]) - """ - static_shape = inputs.get_shape().as_list() - depth = static_shape[-1] - inputs = tf.expand_dims(inputs, axis=1) # [Batch, 1, Time, Frequency] - depthwise_filter = tf.get_variable( - 'depth_conv_w', - shape=[1, filter_size, depth, 1], - initializer=tf.glorot_uniform_initializer(), - dtype=tf.float32) - memory = tf.nn.depthwise_conv2d( - input=inputs, - filter=depthwise_filter, - strides=[1, 1, 1, 1], - padding='SAME', - rate=[1, 1], - data_format='NHWC') - memory = memory + inputs - output = tf.layers.dropout(memory, rate=dropout, training=mode) - output = tf.reshape( - output, - [tf.shape(output)[0], tf.shape(output)[2], depth]) - if mask is not None: - output = output * tf.expand_dims(mask, -1) - - return output - - -def MemoryBlockV2( - inputs, - filter_size, - mode, - shift=0, - mask=None, - dropout=0.0, -): - """ - Define the bidirectional memory block in FSMN - - Agrs: - inputs: The output of the previous layer. [Batch, Time, Frequency] - filter_size: memory block filter size - mode: Training or Evaluation - shift: left padding, to control delay - mask: A ``tf.Tensor`` applied to the memory block output - - return: - output: 3-D tensor ([Batch, Time, Frequency]) - """ - if mask is not None: - inputs = inputs * tf.expand_dims(mask, -1) - - static_shape = inputs.get_shape().as_list() - depth = static_shape[-1] - # padding - left_padding = int(round((filter_size - 1) / 2)) - right_padding = int((filter_size - 1) / 2) - if shift > 0: - left_padding = left_padding + shift - right_padding = right_padding - shift - pad_inputs = pad_in_time(inputs, [left_padding, right_padding]) - pad_inputs = tf.expand_dims( - pad_inputs, axis=1) # [Batch, 1, Time, Frequency] - depthwise_filter = tf.get_variable( - 'depth_conv_w', - shape=[1, filter_size, depth, 1], - initializer=tf.glorot_uniform_initializer(), - dtype=tf.float32) - memory = tf.nn.depthwise_conv2d( - input=pad_inputs, - filter=depthwise_filter, - strides=[1, 1, 1, 1], - padding='VALID', - rate=[1, 1], - data_format='NHWC') - memory = tf.reshape( - memory, - [tf.shape(memory)[0], tf.shape(memory)[2], depth]) - memory = memory + inputs - output = tf.layers.dropout(memory, rate=dropout, training=mode) - if mask is not None: - output = output * tf.expand_dims(mask, -1) - - return output - - -def UniMemoryBlock( - inputs, - filter_size, - mode, - cache=None, - mask=None, - dropout=0.0, -): - """ - Define the unidirectional memory block in FSMN - - Agrs: - inputs: The output of the previous layer. [Batch, Time, Frequency] - filter_size: memory block filter size - cache: for streaming inference - mode: Training or Evaluation - mask: A ``tf.Tensor`` applied to the memory block output - dropout: dorpout factor - return: - output: 3-D tensor ([Batch, Time, Frequency]) - """ - if cache is not None: - static_shape = cache['queries'].get_shape().as_list() - depth = static_shape[-1] - queries = tf.slice(cache['queries'], [0, 1, 0], [ - tf.shape(cache['queries'])[0], - tf.shape(cache['queries'])[1] - 1, depth - ]) - queries = tf.concat([queries, inputs], axis=1) - cache['queries'] = queries - else: - padding_length = filter_size - 1 - queries = pad_in_time(inputs, [padding_length, 0]) - - queries = tf.expand_dims(queries, axis=1) # [Batch, 1, Time, Frequency] - static_shape = queries.get_shape().as_list() - depth = static_shape[-1] - depthwise_filter = tf.get_variable( - 'depth_conv_w', - shape=[1, filter_size, depth, 1], - initializer=tf.glorot_uniform_initializer(), - dtype=tf.float32) - memory = tf.nn.depthwise_conv2d( - input=queries, - filter=depthwise_filter, - strides=[1, 1, 1, 1], - padding='VALID', - rate=[1, 1], - data_format='NHWC') - memory = tf.reshape( - memory, - [tf.shape(memory)[0], tf.shape(memory)[2], depth]) - memory = memory + inputs - output = tf.layers.dropout(memory, rate=dropout, training=mode) - if mask is not None: - output = output * tf.expand_dims(mask, -1) - - return output diff --git a/modelscope/models/audio/tts/models/fsmn_encoder.py b/modelscope/models/audio/tts/models/fsmn_encoder.py deleted file mode 100755 index 2c650624..00000000 --- a/modelscope/models/audio/tts/models/fsmn_encoder.py +++ /dev/null @@ -1,178 +0,0 @@ -import tensorflow as tf - -from . import fsmn - - -class FsmnEncoder(): - """Encoder using Fsmn - """ - - def __init__(self, - filter_size, - fsmn_num_layers, - dnn_num_layers, - num_memory_units=512, - ffn_inner_dim=2048, - dropout=0.0, - position_encoder=None): - """Initializes the parameters of the encoder. - - Args: - filter_size: the total order of memory block - fsmn_num_layers: The number of fsmn layers. - dnn_num_layers: The number of dnn layers - num_units: The number of memory units. - ffn_inner_dim: The number of units of the inner linear transformation - in the feed forward layer. - dropout: The probability to drop units from the outputs. - position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to - apply on inputs or ``None``. - """ - super(FsmnEncoder, self).__init__() - self.filter_size = filter_size - self.fsmn_num_layers = fsmn_num_layers - self.dnn_num_layers = dnn_num_layers - self.num_memory_units = num_memory_units - self.ffn_inner_dim = ffn_inner_dim - self.dropout = dropout - self.position_encoder = position_encoder - - def encode(self, inputs, sequence_length=None, mode=True): - if self.position_encoder is not None: - inputs = self.position_encoder(inputs) - - inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) - - mask = fsmn.build_sequence_mask( - sequence_length, maximum_length=tf.shape(inputs)[1]) - - state = () - - for layer in range(self.fsmn_num_layers): - with tf.variable_scope('fsmn_layer_{}'.format(layer)): - with tf.variable_scope('ffn'): - context = fsmn.feed_forward( - inputs, - self.ffn_inner_dim, - self.num_memory_units, - mode, - dropout=self.dropout) - - with tf.variable_scope('memory'): - memory = fsmn.MemoryBlock( - context, - self.filter_size, - mode, - mask=mask, - dropout=self.dropout) - - memory = fsmn.drop_and_add( - inputs, memory, mode, dropout=self.dropout) - - inputs = memory - state += (tf.reduce_mean(inputs, axis=1), ) - - for layer in range(self.dnn_num_layers): - with tf.variable_scope('dnn_layer_{}'.format(layer)): - transformed = fsmn.feed_forward( - inputs, - self.ffn_inner_dim, - self.num_memory_units, - mode, - dropout=self.dropout) - - inputs = transformed - state += (tf.reduce_mean(inputs, axis=1), ) - - outputs = inputs - return (outputs, state, sequence_length) - - -class FsmnEncoderV2(): - """Encoder using Fsmn - """ - - def __init__(self, - filter_size, - fsmn_num_layers, - dnn_num_layers, - num_memory_units=512, - ffn_inner_dim=2048, - dropout=0.0, - shift=0, - position_encoder=None): - """Initializes the parameters of the encoder. - - Args: - filter_size: the total order of memory block - fsmn_num_layers: The number of fsmn layers. - dnn_num_layers: The number of dnn layers - num_units: The number of memory units. - ffn_inner_dim: The number of units of the inner linear transformation - in the feed forward layer. - dropout: The probability to drop units from the outputs. - shift: left padding, to control delay - position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to - apply on inputs or ``None``. - """ - super(FsmnEncoderV2, self).__init__() - self.filter_size = filter_size - self.fsmn_num_layers = fsmn_num_layers - self.dnn_num_layers = dnn_num_layers - self.num_memory_units = num_memory_units - self.ffn_inner_dim = ffn_inner_dim - self.dropout = dropout - self.shift = shift - if not isinstance(shift, list): - self.shift = [shift for _ in range(self.fsmn_num_layers)] - self.position_encoder = position_encoder - - def encode(self, inputs, sequence_length=None, mode=True): - if self.position_encoder is not None: - inputs = self.position_encoder(inputs) - - inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) - - mask = fsmn.build_sequence_mask( - sequence_length, maximum_length=tf.shape(inputs)[1]) - - state = () - for layer in range(self.fsmn_num_layers): - with tf.variable_scope('fsmn_layer_{}'.format(layer)): - with tf.variable_scope('ffn'): - context = fsmn.feed_forward( - inputs, - self.ffn_inner_dim, - self.num_memory_units, - mode, - dropout=self.dropout) - - with tf.variable_scope('memory'): - memory = fsmn.MemoryBlockV2( - context, - self.filter_size, - mode, - shift=self.shift[layer], - mask=mask, - dropout=self.dropout) - - memory = fsmn.drop_and_add( - inputs, memory, mode, dropout=self.dropout) - - inputs = memory - state += (tf.reduce_mean(inputs, axis=1), ) - - for layer in range(self.dnn_num_layers): - with tf.variable_scope('dnn_layer_{}'.format(layer)): - transformed = fsmn.feed_forward( - inputs, - self.ffn_inner_dim, - self.num_memory_units, - mode, - dropout=self.dropout) - - inputs = transformed - state += (tf.reduce_mean(inputs, axis=1), ) - - outputs = inputs - return (outputs, state, sequence_length) diff --git a/modelscope/models/audio/tts/models/helpers.py b/modelscope/models/audio/tts/models/helpers.py deleted file mode 100755 index 371000a4..00000000 --- a/modelscope/models/audio/tts/models/helpers.py +++ /dev/null @@ -1,159 +0,0 @@ -import numpy as np -import tensorflow as tf - - -class VarTestHelper(tf.contrib.seq2seq.Helper): - - def __init__(self, batch_size, inputs, dim): - with tf.name_scope('VarTestHelper'): - self._batch_size = batch_size - self._inputs = inputs - self._dim = dim - - num_steps = tf.shape(self._inputs)[1] - self._lengths = tf.tile([num_steps], [self._batch_size]) - - self._inputs = tf.roll(inputs, shift=-1, axis=1) - self._init_inputs = inputs[:, 0, :] - - @property - def batch_size(self): - return self._batch_size - - @property - def sample_ids_shape(self): - return tf.TensorShape([]) - - @property - def sample_ids_dtype(self): - return np.int32 - - def initialize(self, name=None): - return (tf.tile([False], [self._batch_size]), - _go_frames(self._batch_size, self._dim, self._init_inputs)) - - def sample(self, time, outputs, state, name=None): - return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them - - def next_inputs(self, time, outputs, state, sample_ids, name=None): - with tf.name_scope('VarTestHelper'): - finished = (time + 1 >= self._lengths) - next_inputs = tf.concat([outputs, self._inputs[:, time, :]], - axis=-1) - return (finished, next_inputs, state) - - -class VarTrainingHelper(tf.contrib.seq2seq.Helper): - - def __init__(self, targets, inputs, dim): - with tf.name_scope('VarTrainingHelper'): - self._targets = targets # [N, T_in, 1] - self._batch_size = tf.shape(inputs)[0] # N - self._inputs = inputs - self._dim = dim - - num_steps = tf.shape(self._targets)[1] - self._lengths = tf.tile([num_steps], [self._batch_size]) - - self._inputs = tf.roll(inputs, shift=-1, axis=1) - self._init_inputs = inputs[:, 0, :] - - @property - def batch_size(self): - return self._batch_size - - @property - def sample_ids_shape(self): - return tf.TensorShape([]) - - @property - def sample_ids_dtype(self): - return np.int32 - - def initialize(self, name=None): - return (tf.tile([False], [self._batch_size]), - _go_frames(self._batch_size, self._dim, self._init_inputs)) - - def sample(self, time, outputs, state, name=None): - return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them - - def next_inputs(self, time, outputs, state, sample_ids, name=None): - with tf.name_scope(name or 'VarTrainingHelper'): - finished = (time + 1 >= self._lengths) - next_inputs = tf.concat( - [self._targets[:, time, :], self._inputs[:, time, :]], axis=-1) - return (finished, next_inputs, state) - - -class VarTrainingSSHelper(tf.contrib.seq2seq.Helper): - - def __init__(self, targets, inputs, dim, global_step, schedule_begin, - alpha, decay_steps): - with tf.name_scope('VarTrainingSSHelper'): - self._targets = targets # [N, T_in, 1] - self._batch_size = tf.shape(inputs)[0] # N - self._inputs = inputs - self._dim = dim - - num_steps = tf.shape(self._targets)[1] - self._lengths = tf.tile([num_steps], [self._batch_size]) - - self._inputs = tf.roll(inputs, shift=-1, axis=1) - self._init_inputs = inputs[:, 0, :] - - # for schedule sampling - self._global_step = global_step - self._schedule_begin = schedule_begin - self._alpha = alpha - self._decay_steps = decay_steps - - @property - def batch_size(self): - return self._batch_size - - @property - def sample_ids_shape(self): - return tf.TensorShape([]) - - @property - def sample_ids_dtype(self): - return np.int32 - - def initialize(self, name=None): - self._ratio = _tf_decay(self._global_step, self._schedule_begin, - self._alpha, self._decay_steps) - return (tf.tile([False], [self._batch_size]), - _go_frames(self._batch_size, self._dim, self._init_inputs)) - - def sample(self, time, outputs, state, name=None): - return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them - - def next_inputs(self, time, outputs, state, sample_ids, name=None): - with tf.name_scope(name or 'VarTrainingHelper'): - finished = (time + 1 >= self._lengths) - next_inputs_tmp = tf.cond( - tf.less( - tf.random_uniform([], minval=0, maxval=1, - dtype=tf.float32), self._ratio), - lambda: self._targets[:, time, :], lambda: outputs) - next_inputs = tf.concat( - [next_inputs_tmp, self._inputs[:, time, :]], axis=-1) - return (finished, next_inputs, state) - - -def _go_frames(batch_size, dim, init_inputs): - '''Returns all-zero frames for a given batch size and output dimension''' - return tf.concat([tf.tile([[0.0]], [batch_size, dim]), init_inputs], - axis=-1) - - -def _tf_decay(global_step, schedule_begin, alpha, decay_steps): - tfr = tf.train.exponential_decay( - 1.0, - global_step=global_step - schedule_begin, - decay_steps=decay_steps, - decay_rate=alpha, - name='tfr_decay') - final_tfr = tf.cond( - tf.less(global_step, schedule_begin), lambda: 1.0, lambda: tfr) - return final_tfr diff --git a/modelscope/models/audio/tts/models/models/__init__.py b/modelscope/models/audio/tts/models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/tts/models/models/hifigan/__init__.py b/modelscope/models/audio/tts/models/models/hifigan/__init__.py new file mode 100644 index 00000000..ae9d10ea --- /dev/null +++ b/modelscope/models/audio/tts/models/models/hifigan/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .hifigan import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/models/hifigan/hifigan.py b/modelscope/models/audio/tts/models/models/hifigan/hifigan.py new file mode 100755 index 00000000..0f950539 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/hifigan/hifigan.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from https://github.com/jik876/hifi-gan + +from distutils.version import LooseVersion + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from modelscope.models.audio.tts.models.utils import get_padding, init_weights +from modelscope.utils.logger import get_logger + +logger = get_logger() +is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7') + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + + Returns: + Tensor: Magnitude spectrogram (B). + + """ + if is_pytorch_17plus: + x_stft = torch.stft( + x, fft_size, hop_size, win_length, window, return_complex=False) + else: + x_stft = torch.stft(x, fft_size, hop_size, win_length, window) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) + + +LRELU_SLOPE = 0.1 + + +def get_padding_casual(kernel_size, dilation=1): + return int(kernel_size * dilation - dilation) + + +class Conv1dCasual(torch.nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros'): + super(Conv1dCasual, self).__init__() + self.pad = padding + self.conv1d = weight_norm( + Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode)) + self.conv1d.apply(init_weights) + + def forward(self, x): # bdt + # described starting from the last dimension and moving forward. + x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), 'constant') + x = self.conv1d(x) + return x + + def remove_weight_norm(self): + remove_weight_norm(self.conv1d) + + +class ConvTranspose1dCausal(torch.nn.Module): + """CausalConvTranspose1d module with customized initialization.""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=0): + """Initialize CausalConvTranspose1d module.""" + super(ConvTranspose1dCausal, self).__init__() + self.deconv = weight_norm( + ConvTranspose1d(in_channels, out_channels, kernel_size, stride)) + self.stride = stride + self.deconv.apply(init_weights) + self.pad = kernel_size - stride + + def forward(self, x): + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, in_channels, T_in). + Returns: + Tensor: Output tensor (B, out_channels, T_out). + """ + # x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant") + return self.deconv(x)[:, :, :-self.pad] + + def remove_weight_norm(self): + remove_weight_norm(self.deconv) + + +class ResBlock1(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + Conv1dCasual( + channels, + channels, + kernel_size, + 1, + dilation=dilation[i], + padding=get_padding_casual(kernel_size, dilation[i])) + for i in range(len(dilation)) + ]) + + self.convs2 = nn.ModuleList([ + Conv1dCasual( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding_casual(kernel_size, 1)) + for i in range(len(dilation)) + ]) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for layer in self.convs1: + layer.remove_weight_norm() + for layer in self.convs2: + layer.remove_weight_norm() + + +class Generator(torch.nn.Module): + + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + logger.info('num_kernels={}, num_upsamples={}'.format( + self.num_kernels, self.num_upsamples)) + self.conv_pre = Conv1dCasual( + 80, h.upsample_initial_channel, 7, 1, padding=7 - 1) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + self.repeat_ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip(h.upsample_rates, h.upsample_kernel_sizes)): + upsample = nn.Sequential( + nn.Upsample(mode='nearest', scale_factor=u), + nn.LeakyReLU(LRELU_SLOPE), + Conv1dCasual( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + kernel_size=7, + stride=1, + padding=7 - 1)) + self.repeat_ups.append(upsample) + self.ups.append( + ConvTranspose1dCausal( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = Conv1dCasual(ch, 1, 7, 1, padding=7 - 1) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = torch.sin(x) + x + # transconv + x1 = F.leaky_relu(x, LRELU_SLOPE) + x1 = self.ups[i](x1) + # repeat + x2 = self.repeat_ups[i](x) + x = x1 + x2 + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + logger.info('Removing weight norm...') + for layer in self.ups: + layer.remove_weight_norm() + for layer in self.repeat_ups: + layer[-1].remove_weight_norm() + for layer in self.resblocks: + layer.remove_weight_norm() + self.conv_pre.remove_weight_norm() + self.conv_post.remove_weight_norm() diff --git a/modelscope/models/audio/tts/models/models/sambert/__init__.py b/modelscope/models/audio/tts/models/models/sambert/__init__.py new file mode 100644 index 00000000..f0bf5290 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .kantts_sambert import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/models/sambert/adaptors.py b/modelscope/models/audio/tts/models/models/sambert/adaptors.py new file mode 100644 index 00000000..c171a1db --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/adaptors.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import Prenet +from .fsmn import FsmnEncoderV2 + + +class LengthRegulator(nn.Module): + + def __init__(self, r=1): + super(LengthRegulator, self).__init__() + + self.r = r + + def forward(self, inputs, durations, masks=None): + reps = (durations + 0.5).long() + output_lens = reps.sum(dim=1) + max_len = output_lens.max() + reps_cumsum = torch.cumsum( + F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] + range_ = torch.arange(max_len).to(inputs.device)[None, :, None] + mult = ((reps_cumsum[:, :, :-1] <= range_) + & (reps_cumsum[:, :, 1:] > range_)) # yapf:disable + mult = mult.float() + out = torch.matmul(mult, inputs) + + if masks is not None: + out = out.masked_fill(masks.unsqueeze(-1), 0.0) + + seq_len = out.size(1) + padding = self.r - int(seq_len) % self.r + if (padding < self.r): + out = F.pad( + out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0) + out = out.transpose(1, 2) + + return out, output_lens + + +class VarRnnARPredictor(nn.Module): + + def __init__(self, cond_units, prenet_units, rnn_units): + super(VarRnnARPredictor, self).__init__() + + self.prenet = Prenet(1, prenet_units) + self.lstm = nn.LSTM( + prenet_units[-1] + cond_units, + rnn_units, + num_layers=2, + batch_first=True, + bidirectional=False) + self.fc = nn.Linear(rnn_units, 1) + + def forward(self, inputs, cond, h=None, masks=None): + x = torch.cat([self.prenet(inputs), cond], dim=-1) + # The input can also be a packed variable length sequence, + # here we just omit it for simplicity due to the mask and uni-directional lstm. + x, h_new = self.lstm(x, h) + + x = self.fc(x).squeeze(-1) + x = F.relu(x) + + if masks is not None: + x = x.masked_fill(masks, 0.0) + + return x, h_new + + def infer(self, cond, masks=None): + batch_size, length = cond.size(0), cond.size(1) + + output = [] + x = torch.zeros((batch_size, 1)).to(cond.device) + h = None + + for i in range(length): + x, h = self.forward(x.unsqueeze(1), cond[:, i:i + 1, :], h=h) + output.append(x) + + output = torch.cat(output, dim=-1) + + if masks is not None: + output = output.masked_fill(masks, 0.0) + + return output + + +class VarFsmnRnnNARPredictor(nn.Module): + + def __init__(self, in_dim, filter_size, fsmn_num_layers, num_memory_units, + ffn_inner_dim, dropout, shift, lstm_units): + super(VarFsmnRnnNARPredictor, self).__init__() + + self.fsmn = FsmnEncoderV2(filter_size, fsmn_num_layers, in_dim, + num_memory_units, ffn_inner_dim, dropout, + shift) + self.blstm = nn.LSTM( + num_memory_units, + lstm_units, + num_layers=1, + batch_first=True, + bidirectional=True) + self.fc = nn.Linear(2 * lstm_units, 1) + + def forward(self, inputs, masks=None): + input_lengths = None + if masks is not None: + input_lengths = torch.sum((~masks).float(), dim=1).long() + + x = self.fsmn(inputs, masks) + + if input_lengths is not None: + x = nn.utils.rnn.pack_padded_sequence( + x, + input_lengths.tolist(), + batch_first=True, + enforce_sorted=False) + x, _ = self.blstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence( + x, batch_first=True, total_length=inputs.size(1)) + else: + x, _ = self.blstm(x) + + x = self.fc(x).squeeze(-1) + + if masks is not None: + x = x.masked_fill(masks, 0.0) + + return x diff --git a/modelscope/models/audio/tts/models/models/sambert/base.py b/modelscope/models/audio/tts/models/models/sambert/base.py new file mode 100644 index 00000000..873aecbf --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/base.py @@ -0,0 +1,369 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ScaledDotProductAttention(nn.Module): + """ Scaled Dot-Product Attention """ + + def __init__(self, temperature, dropatt=0.0): + super().__init__() + self.temperature = temperature + self.softmax = nn.Softmax(dim=2) + self.dropatt = nn.Dropout(dropatt) + + def forward(self, q, k, v, mask=None): + + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + if mask is not None: + attn = attn.masked_fill(mask, -np.inf) + + attn = self.softmax(attn) + attn = self.dropatt(attn) + output = torch.bmm(attn, v) + + return output, attn + + +class Prenet(nn.Module): + + def __init__(self, in_units, prenet_units, out_units=0): + super(Prenet, self).__init__() + + self.fcs = nn.ModuleList() + for in_dim, out_dim in zip([in_units] + prenet_units[:-1], + prenet_units): + self.fcs.append(nn.Linear(in_dim, out_dim)) + self.fcs.append(nn.ReLU()) + self.fcs.append(nn.Dropout(0.5)) + + if (out_units): + self.fcs.append(nn.Linear(prenet_units[-1], out_units)) + + def forward(self, input): + output = input + for layer in self.fcs: + output = layer(output) + return output + + +class MultiHeadSelfAttention(nn.Module): + """ Multi-Head SelfAttention module """ + + def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0): + super().__init__() + + self.n_head = n_head + self.d_head = d_head + self.d_in = d_in + self.d_model = d_model + + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head) + + self.attention = ScaledDotProductAttention( + temperature=np.power(d_head, 0.5), dropatt=dropatt) + + self.fc = nn.Linear(n_head * d_head, d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, input, mask=None): + d_head, n_head = self.d_head, self.n_head + + sz_b, len_in, _ = input.size() + + residual = input + + x = self.layer_norm(input) + qkv = self.w_qkv(x) + q, k, v = qkv.chunk(3, -1) + + q = q.view(sz_b, len_in, n_head, d_head) + k = k.view(sz_b, len_in, n_head, d_head) + v = v.view(sz_b, len_in, n_head, d_head) + + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in, + d_head) # (n*b) x l x d + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in, + d_head) # (n*b) x l x d + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in, + d_head) # (n*b) x l x d + + if mask is not None: + mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + output, attn = self.attention(q, k, v, mask=mask) + + output = output.view(n_head, sz_b, len_in, d_head) + output = (output.permute(1, 2, 0, + 3).contiguous().view(sz_b, len_in, + -1)) # b x l x (n*d) + + output = self.dropout(self.fc(output)) + if (output.size(-1) == residual.size(-1)): + output = output + residual + + return output, attn + + +class PositionwiseConvFeedForward(nn.Module): + """ A two-feed-forward-layer module """ + + def __init__(self, + d_in, + d_hid, + kernel_size=(3, 1), + dropout_inner=0.1, + dropout=0.1): + super().__init__() + # Use Conv1D + # position-wise + self.w_1 = nn.Conv1d( + d_in, + d_hid, + kernel_size=kernel_size[0], + padding=(kernel_size[0] - 1) // 2, + ) + # position-wise + self.w_2 = nn.Conv1d( + d_hid, + d_in, + kernel_size=kernel_size[1], + padding=(kernel_size[1] - 1) // 2, + ) + + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout_inner = nn.Dropout(dropout_inner) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask=None): + residual = x + x = self.layer_norm(x) + + output = x.transpose(1, 2) + output = F.relu(self.w_1(output)) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(1), 0) + output = self.dropout_inner(output) + output = self.w_2(output) + output = output.transpose(1, 2) + output = self.dropout(output) + + output = output + residual + + return output + + +class FFTBlock(nn.Module): + """FFT Block""" + + def __init__(self, + d_in, + d_model, + n_head, + d_head, + d_inner, + kernel_size, + dropout, + dropout_attn=0.0, + dropout_relu=0.0): + super(FFTBlock, self).__init__() + self.slf_attn = MultiHeadSelfAttention( + n_head, + d_in, + d_model, + d_head, + dropout=dropout, + dropatt=dropout_attn) + self.pos_ffn = PositionwiseConvFeedForward( + d_model, + d_inner, + kernel_size, + dropout_inner=dropout_relu, + dropout=dropout) + + def forward(self, input, mask=None, slf_attn_mask=None): + output, slf_attn = self.slf_attn(input, mask=slf_attn_mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + output = self.pos_ffn(output, mask=mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + return output, slf_attn + + +class MultiHeadPNCAAttention(nn.Module): + """ Multi-Head Attention PNCA module """ + + def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0): + super().__init__() + + self.n_head = n_head + self.d_head = d_head + self.d_model = d_model + self.d_mem = d_mem + + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head) + self.fc_x = nn.Linear(n_head * d_head, d_model) + + self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head) + self.fc_h = nn.Linear(n_head * d_head, d_model) + + self.attention = ScaledDotProductAttention( + temperature=np.power(d_head, 0.5), dropatt=dropatt) + + self.dropout = nn.Dropout(dropout) + + def update_x_state(self, x): + d_head, n_head = self.d_head, self.n_head + + sz_b, len_x, _ = x.size() + + x_qkv = self.w_x_qkv(x) + x_q, x_k, x_v = x_qkv.chunk(3, -1) + + x_q = x_q.view(sz_b, len_x, n_head, d_head) + x_k = x_k.view(sz_b, len_x, n_head, d_head) + x_v = x_v.view(sz_b, len_x, n_head, d_head) + + x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) + x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) + x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) + + if (self.x_state_size): + self.x_k = torch.cat([self.x_k, x_k], dim=1) + self.x_v = torch.cat([self.x_v, x_v], dim=1) + else: + self.x_k = x_k + self.x_v = x_v + + self.x_state_size += len_x + + return x_q, x_k, x_v + + def update_h_state(self, h): + if (self.h_state_size == h.size(1)): + return None, None + + d_head, n_head = self.d_head, self.n_head + + # H + sz_b, len_h, _ = h.size() + + h_kv = self.w_h_kv(h) + h_k, h_v = h_kv.chunk(2, -1) + + h_k = h_k.view(sz_b, len_h, n_head, d_head) + h_v = h_v.view(sz_b, len_h, n_head, d_head) + + self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head) + self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head) + + self.h_state_size += len_h + + return h_k, h_v + + def reset_state(self): + self.h_k = None + self.h_v = None + self.h_state_size = 0 + self.x_k = None + self.x_v = None + self.x_state_size = 0 + + def forward(self, x, h, mask_x=None, mask_h=None): + residual = x + self.update_h_state(h) + x_q, x_k, x_v = self.update_x_state(self.layer_norm(x)) + + d_head, n_head = self.d_head, self.n_head + + sz_b, len_in, _ = x.size() + + # X + if mask_x is not None: + mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x .. + output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x) + + output_x = output_x.view(n_head, sz_b, len_in, d_head) + output_x = (output_x.permute(1, 2, 0, + 3).contiguous().view(sz_b, len_in, + -1)) # b x l x (n*d) + output_x = self.fc_x(output_x) + + # H + if mask_h is not None: + mask_h = mask_h.repeat(n_head, 1, 1) + output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h) + + output_h = output_h.view(n_head, sz_b, len_in, d_head) + output_h = (output_h.permute(1, 2, 0, + 3).contiguous().view(sz_b, len_in, + -1)) # b x l x (n*d) + output_h = self.fc_h(output_h) + + output = output_x + output_h + + output = self.dropout(output) + + output = output + residual + + return output, attn_x, attn_h + + +class PNCABlock(nn.Module): + """PNCA Block""" + + def __init__(self, + d_model, + d_mem, + n_head, + d_head, + d_inner, + kernel_size, + dropout, + dropout_attn=0.0, + dropout_relu=0.0): + super(PNCABlock, self).__init__() + self.pnca_attn = MultiHeadPNCAAttention( + n_head, + d_model, + d_mem, + d_head, + dropout=dropout, + dropatt=dropout_attn) + self.pos_ffn = PositionwiseConvFeedForward( + d_model, + d_inner, + kernel_size, + dropout_inner=dropout_relu, + dropout=dropout) + + def forward(self, + input, + memory, + mask=None, + pnca_x_attn_mask=None, + pnca_h_attn_mask=None): + output, pnca_attn_x, pnca_attn_h = self.pnca_attn( + input, memory, pnca_x_attn_mask, pnca_h_attn_mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + output = self.pos_ffn(output, mask=mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + return output, pnca_attn_x, pnca_attn_h + + def reset_state(self): + self.pnca_attn.reset_state() diff --git a/modelscope/models/audio/tts/models/models/sambert/fsmn.py b/modelscope/models/audio/tts/models/models/sambert/fsmn.py new file mode 100644 index 00000000..c070ef35 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/fsmn.py @@ -0,0 +1,126 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +FSMN Pytorch Version +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FeedForwardNet(nn.Module): + """ A two-feed-forward-layer module """ + + def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1): + super().__init__() + + # Use Conv1D + # position-wise + self.w_1 = nn.Conv1d( + d_in, + d_hid, + kernel_size=kernel_size[0], + padding=(kernel_size[0] - 1) // 2, + ) + # position-wise + self.w_2 = nn.Conv1d( + d_hid, + d_out, + kernel_size=kernel_size[1], + padding=(kernel_size[1] - 1) // 2, + bias=False) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + output = x.transpose(1, 2) + output = F.relu(self.w_1(output)) + output = self.dropout(output) + output = self.w_2(output) + output = output.transpose(1, 2) + + return output + + +class MemoryBlockV2(nn.Module): + + def __init__(self, d, filter_size, shift, dropout=0.0): + super(MemoryBlockV2, self).__init__() + + left_padding = int(round((filter_size - 1) / 2)) + right_padding = int((filter_size - 1) / 2) + if shift > 0: + left_padding += shift + right_padding -= shift + + self.lp, self.rp = left_padding, right_padding + + self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, input, mask=None): + if mask is not None: + input = input.masked_fill(mask.unsqueeze(-1), 0) + + x = F.pad( + input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0) + output = self.conv_dw(x.contiguous().transpose( + 1, 2)).contiguous().transpose(1, 2) + output += input + output = self.dropout(output) + + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + return output + + +class FsmnEncoderV2(nn.Module): + + def __init__(self, + filter_size, + fsmn_num_layers, + input_dim, + num_memory_units, + ffn_inner_dim, + dropout=0.0, + shift=0): + super(FsmnEncoderV2, self).__init__() + + self.filter_size = filter_size + self.fsmn_num_layers = fsmn_num_layers + self.num_memory_units = num_memory_units + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.shift = shift + if not isinstance(shift, list): + self.shift = [shift for _ in range(self.fsmn_num_layers)] + + self.ffn_lst = nn.ModuleList() + self.ffn_lst.append( + FeedForwardNet( + input_dim, ffn_inner_dim, num_memory_units, dropout=dropout)) + for i in range(1, fsmn_num_layers): + self.ffn_lst.append( + FeedForwardNet( + num_memory_units, + ffn_inner_dim, + num_memory_units, + dropout=dropout)) + + self.memory_block_lst = nn.ModuleList() + for i in range(fsmn_num_layers): + self.memory_block_lst.append( + MemoryBlockV2(num_memory_units, filter_size, self.shift[i], + dropout)) + + def forward(self, input, mask=None): + x = F.dropout(input, self.dropout, self.training) + for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst): + context = ffn(x) + memory = memory_block(context, mask) + memory = F.dropout(memory, self.dropout, self.training) + if (memory.size(-1) == x.size(-1)): + memory += x + x = memory + + return x diff --git a/modelscope/models/audio/tts/models/models/sambert/kantts_sambert.py b/modelscope/models/audio/tts/models/models/sambert/kantts_sambert.py new file mode 100644 index 00000000..3837a2e8 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/kantts_sambert.py @@ -0,0 +1,718 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.audio.tts.models.utils import get_mask_from_lengths +from .adaptors import (LengthRegulator, VarFsmnRnnNARPredictor, + VarRnnARPredictor) +from .base import FFTBlock, PNCABlock, Prenet +from .fsmn import FsmnEncoderV2 +from .positions import DurSinusoidalPositionEncoder, SinusoidalPositionEncoder + + +class SelfAttentionEncoder(nn.Module): + + def __init__(self, n_layer, d_in, d_model, n_head, d_head, d_inner, + dropout, dropout_att, dropout_relu, position_encoder): + super(SelfAttentionEncoder, self).__init__() + + self.d_in = d_in + self.d_model = d_model + self.dropout = dropout + d_in_lst = [d_in] + [d_model] * (n_layer - 1) + self.fft = nn.ModuleList([ + FFTBlock(d, d_model, n_head, d_head, d_inner, (3, 1), dropout, + dropout_att, dropout_relu) for d in d_in_lst + ]) + self.ln = nn.LayerNorm(d_model, eps=1e-6) + self.position_enc = position_encoder + + def forward(self, input, mask=None, return_attns=False): + input *= self.d_model**0.5 + if (isinstance(self.position_enc, SinusoidalPositionEncoder)): + input = self.position_enc(input) + else: + raise NotImplementedError('modelscope error: position_enc invalid') + + input = F.dropout(input, p=self.dropout, training=self.training) + + enc_slf_attn_list = [] + max_len = input.size(1) + if mask is not None: + slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) + else: + slf_attn_mask = None + + enc_output = input + for id, layer in enumerate(self.fft): + enc_output, enc_slf_attn = layer( + enc_output, mask=mask, slf_attn_mask=slf_attn_mask) + if return_attns: + enc_slf_attn_list += [enc_slf_attn] + + enc_output = self.ln(enc_output) + + return enc_output, enc_slf_attn_list + + +class HybridAttentionDecoder(nn.Module): + + def __init__(self, d_in, prenet_units, n_layer, d_model, d_mem, n_head, + d_head, d_inner, dropout, dropout_att, dropout_relu, d_out): + super(HybridAttentionDecoder, self).__init__() + + self.d_model = d_model + self.dropout = dropout + self.prenet = Prenet(d_in, prenet_units, d_model) + self.dec_in_proj = nn.Linear(d_model + d_mem, d_model) + self.pnca = nn.ModuleList([ + PNCABlock(d_model, d_mem, n_head, d_head, d_inner, (1, 1), dropout, + dropout_att, dropout_relu) for _ in range(n_layer) + ]) + self.ln = nn.LayerNorm(d_model, eps=1e-6) + self.dec_out_proj = nn.Linear(d_model, d_out) + + def reset_state(self): + for layer in self.pnca: + layer.reset_state() + + def get_pnca_attn_mask(self, + device, + max_len, + x_band_width, + h_band_width, + mask=None): + if mask is not None: + pnca_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) + else: + pnca_attn_mask = None + + range_ = torch.arange(max_len).to(device) + x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :] + x_end = (range_ + 1)[None, None, :] + h_start = range_[None, None, :] + h_end = torch.clamp_max(range_ + h_band_width + 1, + max_len + 1)[None, None, :] + + pnca_x_attn_mask = ~((x_start <= range_[None, :, None]) + & (x_end > range_[None, :, None])).transpose(1, 2) # yapf:disable + pnca_h_attn_mask = ~((h_start <= range_[None, :, None]) + & (h_end > range_[None, :, None])).transpose(1, 2) # yapf:disable + + if pnca_attn_mask is not None: + pnca_x_attn_mask = (pnca_x_attn_mask | pnca_attn_mask) + pnca_h_attn_mask = (pnca_h_attn_mask | pnca_attn_mask) + pnca_x_attn_mask = pnca_x_attn_mask.masked_fill( + pnca_attn_mask.transpose(1, 2), False) + pnca_h_attn_mask = pnca_h_attn_mask.masked_fill( + pnca_attn_mask.transpose(1, 2), False) + + return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask + + # must call reset_state before + def forward(self, + input, + memory, + x_band_width, + h_band_width, + mask=None, + return_attns=False): + input = self.prenet(input) + input = torch.cat([memory, input], dim=-1) + input = self.dec_in_proj(input) + + if mask is not None: + input = input.masked_fill(mask.unsqueeze(-1), 0) + + input *= self.d_model**0.5 + input = F.dropout(input, p=self.dropout, training=self.training) + + max_len = input.size(1) + pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask( + input.device, max_len, x_band_width, h_band_width, mask) + + dec_pnca_attn_x_list = [] + dec_pnca_attn_h_list = [] + dec_output = input + for id, layer in enumerate(self.pnca): + dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer( + dec_output, + memory, + mask=mask, + pnca_x_attn_mask=pnca_x_attn_mask, + pnca_h_attn_mask=pnca_h_attn_mask) + if return_attns: + dec_pnca_attn_x_list += [dec_pnca_attn_x] + dec_pnca_attn_h_list += [dec_pnca_attn_h] + + dec_output = self.ln(dec_output) + dec_output = self.dec_out_proj(dec_output) + + return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list + + # must call reset_state before when step == 0 + def infer(self, + step, + input, + memory, + x_band_width, + h_band_width, + mask=None, + return_attns=False): + max_len = memory.size(1) + + input = self.prenet(input) + input = torch.cat([memory[:, step:step + 1, :], input], dim=-1) + input = self.dec_in_proj(input) + + input *= self.d_model**0.5 + input = F.dropout(input, p=self.dropout, training=self.training) + + pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask( + input.device, max_len, x_band_width, h_band_width, mask) + + dec_pnca_attn_x_list = [] + dec_pnca_attn_h_list = [] + dec_output = input + for id, layer in enumerate(self.pnca): + if mask is not None: + mask_step = mask[:, step:step + 1] + else: + mask_step = None + dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer( + dec_output, + memory, + mask=mask_step, + pnca_x_attn_mask=pnca_x_attn_mask[:, + step:step + 1, :(step + 1)], + pnca_h_attn_mask=pnca_h_attn_mask[:, step:step + 1, :]) + if return_attns: + dec_pnca_attn_x_list += [dec_pnca_attn_x] + dec_pnca_attn_h_list += [dec_pnca_attn_h] + + dec_output = self.ln(dec_output) + dec_output = self.dec_out_proj(dec_output) + + return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list + + +class TextFftEncoder(nn.Module): + + def __init__(self, config, ling_unit_size): + super(TextFftEncoder, self).__init__() + + # linguistic unit lookup table + nb_ling_sy = ling_unit_size['sy'] + nb_ling_tone = ling_unit_size['tone'] + nb_ling_syllable_flag = ling_unit_size['syllable_flag'] + nb_ling_ws = ling_unit_size['word_segment'] + + max_len = config['am']['max_len'] + + d_emb = config['am']['embedding_dim'] + nb_layers = config['am']['encoder_num_layers'] + nb_heads = config['am']['encoder_num_heads'] + d_model = config['am']['encoder_num_units'] + d_head = d_model // nb_heads + d_inner = config['am']['encoder_ffn_inner_dim'] + dropout = config['am']['encoder_dropout'] + dropout_attn = config['am']['encoder_attention_dropout'] + dropout_relu = config['am']['encoder_relu_dropout'] + d_proj = config['am']['encoder_projection_units'] + + self.d_model = d_model + + self.sy_emb = nn.Embedding(nb_ling_sy, d_emb) + self.tone_emb = nn.Embedding(nb_ling_tone, d_emb) + self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb) + self.ws_emb = nn.Embedding(nb_ling_ws, d_emb) + + position_enc = SinusoidalPositionEncoder(max_len, d_emb) + + self.ling_enc = SelfAttentionEncoder(nb_layers, d_emb, d_model, + nb_heads, d_head, d_inner, + dropout, dropout_attn, + dropout_relu, position_enc) + + self.ling_proj = nn.Linear(d_model, d_proj, bias=False) + + def forward(self, inputs_ling, masks=None, return_attns=False): + # Parse inputs_ling_seq + inputs_sy = inputs_ling[:, :, 0] + inputs_tone = inputs_ling[:, :, 1] + inputs_syllable_flag = inputs_ling[:, :, 2] + inputs_ws = inputs_ling[:, :, 3] + + # Lookup table + sy_embedding = self.sy_emb(inputs_sy) + tone_embedding = self.tone_emb(inputs_tone) + syllable_flag_embedding = self.syllable_flag_emb(inputs_syllable_flag) + ws_embedding = self.ws_emb(inputs_ws) + + ling_embedding = sy_embedding + tone_embedding + syllable_flag_embedding + ws_embedding + + enc_output, enc_slf_attn_list = self.ling_enc(ling_embedding, masks, + return_attns) + + enc_output = self.ling_proj(enc_output) + + return enc_output, enc_slf_attn_list + + +class VarianceAdaptor(nn.Module): + + def __init__(self, config): + super(VarianceAdaptor, self).__init__() + + input_dim = config['am']['encoder_projection_units'] + config['am'][ + 'emotion_units'] + config['am']['speaker_units'] + filter_size = config['am']['predictor_filter_size'] + fsmn_num_layers = config['am']['predictor_fsmn_num_layers'] + num_memory_units = config['am']['predictor_num_memory_units'] + ffn_inner_dim = config['am']['predictor_ffn_inner_dim'] + dropout = config['am']['predictor_dropout'] + shift = config['am']['predictor_shift'] + lstm_units = config['am']['predictor_lstm_units'] + + dur_pred_prenet_units = config['am']['dur_pred_prenet_units'] + dur_pred_lstm_units = config['am']['dur_pred_lstm_units'] + + self.pitch_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size, + fsmn_num_layers, + num_memory_units, + ffn_inner_dim, dropout, + shift, lstm_units) + self.energy_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size, + fsmn_num_layers, + num_memory_units, + ffn_inner_dim, dropout, + shift, lstm_units) + self.duration_predictor = VarRnnARPredictor(input_dim, + dur_pred_prenet_units, + dur_pred_lstm_units) + + self.length_regulator = LengthRegulator( + config['am']['outputs_per_step']) + self.dur_position_encoder = DurSinusoidalPositionEncoder( + config['am']['encoder_projection_units'], + config['am']['outputs_per_step']) + + self.pitch_emb = nn.Conv1d( + 1, + config['am']['encoder_projection_units'], + kernel_size=9, + padding=4) + self.energy_emb = nn.Conv1d( + 1, + config['am']['encoder_projection_units'], + kernel_size=9, + padding=4) + + def forward(self, + inputs_text_embedding, + inputs_emo_embedding, + inputs_spk_embedding, + masks=None, + output_masks=None, + duration_targets=None, + pitch_targets=None, + energy_targets=None): + + batch_size = inputs_text_embedding.size(0) + + variance_predictor_inputs = torch.cat([ + inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding + ], dim=-1) # yapf:disable + + pitch_predictions = self.pitch_predictor(variance_predictor_inputs, + masks) + energy_predictions = self.energy_predictor(variance_predictor_inputs, + masks) + + if pitch_targets is not None: + pitch_embeddings = self.pitch_emb( + pitch_targets.unsqueeze(1)).transpose(1, 2) + else: + pitch_embeddings = self.pitch_emb( + pitch_predictions.unsqueeze(1)).transpose(1, 2) + + if energy_targets is not None: + energy_embeddings = self.energy_emb( + energy_targets.unsqueeze(1)).transpose(1, 2) + else: + energy_embeddings = self.energy_emb( + energy_predictions.unsqueeze(1)).transpose(1, 2) + + inputs_text_embedding_aug = inputs_text_embedding + pitch_embeddings + energy_embeddings + duration_predictor_cond = torch.cat([ + inputs_text_embedding_aug, inputs_spk_embedding, + inputs_emo_embedding + ], dim=-1) # yapf:disable + if duration_targets is not None: + duration_predictor_go_frame = torch.zeros(batch_size, 1).to( + inputs_text_embedding.device) + duration_predictor_input = torch.cat([ + duration_predictor_go_frame, duration_targets[:, :-1].float() + ], dim=-1) # yapf:disable + duration_predictor_input = torch.log(duration_predictor_input + 1) + log_duration_predictions, _ = self.duration_predictor( + duration_predictor_input.unsqueeze(-1), + duration_predictor_cond, + masks=masks) + duration_predictions = torch.exp(log_duration_predictions) - 1 + else: + log_duration_predictions = self.duration_predictor.infer( + duration_predictor_cond, masks=masks) + duration_predictions = torch.exp(log_duration_predictions) - 1 + + if duration_targets is not None: + LR_text_outputs, LR_length_rounded = self.length_regulator( + inputs_text_embedding_aug, + duration_targets, + masks=output_masks) + LR_position_embeddings = self.dur_position_encoder( + duration_targets, masks=output_masks) + LR_emo_outputs, _ = self.length_regulator( + inputs_emo_embedding, duration_targets, masks=output_masks) + LR_spk_outputs, _ = self.length_regulator( + inputs_spk_embedding, duration_targets, masks=output_masks) + + else: + LR_text_outputs, LR_length_rounded = self.length_regulator( + inputs_text_embedding_aug, + duration_predictions, + masks=output_masks) + LR_position_embeddings = self.dur_position_encoder( + duration_predictions, masks=output_masks) + LR_emo_outputs, _ = self.length_regulator( + inputs_emo_embedding, duration_predictions, masks=output_masks) + LR_spk_outputs, _ = self.length_regulator( + inputs_spk_embedding, duration_predictions, masks=output_masks) + + LR_text_outputs = LR_text_outputs + LR_position_embeddings + + return (LR_text_outputs, LR_emo_outputs, LR_spk_outputs, + LR_length_rounded, log_duration_predictions, pitch_predictions, + energy_predictions) + + +class MelPNCADecoder(nn.Module): + + def __init__(self, config): + super(MelPNCADecoder, self).__init__() + + prenet_units = config['am']['decoder_prenet_units'] + nb_layers = config['am']['decoder_num_layers'] + nb_heads = config['am']['decoder_num_heads'] + d_model = config['am']['decoder_num_units'] + d_head = d_model // nb_heads + d_inner = config['am']['decoder_ffn_inner_dim'] + dropout = config['am']['decoder_dropout'] + dropout_attn = config['am']['decoder_attention_dropout'] + dropout_relu = config['am']['decoder_relu_dropout'] + outputs_per_step = config['am']['outputs_per_step'] + + d_mem = config['am'][ + 'encoder_projection_units'] * outputs_per_step + config['am'][ + 'emotion_units'] + config['am']['speaker_units'] + d_mel = config['am']['num_mels'] + + self.d_mel = d_mel + self.r = outputs_per_step + self.nb_layers = nb_layers + + self.mel_dec = HybridAttentionDecoder(d_mel, prenet_units, nb_layers, + d_model, d_mem, nb_heads, d_head, + d_inner, dropout, dropout_attn, + dropout_relu, + d_mel * outputs_per_step) + + def forward(self, + memory, + x_band_width, + h_band_width, + target=None, + mask=None, + return_attns=False): + batch_size = memory.size(0) + go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device) + + if target is not None: + self.mel_dec.reset_state() + input = target[:, self.r - 1::self.r, :] + input = torch.cat([go_frame, input], dim=1)[:, :-1, :] + dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec( + input, + memory, + x_band_width, + h_band_width, + mask=mask, + return_attns=return_attns) + + else: + dec_output = [] + dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)] + dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)] + self.mel_dec.reset_state() + input = go_frame + for step in range(memory.size(1)): + dec_output_step, dec_pnca_attn_x_step, dec_pnca_attn_h_step = self.mel_dec.infer( + step, + input, + memory, + x_band_width, + h_band_width, + mask=mask, + return_attns=return_attns) + input = dec_output_step[:, :, -self.d_mel:] + + dec_output.append(dec_output_step) + for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate( + zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)): + left = memory.size(1) - pnca_x_attn.size(-1) + if (left > 0): + padding = torch.zeros( + (pnca_x_attn.size(0), 1, left)).to(pnca_x_attn) + pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1) + dec_pnca_attn_x_list[layer_id].append(pnca_x_attn) + dec_pnca_attn_h_list[layer_id].append(pnca_h_attn) + + dec_output = torch.cat(dec_output, dim=1) + for layer_id in range(self.nb_layers): + dec_pnca_attn_x_list[layer_id] = torch.cat( + dec_pnca_attn_x_list[layer_id], dim=1) + dec_pnca_attn_h_list[layer_id] = torch.cat( + dec_pnca_attn_h_list[layer_id], dim=1) + + return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list + + +class PostNet(nn.Module): + + def __init__(self, config): + super(PostNet, self).__init__() + + self.filter_size = config['am']['postnet_filter_size'] + self.fsmn_num_layers = config['am']['postnet_fsmn_num_layers'] + self.num_memory_units = config['am']['postnet_num_memory_units'] + self.ffn_inner_dim = config['am']['postnet_ffn_inner_dim'] + self.dropout = config['am']['postnet_dropout'] + self.shift = config['am']['postnet_shift'] + self.lstm_units = config['am']['postnet_lstm_units'] + self.num_mels = config['am']['num_mels'] + + self.fsmn = FsmnEncoderV2(self.filter_size, self.fsmn_num_layers, + self.num_mels, self.num_memory_units, + self.ffn_inner_dim, self.dropout, self.shift) + self.lstm = nn.LSTM( + self.num_memory_units, + self.lstm_units, + num_layers=1, + batch_first=True) + self.fc = nn.Linear(self.lstm_units, self.num_mels) + + def forward(self, x, mask=None): + postnet_fsmn_output = self.fsmn(x, mask) + # The input can also be a packed variable length sequence, + # here we just omit it for simpliciy due to the mask and uni-directional lstm. + postnet_lstm_output, _ = self.lstm(postnet_fsmn_output) + mel_residual_output = self.fc(postnet_lstm_output) + + return mel_residual_output + + +def mel_recon_loss_fn(output_lengths, + mel_targets, + dec_outputs, + postnet_outputs=None): + mae_loss = nn.L1Loss(reduction='none') + + output_masks = get_mask_from_lengths( + output_lengths, max_len=mel_targets.size(1)) + output_masks = ~output_masks + valid_outputs = output_masks.sum() + + mel_loss_ = torch.sum( + mae_loss(mel_targets, dec_outputs) * output_masks.unsqueeze(-1)) / ( + valid_outputs * mel_targets.size(-1)) + + if postnet_outputs is not None: + mel_loss = torch.sum( + mae_loss(mel_targets, postnet_outputs) + * output_masks.unsqueeze(-1)) / ( + valid_outputs * mel_targets.size(-1)) + else: + mel_loss = 0.0 + + return mel_loss_, mel_loss + + +def prosody_recon_loss_fn(input_lengths, duration_targets, pitch_targets, + energy_targets, log_duration_predictions, + pitch_predictions, energy_predictions): + mae_loss = nn.L1Loss(reduction='none') + + input_masks = get_mask_from_lengths( + input_lengths, max_len=duration_targets.size(1)) + input_masks = ~input_masks + valid_inputs = input_masks.sum() + + dur_loss = torch.sum( + mae_loss( + torch.log(duration_targets.float() + 1), log_duration_predictions) + * input_masks) / valid_inputs + pitch_loss = torch.sum( + mae_loss(pitch_targets, pitch_predictions) + * input_masks) / valid_inputs + energy_loss = torch.sum( + mae_loss(energy_targets, energy_predictions) + * input_masks) / valid_inputs + + return dur_loss, pitch_loss, energy_loss + + +class KanTtsSAMBERT(nn.Module): + + def __init__(self, config, ling_unit_size): + super(KanTtsSAMBERT, self).__init__() + + self.text_encoder = TextFftEncoder(config, ling_unit_size) + self.spk_tokenizer = nn.Embedding(ling_unit_size['speaker'], + config['am']['speaker_units']) + self.emo_tokenizer = nn.Embedding(ling_unit_size['emotion'], + config['am']['emotion_units']) + self.variance_adaptor = VarianceAdaptor(config) + self.mel_decoder = MelPNCADecoder(config) + self.mel_postnet = PostNet(config) + + def get_lfr_mask_from_lengths(self, lengths, max_len): + batch_size = lengths.size(0) + # padding according to the outputs_per_step + padded_lr_lengths = torch.zeros_like(lengths) + for i in range(batch_size): + len_item = int(lengths[i].item()) + padding = self.mel_decoder.r - len_item % self.mel_decoder.r + if (padding < self.mel_decoder.r): + padded_lr_lengths[i] = (len_item + + padding) // self.mel_decoder.r + else: + padded_lr_lengths[i] = len_item // self.mel_decoder.r + + return get_mask_from_lengths( + padded_lr_lengths, max_len=max_len // self.mel_decoder.r) + + def forward(self, + inputs_ling, + inputs_emotion, + inputs_speaker, + input_lengths, + output_lengths=None, + mel_targets=None, + duration_targets=None, + pitch_targets=None, + energy_targets=None): + + batch_size = inputs_ling.size(0) + + input_masks = get_mask_from_lengths( + input_lengths, max_len=inputs_ling.size(1)) + + text_hid, enc_sla_attn_lst = self.text_encoder( + inputs_ling, input_masks, return_attns=True) + + emo_hid = self.emo_tokenizer(inputs_emotion) + spk_hid = self.spk_tokenizer(inputs_speaker) + + if output_lengths is not None: + output_masks = get_mask_from_lengths( + output_lengths, max_len=mel_targets.size(1)) + else: + output_masks = None + + (LR_text_outputs, LR_emo_outputs, LR_spk_outputs, LR_length_rounded, + log_duration_predictions, pitch_predictions, + energy_predictions) = self.variance_adaptor( + text_hid, + emo_hid, + spk_hid, + masks=input_masks, + output_masks=output_masks, + duration_targets=duration_targets, + pitch_targets=pitch_targets, + energy_targets=energy_targets) + + if output_lengths is not None: + lfr_masks = self.get_lfr_mask_from_lengths( + output_lengths, max_len=LR_text_outputs.size(1)) + else: + output_masks = get_mask_from_lengths( + LR_length_rounded, max_len=LR_text_outputs.size(1)) + lfr_masks = None + + # LFR with the factor of outputs_per_step + LFR_text_inputs = LR_text_outputs.contiguous().view( + batch_size, -1, self.mel_decoder.r * text_hid.shape[-1]) + LFR_emo_inputs = LR_emo_outputs.contiguous().view( + batch_size, -1, + self.mel_decoder.r * emo_hid.shape[-1])[:, :, :emo_hid.shape[-1]] + LFR_spk_inputs = LR_spk_outputs.contiguous().view( + batch_size, -1, + self.mel_decoder.r * spk_hid.shape[-1])[:, :, :spk_hid.shape[-1]] + + memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs], + dim=-1) + + if duration_targets is not None: + x_band_width = int( + duration_targets.float().masked_fill(input_masks, 0).max() + / self.mel_decoder.r + 0.5) + h_band_width = x_band_width + else: + x_band_width = int((torch.exp(log_duration_predictions) - 1).max() + / self.mel_decoder.r + 0.5) + h_band_width = x_band_width + + dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder( + memory, + x_band_width, + h_band_width, + target=mel_targets, + mask=lfr_masks, + return_attns=True) + + # De-LFR with the factor of outputs_per_step + dec_outputs = dec_outputs.contiguous().view(batch_size, -1, + self.mel_decoder.d_mel) + + if output_masks is not None: + dec_outputs = dec_outputs.masked_fill( + output_masks.unsqueeze(-1), 0) + + postnet_outputs = self.mel_postnet(dec_outputs, + output_masks) + dec_outputs + if output_masks is not None: + postnet_outputs = postnet_outputs.masked_fill( + output_masks.unsqueeze(-1), 0) + + res = { + 'x_band_width': x_band_width, + 'h_band_width': h_band_width, + 'enc_slf_attn_lst': enc_sla_attn_lst, + 'pnca_x_attn_lst': pnca_x_attn_lst, + 'pnca_h_attn_lst': pnca_h_attn_lst, + 'dec_outputs': dec_outputs, + 'postnet_outputs': postnet_outputs, + 'LR_length_rounded': LR_length_rounded, + 'log_duration_predictions': log_duration_predictions, + 'pitch_predictions': pitch_predictions, + 'energy_predictions': energy_predictions + } + + res['LR_text_outputs'] = LR_text_outputs + res['LR_emo_outputs'] = LR_emo_outputs + res['LR_spk_outputs'] = LR_spk_outputs + + return res diff --git a/modelscope/models/audio/tts/models/models/sambert/positions.py b/modelscope/models/audio/tts/models/models/sambert/positions.py new file mode 100644 index 00000000..9d1e375d --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/positions.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SinusoidalPositionEncoder(nn.Module): + + def __init__(self, max_len, depth): + super(SinusoidalPositionEncoder, self).__init__() + + self.max_len = max_len + self.depth = depth + self.position_enc = nn.Parameter( + self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0), + requires_grad=False) + + def forward(self, input): + bz_in, len_in, _ = input.size() + if len_in > self.max_len: + self.max_len = len_in + self.position_enc.data = self.get_sinusoid_encoding_table( + self.max_len, self.depth).unsqueeze(0).to(input.device) + + output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1) + + return output + + @staticmethod + def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + """ Sinusoid position encoding table """ + + def cal_angle(position, hid_idx): + return position / np.power(10000, hid_idx / float(d_hid / 2 - 1)) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid // 2)] + + scaled_time_table = np.array( + [get_posi_angle_vec(pos_i + 1) for pos_i in range(n_position)]) + + sinusoid_table = np.zeros((n_position, d_hid)) + sinusoid_table[:, :d_hid // 2] = np.sin(scaled_time_table) + sinusoid_table[:, d_hid // 2:] = np.cos(scaled_time_table) + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0.0 + + return torch.FloatTensor(sinusoid_table) + + +class DurSinusoidalPositionEncoder(nn.Module): + + def __init__(self, depth, outputs_per_step): + super(DurSinusoidalPositionEncoder, self).__init__() + + self.depth = depth + self.outputs_per_step = outputs_per_step + + inv_timescales = [ + np.power(10000, 2 * (hid_idx // 2) / depth) + for hid_idx in range(depth) + ] + self.inv_timescales = nn.Parameter( + torch.FloatTensor(inv_timescales), requires_grad=False) + + def forward(self, durations, masks=None): + reps = (durations + 0.5).long() + output_lens = reps.sum(dim=1) + max_len = output_lens.max() + reps_cumsum = torch.cumsum( + F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] + range_ = torch.arange(max_len).to(durations.device)[None, :, None] + mult = ((reps_cumsum[:, :, :-1] <= range_) + & (reps_cumsum[:, :, 1:] > range_)) # yapf:disable + mult = mult.float() + offsets = torch.matmul(mult, + reps_cumsum[:, + 0, :-1].unsqueeze(-1)).squeeze(-1) + dur_pos = range_[:, :, 0] - offsets + 1 + + if masks is not None: + assert masks.size(1) == dur_pos.size(1) + dur_pos = dur_pos.masked_fill(masks, 0.0) + + seq_len = dur_pos.size(1) + padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step + if (padding < self.outputs_per_step): + dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0) + + position_embedding = dur_pos[:, :, None] / self.inv_timescales[None, + None, :] + position_embedding[:, :, 0::2] = torch.sin(position_embedding[:, :, + 0::2]) + position_embedding[:, :, 1::2] = torch.cos(position_embedding[:, :, + 1::2]) + + return position_embedding diff --git a/modelscope/models/audio/tts/models/position.py b/modelscope/models/audio/tts/models/position.py deleted file mode 100755 index bca658dd..00000000 --- a/modelscope/models/audio/tts/models/position.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Define position encoder classes.""" - -import abc -import math - -import tensorflow as tf - -from .reducer import SumReducer - - -class PositionEncoder(tf.keras.layers.Layer): - """Base class for position encoders.""" - - def __init__(self, reducer=None, **kwargs): - """Initializes the position encoder. - Args: - reducer: A :class:`opennmt.layers.Reducer` to merge inputs and position - encodings. Defaults to :class:`opennmt.layers.SumReducer`. - **kwargs: Additional layer keyword arguments. - """ - super(PositionEncoder, self).__init__(**kwargs) - if reducer is None: - reducer = SumReducer(dtype=kwargs.get('dtype')) - self.reducer = reducer - - def call(self, inputs, position=None): # pylint: disable=arguments-differ - """Add position encodings to :obj:`inputs`. - Args: - inputs: The inputs to encode. - position: The single position to encode, to use when this layer is called - step by step. - Returns: - A ``tf.Tensor`` whose shape depends on the configured ``reducer``. - """ - batch_size = tf.shape(inputs)[0] - timesteps = tf.shape(inputs)[1] - input_dim = inputs.shape[-1].value - positions = tf.range(timesteps) + 1 if position is None else [position] - position_encoding = self._encode([positions], input_dim) - position_encoding = tf.tile(position_encoding, [batch_size, 1, 1]) - return self.reducer([inputs, position_encoding]) - - @abc.abstractmethod - def _encode(self, positions, depth): - """Creates position encodings. - Args: - positions: The positions to encode of shape :math:`[B, ...]`. - depth: The encoding depth :math:`D`. - Returns: - A ``tf.Tensor`` of shape :math:`[B, ..., D]`. - """ - raise NotImplementedError() - - -class PositionEmbedder(PositionEncoder): - """Encodes position with a lookup table.""" - - def __init__(self, maximum_position=128, reducer=None, **kwargs): - """Initializes the position encoder. - Args: - maximum_position: The maximum position to embed. Positions greater - than this value will be set to :obj:`maximum_position`. - reducer: A :class:`opennmt.layers.Reducer` to merge inputs and position - encodings. Defaults to :class:`opennmt.layers.SumReducer`. - **kwargs: Additional layer keyword arguments. - """ - super(PositionEmbedder, self).__init__(reducer=reducer, **kwargs) - self.maximum_position = maximum_position - self.embedding = None - - def build(self, input_shape): - shape = [self.maximum_position + 1, input_shape[-1]] - self.embedding = self.add_weight('position_embedding', shape) - super(PositionEmbedder, self).build(input_shape) - - def _encode(self, positions, depth): - positions = tf.minimum(positions, self.maximum_position) - return tf.nn.embedding_lookup(self.embedding, positions) - - -class SinusoidalPositionEncoder(PositionEncoder): - """Encodes positions with sine waves as described in - https://arxiv.org/abs/1706.03762. - """ - - def _encode(self, positions, depth): - if depth % 2 != 0: - raise ValueError( - 'SinusoidalPositionEncoder expects the depth to be divisble ' - 'by 2 but got %d' % depth) - - batch_size = tf.shape(positions)[0] - positions = tf.cast(positions, tf.float32) - - log_timescale_increment = math.log(10000) / (depth / 2 - 1) - inv_timescales = tf.exp( - tf.range(depth / 2, dtype=tf.float32) * -log_timescale_increment) - inv_timescales = tf.reshape( - tf.tile(inv_timescales, [batch_size]), [batch_size, depth // 2]) - scaled_time = tf.expand_dims(positions, -1) * tf.expand_dims( - inv_timescales, 1) - encoding = tf.concat( - [tf.sin(scaled_time), tf.cos(scaled_time)], axis=2) - return tf.cast(encoding, self.dtype) - - -class SinusodalPositionalEncoding(tf.keras.layers.Layer): - - def __init__(self, name='SinusodalPositionalEncoding'): - super(SinusodalPositionalEncoding, self).__init__(name=name) - - @staticmethod - def positional_encoding(len, dim, step=1.): - """ - :param len: int scalar - :param dim: int scalar - :param step: - :return: position embedding - """ - pos_mat = tf.tile( - tf.expand_dims( - tf.range(0, tf.cast(len, dtype=tf.float32), dtype=tf.float32) - * step, - axis=-1), [1, dim]) - dim_mat = tf.tile( - tf.expand_dims( - tf.range(0, tf.cast(dim, dtype=tf.float32), dtype=tf.float32), - axis=0), [len, 1]) - dim_mat_int = tf.cast(dim_mat, dtype=tf.int32) - pos_encoding = tf.where( # [time, dims] - tf.math.equal(tf.math.mod(dim_mat_int, 2), 0), - x=tf.math.sin( - pos_mat / tf.pow(10000., dim_mat / tf.cast(dim, tf.float32))), - y=tf.math.cos(pos_mat - / tf.pow(10000., - (dim_mat - 1) / tf.cast(dim, tf.float32)))) - return pos_encoding - - -class BatchSinusodalPositionalEncoding(tf.keras.layers.Layer): - - def __init__(self, name='BatchSinusodalPositionalEncoding'): - super(BatchSinusodalPositionalEncoding, self).__init__(name=name) - - @staticmethod - def positional_encoding(batch_size, len, dim, pos_mat, step=1.): - """ - :param len: int scalar - :param dim: int scalar - :param step: - :param pos_mat: [B, len] = [len, 1] * dim - :return: position embedding - """ - pos_mat = tf.tile( - tf.expand_dims(tf.cast(pos_mat, dtype=tf.float32) * step, axis=-1), - [1, 1, dim]) # [B, len, dim] - - dim_mat = tf.tile( - tf.expand_dims( - tf.expand_dims( - tf.range( - 0, tf.cast(dim, dtype=tf.float32), dtype=tf.float32), - axis=0), - axis=0), [batch_size, len, 1]) # [B, len, dim] - - dim_mat_int = tf.cast(dim_mat, dtype=tf.int32) - pos_encoding = tf.where( # [B, time, dims] - tf.math.equal(tf.mod(dim_mat_int, 2), 0), - x=tf.math.sin( - pos_mat / tf.pow(10000., dim_mat / tf.cast(dim, tf.float32))), - y=tf.math.cos(pos_mat - / tf.pow(10000., - (dim_mat - 1) / tf.cast(dim, tf.float32)))) - return pos_encoding diff --git a/modelscope/models/audio/tts/models/reducer.py b/modelscope/models/audio/tts/models/reducer.py deleted file mode 100755 index a4c9ae17..00000000 --- a/modelscope/models/audio/tts/models/reducer.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Define reducers: objects that merge inputs.""" - -import abc -import functools - -import tensorflow as tf - - -def pad_in_time(x, padding_length): - """Helper function to pad a tensor in the time dimension and retain the static depth dimension.""" - return tf.pad(x, [[0, 0], [0, padding_length], [0, 0]]) - - -def align_in_time(x, length): - """Aligns the time dimension of :obj:`x` with :obj:`length`.""" - time_dim = tf.shape(x)[1] - return tf.cond( - tf.less(time_dim, length), - true_fn=lambda: pad_in_time(x, length - time_dim), - false_fn=lambda: x[:, :length]) - - -def pad_with_identity(x, - sequence_length, - max_sequence_length, - identity_values=0, - maxlen=None): - """Pads a tensor with identity values up to :obj:`max_sequence_length`. - Args: - x: A ``tf.Tensor`` of shape ``[batch_size, time, depth]``. - sequence_length: The true sequence length of :obj:`x`. - max_sequence_length: The sequence length up to which the tensor must contain - :obj:`identity values`. - identity_values: The identity value. - maxlen: Size of the output time dimension. Default is the maximum value in - obj:`max_sequence_length`. - Returns: - A ``tf.Tensor`` of shape ``[batch_size, maxlen, depth]``. - """ - if maxlen is None: - maxlen = tf.reduce_max(max_sequence_length) - - mask = tf.sequence_mask(sequence_length, maxlen=maxlen, dtype=x.dtype) - mask = tf.expand_dims(mask, axis=-1) - mask_combined = tf.sequence_mask( - max_sequence_length, maxlen=maxlen, dtype=x.dtype) - mask_combined = tf.expand_dims(mask_combined, axis=-1) - - identity_mask = mask_combined * (1.0 - mask) - - x = pad_in_time(x, maxlen - tf.shape(x)[1]) - x = x * mask + (identity_mask * identity_values) - - return x - - -def pad_n_with_identity(inputs, sequence_lengths, identity_values=0): - """Pads each input tensors with identity values up to - ``max(sequence_lengths)`` for each batch. - Args: - inputs: A list of ``tf.Tensor``. - sequence_lengths: A list of sequence length. - identity_values: The identity value. - Returns: - A tuple ``(padded, max_sequence_length)`` which are respectively a list of - ``tf.Tensor`` where each tensor are padded with identity and the combined - sequence length. - """ - max_sequence_length = tf.reduce_max(sequence_lengths, axis=0) - maxlen = tf.reduce_max([tf.shape(x)[1] for x in inputs]) - padded = [ - pad_with_identity( - x, - length, - max_sequence_length, - identity_values=identity_values, - maxlen=maxlen) for x, length in zip(inputs, sequence_lengths) - ] - return padded, max_sequence_length - - -class Reducer(tf.keras.layers.Layer): - """Base class for reducers.""" - - def zip_and_reduce(self, x, y): - """Zips the :obj:`x` with :obj:`y` structures together and reduces all - elements. If the structures are nested, they will be flattened first. - Args: - x: The first structure. - y: The second structure. - Returns: - The same structure as :obj:`x` and :obj:`y` where each element from - :obj:`x` is reduced with the correspond element from :obj:`y`. - Raises: - ValueError: if the two structures are not the same. - """ - tf.nest.assert_same_structure(x, y) - x_flat = tf.nest.flatten(x) - y_flat = tf.nest.flatten(y) - reduced = list(map(self, zip(x_flat, y_flat))) - return tf.nest.pack_sequence_as(x, reduced) - - def call(self, inputs, sequence_length=None): # pylint: disable=arguments-differ - """Reduces all input elements. - Args: - inputs: A list of ``tf.Tensor``. - sequence_length: The length of each input, if reducing sequences. - Returns: - If :obj:`sequence_length` is set, a tuple - ``(reduced_input, reduced_length)``, otherwise a reduced ``tf.Tensor`` - only. - """ - if sequence_length is None: - return self.reduce(inputs) - else: - return self.reduce_sequence( - inputs, sequence_lengths=sequence_length) - - @abc.abstractmethod - def reduce(self, inputs): - """See :meth:`opennmt.layers.Reducer.__call__`.""" - raise NotImplementedError() - - @abc.abstractmethod - def reduce_sequence(self, inputs, sequence_lengths): - """See :meth:`opennmt.layers.Reducer.__call__`.""" - raise NotImplementedError() - - -class SumReducer(Reducer): - """A reducer that sums the inputs.""" - - def reduce(self, inputs): - if len(inputs) == 1: - return inputs[0] - if len(inputs) == 2: - return inputs[0] + inputs[1] - return tf.add_n(inputs) - - def reduce_sequence(self, inputs, sequence_lengths): - padded, combined_length = pad_n_with_identity( - inputs, sequence_lengths, identity_values=0) - return self.reduce(padded), combined_length - - -class MultiplyReducer(Reducer): - """A reducer that multiplies the inputs.""" - - def reduce(self, inputs): - return functools.reduce(lambda a, x: a * x, inputs) - - def reduce_sequence(self, inputs, sequence_lengths): - padded, combined_length = pad_n_with_identity( - inputs, sequence_lengths, identity_values=1) - return self.reduce(padded), combined_length diff --git a/modelscope/models/audio/tts/models/rnn_wrappers.py b/modelscope/models/audio/tts/models/rnn_wrappers.py deleted file mode 100755 index 6c487bab..00000000 --- a/modelscope/models/audio/tts/models/rnn_wrappers.py +++ /dev/null @@ -1,237 +0,0 @@ -import tensorflow as tf -from tensorflow.python.ops import rnn_cell_impl - -from .am_models import prenet - - -class VarPredictorCell(tf.contrib.rnn.RNNCell): - """Wrapper wrapper knock knock.""" - - def __init__(self, var_predictor_cell, is_training, dim, prenet_units): - super(VarPredictorCell, self).__init__() - self._var_predictor_cell = var_predictor_cell - self._is_training = is_training - self._dim = dim - self._prenet_units = prenet_units - - @property - def state_size(self): - return tuple([self.output_size, self._var_predictor_cell.state_size]) - - @property - def output_size(self): - return self._dim - - def zero_state(self, batch_size, dtype): - return tuple([ - rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, - dtype), - self._var_predictor_cell.zero_state(batch_size, dtype) - ]) - - def call(self, inputs, state): - """Run the Tacotron2 super decoder cell.""" - super_cell_out, decoder_state = state - - # split - prenet_input = inputs[:, 0:self._dim] - encoder_output = inputs[:, self._dim:] - - # prenet and concat - prenet_output = prenet( - prenet_input, - self._prenet_units, - self._is_training, - scope='var_prenet') - decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) - - # decoder LSTM/GRU - new_super_cell_out, new_decoder_state = self._var_predictor_cell( - decoder_input, decoder_state) - - # projection - new_super_cell_out = tf.layers.dense( - new_super_cell_out, units=self._dim) - - new_states = tuple([new_super_cell_out, new_decoder_state]) - - return new_super_cell_out, new_states - - -class DurPredictorCell(tf.contrib.rnn.RNNCell): - """Wrapper wrapper knock knock.""" - - def __init__(self, var_predictor_cell, is_training, dim, prenet_units): - super(DurPredictorCell, self).__init__() - self._var_predictor_cell = var_predictor_cell - self._is_training = is_training - self._dim = dim - self._prenet_units = prenet_units - - @property - def state_size(self): - return tuple([self.output_size, self._var_predictor_cell.state_size]) - - @property - def output_size(self): - return self._dim - - def zero_state(self, batch_size, dtype): - return tuple([ - rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, - dtype), - self._var_predictor_cell.zero_state(batch_size, dtype) - ]) - - def call(self, inputs, state): - """Run the Tacotron2 super decoder cell.""" - super_cell_out, decoder_state = state - - # split - prenet_input = inputs[:, 0:self._dim] - encoder_output = inputs[:, self._dim:] - - # prenet and concat - prenet_output = prenet( - prenet_input, - self._prenet_units, - self._is_training, - scope='dur_prenet') - decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) - - # decoder LSTM/GRU - new_super_cell_out, new_decoder_state = self._var_predictor_cell( - decoder_input, decoder_state) - - # projection - new_super_cell_out = tf.layers.dense( - new_super_cell_out, units=self._dim) - new_super_cell_out = tf.nn.relu(new_super_cell_out) - # new_super_cell_out = tf.log(tf.cast(tf.round(tf.exp(new_super_cell_out) - 1), tf.float32) + 1) - - new_states = tuple([new_super_cell_out, new_decoder_state]) - - return new_super_cell_out, new_states - - -class DurPredictorCECell(tf.contrib.rnn.RNNCell): - """Wrapper wrapper knock knock.""" - - def __init__(self, var_predictor_cell, is_training, dim, prenet_units, - max_dur, dur_embedding_dim): - super(DurPredictorCECell, self).__init__() - self._var_predictor_cell = var_predictor_cell - self._is_training = is_training - self._dim = dim - self._prenet_units = prenet_units - self._max_dur = max_dur - self._dur_embedding_dim = dur_embedding_dim - - @property - def state_size(self): - return tuple([self.output_size, self._var_predictor_cell.state_size]) - - @property - def output_size(self): - return self._max_dur - - def zero_state(self, batch_size, dtype): - return tuple([ - rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, - dtype), - self._var_predictor_cell.zero_state(batch_size, dtype) - ]) - - def call(self, inputs, state): - """Run the Tacotron2 super decoder cell.""" - super_cell_out, decoder_state = state - - # split - prenet_input = tf.squeeze( - tf.cast(inputs[:, 0:self._dim], tf.int32), axis=-1) # [N] - prenet_input = tf.one_hot( - prenet_input, self._max_dur, on_value=1.0, off_value=0.0, - axis=-1) # [N, 120] - prenet_input = tf.layers.dense( - prenet_input, units=self._dur_embedding_dim) - encoder_output = inputs[:, self._dim:] - - # prenet and concat - prenet_output = prenet( - prenet_input, - self._prenet_units, - self._is_training, - scope='dur_prenet') - decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) - - # decoder LSTM/GRU - new_super_cell_out, new_decoder_state = self._var_predictor_cell( - decoder_input, decoder_state) - - # projection - new_super_cell_out = tf.layers.dense( - new_super_cell_out, units=self._max_dur) # [N, 120] - new_super_cell_out = tf.nn.softmax(new_super_cell_out) # [N, 120] - - new_states = tuple([new_super_cell_out, new_decoder_state]) - - return new_super_cell_out, new_states - - -class VarPredictorCell2(tf.contrib.rnn.RNNCell): - """Wrapper wrapper knock knock.""" - - def __init__(self, var_predictor_cell, is_training, dim, prenet_units): - super(VarPredictorCell2, self).__init__() - self._var_predictor_cell = var_predictor_cell - self._is_training = is_training - self._dim = dim - self._prenet_units = prenet_units - - @property - def state_size(self): - return tuple([self.output_size, self._var_predictor_cell.state_size]) - - @property - def output_size(self): - return self._dim - - def zero_state(self, batch_size, dtype): - return tuple([ - rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, - dtype), - self._var_predictor_cell.zero_state(batch_size, dtype) - ]) - - def call(self, inputs, state): - '''Run the Tacotron2 super decoder cell.''' - super_cell_out, decoder_state = state - - # split - prenet_input = inputs[:, 0:self._dim] - encoder_output = inputs[:, self._dim:] - - # prenet and concat - prenet_output = prenet( - prenet_input, - self._prenet_units, - self._is_training, - scope='var_prenet') - decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) - - # decoder LSTM/GRU - new_super_cell_out, new_decoder_state = self._var_predictor_cell( - decoder_input, decoder_state) - - # projection - new_super_cell_out = tf.layers.dense( - new_super_cell_out, units=self._dim) - - # split and relu - new_super_cell_out = tf.concat([ - tf.nn.relu(new_super_cell_out[:, 0:1]), new_super_cell_out[:, 1:] - ], axis=-1) # yapf:disable - - new_states = tuple([new_super_cell_out, new_decoder_state]) - - return new_super_cell_out, new_states diff --git a/modelscope/models/audio/tts/models/robutrans.py b/modelscope/models/audio/tts/models/robutrans.py deleted file mode 100755 index ab9fdfcc..00000000 --- a/modelscope/models/audio/tts/models/robutrans.py +++ /dev/null @@ -1,760 +0,0 @@ -import tensorflow as tf -from tensorflow.python.ops.ragged.ragged_util import repeat - -from .fsmn_encoder import FsmnEncoderV2 -from .position import BatchSinusodalPositionalEncoding -from .self_attention_decoder import SelfAttentionDecoder -from .self_attention_encoder import SelfAttentionEncoder - - -class RobuTrans(): - - def __init__(self, hparams): - self._hparams = hparams - - def initialize(self, - inputs, - inputs_emotion, - inputs_speaker, - input_lengths, - output_lengths=None, - mel_targets=None, - durations=None, - pitch_contours=None, - uv_masks=None, - pitch_scales=None, - duration_scales=None, - energy_contours=None, - energy_scales=None): - """Initializes the model for inference. - - Sets "mel_outputs", "linear_outputs", "stop_token_outputs", and "alignments" fields. - - Args: - inputs: int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of - steps in the input time series, and values are character IDs - input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths - of each sequence in inputs. - output_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths - of each sequence in outputs. - mel_targets: float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number - of steps in the output time series, M is num_mels, and values are entries in the mel - spectrogram. Only needed for training. - """ - from tensorflow.contrib.rnn import LSTMBlockCell, MultiRNNCell - from tensorflow.contrib.seq2seq import BasicDecoder - - with tf.variable_scope('inference') as _: - is_training = mel_targets is not None - batch_size = tf.shape(inputs)[0] - hp = self._hparams - - input_mask = None - if input_lengths is not None and is_training: - input_mask = tf.sequence_mask( - input_lengths, tf.shape(inputs)[1], dtype=tf.float32) - - if input_mask is not None: - inputs = inputs * tf.expand_dims(input_mask, -1) - - # speaker embedding - embedded_inputs_speaker = tf.layers.dense( - inputs_speaker, - 32, - activation=None, - use_bias=False, - kernel_initializer=tf.truncated_normal_initializer(stddev=0.5)) - - # emotion embedding - embedded_inputs_emotion = tf.layers.dense( - inputs_emotion, - 32, - activation=None, - use_bias=False, - kernel_initializer=tf.truncated_normal_initializer(stddev=0.5)) - - # symbol embedding - with tf.variable_scope('Embedding'): - embedded_inputs = tf.layers.dense( - inputs, - hp.embedding_dim, - activation=None, - use_bias=False, - kernel_initializer=tf.truncated_normal_initializer( - stddev=0.5)) - - # Encoder - with tf.variable_scope('Encoder'): - Encoder = SelfAttentionEncoder( - num_layers=hp.encoder_num_layers, - num_units=hp.encoder_num_units, - num_heads=hp.encoder_num_heads, - ffn_inner_dim=hp.encoder_ffn_inner_dim, - dropout=hp.encoder_dropout, - attention_dropout=hp.encoder_attention_dropout, - relu_dropout=hp.encoder_relu_dropout) - encoder_outputs, state_mo, sequence_length_mo, attns = Encoder.encode( - embedded_inputs, - sequence_length=input_lengths, - mode=is_training) - encoder_outputs = tf.layers.dense( - encoder_outputs, - hp.encoder_projection_units, - activation=None, - use_bias=False, - kernel_initializer=tf.truncated_normal_initializer( - stddev=0.5)) - - # pitch and energy - var_inputs = tf.concat([ - encoder_outputs, embedded_inputs_speaker, - embedded_inputs_emotion - ], 2) - if input_mask is not None: - var_inputs = var_inputs * tf.expand_dims(input_mask, -1) - - with tf.variable_scope('Pitch_Predictor'): - Pitch_Predictor_FSMN = FsmnEncoderV2( - filter_size=hp.predictor_filter_size, - fsmn_num_layers=hp.predictor_fsmn_num_layers, - dnn_num_layers=hp.predictor_dnn_num_layers, - num_memory_units=hp.predictor_num_memory_units, - ffn_inner_dim=hp.predictor_ffn_inner_dim, - dropout=hp.predictor_dropout, - shift=hp.predictor_shift, - position_encoder=None) - pitch_contour_outputs, _, _ = Pitch_Predictor_FSMN.encode( - tf.concat([ - encoder_outputs, embedded_inputs_speaker, - embedded_inputs_emotion - ], 2), - sequence_length=input_lengths, - mode=is_training) - pitch_contour_outputs, _ = tf.nn.bidirectional_dynamic_rnn( - LSTMBlockCell(hp.predictor_lstm_units), - LSTMBlockCell(hp.predictor_lstm_units), - pitch_contour_outputs, - sequence_length=input_lengths, - dtype=tf.float32) - pitch_contour_outputs = tf.concat( - pitch_contour_outputs, axis=-1) - pitch_contour_outputs = tf.layers.dense( - pitch_contour_outputs, units=1) # [N, T_in, 1] - pitch_contour_outputs = tf.squeeze( - pitch_contour_outputs, axis=2) # [N, T_in] - - with tf.variable_scope('Energy_Predictor'): - Energy_Predictor_FSMN = FsmnEncoderV2( - filter_size=hp.predictor_filter_size, - fsmn_num_layers=hp.predictor_fsmn_num_layers, - dnn_num_layers=hp.predictor_dnn_num_layers, - num_memory_units=hp.predictor_num_memory_units, - ffn_inner_dim=hp.predictor_ffn_inner_dim, - dropout=hp.predictor_dropout, - shift=hp.predictor_shift, - position_encoder=None) - energy_contour_outputs, _, _ = Energy_Predictor_FSMN.encode( - tf.concat([ - encoder_outputs, embedded_inputs_speaker, - embedded_inputs_emotion - ], 2), - sequence_length=input_lengths, - mode=is_training) - energy_contour_outputs, _ = tf.nn.bidirectional_dynamic_rnn( - LSTMBlockCell(hp.predictor_lstm_units), - LSTMBlockCell(hp.predictor_lstm_units), - energy_contour_outputs, - sequence_length=input_lengths, - dtype=tf.float32) - energy_contour_outputs = tf.concat( - energy_contour_outputs, axis=-1) - energy_contour_outputs = tf.layers.dense( - energy_contour_outputs, units=1) # [N, T_in, 1] - energy_contour_outputs = tf.squeeze( - energy_contour_outputs, axis=2) # [N, T_in] - - if is_training: - pitch_embeddings = tf.expand_dims( - pitch_contours, axis=2) # [N, T_in, 1] - pitch_embeddings = tf.layers.conv1d( - pitch_embeddings, - filters=hp.encoder_projection_units, - kernel_size=9, - padding='same', - name='pitch_embeddings') # [N, T_in, 32] - - energy_embeddings = tf.expand_dims( - energy_contours, axis=2) # [N, T_in, 1] - energy_embeddings = tf.layers.conv1d( - energy_embeddings, - filters=hp.encoder_projection_units, - kernel_size=9, - padding='same', - name='energy_embeddings') # [N, T_in, 32] - else: - pitch_contour_outputs *= pitch_scales - pitch_embeddings = tf.expand_dims( - pitch_contour_outputs, axis=2) # [N, T_in, 1] - pitch_embeddings = tf.layers.conv1d( - pitch_embeddings, - filters=hp.encoder_projection_units, - kernel_size=9, - padding='same', - name='pitch_embeddings') # [N, T_in, 32] - - energy_contour_outputs *= energy_scales - energy_embeddings = tf.expand_dims( - energy_contour_outputs, axis=2) # [N, T_in, 1] - energy_embeddings = tf.layers.conv1d( - energy_embeddings, - filters=hp.encoder_projection_units, - kernel_size=9, - padding='same', - name='energy_embeddings') # [N, T_in, 32] - - encoder_outputs_ = encoder_outputs + pitch_embeddings + energy_embeddings - - # duration - dur_inputs = tf.concat([ - encoder_outputs_, embedded_inputs_speaker, - embedded_inputs_emotion - ], 2) - if input_mask is not None: - dur_inputs = dur_inputs * tf.expand_dims(input_mask, -1) - with tf.variable_scope('Duration_Predictor'): - duration_predictor_cell = MultiRNNCell([ - LSTMBlockCell(hp.predictor_lstm_units), - LSTMBlockCell(hp.predictor_lstm_units) - ], state_is_tuple=True) # yapf:disable - from .rnn_wrappers import DurPredictorCell - duration_output_cell = DurPredictorCell( - duration_predictor_cell, is_training, 1, - hp.predictor_prenet_units) - duration_predictor_init_state = duration_output_cell.zero_state( - batch_size=batch_size, dtype=tf.float32) - if is_training: - from .helpers import VarTrainingHelper - duration_helper = VarTrainingHelper( - tf.expand_dims( - tf.log(tf.cast(durations, tf.float32) + 1), - axis=2), dur_inputs, 1) - else: - from .helpers import VarTestHelper - duration_helper = VarTestHelper(batch_size, dur_inputs, 1) - ( - duration_outputs, _ - ), final_duration_predictor_state, _ = tf.contrib.seq2seq.dynamic_decode( - BasicDecoder(duration_output_cell, duration_helper, - duration_predictor_init_state), - maximum_iterations=1000) - duration_outputs = tf.squeeze( - duration_outputs, axis=2) # [N, T_in] - if input_mask is not None: - duration_outputs = duration_outputs * input_mask - duration_outputs_ = tf.exp(duration_outputs) - 1 - - # Length Regulator - with tf.variable_scope('Length_Regulator'): - if is_training: - i = tf.constant(1) - # position embedding - j = tf.constant(1) - dur_len = tf.shape(durations)[-1] - embedded_position_i = tf.range(1, durations[0, 0] + 1) - - def condition_pos(j, e): - return tf.less(j, dur_len) - - def loop_body_pos(j, embedded_position_i): - embedded_position_i = tf.concat([ - embedded_position_i, - tf.range(1, durations[0, j] + 1) - ], axis=0) # yapf:disable - return [j + 1, embedded_position_i] - - j, embedded_position_i = tf.while_loop( - condition_pos, - loop_body_pos, [j, embedded_position_i], - shape_invariants=[ - j.get_shape(), - tf.TensorShape([None]) - ]) - embedded_position = tf.reshape(embedded_position_i, - (1, -1)) - - # others - LR_outputs = repeat( - encoder_outputs_[0:1, :, :], durations[0, :], axis=1) - embedded_outputs_speaker = repeat( - embedded_inputs_speaker[0:1, :, :], - durations[0, :], - axis=1) - embedded_outputs_emotion = repeat( - embedded_inputs_emotion[0:1, :, :], - durations[0, :], - axis=1) - - def condition(i, pos, layer, s, e): - return tf.less(i, tf.shape(mel_targets)[0]) - - def loop_body(i, embedded_position, LR_outputs, - embedded_outputs_speaker, - embedded_outputs_emotion): - # position embedding - jj = tf.constant(1) - embedded_position_i = tf.range(1, durations[i, 0] + 1) - - def condition_pos_i(j, e): - return tf.less(j, dur_len) - - def loop_body_pos_i(j, embedded_position_i): - embedded_position_i = tf.concat([ - embedded_position_i, - tf.range(1, durations[i, j] + 1) - ], axis=0) # yapf:disable - return [j + 1, embedded_position_i] - - jj, embedded_position_i = tf.while_loop( - condition_pos_i, - loop_body_pos_i, [jj, embedded_position_i], - shape_invariants=[ - jj.get_shape(), - tf.TensorShape([None]) - ]) - embedded_position = tf.concat([ - embedded_position, - tf.reshape(embedded_position_i, (1, -1)) - ], 0) - - # others - LR_outputs = tf.concat([ - LR_outputs, - repeat( - encoder_outputs_[i:i + 1, :, :], - durations[i, :], - axis=1) - ], 0) - embedded_outputs_speaker = tf.concat([ - embedded_outputs_speaker, - repeat( - embedded_inputs_speaker[i:i + 1, :, :], - durations[i, :], - axis=1) - ], 0) - embedded_outputs_emotion = tf.concat([ - embedded_outputs_emotion, - repeat( - embedded_inputs_emotion[i:i + 1, :, :], - durations[i, :], - axis=1) - ], 0) - return [ - i + 1, embedded_position, LR_outputs, - embedded_outputs_speaker, embedded_outputs_emotion - ] - - i, embedded_position, LR_outputs, - embedded_outputs_speaker, - embedded_outputs_emotion = tf.while_loop( - condition, - loop_body, [ - i, embedded_position, LR_outputs, - embedded_outputs_speaker, embedded_outputs_emotion - ], - shape_invariants=[ - i.get_shape(), - tf.TensorShape([None, None]), - tf.TensorShape([None, None, None]), - tf.TensorShape([None, None, None]), - tf.TensorShape([None, None, None]) - ], - parallel_iterations=hp.batch_size) - - ori_framenum = tf.shape(mel_targets)[1] - else: - # position - j = tf.constant(1) - dur_len = tf.shape(duration_outputs_)[-1] - embedded_position_i = tf.range( - 1, - tf.cast(tf.round(duration_outputs_)[0, 0], tf.int32) - + 1) - - def condition_pos(j, e): - return tf.less(j, dur_len) - - def loop_body_pos(j, embedded_position_i): - embedded_position_i = tf.concat([ - embedded_position_i, - tf.range( - 1, - tf.cast( - tf.round(duration_outputs_)[0, j], - tf.int32) + 1) - ], axis=0) # yapf:disable - return [j + 1, embedded_position_i] - - j, embedded_position_i = tf.while_loop( - condition_pos, - loop_body_pos, [j, embedded_position_i], - shape_invariants=[ - j.get_shape(), - tf.TensorShape([None]) - ]) - embedded_position = tf.reshape(embedded_position_i, - (1, -1)) - # others - duration_outputs_ *= duration_scales - LR_outputs = repeat( - encoder_outputs_[0:1, :, :], - tf.cast(tf.round(duration_outputs_)[0, :], tf.int32), - axis=1) - embedded_outputs_speaker = repeat( - embedded_inputs_speaker[0:1, :, :], - tf.cast(tf.round(duration_outputs_)[0, :], tf.int32), - axis=1) - embedded_outputs_emotion = repeat( - embedded_inputs_emotion[0:1, :, :], - tf.cast(tf.round(duration_outputs_)[0, :], tf.int32), - axis=1) - ori_framenum = tf.shape(LR_outputs)[1] - - left = hp.outputs_per_step - tf.mod( - ori_framenum, hp.outputs_per_step) - LR_outputs = tf.cond( - tf.equal(left, - hp.outputs_per_step), lambda: LR_outputs, - lambda: tf.pad(LR_outputs, [[0, 0], [0, left], [0, 0]], - 'CONSTANT')) - embedded_outputs_speaker = tf.cond( - tf.equal(left, hp.outputs_per_step), - lambda: embedded_outputs_speaker, lambda: tf.pad( - embedded_outputs_speaker, [[0, 0], [0, left], - [0, 0]], 'CONSTANT')) - embedded_outputs_emotion = tf.cond( - tf.equal(left, hp.outputs_per_step), - lambda: embedded_outputs_emotion, lambda: tf.pad( - embedded_outputs_emotion, [[0, 0], [0, left], - [0, 0]], 'CONSTANT')) - embedded_position = tf.cond( - tf.equal(left, hp.outputs_per_step), - lambda: embedded_position, - lambda: tf.pad(embedded_position, [[0, 0], [0, left]], - 'CONSTANT')) - - # Pos_Embedding - with tf.variable_scope('Position_Embedding'): - Pos_Embedding = BatchSinusodalPositionalEncoding() - position_embeddings = Pos_Embedding.positional_encoding( - batch_size, - tf.shape(LR_outputs)[1], hp.encoder_projection_units, - embedded_position) - LR_outputs += position_embeddings - - # multi-frame - LR_outputs = tf.reshape(LR_outputs, [ - batch_size, -1, - hp.outputs_per_step * hp.encoder_projection_units - ]) - embedded_outputs_speaker = tf.reshape( - embedded_outputs_speaker, - [batch_size, -1, hp.outputs_per_step * 32])[:, :, :32] - embedded_outputs_emotion = tf.reshape( - embedded_outputs_emotion, - [batch_size, -1, hp.outputs_per_step * 32])[:, :, :32] - # [N, T_out, D_LR_outputs] (D_LR_outputs = hp.outputs_per_step * hp.encoder_projection_units + 64) - LR_outputs = tf.concat([ - LR_outputs, embedded_outputs_speaker, embedded_outputs_emotion - ], -1) - - # auto bandwidth - if is_training: - durations_mask = tf.cast(durations, - tf.float32) * input_mask # [N, T_in] - else: - durations_mask = duration_outputs_ - X_band_width = tf.cast( - tf.round(tf.reduce_max(durations_mask) / hp.outputs_per_step), - tf.int32) - H_band_width = X_band_width - - with tf.variable_scope('Decoder'): - Decoder = SelfAttentionDecoder( - num_layers=hp.decoder_num_layers, - num_units=hp.decoder_num_units, - num_heads=hp.decoder_num_heads, - ffn_inner_dim=hp.decoder_ffn_inner_dim, - dropout=hp.decoder_dropout, - attention_dropout=hp.decoder_attention_dropout, - relu_dropout=hp.decoder_relu_dropout, - prenet_units=hp.prenet_units, - dense_units=hp.prenet_proj_units, - num_mels=hp.num_mels, - outputs_per_step=hp.outputs_per_step, - X_band_width=X_band_width, - H_band_width=H_band_width, - position_encoder=None) - if is_training: - if hp.free_run: - r = hp.outputs_per_step - init_decoder_input = tf.expand_dims( - tf.tile([[0.0]], [batch_size, hp.num_mels]), - axis=1) # [N, 1, hp.num_mels] - decoder_input_lengths = tf.cast( - output_lengths / r, tf.int32) - decoder_outputs, attention_x, attention_h = Decoder.dynamic_decode_and_search( - init_decoder_input, - maximum_iterations=tf.shape(LR_outputs)[1], - mode=is_training, - memory=LR_outputs, - memory_sequence_length=decoder_input_lengths) - else: - r = hp.outputs_per_step - decoder_input = mel_targets[:, r - 1:: - r, :] # [N, T_out / r, hp.num_mels] - init_decoder_input = tf.expand_dims( - tf.tile([[0.0]], [batch_size, hp.num_mels]), - axis=1) # [N, 1, hp.num_mels] - decoder_input = tf.concat( - [init_decoder_input, decoder_input], - axis=1) # [N, T_out / r + 1, hp.num_mels] - decoder_input = decoder_input[:, : - -1, :] # [N, T_out / r, hp.num_mels] - decoder_input_lengths = tf.cast( - output_lengths / r, tf.int32) - decoder_outputs, attention_x, attention_h = Decoder.decode_from_inputs( - decoder_input, - decoder_input_lengths, - mode=is_training, - memory=LR_outputs, - memory_sequence_length=decoder_input_lengths) - else: - init_decoder_input = tf.expand_dims( - tf.tile([[0.0]], [batch_size, hp.num_mels]), - axis=1) # [N, 1, hp.num_mels] - decoder_outputs, attention_x, attention_h = Decoder.dynamic_decode_and_search( - init_decoder_input, - maximum_iterations=tf.shape(LR_outputs)[1], - mode=is_training, - memory=LR_outputs, - memory_sequence_length=tf.expand_dims( - tf.shape(LR_outputs)[1], axis=0)) - - if is_training: - mel_outputs_ = tf.reshape(decoder_outputs, - [batch_size, -1, hp.num_mels]) - else: - mel_outputs_ = tf.reshape( - decoder_outputs, - [batch_size, -1, hp.num_mels])[:, :ori_framenum, :] - mel_outputs = mel_outputs_ - - with tf.variable_scope('Postnet'): - Postnet_FSMN = FsmnEncoderV2( - filter_size=hp.postnet_filter_size, - fsmn_num_layers=hp.postnet_fsmn_num_layers, - dnn_num_layers=hp.postnet_dnn_num_layers, - num_memory_units=hp.postnet_num_memory_units, - ffn_inner_dim=hp.postnet_ffn_inner_dim, - dropout=hp.postnet_dropout, - shift=hp.postnet_shift, - position_encoder=None) - if is_training: - postnet_fsmn_outputs, _, _ = Postnet_FSMN.encode( - mel_outputs, - sequence_length=output_lengths, - mode=is_training) - hidden_lstm_outputs, _ = tf.nn.dynamic_rnn( - LSTMBlockCell(hp.postnet_lstm_units), - postnet_fsmn_outputs, - sequence_length=output_lengths, - dtype=tf.float32) - else: - postnet_fsmn_outputs, _, _ = Postnet_FSMN.encode( - mel_outputs, - sequence_length=[tf.shape(mel_outputs_)[1]], - mode=is_training) - hidden_lstm_outputs, _ = tf.nn.dynamic_rnn( - LSTMBlockCell(hp.postnet_lstm_units), - postnet_fsmn_outputs, - sequence_length=[tf.shape(mel_outputs_)[1]], - dtype=tf.float32) - - mel_residual_outputs = tf.layers.dense( - hidden_lstm_outputs, units=hp.num_mels) - mel_outputs += mel_residual_outputs - - self.inputs = inputs - self.inputs_speaker = inputs_speaker - self.inputs_emotion = inputs_emotion - self.input_lengths = input_lengths - self.durations = durations - self.output_lengths = output_lengths - self.mel_outputs_ = mel_outputs_ - self.mel_outputs = mel_outputs - self.mel_targets = mel_targets - self.duration_outputs = duration_outputs - self.duration_outputs_ = duration_outputs_ - self.duration_scales = duration_scales - self.pitch_contour_outputs = pitch_contour_outputs - self.pitch_contours = pitch_contours - self.pitch_scales = pitch_scales - self.energy_contour_outputs = energy_contour_outputs - self.energy_contours = energy_contours - self.energy_scales = energy_scales - self.uv_masks_ = uv_masks - - self.embedded_inputs_emotion = embedded_inputs_emotion - self.embedding_fsmn_outputs = embedded_inputs - self.encoder_outputs = encoder_outputs - self.encoder_outputs_ = encoder_outputs_ - self.LR_outputs = LR_outputs - self.postnet_fsmn_outputs = postnet_fsmn_outputs - - self.pitch_embeddings = pitch_embeddings - self.energy_embeddings = energy_embeddings - - self.attns = attns - self.attention_x = attention_x - self.attention_h = attention_h - self.X_band_width = X_band_width - self.H_band_width = H_band_width - - def add_loss(self): - '''Adds loss to the model. Sets "loss" field. initialize must have been called.''' - with tf.variable_scope('loss') as _: - hp = self._hparams - mask = tf.sequence_mask( - self.output_lengths, - tf.shape(self.mel_targets)[1], - dtype=tf.float32) - valid_outputs = tf.reduce_sum(mask) - - mask_input = tf.sequence_mask( - self.input_lengths, - tf.shape(self.durations)[1], - dtype=tf.float32) - valid_inputs = tf.reduce_sum(mask_input) - - # mel loss - if self.uv_masks_ is not None: - valid_outputs_mask = tf.reduce_sum( - tf.expand_dims(mask, -1) * self.uv_masks_) - self.mel_loss_ = tf.reduce_sum( - tf.abs(self.mel_targets - self.mel_outputs_) - * tf.expand_dims(mask, -1) * self.uv_masks_) / ( - valid_outputs_mask * hp.num_mels) - self.mel_loss = tf.reduce_sum( - tf.abs(self.mel_targets - self.mel_outputs) - * tf.expand_dims(mask, -1) * self.uv_masks_) / ( - valid_outputs_mask * hp.num_mels) - else: - self.mel_loss_ = tf.reduce_sum( - tf.abs(self.mel_targets - self.mel_outputs_) - * tf.expand_dims(mask, -1)) / ( - valid_outputs * hp.num_mels) - self.mel_loss = tf.reduce_sum( - tf.abs(self.mel_targets - self.mel_outputs) - * tf.expand_dims(mask, -1)) / ( - valid_outputs * hp.num_mels) - - # duration loss - self.duration_loss = tf.reduce_sum( - tf.abs( - tf.log(tf.cast(self.durations, tf.float32) + 1) - - self.duration_outputs) * mask_input) / valid_inputs - - # pitch contour loss - self.pitch_contour_loss = tf.reduce_sum( - tf.abs(self.pitch_contours - self.pitch_contour_outputs) - * mask_input) / valid_inputs - - # energy contour loss - self.energy_contour_loss = tf.reduce_sum( - tf.abs(self.energy_contours - self.energy_contour_outputs) - * mask_input) / valid_inputs - - # final loss - self.loss = self.mel_loss_ + self.mel_loss + self.duration_loss \ - + self.pitch_contour_loss + self.energy_contour_loss - - # guided attention loss - self.guided_attention_loss = tf.constant(0.0) - if hp.guided_attention: - i0 = tf.constant(0) - loss0 = tf.constant(0.0) - - def c(i, _): - return tf.less(i, tf.shape(mel_targets)[0]) - - def loop_body(i, loss): - decoder_input_lengths = tf.cast( - self.output_lengths / hp.outputs_per_step, tf.int32) - input_len = decoder_input_lengths[i] - output_len = decoder_input_lengths[i] - input_w = tf.expand_dims( - tf.range(tf.cast(input_len, dtype=tf.float32)), - axis=1) / tf.cast( - input_len, dtype=tf.float32) # [T_in, 1] - output_w = tf.expand_dims( - tf.range(tf.cast(output_len, dtype=tf.float32)), - axis=0) / tf.cast( - output_len, dtype=tf.float32) # [1, T_out] - guided_attention_w = 1.0 - tf.exp( - -(1 / hp.guided_attention_2g_squared) - * tf.square(input_w - output_w)) # [T_in, T_out] - guided_attention_w = tf.expand_dims( - guided_attention_w, axis=0) # [1, T_in, T_out] - # [hp.decoder_num_heads, T_in, T_out] - guided_attention_w = tf.tile(guided_attention_w, - [hp.decoder_num_heads, 1, 1]) - loss_i = tf.constant(0.0) - for j in range(hp.decoder_num_layers): - loss_i += tf.reduce_mean( - self.attention_h[j][i, :, :input_len, :output_len] - * guided_attention_w) - - return [tf.add(i, 1), tf.add(loss, loss_i)] - - _, loss = tf.while_loop( - c, - loop_body, - loop_vars=[i0, loss0], - parallel_iterations=hp.batch_size) - self.guided_attention_loss = loss / hp.batch_size - self.loss += hp.guided_attention_loss_weight * self.guided_attention_loss - - def add_optimizer(self, global_step): - '''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called. - - Args: - global_step: int32 scalar Tensor representing current global step in training - ''' - with tf.variable_scope('optimizer') as _: - hp = self._hparams - if hp.decay_learning_rate: - self.learning_rate = _learning_rate_decay( - hp.initial_learning_rate, global_step) - else: - self.learning_rate = tf.convert_to_tensor( - hp.initial_learning_rate) - optimizer = tf.train.AdamOptimizer(self.learning_rate, - hp.adam_beta1, hp.adam_beta2) - gradients, variables = zip(*optimizer.compute_gradients(self.loss)) - self.gradients = gradients - clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) - - # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See: - # https://github.com/tensorflow/tensorflow/issues/1122 - with tf.control_dependencies( - tf.get_collection(tf.GraphKeys.UPDATE_OPS)): - self.optimize = optimizer.apply_gradients( - zip(clipped_gradients, variables), global_step=global_step) - - -def _learning_rate_decay(init_lr, global_step): - # Noam scheme from tensor2tensor: - warmup_steps = 4000.0 - step = tf.cast(global_step + 1, dtype=tf.float32) - return init_lr * warmup_steps**0.5 * tf.minimum(step * warmup_steps**-1.5, - step**-0.5) diff --git a/modelscope/models/audio/tts/models/self_attention_decoder.py b/modelscope/models/audio/tts/models/self_attention_decoder.py deleted file mode 100755 index 9cf3fcaa..00000000 --- a/modelscope/models/audio/tts/models/self_attention_decoder.py +++ /dev/null @@ -1,817 +0,0 @@ -"""Define self-attention decoder.""" - -import sys - -import tensorflow as tf - -from . import compat, transformer -from .am_models import decoder_prenet -from .position import SinusoidalPositionEncoder - - -class SelfAttentionDecoder(): - """Decoder using self-attention as described in - https://arxiv.org/abs/1706.03762. - """ - - def __init__(self, - num_layers, - num_units=512, - num_heads=8, - ffn_inner_dim=2048, - dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1, - prenet_units=256, - dense_units=128, - num_mels=80, - outputs_per_step=3, - X_band_width=None, - H_band_width=None, - position_encoder=SinusoidalPositionEncoder(), - self_attention_type='scaled_dot'): - """Initializes the parameters of the decoder. - - Args: - num_layers: The number of layers. - num_units: The number of hidden units. - num_heads: The number of heads in the multi-head attention. - ffn_inner_dim: The number of units of the inner linear transformation - in the feed forward layer. - dropout: The probability to drop units from the outputs. - attention_dropout: The probability to drop units from the attention. - relu_dropout: The probability to drop units from the ReLU activation in - the feed forward layer. - position_encoder: A :class:`opennmt.layers.position.PositionEncoder` to - apply on inputs or ``None``. - self_attention_type: Type of self attention, "scaled_dot" or "average" (case - insensitive). - - Raises: - ValueError: if :obj:`self_attention_type` is invalid. - """ - super(SelfAttentionDecoder, self).__init__() - self.num_layers = num_layers - self.num_units = num_units - self.num_heads = num_heads - self.ffn_inner_dim = ffn_inner_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.relu_dropout = relu_dropout - self.position_encoder = position_encoder - self.self_attention_type = self_attention_type.lower() - if self.self_attention_type not in ('scaled_dot', 'average'): - raise ValueError('invalid attention type %s' - % self.self_attention_type) - if self.self_attention_type == 'average': - tf.logging.warning( - 'Support for average attention network is experimental ' - 'and may change in future versions.') - self.prenet_units = prenet_units - self.dense_units = dense_units - self.num_mels = num_mels - self.outputs_per_step = outputs_per_step - self.X_band_width = X_band_width - self.H_band_width = H_band_width - - @property - def output_size(self): - """Returns the decoder output size.""" - return self.num_units - - @property - def support_alignment_history(self): - return True - - @property - def support_multi_source(self): - return True - - def _init_cache(self, batch_size, dtype=tf.float32, num_sources=1): - cache = {} - - for layer in range(self.num_layers): - proj_cache_shape = [ - batch_size, self.num_heads, 0, self.num_units // self.num_heads - ] - layer_cache = {} - layer_cache['memory'] = [{ - 'memory_keys': - tf.zeros(proj_cache_shape, dtype=dtype), - 'memory_values': - tf.zeros(proj_cache_shape, dtype=dtype) - } for _ in range(num_sources)] - if self.self_attention_type == 'scaled_dot': - layer_cache['self_keys'] = tf.zeros( - proj_cache_shape, dtype=dtype) - layer_cache['self_values'] = tf.zeros( - proj_cache_shape, dtype=dtype) - elif self.self_attention_type == 'average': - layer_cache['prev_g'] = tf.zeros( - [batch_size, 1, self.num_units], dtype=dtype) - cache['layer_{}'.format(layer)] = layer_cache - - return cache - - def _init_attn(self, dtype=tf.float32): - attn = [] - for layer in range(self.num_layers): - attn.append(tf.TensorArray(tf.float32, size=0, dynamic_size=True)) - return attn - - def _self_attention_stack(self, - inputs, - sequence_length=None, - mode=True, - cache=None, - memory=None, - memory_sequence_length=None, - step=None): - - # [N, T_out, self.dense_units] or [N, 1, self.dense_units] - prenet_outputs = decoder_prenet(inputs, self.prenet_units, - self.dense_units, mode) - if step is None: - decoder_inputs = tf.concat( - [memory, prenet_outputs], - axis=-1) # [N, T_out, memory_size + self.dense_units] - else: - decoder_inputs = tf.concat( - [memory[:, step:step + 1, :], prenet_outputs], - axis=-1) # [N, 1, memory_size + self.dense_units] - decoder_inputs = tf.layers.dense( - decoder_inputs, units=self.dense_units) - - inputs = decoder_inputs - inputs *= self.num_units**0.5 - if self.position_encoder is not None: - inputs = self.position_encoder( - inputs, position=step + 1 if step is not None else None) - - inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) - - decoder_mask = None - memory_mask = None - # last_attention = None - - X_band_width_tmp = -1 - H_band_width_tmp = -1 - if self.X_band_width is not None: - X_band_width_tmp = tf.cast( - tf.cond( - tf.less(tf.shape(memory)[1], self.X_band_width), - lambda: -1, lambda: self.X_band_width), - dtype=tf.int64) - if self.H_band_width is not None: - H_band_width_tmp = tf.cast( - tf.cond( - tf.less(tf.shape(memory)[1], self.H_band_width), - lambda: -1, lambda: self.H_band_width), - dtype=tf.int64) - - if self.self_attention_type == 'scaled_dot': - if sequence_length is not None: - decoder_mask = transformer.build_future_mask( - sequence_length, - num_heads=self.num_heads, - maximum_length=tf.shape(inputs)[1], - band=X_band_width_tmp) # [N, 1, T_out, T_out] - elif self.self_attention_type == 'average': - if cache is None: - if sequence_length is None: - sequence_length = tf.fill([tf.shape(inputs)[0]], - tf.shape(inputs)[1]) - decoder_mask = transformer.cumulative_average_mask( - sequence_length, - maximum_length=tf.shape(inputs)[1], - dtype=inputs.dtype) - - if memory is not None and not tf.contrib.framework.nest.is_sequence( - memory): - memory = (memory, ) - if memory_sequence_length is not None: - if not tf.contrib.framework.nest.is_sequence( - memory_sequence_length): - memory_sequence_length = (memory_sequence_length, ) - if step is None: - memory_mask = [ - transformer.build_history_mask( - length, - num_heads=self.num_heads, - maximum_length=tf.shape(m)[1], - band=H_band_width_tmp) - for m, length in zip(memory, memory_sequence_length) - ] - else: - memory_mask = [ - transformer.build_history_mask( - length, - num_heads=self.num_heads, - maximum_length=tf.shape(m)[1], - band=H_band_width_tmp)[:, :, step:step + 1, :] - for m, length in zip(memory, memory_sequence_length) - ] - - # last_attention = None - attns_x = [] - attns_h = [] - for layer in range(self.num_layers): - layer_name = 'layer_{}'.format(layer) - layer_cache = cache[layer_name] if cache is not None else None - with tf.variable_scope(layer_name): - if memory is not None: - for i, (mem, mask) in enumerate(zip(memory, memory_mask)): - memory_cache = None - if layer_cache is not None: - memory_cache = layer_cache['memory'][i] - scope_name = 'multi_head_{}'.format(i) - if i == 0: - scope_name = 'multi_head' - with tf.variable_scope(scope_name): - encoded, attn_x, attn_h = transformer.multi_head_attention_PNCA( - self.num_heads, - transformer.norm(inputs), - mem, - mode, - num_units=self.num_units, - mask=decoder_mask, - mask_h=mask, - cache=layer_cache, - cache_h=memory_cache, - dropout=self.attention_dropout, - return_attention=True, - layer_name=layer_name, - X_band_width=self.X_band_width) - attns_x.append(attn_x) - attns_h.append(attn_h) - context = transformer.drop_and_add( - inputs, encoded, mode, dropout=self.dropout) - - with tf.variable_scope('ffn'): - transformed = transformer.feed_forward_ori( - transformer.norm(context), - self.ffn_inner_dim, - mode, - dropout=self.relu_dropout) - transformed = transformer.drop_and_add( - context, transformed, mode, dropout=self.dropout) - - inputs = transformed - - outputs = transformer.norm(inputs) - outputs = tf.layers.dense( - outputs, units=self.num_mels * self.outputs_per_step) - return outputs, attns_x, attns_h - - def decode_from_inputs(self, - inputs, - sequence_length, - initial_state=None, - mode=True, - memory=None, - memory_sequence_length=None): - outputs, attention_x, attention_h = self._self_attention_stack( - inputs, - sequence_length=sequence_length, - mode=mode, - memory=memory, - memory_sequence_length=memory_sequence_length) - return outputs, attention_x, attention_h - - def step_fn(self, - mode, - batch_size, - initial_state=None, - memory=None, - memory_sequence_length=None, - dtype=tf.float32): - if memory is None: - num_sources = 0 - elif tf.contrib.framework.nest.is_sequence(memory): - num_sources = len(memory) - else: - num_sources = 1 - cache = self._init_cache( - batch_size, dtype=dtype, num_sources=num_sources) - attention_x = self._init_attn(dtype=dtype) - attention_h = self._init_attn(dtype=dtype) - - def _fn(step, inputs, cache): - outputs, attention_x, attention_h = self._self_attention_stack( - inputs, - mode=mode, - cache=cache, - memory=memory, - memory_sequence_length=memory_sequence_length, - step=step) - attention_x_tmp = [] - for layer in range(len(attention_h)): - attention_x_tmp_l = tf.zeros_like(attention_h[layer]) - if self.X_band_width is not None: - pred = tf.less(step, self.X_band_width + 1) - attention_x_tmp_l_1 = tf.cond(pred, # yapf:disable - lambda: attention_x_tmp_l[:, :, :, :step + 1] + attention_x[layer], - lambda: tf.concat([ - attention_x_tmp_l[:, :, :, - :step - self.X_band_width], - attention_x_tmp_l[:, :, :, - step - self.X_band_width:step + 1] - + attention_x[layer]], - axis=-1)) # yapf:disable - attention_x_tmp_l_2 = attention_x_tmp_l[:, :, :, step + 1:] - attention_x_tmp.append( - tf.concat([attention_x_tmp_l_1, attention_x_tmp_l_2], - axis=-1)) - else: - attention_x_tmp_l_1 = attention_x_tmp_l[:, :, :, :step + 1] - attention_x_tmp_l_2 = attention_x_tmp_l[:, :, :, step + 1:] - attention_x_tmp.append( - tf.concat([ - attention_x_tmp_l_1 + attention_x[layer], - attention_x_tmp_l_2 - ], axis=-1)) # yapf:disable - attention_x = attention_x_tmp - return outputs, cache, attention_x, attention_h - - return _fn, cache, attention_x, attention_h - - def dynamic_decode_and_search(self, init_decoder_input, maximum_iterations, - mode, memory, memory_sequence_length): - batch_size = tf.shape(init_decoder_input)[0] - step_fn, init_cache, init_attn_x, init_attn_h = self.step_fn( - mode, - batch_size, - memory=memory, - memory_sequence_length=memory_sequence_length) - - outputs, attention_x, attention_h, cache = self.dynamic_decode( - step_fn, - init_decoder_input, - init_cache=init_cache, - init_attn_x=init_attn_x, - init_attn_h=init_attn_h, - maximum_iterations=maximum_iterations, - batch_size=batch_size) - return outputs, attention_x, attention_h - - def dynamic_decode_and_search_teacher_forcing(self, decoder_input, - maximum_iterations, mode, - memory, - memory_sequence_length): - batch_size = tf.shape(decoder_input)[0] - step_fn, init_cache, init_attn_x, init_attn_h = self.step_fn( - mode, - batch_size, - memory=memory, - memory_sequence_length=memory_sequence_length) - - outputs, attention_x, attention_h, cache = self.dynamic_decode_teacher_forcing( - step_fn, - decoder_input, - init_cache=init_cache, - init_attn_x=init_attn_x, - init_attn_h=init_attn_h, - maximum_iterations=maximum_iterations, - batch_size=batch_size) - return outputs, attention_x, attention_h - - def dynamic_decode(self, - step_fn, - init_decoder_input, - init_cache=None, - init_attn_x=None, - init_attn_h=None, - maximum_iterations=None, - batch_size=None): - - def _cond(step, cache, inputs, outputs, attention_x, attention_h): # pylint: disable=unused-argument - return tf.less(step, maximum_iterations) - - def _body(step, cache, inputs, outputs, attention_x, attention_h): - # output: [1, 1, num_mels * r] - # attn: [1, 1, T_out] - output, cache, attn_x, attn_h = step_fn( - step, inputs, cache) # outputs, cache, attention, attns - for layer in range(len(attention_x)): - attention_x[layer] = attention_x[layer].write( - step, tf.cast(attn_x[layer], tf.float32)) - - for layer in range(len(attention_h)): - attention_h[layer] = attention_h[layer].write( - step, tf.cast(attn_h[layer], tf.float32)) - - outputs = outputs.write(step, tf.cast(output, tf.float32)) - return step + 1, cache, output[:, :, -self. - num_mels:], outputs, attention_x, attention_h - - step = tf.constant(0, dtype=tf.int32) - outputs = tf.TensorArray(tf.float32, size=0, dynamic_size=True) - - _, cache, _, outputs, attention_x, attention_h = tf.while_loop( - _cond, - _body, - loop_vars=(step, init_cache, init_decoder_input, outputs, - init_attn_x, init_attn_h), - shape_invariants=(step.shape, - compat.nest.map_structure( - self._get_shape_invariants, init_cache), - compat.nest.map_structure( - self._get_shape_invariants, - init_decoder_input), tf.TensorShape(None), - compat.nest.map_structure( - self._get_shape_invariants, init_attn_x), - compat.nest.map_structure( - self._get_shape_invariants, init_attn_h)), - parallel_iterations=1, - back_prop=False, - maximum_iterations=maximum_iterations) - # element of outputs: [N, 1, num_mels * r] - outputs_stack = outputs.stack() # [T_out, N, 1, num_mels * r] - outputs_stack = tf.transpose( - outputs_stack, perm=[2, 1, 0, 3]) # [1, N, T_out, num_mels * r] - outputs_stack = tf.squeeze( - outputs_stack, axis=0) # [N, T_out, num_mels * r] - - attention_x_stack = [] - for layer in range(len(attention_x)): - attention_x_stack_tmp = attention_x[layer].stack( - ) # [T_out, N, H, 1, T_out] - attention_x_stack_tmp = tf.transpose( - attention_x_stack_tmp, perm=[3, 1, 2, 0, - 4]) # [1, N, H, T_out, T_out] - attention_x_stack_tmp = tf.squeeze( - attention_x_stack_tmp, axis=0) # [N, H, T_out, T_out] - attention_x_stack.append(attention_x_stack_tmp) - - attention_h_stack = [] - for layer in range(len(attention_h)): - attention_h_stack_tmp = attention_h[layer].stack( - ) # [T_out, N, H, 1, T_out] - attention_h_stack_tmp = tf.transpose( - attention_h_stack_tmp, perm=[3, 1, 2, 0, - 4]) # [1, N, H, T_out, T_out] - attention_h_stack_tmp = tf.squeeze( - attention_h_stack_tmp, axis=0) # [N, H, T_out, T_out] - attention_h_stack.append(attention_h_stack_tmp) - - return outputs_stack, attention_x_stack, attention_h_stack, cache - - def dynamic_decode_teacher_forcing(self, - step_fn, - decoder_input, - init_cache=None, - init_attn_x=None, - init_attn_h=None, - maximum_iterations=None, - batch_size=None): - - def _cond(step, cache, inputs, outputs, attention_x, attention_h): # pylint: disable=unused-argument - return tf.less(step, maximum_iterations) - - def _body(step, cache, inputs, outputs, attention_x, attention_h): - # output: [1, 1, num_mels * r] - # attn: [1, 1, T_out] - output, cache, attn_x, attn_h = step_fn( - step, inputs[:, step:step + 1, :], - cache) # outputs, cache, attention, attns - for layer in range(len(attention_x)): - attention_x[layer] = attention_x[layer].write( - step, tf.cast(attn_x[layer], tf.float32)) - - for layer in range(len(attention_h)): - attention_h[layer] = attention_h[layer].write( - step, tf.cast(attn_h[layer], tf.float32)) - outputs = outputs.write(step, tf.cast(output, tf.float32)) - return step + 1, cache, inputs, outputs, attention_x, attention_h - - step = tf.constant(0, dtype=tf.int32) - outputs = tf.TensorArray(tf.float32, size=0, dynamic_size=True) - - _, cache, _, outputs, attention_x, attention_h = tf.while_loop( - _cond, - _body, - loop_vars=(step, init_cache, decoder_input, outputs, init_attn_x, - init_attn_h), - shape_invariants=(step.shape, - compat.nest.map_structure( - self._get_shape_invariants, - init_cache), decoder_input.shape, - tf.TensorShape(None), - compat.nest.map_structure( - self._get_shape_invariants, init_attn_x), - compat.nest.map_structure( - self._get_shape_invariants, init_attn_h)), - parallel_iterations=1, - back_prop=False, - maximum_iterations=maximum_iterations) - # element of outputs: [N, 1, num_mels * r] - outputs_stack = outputs.stack() # [T_out, N, 1, num_mels * r] - outputs_stack = tf.transpose( - outputs_stack, perm=[2, 1, 0, 3]) # [1, N, T_out, num_mels * r] - outputs_stack = tf.squeeze( - outputs_stack, axis=0) # [N, T_out, num_mels * r] - - attention_x_stack = [] - for layer in range(len(attention_x)): - attention_x_stack_tmp = attention_x[layer].stack( - ) # [T_out, N, H, 1, T_out] - attention_x_stack_tmp = tf.transpose( - attention_x_stack_tmp, perm=[3, 1, 2, 0, - 4]) # [1, N, H, T_out, T_out] - attention_x_stack_tmp = tf.squeeze( - attention_x_stack_tmp, axis=0) # [N, H, T_out, T_out] - attention_x_stack.append(attention_x_stack_tmp) - - attention_h_stack = [] - for layer in range(len(attention_h)): - attention_h_stack_tmp = attention_h[layer].stack( - ) # [T_out, N, H, 1, T_out] - attention_h_stack_tmp = tf.transpose( - attention_h_stack_tmp, perm=[3, 1, 2, 0, - 4]) # [1, N, H, T_out, T_out] - attention_h_stack_tmp = tf.squeeze( - attention_h_stack_tmp, axis=0) # [N, H, T_out, T_out] - attention_h_stack.append(attention_h_stack_tmp) - - return outputs_stack, attention_x_stack, attention_h_stack, cache - - def _get_shape_invariants(self, tensor): - """Returns the shape of the tensor but sets middle dims to None.""" - if isinstance(tensor, tf.TensorArray): - shape = None - else: - shape = tensor.shape.as_list() - for i in range(1, len(shape) - 1): - shape[i] = None - return tf.TensorShape(shape) - - -class SelfAttentionDecoderOri(): - """Decoder using self-attention as described in - https://arxiv.org/abs/1706.03762. - """ - - def __init__(self, - num_layers, - num_units=512, - num_heads=8, - ffn_inner_dim=2048, - dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1, - position_encoder=SinusoidalPositionEncoder(), - self_attention_type='scaled_dot'): - """Initializes the parameters of the decoder. - - Args: - num_layers: The number of layers. - num_units: The number of hidden units. - num_heads: The number of heads in the multi-head attention. - ffn_inner_dim: The number of units of the inner linear transformation - in the feed forward layer. - dropout: The probability to drop units from the outputs. - attention_dropout: The probability to drop units from the attention. - relu_dropout: The probability to drop units from the ReLU activation in - the feed forward layer. - position_encoder: A :class:`opennmt.layers.position.PositionEncoder` to - apply on inputs or ``None``. - self_attention_type: Type of self attention, "scaled_dot" or "average" (case - insensitive). - - Raises: - ValueError: if :obj:`self_attention_type` is invalid. - """ - super(SelfAttentionDecoderOri, self).__init__() - self.num_layers = num_layers - self.num_units = num_units - self.num_heads = num_heads - self.ffn_inner_dim = ffn_inner_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.relu_dropout = relu_dropout - self.position_encoder = position_encoder - self.self_attention_type = self_attention_type.lower() - if self.self_attention_type not in ('scaled_dot', 'average'): - raise ValueError('invalid attention type %s' - % self.self_attention_type) - if self.self_attention_type == 'average': - tf.logging.warning( - 'Support for average attention network is experimental ' - 'and may change in future versions.') - - @property - def output_size(self): - """Returns the decoder output size.""" - return self.num_units - - @property - def support_alignment_history(self): - return True - - @property - def support_multi_source(self): - return True - - def _init_cache(self, batch_size, dtype=tf.float32, num_sources=1): - cache = {} - - for layer in range(self.num_layers): - proj_cache_shape = [ - batch_size, self.num_heads, 0, self.num_units // self.num_heads - ] - layer_cache = {} - layer_cache['memory'] = [{ - 'memory_keys': - tf.zeros(proj_cache_shape, dtype=dtype), - 'memory_values': - tf.zeros(proj_cache_shape, dtype=dtype) - } for _ in range(num_sources)] - if self.self_attention_type == 'scaled_dot': - layer_cache['self_keys'] = tf.zeros( - proj_cache_shape, dtype=dtype) - layer_cache['self_values'] = tf.zeros( - proj_cache_shape, dtype=dtype) - elif self.self_attention_type == 'average': - layer_cache['prev_g'] = tf.zeros( - [batch_size, 1, self.num_units], dtype=dtype) - cache['layer_{}'.format(layer)] = layer_cache - - return cache - - def _self_attention_stack(self, - inputs, - sequence_length=None, - mode=True, - cache=None, - memory=None, - memory_sequence_length=None, - step=None): - inputs *= self.num_units**0.5 - if self.position_encoder is not None: - inputs = self.position_encoder( - inputs, position=step + 1 if step is not None else None) - - inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) - - decoder_mask = None - memory_mask = None - last_attention = None - - if self.self_attention_type == 'scaled_dot': - if sequence_length is not None: - decoder_mask = transformer.build_future_mask( - sequence_length, - num_heads=self.num_heads, - maximum_length=tf.shape(inputs)[1]) - elif self.self_attention_type == 'average': - if cache is None: - if sequence_length is None: - sequence_length = tf.fill([tf.shape(inputs)[0]], - tf.shape(inputs)[1]) - decoder_mask = transformer.cumulative_average_mask( - sequence_length, - maximum_length=tf.shape(inputs)[1], - dtype=inputs.dtype) - - if memory is not None and not tf.contrib.framework.nest.is_sequence( - memory): - memory = (memory, ) - if memory_sequence_length is not None: - if not tf.contrib.framework.nest.is_sequence( - memory_sequence_length): - memory_sequence_length = (memory_sequence_length, ) - memory_mask = [ - transformer.build_sequence_mask( - length, - num_heads=self.num_heads, - maximum_length=tf.shape(m)[1]) - for m, length in zip(memory, memory_sequence_length) - ] - - for layer in range(self.num_layers): - layer_name = 'layer_{}'.format(layer) - layer_cache = cache[layer_name] if cache is not None else None - with tf.variable_scope(layer_name): - if self.self_attention_type == 'scaled_dot': - with tf.variable_scope('masked_multi_head'): - encoded = transformer.multi_head_attention( - self.num_heads, - transformer.norm(inputs), - None, - mode, - num_units=self.num_units, - mask=decoder_mask, - cache=layer_cache, - dropout=self.attention_dropout) - last_context = transformer.drop_and_add( - inputs, encoded, mode, dropout=self.dropout) - elif self.self_attention_type == 'average': - with tf.variable_scope('average_attention'): - # Cumulative average. - x = transformer.norm(inputs) - y = transformer.cumulative_average( - x, - decoder_mask if cache is None else step, - cache=layer_cache) - # FFN. - y = transformer.feed_forward( - y, - self.ffn_inner_dim, - mode, - dropout=self.relu_dropout) - # Gating layer. - z = tf.layers.dense( - tf.concat([x, y], -1), self.num_units * 2) - i, f = tf.split(z, 2, axis=-1) - y = tf.sigmoid(i) * x + tf.sigmoid(f) * y - last_context = transformer.drop_and_add( - inputs, y, mode, dropout=self.dropout) - - if memory is not None: - for i, (mem, mask) in enumerate(zip(memory, memory_mask)): - memory_cache = layer_cache['memory'][i] if layer_cache is not None else None # yapf:disable - with tf.variable_scope('multi_head' if i - == 0 else 'multi_head_%d' % i): # yapf:disable - context, last_attention = transformer.multi_head_attention( - self.num_heads, - transformer.norm(last_context), - mem, - mode, - mask=mask, - cache=memory_cache, - dropout=self.attention_dropout, - return_attention=True) - last_context = transformer.drop_and_add( - last_context, - context, - mode, - dropout=self.dropout) - if i > 0: # Do not return attention in case of multi source. - last_attention = None - - with tf.variable_scope('ffn'): - transformed = transformer.feed_forward_ori( - transformer.norm(last_context), - self.ffn_inner_dim, - mode, - dropout=self.relu_dropout) - transformed = transformer.drop_and_add( - last_context, transformed, mode, dropout=self.dropout) - - inputs = transformed - - if last_attention is not None: - # The first head of the last layer is returned. - first_head_attention = last_attention[:, 0] - else: - first_head_attention = None - - outputs = transformer.norm(inputs) - return outputs, first_head_attention - - def decode_from_inputs(self, - inputs, - sequence_length, - initial_state=None, - mode=True, - memory=None, - memory_sequence_length=None): - outputs, attention = self._self_attention_stack( - inputs, - sequence_length=sequence_length, - mode=mode, - memory=memory, - memory_sequence_length=memory_sequence_length) - return outputs, None, attention - - def step_fn(self, - mode, - batch_size, - initial_state=None, - memory=None, - memory_sequence_length=None, - dtype=tf.float32): - if memory is None: - num_sources = 0 - elif tf.contrib.framework.nest.is_sequence(memory): - num_sources = len(memory) - else: - num_sources = 1 - cache = self._init_cache( - batch_size, dtype=dtype, num_sources=num_sources) - - def _fn(step, inputs, cache, mode): - inputs = tf.expand_dims(inputs, 1) - outputs, attention = self._self_attention_stack( - inputs, - mode=mode, - cache=cache, - memory=memory, - memory_sequence_length=memory_sequence_length, - step=step) - outputs = tf.squeeze(outputs, axis=1) - if attention is not None: - attention = tf.squeeze(attention, axis=1) - return outputs, cache, attention - - return _fn, cache diff --git a/modelscope/models/audio/tts/models/self_attention_encoder.py b/modelscope/models/audio/tts/models/self_attention_encoder.py deleted file mode 100755 index ce4193dc..00000000 --- a/modelscope/models/audio/tts/models/self_attention_encoder.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Define the self-attention encoder.""" - -import tensorflow as tf - -from . import transformer -from .position import SinusoidalPositionEncoder - - -class SelfAttentionEncoder(): - """Encoder using self-attention as described in - https://arxiv.org/abs/1706.03762. - """ - - def __init__(self, - num_layers, - num_units=512, - num_heads=8, - ffn_inner_dim=2048, - dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1, - position_encoder=SinusoidalPositionEncoder()): - """Initializes the parameters of the encoder. - - Args: - num_layers: The number of layers. - num_units: The number of hidden units. - num_heads: The number of heads in the multi-head attention. - ffn_inner_dim: The number of units of the inner linear transformation - in the feed forward layer. - dropout: The probability to drop units from the outputs. - attention_dropout: The probability to drop units from the attention. - relu_dropout: The probability to drop units from the ReLU activation in - the feed forward layer. - position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to - apply on inputs or ``None``. - """ - super(SelfAttentionEncoder, self).__init__() - self.num_layers = num_layers - self.num_units = num_units - self.num_heads = num_heads - self.ffn_inner_dim = ffn_inner_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.relu_dropout = relu_dropout - self.position_encoder = position_encoder - - def encode(self, inputs, sequence_length=None, mode=True): - inputs *= self.num_units**0.5 - if self.position_encoder is not None: - inputs = self.position_encoder(inputs) - - inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) - mask = transformer.build_sequence_mask( - sequence_length, - num_heads=self.num_heads, - maximum_length=tf.shape(inputs)[1]) - - mask_FF = tf.squeeze( - transformer.build_sequence_mask( - sequence_length, maximum_length=tf.shape(inputs)[1]), - axis=1) - - state = () - - attns = [] - for layer in range(self.num_layers): - with tf.variable_scope('layer_{}'.format(layer)): - with tf.variable_scope('multi_head'): - context, attn = transformer.multi_head_attention( - self.num_heads, - transformer.norm(inputs), - None, - mode, - num_units=self.num_units, - mask=mask, - dropout=self.attention_dropout, - return_attention=True) - attns.append(attn) - context = transformer.drop_and_add( - inputs, context, mode, dropout=self.dropout) - - with tf.variable_scope('ffn'): - transformed = transformer.feed_forward( - transformer.norm(context), - self.ffn_inner_dim, - mode, - dropout=self.relu_dropout, - mask=mask_FF) - transformed = transformer.drop_and_add( - context, transformed, mode, dropout=self.dropout) - - inputs = transformed - state += (tf.reduce_mean(inputs, axis=1), ) - - outputs = transformer.norm(inputs) - return (outputs, state, sequence_length, attns) - - -class SelfAttentionEncoderOri(): - """Encoder using self-attention as described in - https://arxiv.org/abs/1706.03762. - """ - - def __init__(self, - num_layers, - num_units=512, - num_heads=8, - ffn_inner_dim=2048, - dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1, - position_encoder=SinusoidalPositionEncoder()): - """Initializes the parameters of the encoder. - - Args: - num_layers: The number of layers. - num_units: The number of hidden units. - num_heads: The number of heads in the multi-head attention. - ffn_inner_dim: The number of units of the inner linear transformation - in the feed forward layer. - dropout: The probability to drop units from the outputs. - attention_dropout: The probability to drop units from the attention. - relu_dropout: The probability to drop units from the ReLU activation in - the feed forward layer. - position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to - apply on inputs or ``None``. - """ - super(SelfAttentionEncoderOri, self).__init__() - self.num_layers = num_layers - self.num_units = num_units - self.num_heads = num_heads - self.ffn_inner_dim = ffn_inner_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.relu_dropout = relu_dropout - self.position_encoder = position_encoder - - def encode(self, inputs, sequence_length=None, mode=True): - inputs *= self.num_units**0.5 - if self.position_encoder is not None: - inputs = self.position_encoder(inputs) - - inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) - mask = transformer.build_sequence_mask( - sequence_length, - num_heads=self.num_heads, - maximum_length=tf.shape(inputs)[1]) # [N, 1, 1, T_out] - - state = () - - attns = [] - for layer in range(self.num_layers): - with tf.variable_scope('layer_{}'.format(layer)): - with tf.variable_scope('multi_head'): - context, attn = transformer.multi_head_attention( - self.num_heads, - transformer.norm(inputs), - None, - mode, - num_units=self.num_units, - mask=mask, - dropout=self.attention_dropout, - return_attention=True) - attns.append(attn) - context = transformer.drop_and_add( - inputs, context, mode, dropout=self.dropout) - - with tf.variable_scope('ffn'): - transformed = transformer.feed_forward_ori( - transformer.norm(context), - self.ffn_inner_dim, - mode, - dropout=self.relu_dropout) - transformed = transformer.drop_and_add( - context, transformed, mode, dropout=self.dropout) - - inputs = transformed - state += (tf.reduce_mean(inputs, axis=1), ) - - outputs = transformer.norm(inputs) - return (outputs, state, sequence_length, attns) diff --git a/modelscope/models/audio/tts/models/transformer.py b/modelscope/models/audio/tts/models/transformer.py deleted file mode 100755 index a9f0bedc..00000000 --- a/modelscope/models/audio/tts/models/transformer.py +++ /dev/null @@ -1,1157 +0,0 @@ -"""Define layers related to the Google's Transformer model.""" - -import tensorflow as tf - -from . import compat, fsmn - - -def tile_sequence_length(sequence_length, num_heads): - """Tiles lengths :obj:`num_heads` times. - - Args: - sequence_length: The sequence length. - num_heads: The number of heads. - - Returns: - A ``tf.Tensor`` where each length is replicated :obj:`num_heads` times. - """ - sequence_length = tf.tile(sequence_length, [num_heads]) - sequence_length = tf.reshape(sequence_length, [num_heads, -1]) - sequence_length = tf.transpose(sequence_length, perm=[1, 0]) - sequence_length = tf.reshape(sequence_length, [-1]) - return sequence_length - - -def build_sequence_mask(sequence_length, - num_heads=None, - maximum_length=None, - dtype=tf.float32): - """Builds the dot product mask. - - Args: - sequence_length: The sequence length. - num_heads: The number of heads. - maximum_length: Optional size of the returned time dimension. Otherwise - it is the maximum of :obj:`sequence_length`. - dtype: The type of the mask tensor. - - Returns: - A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape - ``[batch_size, 1, 1, max_length]``. - """ - mask = tf.sequence_mask( - sequence_length, maxlen=maximum_length, dtype=dtype) - mask = tf.expand_dims(mask, axis=1) - if num_heads is not None: - mask = tf.expand_dims(mask, axis=1) - return mask - - -def build_sequence_mask_window(sequence_length, - left_window_size=-1, - right_window_size=-1, - num_heads=None, - maximum_length=None, - dtype=tf.float32): - """Builds the dot product mask. - - Args: - sequence_length: The sequence length. - num_heads: The number of heads. - maximum_length: Optional size of the returned time dimension. Otherwise - it is the maximum of :obj:`sequence_length`. - dtype: The type of the mask tensor. - - Returns: - A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape - ``[batch_size, 1, 1, max_length]``. - """ - sequence_mask = tf.sequence_mask( - sequence_length, maxlen=maximum_length, dtype=dtype) - mask = _window_mask( - sequence_length, - left_window_size=left_window_size, - right_window_size=right_window_size, - maximum_length=maximum_length, - dtype=dtype) - mask *= tf.expand_dims(sequence_mask, axis=1) - if num_heads is not None: - mask = tf.expand_dims(mask, axis=1) - return mask - - -def _lower_triangle_mask(sequence_length, - maximum_length=None, - dtype=tf.float32, - band=-1): - batch_size = tf.shape(sequence_length)[0] - if maximum_length is None: - maximum_length = tf.reduce_max(sequence_length) - mask = tf.ones([batch_size, maximum_length, maximum_length], dtype=dtype) - mask = compat.tf_compat( - v2='linalg.band_part', v1='matrix_band_part')(mask, band, 0) - return mask - - -def _higher_triangle_mask(sequence_length, - maximum_length=None, - dtype=tf.float32, - band=-1): - batch_size = tf.shape(sequence_length)[0] - if maximum_length is None: - maximum_length = tf.reduce_max(sequence_length) - mask = tf.ones([batch_size, maximum_length, maximum_length], dtype=dtype) - mask = compat.tf_compat( - v2='linalg.band_part', v1='matrix_band_part')(mask, 0, band) - return mask - - -def _window_mask(sequence_length, - left_window_size=-1, - right_window_size=-1, - maximum_length=None, - dtype=tf.float32): - batch_size = tf.shape(sequence_length)[0] - if maximum_length is None: - maximum_length = tf.reduce_max(sequence_length) - mask = tf.ones([batch_size, maximum_length, maximum_length], dtype=dtype) - left_window_size = tf.minimum( - tf.cast(left_window_size, tf.int64), - tf.cast(maximum_length - 1, tf.int64)) - right_window_size = tf.minimum( - tf.cast(right_window_size, tf.int64), - tf.cast(maximum_length - 1, tf.int64)) - mask = tf.matrix_band_part(mask, left_window_size, right_window_size) - return mask - - -def build_future_mask(sequence_length, - num_heads=None, - maximum_length=None, - dtype=tf.float32, - band=-1): - """Builds the dot product mask for future positions. - - Args: - sequence_length: The sequence length. - num_heads: The number of heads. - maximum_length: Optional size of the returned time dimension. Otherwise - it is the maximum of :obj:`sequence_length`. - dtype: The type of the mask tensor. - - Returns: - A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape - ``[batch_size, 1, max_length, max_length]``. - """ - sequence_mask = tf.sequence_mask( - sequence_length, maxlen=maximum_length, dtype=dtype) - mask = _lower_triangle_mask( - sequence_length, maximum_length=maximum_length, dtype=dtype, band=band) - mask *= tf.expand_dims(sequence_mask, axis=1) - if num_heads is not None: - mask = tf.expand_dims(mask, axis=1) - return mask - - -def build_history_mask(sequence_length, - num_heads=None, - maximum_length=None, - dtype=tf.float32, - band=-1): - """Builds the dot product mask for future positions. - - Args: - sequence_length: The sequence length. - num_heads: The number of heads. - maximum_length: Optional size of the returned time dimension. Otherwise - it is the maximum of :obj:`sequence_length`. - dtype: The type of the mask tensor. - - Returns: - A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape - ``[batch_size, 1, max_length, max_length]``. - """ - sequence_mask = tf.sequence_mask( - sequence_length, maxlen=maximum_length, dtype=dtype) - mask = _higher_triangle_mask( - sequence_length, maximum_length=maximum_length, dtype=dtype, band=band) - mask *= tf.expand_dims(sequence_mask, axis=1) - if num_heads is not None: - mask = tf.expand_dims(mask, axis=1) - return mask - - -def cumulative_average_mask(sequence_length, - maximum_length=None, - dtype=tf.float32): - """Builds the mask to compute the cumulative average as described in - https://arxiv.org/abs/1805.00631. - - Args: - sequence_length: The sequence length. - maximum_length: Optional size of the returned time dimension. Otherwise - it is the maximum of :obj:`sequence_length`. - dtype: The type of the mask tensor. - - Returns: - A ``tf.Tensor`` of type :obj:`dtype` and shape - ``[batch_size, max_length, max_length]``. - """ - sequence_mask = tf.sequence_mask( - sequence_length, maxlen=maximum_length, dtype=dtype) - mask = _lower_triangle_mask( - sequence_length, maximum_length=maximum_length, dtype=dtype) - mask *= tf.expand_dims(sequence_mask, axis=2) - weight = tf.range(1, tf.cast(tf.shape(mask)[1] + 1, dtype), dtype=dtype) - mask /= tf.expand_dims(weight, 1) - return mask - - -def cumulative_average(inputs, mask_or_step, cache=None): - """Computes the cumulative average as described in - https://arxiv.org/abs/1805.00631. - - Args: - inputs: The sequence to average. A tensor of shape :math:`[B, T, D]`. - mask_or_step: If :obj:`cache` is set, this is assumed to be the current step - of the dynamic decoding. Otherwise, it is the mask matrix used to compute - the cumulative average. - cache: A dictionnary containing the cumulative average of the previous step. - - Returns: - The cumulative average, a tensor of the same shape and type as :obj:`inputs`. - """ - if cache is not None: - step = tf.cast(mask_or_step, inputs.dtype) - aa = (inputs + step * cache['prev_g']) / (step + 1.0) - cache['prev_g'] = aa - return aa - else: - mask = mask_or_step - return tf.matmul(mask, inputs) - - -def fused_projection(inputs, num_units, num_outputs=1): - """Projects the same input into multiple output spaces. - - Args: - inputs: The inputs to project. - num_units: The number of output units of each space. - num_outputs: The number of output spaces. - - Returns: - :obj:`num_outputs` ``tf.Tensor`` of depth :obj:`num_units`. - """ - return tf.split( - tf.layers.conv1d(inputs, num_units * num_outputs, 1), - num_outputs, - axis=2) - - -def split_heads(inputs, num_heads): - """Splits a tensor in depth. - - Args: - inputs: A ``tf.Tensor`` of shape :math:`[B, T, D]`. - num_heads: The number of heads :math:`H`. - - Returns: - A ``tf.Tensor`` of shape :math:`[B, H, T, D / H]`. - """ - static_shape = inputs.get_shape().as_list() - depth = static_shape[-1] - outputs = tf.reshape(inputs, [ - tf.shape(inputs)[0], - tf.shape(inputs)[1], num_heads, depth // num_heads - ]) - outputs = tf.transpose(outputs, perm=[0, 2, 1, 3]) - return outputs - - -def combine_heads(inputs): - """Concatenates heads. - - Args: - inputs: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`. - - Returns: - A ``tf.Tensor`` of shape :math:`[B, T, D * H]`. - """ - static_shape = inputs.get_shape().as_list() - depth = static_shape[-1] - num_heads = static_shape[1] - outputs = tf.transpose(inputs, perm=[0, 2, 1, 3]) - outputs = tf.reshape( - outputs, - [tf.shape(outputs)[0], - tf.shape(outputs)[1], depth * num_heads]) - return outputs - - -def dot_product_attention(queries, keys, values, mode, mask=None, dropout=0.0): - """Computes the dot product attention. - - Args: - queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - keys: The sequence use to calculate attention scores. A tensor of shape - :math:`[B, T_2, ...]`. - values: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - mode: A ``tf.estimator.ModeKeys`` mode. - mask: A ``tf.Tensor`` applied to the dot product. - dropout: The probability to drop units from the inputs. - - Returns: - A tuple ``(context vector, attention vector)``. - """ - dot = tf.matmul(queries, keys, transpose_b=True) - - if mask is not None: - dot = tf.cast( - tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), - dot.dtype) - - softmax = tf.nn.softmax(tf.cast(dot, tf.float32)) - attn = tf.cast(softmax, dot.dtype) - drop_attn = tf.layers.dropout(attn, rate=dropout, training=mode) - - context = tf.matmul(drop_attn, values) - - return context, attn - - -def dot_product_attention_wpa(num_heads, - queries, - keys, - values, - mode, - attention_left_window=-1, - attention_right_window=0, - mask=None, - max_id_cache=None, - mono=False, - peak_delay=-1, - dropout=0.0): - """ - Computes the dot product attention. - Args: - queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - keys: The sequence use to calculate attention scores. A tensor of shape - :math:`[B, T_2, ...]`. - values: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - mode: A ``tf.estimator.ModeKeys`` mode. - mask: A ``tf.Tensor`` applied to the dot product. - dropout: The probability to drop units from the inputs. - - Returns: - A tuple ``(context vector, attention vector)``. - """ - # Dot product between queries and keys. - dot = tf.matmul(queries, keys, transpose_b=True) - depth = tf.shape(dot)[-1] - if mask is not None: - dot = tf.cast( - tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), - dot.dtype) - # wpa - max_id = tf.math.argmax(input=dot, axis=-1) - # peak delay - if peak_delay > 0: - if max_id_cache is not None: - M = tf.cast(max_id_cache['pre_max_id'], dtype=max_id.dtype) - inputs_len = tf.math.minimum( - M + peak_delay, tf.cast(depth - 1, dtype=max_id.dtype)) - delay_mask = tf.sequence_mask( - inputs_len, maxlen=depth, dtype=tf.float32) - dot = tf.cast( - tf.cast(dot, tf.float32) * delay_mask - + ((1.0 - delay_mask) * tf.float32.min), dot.dtype) # yapf:disable - max_id = tf.math.argmax(input=dot, axis=-1) - # mono - if mono: - if max_id_cache is None: - d = tf.shape(max_id)[-1] - tmp_max_id = tf.reshape(max_id, [-1, num_heads, d]) - tmp_max_id = tf.slice( - tmp_max_id, [0, 0, 0], - [tf.shape(tmp_max_id)[0], - tf.shape(tmp_max_id)[1], d - 1]) - zeros = tf.zeros( - shape=(tf.shape(tmp_max_id)[0], tf.shape(tmp_max_id)[1], 1), - dtype=max_id.dtype) - tmp_max_id = tf.concat([zeros, tmp_max_id], axis=-1) - mask1 = tf.sequence_mask( - tmp_max_id, maxlen=depth, dtype=tf.float32) - dot = tf.cast( - tf.cast(dot, tf.float32) - * (1.0 - mask1) + mask1 * tf.float32.min, dot.dtype) # yapf:disable - max_id = tf.math.argmax(input=dot, axis=-1) - else: - # eval - tmp_max_id = tf.reshape(max_id, [-1, num_heads, 1]) - max_id_cache['pre_max_id'] = tmp_max_id - # right_mask - right_offset = tf.constant(attention_right_window, dtype=max_id.dtype) - right_len = tf.math.minimum(max_id + right_offset, - tf.cast(depth - 1, dtype=max_id.dtype)) - right_mask = tf.sequence_mask(right_len, maxlen=depth, dtype=tf.float32) - dot = tf.cast( - tf.cast(dot, tf.float32) * right_mask - + ((1.0 - right_mask) * tf.float32.min), dot.dtype) # yapf:disable - # left_mask - if attention_left_window > 0: - left_offset = tf.constant(attention_left_window, dtype=max_id.dtype) - left_len = tf.math.maximum(max_id - left_offset, - tf.cast(0, dtype=max_id.dtype)) - left_mask = tf.sequence_mask(left_len, maxlen=depth, dtype=tf.float32) - dot = tf.cast( - tf.cast(dot, tf.float32) * (1.0 - left_mask) - + (left_mask * tf.float32.min), dot.dtype) # yapf:disable - # Compute attention weights. - attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) - drop_attn = tf.layers.dropout(attn, rate=dropout, training=mode) - - # Compute attention context. - context = tf.matmul(drop_attn, values) - - return context, attn - - -def multi_head_attention(num_heads, - queries, - memory, - mode, - num_units=None, - mask=None, - cache=None, - dropout=0.0, - return_attention=False): - """Computes the multi-head attention as described in - https://arxiv.org/abs/1706.03762. - - Args: - num_heads: The number of attention heads. - queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - If ``None``, computes self-attention. - mode: A ``tf.estimator.ModeKeys`` mode. - num_units: The number of hidden units. If not set, it is set to the input - dimension. - mask: A ``tf.Tensor`` applied to the dot product. - cache: A dictionary containing pre-projected keys and values. - dropout: The probability to drop units from the inputs. - return_attention: Return the attention head probabilities in addition to the - context. - - Returns: - The concatenated attention context of each head and the attention - probabilities (if :obj:`return_attention` is set). - """ - num_units = num_units or queries.get_shape().as_list()[-1] - - if num_units % num_heads != 0: - raise ValueError('Multi head attention requires that num_units is a' - ' multiple of {}'.format(num_heads)) - - if memory is None: - queries, keys, values = fused_projection( - queries, num_units, num_outputs=3) - - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - if cache is not None: - keys = tf.concat([cache['self_keys'], keys], axis=2) - values = tf.concat([cache['self_values'], values], axis=2) - cache['self_keys'] = keys - cache['self_values'] = values - else: - queries = tf.layers.conv1d(queries, num_units, 1) - - if cache is not None: - - def _project_and_split(): - k, v = fused_projection(memory, num_units, num_outputs=2) - return split_heads(k, num_heads), split_heads(v, num_heads) - - keys, values = tf.cond( - tf.equal(tf.shape(cache['memory_keys'])[2], 0), - true_fn=_project_and_split, - false_fn=lambda: - (cache['memory_keys'], cache['memory_values'])) - cache['memory_keys'] = keys - cache['memory_values'] = values - else: - keys, values = fused_projection(memory, num_units, num_outputs=2) - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - queries = split_heads(queries, num_heads) - queries *= (num_units // num_heads)**-0.5 - - heads, attn = dot_product_attention( - queries, keys, values, mode, mask=mask, dropout=dropout) - - # Concatenate all heads output. - combined = combine_heads(heads) - outputs = tf.layers.conv1d(combined, num_units, 1) - - if not return_attention: - return outputs - return outputs, attn - - -def multi_head_attention_PNCA(num_heads, - queries, - memory, - mode, - num_units=None, - mask=None, - mask_h=None, - cache=None, - cache_h=None, - dropout=0.0, - return_attention=False, - X_band_width=None, - layer_name='multi_head'): - """Computes the multi-head attention as described in - https://arxiv.org/abs/1706.03762. - - Args: - num_heads: The number of attention heads. - queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - If ``None``, computes self-attention. - mode: A ``tf.estimator.ModeKeys`` mode. - num_units: The number of hidden units. If not set, it is set to the input - dimension. - mask: A ``tf.Tensor`` applied to the dot product. - cache: A dictionary containing pre-projected keys and values. - dropout: The probability to drop units from the inputs. - return_attention: Return the attention head probabilities in addition to the - context. - - Returns: - The concatenated attention context of each head and the attention - probabilities (if :obj:`return_attention` is set). - """ - num_units = num_units or queries.get_shape().as_list()[-1] - - if num_units % num_heads != 0: - raise ValueError('Multi head attention requires that num_units is a' - ' multiple of {}'.format(num_heads)) - - # X - queries, keys, values = fused_projection(queries, num_units, num_outputs=3) - - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - if cache is not None: - keys = tf.concat([cache['self_keys'], keys], axis=2) - values = tf.concat([cache['self_values'], values], axis=2) - if X_band_width is not None: - keys_band = tf.cond( - tf.less(X_band_width, 0), lambda: keys, lambda: tf.cond( - tf.less(tf.shape(keys)[2], X_band_width), lambda: keys, - lambda: keys[:, :, -X_band_width:, :]) - ) # not support X_band_width == 0 - values_band = tf.cond( - tf.less(X_band_width, 0), lambda: values, lambda: tf.cond( - tf.less(tf.shape(values)[2], X_band_width), lambda: values, - lambda: values[:, :, -X_band_width:, :])) - cache['self_keys'] = keys_band - cache['self_values'] = values_band - else: - cache['self_keys'] = keys - cache['self_values'] = values - - queries = split_heads(queries, num_heads) - queries *= (num_units // num_heads)**-0.5 - - heads, attn = dot_product_attention( - queries, keys, values, mode, mask=mask, dropout=dropout) - - # Concatenate all heads output. - combined = combine_heads(heads) - outputs = tf.layers.conv1d(combined, num_units, 1) - - # H - if cache_h is not None: - - def _project_and_split(): - k, v = fused_projection(memory, num_units, num_outputs=2) - return split_heads(k, num_heads), split_heads(v, num_heads) - - keys_h, values_h = tf.cond( - tf.equal(tf.shape(cache_h['memory_keys'])[2], 0), - true_fn=_project_and_split, - false_fn=lambda: - (cache_h['memory_keys'], cache_h['memory_values'])) - cache_h['memory_keys'] = keys_h - cache_h['memory_values'] = values_h - else: - keys_h, values_h = fused_projection(memory, num_units, num_outputs=2) - keys_h = split_heads(keys_h, num_heads) - values_h = split_heads(values_h, num_heads) - - heads_h, attn_h = dot_product_attention( - queries, keys_h, values_h, mode, mask=mask_h, dropout=dropout) - - # Concatenate all heads output. - combined_h = combine_heads(heads_h) - outputs_h = tf.layers.conv1d(combined_h, num_units, 1) - - # ADD - outputs = outputs + outputs_h - - # RETURN - return outputs, attn, attn_h - - -def multi_head_attention_memory(num_heads, - queries, - memory, - mode, - num_memory=None, - num_units=None, - mask=None, - cache=None, - dropout=0.0, - return_attention=False): - """Computes the multi-head attention as described in - https://arxiv.org/abs/1706.03762. - - Args: - num_heads: The number of attention heads. - queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - If ``None``, computes self-attention. - mode: A ``tf.estimator.ModeKeys`` mode. - num_units: The number of hidden units. If not set, it is set to the input - dimension. - mask: A ``tf.Tensor`` applied to the dot product. - cache: A dictionary containing pre-projected keys and values. - dropout: The probability to drop units from the inputs. - return_attention: Return the attention head probabilities in addition to the - context. - - Returns: - The concatenated attention context of each head and the attention - probabilities (if :obj:`return_attention` is set). - """ - num_units = num_units or queries.get_shape().as_list()[-1] - - if num_units % num_heads != 0: - raise ValueError('Multi head attention requires that num_units is a' - ' multiple of {}'.format(num_heads)) - - # PERSISTENT MEMORY - # key memory - if num_memory is not None: - key_m = tf.get_variable( - 'key_m', - shape=[num_memory, num_units], - initializer=tf.glorot_uniform_initializer(), - dtype=tf.float32) - # value memory - value_m = tf.get_variable( - 'value_m', - shape=[num_memory, num_units], - initializer=tf.glorot_uniform_initializer(), - dtype=tf.float32) - if memory is None: - queries, keys, values = fused_projection( - queries, num_units, num_outputs=3) - - # concat memory - if num_memory is not None: - key_m_expand = tf.tile( - tf.expand_dims(key_m, 0), [tf.shape(keys)[0], 1, 1]) - value_m_expand = tf.tile( - tf.expand_dims(value_m, 0), [tf.shape(values)[0], 1, 1]) - keys = tf.concat([key_m_expand, keys], axis=1) - values = tf.concat([value_m_expand, values], axis=1) - - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - if cache is not None: - keys = tf.concat([cache['self_keys'], keys], axis=2) - values = tf.concat([cache['self_values'], values], axis=2) - cache['self_keys'] = keys - cache['self_values'] = values - else: - queries = tf.layers.conv1d(queries, num_units, 1) - - if cache is not None: - - def _project_and_split(): - k, v = fused_projection(memory, num_units, num_outputs=2) - return split_heads(k, num_heads), split_heads(v, num_heads) - - keys, values = tf.cond( - tf.equal(tf.shape(cache['memory_keys'])[2], 0), - true_fn=_project_and_split, - false_fn=lambda: - (cache['memory_keys'], cache['memory_values'])) - cache['memory_keys'] = keys - cache['memory_values'] = values - else: - keys, values = fused_projection(memory, num_units, num_outputs=2) - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - queries = split_heads(queries, num_heads) - queries *= (num_units // num_heads)**-0.5 - - heads, attn = dot_product_attention( - queries, keys, values, mode, mask=mask, dropout=dropout) - - # Concatenate all heads output. - combined = combine_heads(heads) - outputs = tf.layers.conv1d(combined, num_units, 1) - - if not return_attention: - return outputs - return outputs, attn - - -def Ci_Cd_Memory(num_heads, - queries, - mode, - filter_size=None, - num_memory=None, - num_units=None, - fsmn_mask=None, - san_mask=None, - cache=None, - shift=None, - dropout=0.0, - return_attention=False): - """ - Args: - num_heads: The number of attention heads. - queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - If ``None``, computes self-attention. - mode: A ``tf.estimator.ModeKeys`` mode. - num_units: The number of hidden units. If not set, it is set to the input - dimension. - mask: A ``tf.Tensor`` applied to the dot product. - cache: A dictionary containing pre-projected keys and values. - dropout: The probability to drop units from the inputs. - return_attention: Return the attention head probabilities in addition to the - context. - - Returns: - The concatenated attention context of each head and the attention - probabilities (if :obj:`return_attention` is set). - """ - num_units = num_units or queries.get_shape().as_list()[-1] - - if num_units % num_heads != 0: - raise ValueError('Multi head attention requires that num_units is a' - ' multiple of {}'.format(num_heads)) - # PERSISTENT MEMORY - if num_memory is not None: - key_m = tf.get_variable( - 'key_m', - shape=[num_memory, num_units], - initializer=tf.glorot_uniform_initializer(), - dtype=tf.float32) - value_m = tf.get_variable( - 'value_m', - shape=[num_memory, num_units], - initializer=tf.glorot_uniform_initializer(), - dtype=tf.float32) - - queries, keys, values = fused_projection(queries, num_units, num_outputs=3) - # fsmn memory block - if shift is not None: - # encoder - fsmn_memory = fsmn.MemoryBlockV2( - values, - filter_size, - mode, - shift=shift, - mask=fsmn_mask, - dropout=dropout) - else: - # decoder - fsmn_memory = fsmn.UniMemoryBlock( - values, - filter_size, - mode, - cache=cache, - mask=fsmn_mask, - dropout=dropout) - - # concat persistent memory - if num_memory is not None: - key_m_expand = tf.tile( - tf.expand_dims(key_m, 0), [tf.shape(keys)[0], 1, 1]) - value_m_expand = tf.tile( - tf.expand_dims(value_m, 0), [tf.shape(values)[0], 1, 1]) - keys = tf.concat([key_m_expand, keys], axis=1) - values = tf.concat([value_m_expand, values], axis=1) - - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - if cache is not None: - keys = tf.concat([cache['self_keys'], keys], axis=2) - values = tf.concat([cache['self_values'], values], axis=2) - cache['self_keys'] = keys - cache['self_values'] = values - - queries = split_heads(queries, num_heads) - queries *= (num_units // num_heads)**-0.5 - - heads, attn = dot_product_attention( - queries, keys, values, mode, mask=san_mask, dropout=dropout) - - # Concatenate all heads output. - combined = combine_heads(heads) - outputs = tf.layers.conv1d(combined, num_units, 1) - outputs = outputs + fsmn_memory - - if not return_attention: - return outputs - return outputs, attn - - -def multi_head_attention_wpa(num_heads, - queries, - memory, - mode, - attention_left_window=-1, - attention_right_window=0, - num_units=None, - mask=None, - cache=None, - max_id_cache=None, - dropout=0.0, - mono=False, - peak_delay=-1, - return_attention=False): - """Computes the multi-head attention as described in - https://arxiv.org/abs/1706.03762. - - Args: - num_heads: The number of attention heads. - queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - If ``None``, computes self-attention. - mode: A ``tf.estimator.ModeKeys`` mode. - num_units: The number of hidden units. If not set, it is set to the input - dimension. - mask: A ``tf.Tensor`` applied to the dot product. - cache: A dictionary containing pre-projected keys and values. - dropout: The probability to drop units from the inputs. - return_attention: Return the attention head probabilities in addition to the - context. - - Returns: - The concatenated attention context of each head and the attention - probabilities (if :obj:`return_attention` is set). - """ - num_units = num_units or queries.get_shape().as_list()[-1] - - if num_units % num_heads != 0: - raise ValueError('Multi head attention requires that num_units is a' - ' multiple of {}'.format(num_heads)) - - if memory is None: - queries, keys, values = fused_projection( - queries, num_units, num_outputs=3) - - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - if cache is not None: - keys = tf.concat([cache['self_keys'], keys], axis=2) - values = tf.concat([cache['self_values'], values], axis=2) - cache['self_keys'] = keys - cache['self_values'] = values - else: - queries = tf.layers.conv1d(queries, num_units, 1) - - if cache is not None: - - def _project_and_split(): - k, v = fused_projection(memory, num_units, num_outputs=2) - return split_heads(k, num_heads), split_heads(v, num_heads) - - keys, values = tf.cond( - tf.equal(tf.shape(cache['memory_keys'])[2], 0), - true_fn=_project_and_split, - false_fn=lambda: - (cache['memory_keys'], cache['memory_values'])) - cache['memory_keys'] = keys - cache['memory_values'] = values - else: - keys, values = fused_projection(memory, num_units, num_outputs=2) - keys = split_heads(keys, num_heads) - values = split_heads(values, num_heads) - - queries = split_heads(queries, num_heads) - queries *= (num_units // num_heads)**-0.5 - - heads, attn = dot_product_attention_wpa( - num_heads, - queries, - keys, - values, - mode, - attention_left_window=attention_left_window, - attention_right_window=attention_right_window, - mask=mask, - max_id_cache=max_id_cache, - mono=mono, - peak_delay=peak_delay, - dropout=dropout) - - # Concatenate all heads output. - combined = combine_heads(heads) - outputs = tf.layers.conv1d(combined, num_units, 1) - - if not return_attention: - return outputs - return outputs, attn - - -def feed_forward(x, inner_dim, mode, dropout=0.0, mask=None): - """Implements the Transformer's "Feed Forward" layer. - - .. math:: - - ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 - - Args: - x: The input. - inner_dim: The number of units of the inner linear transformation. - mode: A ``tf.estimator.ModeKeys`` mode. - dropout: The probability to drop units from the inner transformation. - - Returns: - The transformed input. - """ - input_dim = x.get_shape().as_list()[-1] - - if mask is not None: - x = x * tf.expand_dims(mask, -1) - - inner = tf.layers.conv1d( - x, inner_dim, 3, padding='same', activation=tf.nn.relu) - - if mask is not None: - inner = inner * tf.expand_dims(mask, -1) - inner = tf.layers.dropout(inner, rate=dropout, training=mode) - outer = tf.layers.conv1d(inner, input_dim, 1) - - return outer - - -def feed_forward_ori(x, inner_dim, mode, dropout=0.0): - """Implements the Transformer's "Feed Forward" layer. - - .. math:: - - ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 - - Args: - x: The input. - inner_dim: The number of units of the inner linear transformation. - mode: A ``tf.estimator.ModeKeys`` mode. - dropout: The probability to drop units from the inner transformation. - - Returns: - The transformed input. - """ - input_dim = x.get_shape().as_list()[-1] - - inner = tf.layers.conv1d(x, inner_dim, 1, activation=tf.nn.relu) - inner = tf.layers.dropout(inner, rate=dropout, training=mode) - outer = tf.layers.conv1d(inner, input_dim, 1) - - return outer - - -def norm(inputs): - """Layer normalizes :obj:`inputs`.""" - return tf.contrib.layers.layer_norm(inputs, begin_norm_axis=-1) - - -def drop_and_add(inputs, outputs, mode, dropout=0.1): - """Drops units in the outputs and adds the previous values. - - Args: - inputs: The input of the previous layer. - outputs: The output of the previous layer. - mode: A ``tf.estimator.ModeKeys`` mode. - dropout: The probability to drop units in :obj:`outputs`. - - Returns: - The residual and normalized output. - """ - outputs = tf.layers.dropout(outputs, rate=dropout, training=mode) - - input_dim = inputs.get_shape().as_list()[-1] - output_dim = outputs.get_shape().as_list()[-1] - - if input_dim == output_dim: - outputs += inputs - return outputs - - -class FeedForwardNetwork(tf.keras.layers.Layer): - """Implements the Transformer's "Feed Forward" layer. - - .. math:: - - ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 - - Note: - Object-oriented implementation for TensorFlow 2.0. - """ - - def __init__(self, - inner_dim, - output_dim, - dropout=0.1, - activation=tf.nn.relu, - **kwargs): - """Initializes this layer. - - Args: - inner_dim: The number of units of the inner linear transformation. - output_dim: The number of units of the ouput linear transformation. - dropout: The probability to drop units from the activation output. - activation: The activation function to apply between the two linear - transformations. - kwargs: Additional layer arguments. - """ - super(FeedForwardNetwork, self).__init__(**kwargs) - self.inner = tf.keras.layers.Dense( - inner_dim, activation=activation, name='inner') - self.outer = tf.keras.layers.Dense(output_dim, name='outer') - self.dropout = dropout - - def call(self, inputs, training=None): # pylint: disable=arguments-differ - """Runs the layer.""" - inner = self.inner(inputs) - inner = tf.layers.dropout(inner, self.dropout, training=training) - return self.outer(inner) - - -class MultiHeadAttention(tf.keras.layers.Layer): - """Computes the multi-head attention as described in - https://arxiv.org/abs/1706.03762. - - Note: - Object-oriented implementation for TensorFlow 2.0. - """ - - def __init__(self, - num_heads, - num_units, - dropout=0.1, - return_attention=False, - **kwargs): - """Initializes this layers. - - Args: - num_heads: The number of attention heads. - num_units: The number of hidden units. - dropout: The probability to drop units from the inputs. - return_attention: If ``True``, also return the attention weights of the - first head. - kwargs: Additional layer arguments. - """ - super(MultiHeadAttention, self).__init__(**kwargs) - if num_units % num_heads != 0: - raise ValueError( - 'Multi head attention requires that num_units is a' - ' multiple of %s' % num_heads) - self.num_heads = num_heads - self.num_units = num_units - self.linear_queries = tf.keras.layers.Dense( - num_units, name='linear_queries') - self.linear_keys = tf.keras.layers.Dense(num_units, name='linear_keys') - self.linear_values = tf.keras.layers.Dense( - num_units, name='linear_values') - self.linear_output = tf.keras.layers.Dense( - num_units, name='linear_output') - self.dropout = dropout - self.return_attention = return_attention - - def call(self, inputs, memory=None, mask=None, cache=None, training=None): # pylint: disable=arguments-differ - """Runs the layer. - - Args: - inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. - memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. - If ``None``, computes self-attention. - mask: A ``tf.Tensor`` applied to the dot product. - cache: A dictionary containing pre-projected keys and values. - training: Run in training mode. - - Returns: - A tuple with the attention context, the updated cache and the attention - probabilities of the first head (if :obj:`return_attention` is ``True``). - """ - - def _compute_kv(x): - keys = self.linear_keys(x) - keys = split_heads(keys, self.num_heads) - values = self.linear_values(x) - values = split_heads(values, self.num_heads) - return keys, values - - # Compute queries. - queries = self.linear_queries(inputs) - queries = split_heads(queries, self.num_heads) - queries *= (self.num_units // self.num_heads)**-0.5 - - # Compute keys and values. - if memory is None: - keys, values = _compute_kv(inputs) - if cache: - keys = tf.concat([cache[0], keys], axis=2) - values = tf.concat([cache[1], values], axis=2) - else: - if cache: - if not self.linear_keys.built: - # Ensure that the variable names are not impacted by the tf.cond name - # scope if the layers have not already been built. - with tf.name_scope(self.linear_keys.name): - self.linear_keys.build(memory.shape) - with tf.name_scope(self.linear_values.name): - self.linear_values.build(memory.shape) - keys, values = tf.cond( - tf.equal(tf.shape(cache[0])[2], 0), - true_fn=lambda: _compute_kv(memory), - false_fn=lambda: cache) - else: - keys, values = _compute_kv(memory) - - cache = (keys, values) - - # Dot product attention. - dot = tf.matmul(queries, keys, transpose_b=True) - if mask is not None: - mask = tf.expand_dims(tf.cast(mask, tf.float32), - 1) # Broadcast on heads dimension. - dot = tf.cast( - tf.cast(dot, tf.float32) * mask - + ((1.0 - mask) * tf.float32.min), dot.dtype) # yapf:disable - attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) - drop_attn = tf.layers.dropout(attn, self.dropout, training=training) - heads = tf.matmul(drop_attn, values) - - # Concatenate all heads output. - combined = combine_heads(heads) - outputs = self.linear_output(combined) - if self.return_attention: - return outputs, cache, attn - return outputs, cache diff --git a/modelscope/models/audio/tts/models/utils.py b/modelscope/models/audio/tts/models/utils.py deleted file mode 100755 index 03e1ef8c..00000000 --- a/modelscope/models/audio/tts/models/utils.py +++ /dev/null @@ -1,59 +0,0 @@ -import glob -import os - -import matplotlib -import matplotlib.pylab as plt -import torch -from torch.nn.utils import weight_norm - -matplotlib.use('Agg') - - -def plot_spectrogram(spectrogram): - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow( - spectrogram, aspect='auto', origin='lower', interpolation='none') - plt.colorbar(im, ax=ax) - - fig.canvas.draw() - plt.close() - - return fig - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - m.weight.data.normal_(mean, std) - - -def apply_weight_norm(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - weight_norm(m) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - print("Loading '{}'".format(filepath)) - checkpoint_dict = torch.load(filepath, map_location=device) - print('Complete.') - return checkpoint_dict - - -def save_checkpoint(filepath, obj): - print('Saving checkpoint to {}'.format(filepath)) - torch.save(obj, filepath) - print('Complete.') - - -def scan_checkpoint(cp_dir, prefix): - pattern = os.path.join(cp_dir, prefix + '????????') - cp_list = glob.glob(pattern) - if len(cp_list) == 0: - return None - return sorted(cp_list)[-1] diff --git a/modelscope/models/audio/tts/models/utils/__init__.py b/modelscope/models/audio/tts/models/utils/__init__.py new file mode 100644 index 00000000..e07f08ea --- /dev/null +++ b/modelscope/models/audio/tts/models/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .utils import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/utils/utils.py b/modelscope/models/audio/tts/models/utils/utils.py new file mode 100755 index 00000000..17ac8aee --- /dev/null +++ b/modelscope/models/audio/tts/models/utils/utils.py @@ -0,0 +1,136 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import glob +import os +import shutil + +import matplotlib +import matplotlib.pylab as plt +import torch + +matplotlib.use('Agg') + + +class AttrDict(dict): + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, aspect='auto', origin='lower', interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def plot_alignment(alignment, info=None): + fig, ax = plt.subplots() + im = ax.imshow( + alignment, aspect='auto', origin='lower', interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Input timestep' + if info is not None: + xlabel += '\t' + info + plt.xlabel(xlabel) + plt.ylabel('Output timestep') + fig.canvas.draw() + plt.close() + + return fig + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + checkpoint_dict = torch.load(filepath, map_location=device) + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + torch.save(obj, filepath) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????.pkl') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ValueWindow(): + + def __init__(self, window_size=100): + self._window_size = window_size + self._values = [] + + def append(self, x): + self._values = self._values[-(self._window_size - 1):] + [x] + + @property + def sum(self): + return sum(self._values) + + @property + def count(self): + return len(self._values) + + @property + def average(self): + return self.sum / max(1, self.count) + + def reset(self): + self._values = [] + + +def get_model_size(model): + param_num = sum([p.numel() for p in model.parameters() if p.requires_grad]) + param_size = param_num * 4 / 1024 / 1024 + return param_size + + +def get_grad_norm(model): + total_norm = 0 + params = [ + p for p in model.parameters() if p.grad is not None and p.requires_grad + ] + for p in params: + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item()**2 + total_norm = total_norm**0.5 + return total_norm + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(mean, std) + + +def get_mask_from_lengths(lengths, max_len=None): + batch_size = lengths.shape[0] + if max_len is None: + max_len = torch.max(lengths).item() + + ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, + -1).to(lengths.device) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + + return mask diff --git a/modelscope/models/audio/tts/models/vocoder_models.py b/modelscope/models/audio/tts/models/vocoder_models.py deleted file mode 100755 index c46a9204..00000000 --- a/modelscope/models/audio/tts/models/vocoder_models.py +++ /dev/null @@ -1,516 +0,0 @@ -from distutils.version import LooseVersion - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d -from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm - -from .utils import get_padding, init_weights - -is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7') - - -def stft(x, fft_size, hop_size, win_length, window): - """Perform STFT and convert to magnitude spectrogram. - - Args: - x (Tensor): Input signal tensor (B, T). - fft_size (int): FFT size. - hop_size (int): Hop size. - win_length (int): Window length. - window (str): Window function type. - - Returns: - Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). - - """ - if is_pytorch_17plus: - x_stft = torch.stft( - x, fft_size, hop_size, win_length, window, return_complex=False) - else: - x_stft = torch.stft(x, fft_size, hop_size, win_length, window) - real = x_stft[..., 0] - imag = x_stft[..., 1] - - # NOTE(kan-bayashi): clamp is needed to avoid nan or inf - return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) - - -LRELU_SLOPE = 0.1 - - -def get_padding_casual(kernel_size, dilation=1): - return int(kernel_size * dilation - dilation) - - -class Conv1dCasual(torch.nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros'): - super(Conv1dCasual, self).__init__() - self.pad = padding - self.conv1d = weight_norm( - Conv1d( - in_channels, - out_channels, - kernel_size, - stride, - padding=0, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode)) - self.conv1d.apply(init_weights) - - def forward(self, x): # bdt - # described starting from the last dimension and moving forward. - x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), 'constant') - x = self.conv1d(x) - return x - - def remove_weight_norm(self): - remove_weight_norm(self.conv1d) - - -class ConvTranspose1dCausal(torch.nn.Module): - """CausalConvTranspose1d module with customized initialization.""" - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding=0): - """Initialize CausalConvTranspose1d module.""" - super(ConvTranspose1dCausal, self).__init__() - self.deconv = weight_norm( - ConvTranspose1d(in_channels, out_channels, kernel_size, stride)) - self.stride = stride - self.deconv.apply(init_weights) - self.pad = kernel_size - stride - - def forward(self, x): - """Calculate forward propagation. - Args: - x (Tensor): Input tensor (B, in_channels, T_in). - Returns: - Tensor: Output tensor (B, out_channels, T_out). - """ - # x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant") - return self.deconv(x)[:, :, :-self.pad] - - def remove_weight_norm(self): - remove_weight_norm(self.deconv) - - -class ResBlock1(torch.nn.Module): - - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() - self.h = h - self.convs1 = nn.ModuleList([ - Conv1dCasual( - channels, - channels, - kernel_size, - 1, - dilation=dilation[i], - padding=get_padding_casual(kernel_size, dilation[i])) - for i in range(len(dilation)) - ]) - - self.convs2 = nn.ModuleList([ - Conv1dCasual( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding_casual(kernel_size, 1)) - for i in range(len(dilation)) - ]) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for layer in self.convs1: - layer.remove_weight_norm() - for layer in self.convs2: - layer.remove_weight_norm() - - -class Generator(torch.nn.Module): - - def __init__(self, h): - super(Generator, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - print('num_kernels={}, num_upsamples={}'.format( - self.num_kernels, self.num_upsamples)) - self.conv_pre = Conv1dCasual( - 80, h.upsample_initial_channel, 7, 1, padding=7 - 1) - resblock = ResBlock1 if h.resblock == '1' else ResBlock2 - - self.ups = nn.ModuleList() - self.repeat_ups = nn.ModuleList() - for i, (u, k) in enumerate( - zip(h.upsample_rates, h.upsample_kernel_sizes)): - upsample = nn.Sequential( - nn.Upsample(mode='nearest', scale_factor=u), - nn.LeakyReLU(LRELU_SLOPE), - Conv1dCasual( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2**(i + 1)), - kernel_size=7, - stride=1, - padding=7 - 1)) - self.repeat_ups.append(upsample) - self.ups.append( - ConvTranspose1dCausal( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2**(i + 1)), - k, - u, - padding=(k - u) // 2)) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2**(i + 1)) - for j, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = Conv1dCasual(ch, 1, 7, 1, padding=7 - 1) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = torch.sin(x) + x - # transconv - x1 = F.leaky_relu(x, LRELU_SLOPE) - x1 = self.ups[i](x1) - # repeat - x2 = self.repeat_ups[i](x) - x = x1 + x2 - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - return x - - def remove_weight_norm(self): - print('Removing weight norm...') - for layer in self.ups: - layer.remove_weight_norm() - for layer in self.repeat_ups: - layer[-1].remove_weight_norm() - for layer in self.resblocks: - layer.remove_weight_norm() - self.conv_pre.remove_weight_norm() - self.conv_post.remove_weight_norm() - - -class DiscriminatorP(torch.nn.Module): - - def __init__(self, - period, - kernel_size=5, - stride=3, - use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList([ - norm_f( - Conv2d( - 1, - 32, (kernel_size, 1), (stride, 1), - padding=(get_padding(5, 1), 0))), - norm_f( - Conv2d( - 32, - 128, (kernel_size, 1), (stride, 1), - padding=(get_padding(5, 1), 0))), - norm_f( - Conv2d( - 128, - 512, (kernel_size, 1), (stride, 1), - padding=(get_padding(5, 1), 0))), - norm_f( - Conv2d( - 512, - 1024, (kernel_size, 1), (stride, 1), - padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), - ]) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), 'reflect') - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for layer in self.convs: - x = layer(x) - x = F.leaky_relu(x, LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - - def __init__(self): - super(MultiPeriodDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ - DiscriminatorP(2), - DiscriminatorP(3), - DiscriminatorP(5), - DiscriminatorP(7), - DiscriminatorP(11), - ]) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorS(torch.nn.Module): - - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 128, 15, 1, padding=7)), - norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), - norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), - norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - for layer in self.convs: - x = layer(x) - x = F.leaky_relu(x, LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiScaleDiscriminator(torch.nn.Module): - - def __init__(self): - super(MultiScaleDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ - DiscriminatorS(use_spectral_norm=True), - DiscriminatorS(), - DiscriminatorS(), - ]) - from pytorch_wavelets import DWT1DForward - self.meanpools = nn.ModuleList( - [DWT1DForward(wave='db3', J=1), - DWT1DForward(wave='db3', J=1)]) - self.convs = nn.ModuleList([ - weight_norm(Conv1d(2, 1, 15, 1, padding=7)), - weight_norm(Conv1d(2, 1, 15, 1, padding=7)) - ]) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - if i != 0: - yl, yh = self.meanpools[i - 1](y) - y = torch.cat([yl, yh[0]], dim=1) - y = self.convs[i - 1](y) - y = F.leaky_relu(y, LRELU_SLOPE) - - yl_hat, yh_hat = self.meanpools[i - 1](y_hat) - y_hat = torch.cat([yl_hat, yh_hat[0]], dim=1) - y_hat = self.convs[i - 1](y_hat) - y_hat = F.leaky_relu(y_hat, LRELU_SLOPE) - - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorSTFT(torch.nn.Module): - - def __init__(self, - kernel_size=11, - stride=2, - use_spectral_norm=False, - fft_size=1024, - shift_size=120, - win_length=600, - window='hann_window'): - super(DiscriminatorSTFT, self).__init__() - self.fft_size = fft_size - self.shift_size = shift_size - self.win_length = win_length - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList([ - norm_f( - Conv2d( - fft_size // 2 + 1, - 32, (15, 1), (1, 1), - padding=(get_padding(15, 1), 0))), - norm_f( - Conv2d( - 32, - 32, (kernel_size, 1), (stride, 1), - padding=(get_padding(9, 1), 0))), - norm_f( - Conv2d( - 32, - 32, (kernel_size, 1), (stride, 1), - padding=(get_padding(9, 1), 0))), - norm_f( - Conv2d( - 32, - 32, (kernel_size, 1), (stride, 1), - padding=(get_padding(9, 1), 0))), - norm_f(Conv2d(32, 32, (5, 1), (1, 1), padding=(2, 0))), - ]) - self.conv_post = norm_f(Conv2d(32, 1, (3, 1), (1, 1), padding=(1, 0))) - self.register_buffer('window', getattr(torch, window)(win_length)) - - def forward(self, wav): - wav = torch.squeeze(wav, 1) - x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length, - self.window) - x = torch.transpose(x_mag, 2, 1).unsqueeze(-1) - fmap = [] - for layer in self.convs: - x = layer(x) - x = F.leaky_relu(x, LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = x.squeeze(-1) - - return x, fmap - - -class MultiSTFTDiscriminator(torch.nn.Module): - - def __init__( - self, - fft_sizes=[1024, 2048, 512], - hop_sizes=[120, 240, 50], - win_lengths=[600, 1200, 240], - window='hann_window', - ): - super(MultiSTFTDiscriminator, self).__init__() - self.discriminators = nn.ModuleList() - for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): - self.discriminators += [ - DiscriminatorSTFT(fft_size=fs, shift_size=ss, win_length=wl) - ] - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -def feature_loss(fmap_r, fmap_g): - loss = 0 - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - loss += torch.mean(torch.abs(rl - gl)) - - return loss * 2 - - -def discriminator_loss(disc_real_outputs, disc_generated_outputs): - loss = 0 - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - r_loss = torch.mean((1 - dr)**2) - g_loss = torch.mean(dg**2) - loss += (r_loss + g_loss) - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) - - return loss, r_losses, g_losses - - -def generator_loss(disc_outputs): - loss = 0 - gen_losses = [] - for dg in disc_outputs: - temp_loss = torch.mean((1 - dg)**2) - gen_losses.append(temp_loss) - loss += temp_loss - - return loss, gen_losses diff --git a/modelscope/models/audio/tts/sambert_hifi.py b/modelscope/models/audio/tts/sambert_hifi.py index 79f8068e..a9b55795 100644 --- a/modelscope/models/audio/tts/sambert_hifi.py +++ b/modelscope/models/audio/tts/sambert_hifi.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from __future__ import (absolute_import, division, print_function, unicode_literals) import os @@ -11,13 +13,11 @@ from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.utils.audio.tts_exceptions import ( TtsFrontendInitializeFailedException, - TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationExcetion, + TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationException, TtsVoiceNotExistsException) from modelscope.utils.constant import Tasks from .voice import Voice -import tensorflow as tf # isort:skip - __all__ = ['SambertHifigan'] @@ -28,14 +28,15 @@ class SambertHifigan(Model): def __init__(self, model_dir, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) if 'am' not in kwargs: - raise TtsModelConfigurationExcetion( - 'configuration model field missing am!') + raise TtsModelConfigurationException( + 'modelscope error: configuration model field missing am!') if 'vocoder' not in kwargs: - raise TtsModelConfigurationExcetion( - 'configuration model field missing vocoder!') + raise TtsModelConfigurationException( + 'modelscope error: configuration model field missing vocoder!') if 'lang_type' not in kwargs: - raise TtsModelConfigurationExcetion( - 'configuration model field missing lang_type!') + raise TtsModelConfigurationException( + 'modelscope error: configuration model field missing lang_type!' + ) am_cfg = kwargs['am'] voc_cfg = kwargs['vocoder'] # initialize frontend @@ -47,10 +48,12 @@ class SambertHifigan(Model): zip_ref.extractall(model_dir) if not frontend.initialize(self.__res_path): raise TtsFrontendInitializeFailedException( - 'resource invalid: {}'.format(self.__res_path)) + 'modelscope error: resource invalid: {}'.format( + self.__res_path)) if not frontend.set_lang_type(kwargs['lang_type']): raise TtsFrontendLanguageTypeInvalidException( - 'language type invalid: {}'.format(kwargs['lang_type'])) + 'modelscope error: language type invalid: {}'.format( + kwargs['lang_type'])) self.__frontend = frontend zip_file = os.path.join(model_dir, 'voices.zip') self.__voice_path = os.path.join(model_dir, 'voices') @@ -60,7 +63,8 @@ class SambertHifigan(Model): with open(voice_cfg_path, 'r') as f: voice_cfg = json.load(f) if 'voices' not in voice_cfg: - raise TtsModelConfigurationExcetion('voices invalid') + raise TtsModelConfigurationException( + 'modelscope error: voices invalid') self.__voice = {} for name in voice_cfg['voices']: voice_path = os.path.join(self.__voice_path, name) @@ -70,11 +74,13 @@ class SambertHifigan(Model): if voice_cfg['voices']: self.__default_voice_name = voice_cfg['voices'][0] else: - raise TtsVoiceNotExistsException('voices is empty in voices.json') + raise TtsVoiceNotExistsException( + 'modelscope error: voices is empty in voices.json') def __synthesis_one_sentences(self, voice_name, text): if voice_name not in self.__voice: - raise TtsVoiceNotExistsException(f'Voice {voice_name} not exists') + raise TtsVoiceNotExistsException( + f'modelscope error: Voice {voice_name} not exists') return self.__voice[voice_name].forward(text) def forward(self, text: str, voice_name: str = None): diff --git a/modelscope/models/audio/tts/text/cleaners.py b/modelscope/models/audio/tts/text/cleaners.py deleted file mode 100755 index 19d838d1..00000000 --- a/modelscope/models/audio/tts/text/cleaners.py +++ /dev/null @@ -1,89 +0,0 @@ -''' -Cleaners are transformations that run over the input text at both training and eval time. - -Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" -hyperparameter. Some cleaners are English-specific. You'll typically want to use: - 1. "english_cleaners" for English text - 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using - the Unidecode library (https://pypi.python.org/pypi/Unidecode) - 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update - the symbols in symbols.py to match your data). -''' - -import re - -from unidecode import unidecode - -from .numbers import normalize_numbers - -# Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) - for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), ]] # yapf:disable - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def expand_numbers(text): - return normalize_numbers(text) - - -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) - - -def convert_to_ascii(text): - return unidecode(text) - - -def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def transliteration_cleaners(text): - '''Pipeline for non-English text that transliterates to ASCII.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def english_cleaners(text): - '''Pipeline for English text, including number and abbreviation expansion.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_numbers(text) - text = expand_abbreviations(text) - text = collapse_whitespace(text) - return text diff --git a/modelscope/models/audio/tts/text/cmudict.py b/modelscope/models/audio/tts/text/cmudict.py deleted file mode 100755 index b4da4be9..00000000 --- a/modelscope/models/audio/tts/text/cmudict.py +++ /dev/null @@ -1,64 +0,0 @@ -import re - -valid_symbols = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', - 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', - 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', - 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', - 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', - 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', - 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', - 'Y', 'Z', 'ZH' -] - -_valid_symbol_set = set(valid_symbols) - - -class CMUDict: - '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' - - def __init__(self, file_or_path, keep_ambiguous=True): - if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: - entries = _parse_cmudict(f) - else: - entries = _parse_cmudict(file_or_path) - if not keep_ambiguous: - entries = { - word: pron - for word, pron in entries.items() if len(pron) == 1 - } - self._entries = entries - - def __len__(self): - return len(self._entries) - - def lookup(self, word): - '''Returns list of ARPAbet pronunciations of the given word.''' - return self._entries.get(word.upper()) - - -_alt_re = re.compile(r'\([0-9]+\)') - - -def _parse_cmudict(file): - cmudict = {} - for line in file: - if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) - pronunciation = _get_pronunciation(parts[1]) - if pronunciation: - if word in cmudict: - cmudict[word].append(pronunciation) - else: - cmudict[word] = [pronunciation] - return cmudict - - -def _get_pronunciation(s): - parts = s.strip().split(' ') - for part in parts: - if part not in _valid_symbol_set: - return None - return ' '.join(parts) diff --git a/modelscope/models/audio/tts/text/symbols.py b/modelscope/models/audio/tts/text/symbols.py deleted file mode 100644 index 63975abb..00000000 --- a/modelscope/models/audio/tts/text/symbols.py +++ /dev/null @@ -1,105 +0,0 @@ -''' -Defines the set of symbols used in text input to the model. - -The default is a set of ASCII characters that works well for English or text that has been run -through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. -''' -import codecs -import os - -_pad = '_' -_eos = '~' -_mask = '@[MASK]' - - -def load_symbols(dict_path, has_mask=True): - _characters = '' - _ch_symbols = [] - sy_dict_name = 'sy_dict.txt' - sy_dict_path = os.path.join(dict_path, sy_dict_name) - f = codecs.open(sy_dict_path, 'r') - for line in f: - line = line.strip('\r\n') - _ch_symbols.append(line) - - _arpabet = ['@' + s for s in _ch_symbols] - - # Export all symbols: - sy = list(_characters) + _arpabet + [_pad, _eos] - if has_mask: - sy.append(_mask) - - _characters = '' - - _ch_tones = [] - tone_dict_name = 'tone_dict.txt' - tone_dict_path = os.path.join(dict_path, tone_dict_name) - f = codecs.open(tone_dict_path, 'r') - for line in f: - line = line.strip('\r\n') - _ch_tones.append(line) - - # Export all tones: - tone = list(_characters) + _ch_tones + [_pad, _eos] - if has_mask: - tone.append(_mask) - - _characters = '' - - _ch_syllable_flags = [] - syllable_flag_name = 'syllable_flag_dict.txt' - syllable_flag_path = os.path.join(dict_path, syllable_flag_name) - f = codecs.open(syllable_flag_path, 'r') - for line in f: - line = line.strip('\r\n') - _ch_syllable_flags.append(line) - - # Export all syllable_flags: - syllable_flag = list(_characters) + _ch_syllable_flags + [_pad, _eos] - if has_mask: - syllable_flag.append(_mask) - - _characters = '' - - _ch_word_segments = [] - word_segment_name = 'word_segment_dict.txt' - word_segment_path = os.path.join(dict_path, word_segment_name) - f = codecs.open(word_segment_path, 'r') - for line in f: - line = line.strip('\r\n') - _ch_word_segments.append(line) - - # Export all syllable_flags: - word_segment = list(_characters) + _ch_word_segments + [_pad, _eos] - if has_mask: - word_segment.append(_mask) - - _characters = '' - - _ch_emo_types = [] - emo_category_name = 'emo_category_dict.txt' - emo_category_path = os.path.join(dict_path, emo_category_name) - f = codecs.open(emo_category_path, 'r') - for line in f: - line = line.strip('\r\n') - _ch_emo_types.append(line) - - emo_category = list(_characters) + _ch_emo_types + [_pad, _eos] - if has_mask: - emo_category.append(_mask) - - _characters = '' - - _ch_speakers = [] - speaker_name = 'speaker_dict.txt' - speaker_path = os.path.join(dict_path, speaker_name) - f = codecs.open(speaker_path, 'r') - for line in f: - line = line.strip('\r\n') - _ch_speakers.append(line) - - # Export all syllable_flags: - speaker = list(_characters) + _ch_speakers + [_pad, _eos] - if has_mask: - speaker.append(_mask) - return sy, tone, syllable_flag, word_segment, emo_category, speaker diff --git a/modelscope/models/audio/tts/text/symbols_dict.py b/modelscope/models/audio/tts/text/symbols_dict.py deleted file mode 100644 index e8f7ed19..00000000 --- a/modelscope/models/audio/tts/text/symbols_dict.py +++ /dev/null @@ -1,200 +0,0 @@ -import re -import sys - -from .cleaners import (basic_cleaners, english_cleaners, - transliteration_cleaners) - - -class SymbolsDict: - - def __init__(self, sy, tone, syllable_flag, word_segment, emo_category, - speaker, inputs_dim, lfeat_type_list): - self._inputs_dim = inputs_dim - self._lfeat_type_list = lfeat_type_list - self._sy_to_id = {s: i for i, s in enumerate(sy)} - self._id_to_sy = {i: s for i, s in enumerate(sy)} - self._tone_to_id = {s: i for i, s in enumerate(tone)} - self._id_to_tone = {i: s for i, s in enumerate(tone)} - self._syllable_flag_to_id = {s: i for i, s in enumerate(syllable_flag)} - self._id_to_syllable_flag = {i: s for i, s in enumerate(syllable_flag)} - self._word_segment_to_id = {s: i for i, s in enumerate(word_segment)} - self._id_to_word_segment = {i: s for i, s in enumerate(word_segment)} - self._emo_category_to_id = {s: i for i, s in enumerate(emo_category)} - self._id_to_emo_category = {i: s for i, s in enumerate(emo_category)} - self._speaker_to_id = {s: i for i, s in enumerate(speaker)} - self._id_to_speaker = {i: s for i, s in enumerate(speaker)} - print('_sy_to_id: ') - print(self._sy_to_id) - print('_tone_to_id: ') - print(self._tone_to_id) - print('_syllable_flag_to_id: ') - print(self._syllable_flag_to_id) - print('_word_segment_to_id: ') - print(self._word_segment_to_id) - print('_emo_category_to_id: ') - print(self._emo_category_to_id) - print('_speaker_to_id: ') - print(self._speaker_to_id) - self._curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') - self._cleaners = { - basic_cleaners.__name__: basic_cleaners, - transliteration_cleaners.__name__: transliteration_cleaners, - english_cleaners.__name__: english_cleaners - } - - def _clean_text(self, text, cleaner_names): - for name in cleaner_names: - cleaner = self._cleaners.get(name) - if not cleaner: - raise Exception('Unknown cleaner: %s' % name) - text = cleaner(text) - return text - - def _sy_to_sequence(self, sy): - return [self._sy_to_id[s] for s in sy if self._should_keep_sy(s)] - - def _arpabet_to_sequence(self, text): - return self._sy_to_sequence(['@' + s for s in text.split()]) - - def _should_keep_sy(self, s): - return s in self._sy_to_id and s != '_' and s != '~' - - def symbol_to_sequence(self, this_lfeat_symbol, lfeat_type, cleaner_names): - sequence = [] - if lfeat_type == 'sy': - this_lfeat_symbol = this_lfeat_symbol.strip().split(' ') - this_lfeat_symbol_format = '' - index = 0 - while index < len(this_lfeat_symbol): - this_lfeat_symbol_format = this_lfeat_symbol_format + '{' + this_lfeat_symbol[ - index] + '}' + ' ' - index = index + 1 - sequence = self.text_to_sequence(this_lfeat_symbol_format, - cleaner_names) - elif lfeat_type == 'tone': - sequence = self.tone_to_sequence(this_lfeat_symbol) - elif lfeat_type == 'syllable_flag': - sequence = self.syllable_flag_to_sequence(this_lfeat_symbol) - elif lfeat_type == 'word_segment': - sequence = self.word_segment_to_sequence(this_lfeat_symbol) - elif lfeat_type == 'emo_category': - sequence = self.emo_category_to_sequence(this_lfeat_symbol) - elif lfeat_type == 'speaker': - sequence = self.speaker_to_sequence(this_lfeat_symbol) - else: - raise Exception('Unknown lfeat type: %s' % lfeat_type) - - return sequence - - def text_to_sequence(self, text, cleaner_names): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." - - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through - - Returns: - List of integers corresponding to the symbols in the text - ''' - sequence = [] - - # Check for curly braces and treat their contents as ARPAbet: - while len(text): - m = self._curly_re.match(text) - if not m: - sequence += self._sy_to_sequence( - self._clean_text(text, cleaner_names)) - break - sequence += self._sy_to_sequence( - self._clean_text(m.group(1), cleaner_names)) - sequence += self._arpabet_to_sequence(m.group(2)) - text = m.group(3) - - # Append EOS token - sequence.append(self._sy_to_id['~']) - return sequence - - def tone_to_sequence(self, tone): - tones = tone.strip().split(' ') - sequence = [] - for this_tone in tones: - sequence.append(self._tone_to_id[this_tone]) - sequence.append(self._tone_to_id['~']) - return sequence - - def syllable_flag_to_sequence(self, syllable_flag): - syllable_flags = syllable_flag.strip().split(' ') - sequence = [] - for this_syllable_flag in syllable_flags: - sequence.append(self._syllable_flag_to_id[this_syllable_flag]) - sequence.append(self._syllable_flag_to_id['~']) - return sequence - - def word_segment_to_sequence(self, word_segment): - word_segments = word_segment.strip().split(' ') - sequence = [] - for this_word_segment in word_segments: - sequence.append(self._word_segment_to_id[this_word_segment]) - sequence.append(self._word_segment_to_id['~']) - return sequence - - def emo_category_to_sequence(self, emo_type): - emo_categories = emo_type.strip().split(' ') - sequence = [] - for this_category in emo_categories: - sequence.append(self._emo_category_to_id[this_category]) - sequence.append(self._emo_category_to_id['~']) - return sequence - - def speaker_to_sequence(self, speaker): - speakers = speaker.strip().split(' ') - sequence = [] - for this_speaker in speakers: - sequence.append(self._speaker_to_id[this_speaker]) - sequence.append(self._speaker_to_id['~']) - return sequence - - def sequence_to_symbol(self, sequence): - result = '' - pre_lfeat_dim = 0 - for lfeat_type in self._lfeat_type_list: - current_one_hot_sequence = sequence[:, pre_lfeat_dim:pre_lfeat_dim - + self._inputs_dim[lfeat_type]] - current_sequence = current_one_hot_sequence.argmax(1) - length = current_sequence.shape[0] - - index = 0 - while index < length: - this_sequence = current_sequence[index] - s = '' - if lfeat_type == 'sy': - s = self._id_to_sy[this_sequence] - if len(s) > 1 and s[0] == '@': - s = s[1:] - elif lfeat_type == 'tone': - s = self._id_to_tone[this_sequence] - elif lfeat_type == 'syllable_flag': - s = self._id_to_syllable_flag[this_sequence] - elif lfeat_type == 'word_segment': - s = self._id_to_word_segment[this_sequence] - elif lfeat_type == 'emo_category': - s = self._id_to_emo_category[this_sequence] - elif lfeat_type == 'speaker': - s = self._id_to_speaker[this_sequence] - else: - raise Exception('Unknown lfeat type: %s' % lfeat_type) - - if index == 0: - result = result + lfeat_type + ': ' - - result = result + '{' + s + '}' - - if index == length - 1: - result = result + '; ' - - index = index + 1 - pre_lfeat_dim = pre_lfeat_dim + self._inputs_dim[lfeat_type] - return result diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py index deaebf11..dc830db5 100644 --- a/modelscope/models/audio/tts/voice.py +++ b/modelscope/models/audio/tts/voice.py @@ -1,286 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os +import pickle as pkl import json import numpy as np import torch -from sklearn.preprocessing import MultiLabelBinarizer +from modelscope.utils.audio.tts_exceptions import \ + TtsModelConfigurationException from modelscope.utils.constant import ModelFile, Tasks -from .models import Generator, create_am_model -from .text.symbols import load_symbols -from .text.symbols_dict import SymbolsDict - -import tensorflow as tf # isort:skip +from .models.datasets.units import KanTtsLinguisticUnit +from .models.models.hifigan import Generator +from .models.models.sambert import KanTtsSAMBERT +from .models.utils import (AttrDict, build_env, init_weights, load_checkpoint, + plot_spectrogram, save_checkpoint, scan_checkpoint) MAX_WAV_VALUE = 32768.0 -def multi_label_symbol_to_sequence(my_classes, my_symbol): - one_hot = MultiLabelBinarizer(classes=my_classes) - tokens = my_symbol.strip().split(' ') - sequences = [] - for token in tokens: - sequences.append(tuple(token.split('&'))) - return one_hot.fit_transform(sequences) - - -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - checkpoint_dict = torch.load(filepath, map_location=device) - return checkpoint_dict - - -class AttrDict(dict): - - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - class Voice: - def __init__(self, voice_name, voice_path, am_hparams, voc_config): + def __init__(self, voice_name, voice_path, am_config, voc_config): self.__voice_name = voice_name self.__voice_path = voice_path - self.__am_hparams = tf.contrib.training.HParams(**am_hparams) + self.__am_config = AttrDict(**am_config) self.__voc_config = AttrDict(**voc_config) self.__model_loaded = False + if 'am' not in self.__am_config: + raise TtsModelConfigurationException( + 'modelscope error: am configuration invalid') + if 'linguistic_unit' not in self.__am_config: + raise TtsModelConfigurationException( + 'modelscope error: am configuration invalid') + self.__am_lingustic_unit_config = self.__am_config['linguistic_unit'] def __load_am(self): - local_am_ckpt_path = os.path.join(self.__voice_path, - ModelFile.TF_CHECKPOINT_FOLDER) - self.__am_ckpt_path = os.path.join(local_am_ckpt_path, 'ckpt') - self.__dict_path = os.path.join(self.__voice_path, 'dicts') + local_am_ckpt_path = os.path.join(self.__voice_path, 'am') + self.__am_ckpt_path = os.path.join(local_am_ckpt_path, + ModelFile.TORCH_MODEL_BIN_FILE) has_mask = True - if self.__am_hparams.get('has_mask') is not None: - has_mask = self.__am_hparams.has_mask - model_name = 'robutrans' - self.__lfeat_type_list = self.__am_hparams.lfeat_type_list.strip( - ).split(',') - sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( - self.__dict_path, has_mask) - self.__sy = sy - self.__tone = tone - self.__syllable_flag = syllable_flag - self.__word_segment = word_segment - self.__emo_category = emo_category - self.__speaker = speaker - self.__inputs_dim = dict() - for lfeat_type in self.__lfeat_type_list: - if lfeat_type == 'sy': - self.__inputs_dim[lfeat_type] = len(sy) - elif lfeat_type == 'tone': - self.__inputs_dim[lfeat_type] = len(tone) - elif lfeat_type == 'syllable_flag': - self.__inputs_dim[lfeat_type] = len(syllable_flag) - elif lfeat_type == 'word_segment': - self.__inputs_dim[lfeat_type] = len(word_segment) - elif lfeat_type == 'emo_category': - self.__inputs_dim[lfeat_type] = len(emo_category) - elif lfeat_type == 'speaker': - self.__inputs_dim[lfeat_type] = len(speaker) - - self.__symbols_dict = SymbolsDict(sy, tone, syllable_flag, - word_segment, emo_category, speaker, - self.__inputs_dim, - self.__lfeat_type_list) - dim_inputs = sum(self.__inputs_dim.values( - )) - self.__inputs_dim['speaker'] - self.__inputs_dim['emo_category'] - self.__graph = tf.Graph() - with self.__graph.as_default(): - inputs = tf.placeholder(tf.float32, [1, None, dim_inputs], - 'inputs') - inputs_emotion = tf.placeholder( - tf.float32, [1, None, self.__inputs_dim['emo_category']], - 'inputs_emotion') - inputs_speaker = tf.placeholder( - tf.float32, [1, None, self.__inputs_dim['speaker']], - 'inputs_speaker') - input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') - pitch_contours_scale = tf.placeholder(tf.float32, [1, None], - 'pitch_contours_scale') - energy_contours_scale = tf.placeholder(tf.float32, [1, None], - 'energy_contours_scale') - duration_scale = tf.placeholder(tf.float32, [1, None], - 'duration_scale') - with tf.variable_scope('model') as _: - self.__model = create_am_model(model_name, self.__am_hparams) - self.__model.initialize( - inputs, - inputs_emotion, - inputs_speaker, - input_lengths, - duration_scales=duration_scale, - pitch_scales=pitch_contours_scale, - energy_scales=energy_contours_scale) - self.__mel_spec = self.__model.mel_outputs[0] - self.__duration_outputs = self.__model.duration_outputs[0] - self.__duration_outputs_ = self.__model.duration_outputs_[0] - self.__pitch_contour_outputs = self.__model.pitch_contour_outputs[ - 0] - self.__energy_contour_outputs = self.__model.energy_contour_outputs[ - 0] - self.__embedded_inputs_emotion = self.__model.embedded_inputs_emotion[ - 0] - self.__embedding_fsmn_outputs = self.__model.embedding_fsmn_outputs[ - 0] - self.__encoder_outputs = self.__model.encoder_outputs[0] - self.__pitch_embeddings = self.__model.pitch_embeddings[0] - self.__energy_embeddings = self.__model.energy_embeddings[0] - self.__LR_outputs = self.__model.LR_outputs[0] - self.__postnet_fsmn_outputs = self.__model.postnet_fsmn_outputs[ - 0] - self.__attention_h = self.__model.attention_h - self.__attention_x = self.__model.attention_x - - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - self.__session = tf.Session(config=config) - self.__session.run(tf.global_variables_initializer()) - - saver = tf.train.Saver() - saver.restore(self.__session, self.__am_ckpt_path) + if 'has_mask' in self.__am_lingustic_unit_config: + has_mask = self.__am_lingustic_unit_config.has_mask + self.__ling_unit = KanTtsLinguisticUnit( + self.__am_lingustic_unit_config, self.__voice_path, has_mask) + self.__am_net = KanTtsSAMBERT(self.__am_config, + self.__ling_unit.get_unit_size()).to( + self.__device) + state_dict_g = {} + try: + state_dict_g = load_checkpoint(self.__am_ckpt_path, self.__device) + except RuntimeError: + with open(self.__am_ckpt_path, 'rb') as f: + pth_var_dict = pkl.load(f) + state_dict_g['fsnet'] = { + k: torch.FloatTensor(v) + for k, v in pth_var_dict['fsnet'].items() + } + self.__am_net.load_state_dict(state_dict_g['fsnet'], strict=False) + self.__am_net.eval() def __load_vocoder(self): - self.__voc_ckpt_path = os.path.join(self.__voice_path, + local_voc_ckpy_path = os.path.join(self.__voice_path, 'vocoder') + self.__voc_ckpt_path = os.path.join(local_voc_ckpy_path, ModelFile.TORCH_MODEL_BIN_FILE) - if torch.cuda.is_available(): - torch.manual_seed(self.__voc_config.seed) - self.__device = torch.device('cuda') - else: - self.__device = torch.device('cpu') self.__generator = Generator(self.__voc_config).to(self.__device) state_dict_g = load_checkpoint(self.__voc_ckpt_path, self.__device) self.__generator.load_state_dict(state_dict_g['generator']) self.__generator.eval() self.__generator.remove_weight_norm() - def __am_forward(self, - text, - pitch_control_str='', - duration_control_str='', - energy_control_str=''): - duration_cfg_lst = [] - if len(duration_control_str) != 0: - for item in duration_control_str.strip().split('|'): - percent, scale = item.lstrip('(').rstrip(')').split(',') - duration_cfg_lst.append((float(percent), float(scale))) - pitch_contours_cfg_lst = [] - if len(pitch_control_str) != 0: - for item in pitch_control_str.strip().split('|'): - percent, scale = item.lstrip('(').rstrip(')').split(',') - pitch_contours_cfg_lst.append((float(percent), float(scale))) - energy_contours_cfg_lst = [] - if len(energy_control_str) != 0: - for item in energy_control_str.strip().split('|'): - percent, scale = item.lstrip('(').rstrip(')').split(',') - energy_contours_cfg_lst.append((float(percent), float(scale))) - cleaner_names = [ - x.strip() for x in self.__am_hparams.cleaners.split(',') - ] - - lfeat_symbol = text.strip().split(' ') - lfeat_symbol_separate = [''] * int(len(self.__lfeat_type_list)) - for this_lfeat_symbol in lfeat_symbol: - this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split( - '$') - if len(this_lfeat_symbol) != len(self.__lfeat_type_list): - raise Exception( - 'Length of this_lfeat_symbol in training data' - + ' is not equal to the length of lfeat_type_list, ' - + str(len(this_lfeat_symbol)) + ' VS. ' - + str(len(self.__lfeat_type_list))) - index = 0 - while index < len(lfeat_symbol_separate): - lfeat_symbol_separate[index] = lfeat_symbol_separate[ - index] + this_lfeat_symbol[index] + ' ' - index = index + 1 - - index = 0 - lfeat_type = self.__lfeat_type_list[index] - sequence = self.__symbols_dict.symbol_to_sequence( - lfeat_symbol_separate[index].strip(), lfeat_type, cleaner_names) - sequence_array = np.asarray( - sequence[:-1], - dtype=np.int32) # sequence length minus 1 to ignore EOS ~ - inputs = np.eye( - self.__inputs_dim[lfeat_type], dtype=np.float32)[sequence_array] - index = index + 1 - while index < len(self.__lfeat_type_list) - 2: - lfeat_type = self.__lfeat_type_list[index] - sequence = self.__symbols_dict.symbol_to_sequence( - lfeat_symbol_separate[index].strip(), lfeat_type, - cleaner_names) - sequence_array = np.asarray( - sequence[:-1], - dtype=np.int32) # sequence length minus 1 to ignore EOS ~ - inputs_temp = np.eye( - self.__inputs_dim[lfeat_type], - dtype=np.float32)[sequence_array] - inputs = np.concatenate((inputs, inputs_temp), axis=1) - index = index + 1 - seq = inputs - - lfeat_type = 'emo_category' - inputs_emotion = multi_label_symbol_to_sequence( - self.__emo_category, lfeat_symbol_separate[index].strip()) - # inputs_emotion = inputs_emotion * 1.5 - index = index + 1 - - lfeat_type = 'speaker' - inputs_speaker = multi_label_symbol_to_sequence( - self.__speaker, lfeat_symbol_separate[index].strip()) - - duration_scale = np.ones((len(seq), ), dtype=np.float32) - start_idx = 0 - for (percent, scale) in duration_cfg_lst: - duration_scale[start_idx:start_idx - + int(percent * len(seq))] = scale - start_idx += int(percent * len(seq)) - - pitch_contours_scale = np.ones((len(seq), ), dtype=np.float32) - start_idx = 0 - for (percent, scale) in pitch_contours_cfg_lst: - pitch_contours_scale[start_idx:start_idx - + int(percent * len(seq))] = scale - start_idx += int(percent * len(seq)) - - energy_contours_scale = np.ones((len(seq), ), dtype=np.float32) - start_idx = 0 - for (percent, scale) in energy_contours_cfg_lst: - energy_contours_scale[start_idx:start_idx - + int(percent * len(seq))] = scale - start_idx += int(percent * len(seq)) - - feed_dict = { - self.__model.inputs: [np.asarray(seq, dtype=np.float32)], - self.__model.inputs_emotion: - [np.asarray(inputs_emotion, dtype=np.float32)], - self.__model.inputs_speaker: - [np.asarray(inputs_speaker, dtype=np.float32)], - self.__model.input_lengths: - np.asarray([len(seq)], dtype=np.int32), - self.__model.duration_scales: [duration_scale], - self.__model.pitch_scales: [pitch_contours_scale], - self.__model.energy_scales: [energy_contours_scale] - } - - result = self.__session.run([ - self.__mel_spec, self.__duration_outputs, self.__duration_outputs_, - self.__pitch_contour_outputs, self.__embedded_inputs_emotion, - self.__embedding_fsmn_outputs, self.__encoder_outputs, - self.__pitch_embeddings, self.__LR_outputs, - self.__postnet_fsmn_outputs, self.__energy_contour_outputs, - self.__energy_embeddings, self.__attention_x, self.__attention_h - ], feed_dict=feed_dict) # yapf:disable - return result[0] + def __am_forward(self, symbol_seq): + with torch.no_grad(): + inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( + symbol_seq) + inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( + self.__device) + inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( + self.__device) + inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( + self.__device) + inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( + self.__device) + inputs_ling = torch.stack( + [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], + dim=-1).unsqueeze(0) + inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( + self.__device).unsqueeze(0) + inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( + self.__device).unsqueeze(0) + inputs_len = torch.zeros(1).to(self.__device).long( + ) + inputs_emo.size(1) - 1 # minus 1 for "~" + res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], + inputs_spk[:, :-1], inputs_len) + postnet_outputs = res['postnet_outputs'] + LR_length_rounded = res['LR_length_rounded'] + valid_length = int(LR_length_rounded[0].item()) + postnet_outputs = postnet_outputs[ + 0, :valid_length, :].cpu().numpy() + return postnet_outputs def __vocoder_forward(self, melspec): dim0 = list(melspec.shape)[-1] if dim0 != self.__voc_config.num_mels: raise TtsVocoderMelspecShapeMismatchException( - 'input melspec mismatch require {} but {}'.format( - self.__voc_config.num_mels, dim0)) + 'modelscope error: input melspec mismatch require {} but {}'. + format(self.__voc_config.num_mels, dim0)) with torch.no_grad(): x = melspec.T x = torch.FloatTensor(x).to(self.__device) @@ -292,9 +117,15 @@ class Voice: audio = audio.cpu().numpy().astype('int16') return audio - def forward(self, text): + def forward(self, symbol_seq): if not self.__model_loaded: + torch.manual_seed(self.__am_config.seed) + if torch.cuda.is_available(): + torch.manual_seed(self.__am_config.seed) + self.__device = torch.device('cuda') + else: + self.__device = torch.device('cpu') self.__load_am() self.__load_vocoder() self.__model_loaded = True - return self.__vocoder_forward(self.__am_forward(text)) + return self.__vocoder_forward(self.__am_forward(symbol_seq)) diff --git a/modelscope/pipelines/audio/text_to_speech_pipeline.py b/modelscope/pipelines/audio/text_to_speech_pipeline.py index f9e7d80a..2063da68 100644 --- a/modelscope/pipelines/audio/text_to_speech_pipeline.py +++ b/modelscope/pipelines/audio/text_to_speech_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from typing import Any, Dict, List import numpy as np @@ -42,3 +44,6 @@ class TextToSpeechSambertHifiganPipeline(Pipeline): def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: return inputs + + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} diff --git a/modelscope/utils/audio/tts_exceptions.py b/modelscope/utils/audio/tts_exceptions.py index 8c73b603..43ec994b 100644 --- a/modelscope/utils/audio/tts_exceptions.py +++ b/modelscope/utils/audio/tts_exceptions.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. """ Define TTS exceptions """ @@ -10,7 +11,7 @@ class TtsException(Exception): pass -class TtsModelConfigurationExcetion(TtsException): +class TtsModelConfigurationException(TtsException): """ TTS model configuration exceptions. """ diff --git a/requirements/audio.txt b/requirements/audio.txt index 5e4bc104..d22ad8f1 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,6 +1,5 @@ easyasr>=0.0.2 espnet>=202204 -#tts h5py inflect keras @@ -15,11 +14,7 @@ nltk numpy<=1.18 # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. protobuf>3,<3.21.0 -ptflops py_sound_connect -pytorch_wavelets -PyWavelets>=1.0.0 -scikit-learn SoundFile>0.10 sox torchaudio diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py index e82cf43e..f659e59b 100644 --- a/tests/pipelines/test_text_to_speech.py +++ b/tests/pipelines/test_text_to_speech.py @@ -9,6 +9,7 @@ import unittest import torch from scipy.io.wavfile import write +from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks @@ -33,7 +34,9 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase, text = '今天北京天气怎么样?' voice = 'zhitian_emo' - sambert_hifigan_tts = pipeline(task=self.task, model=self.model_id) + model = Model.from_pretrained( + model_name_or_path=self.model_id, revision='pytorch_am') + sambert_hifigan_tts = pipeline(task=self.task, model=model) self.assertTrue(sambert_hifigan_tts is not None) output = sambert_hifigan_tts(input=text, voice=voice) self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])