You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 11 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """evaluation"""
  16. import os
  17. from os.path import join
  18. import argparse
  19. import glob
  20. import numpy as np
  21. from scipy.io import wavfile
  22. from hparams import hparams, hparams_debug_string
  23. import audio
  24. from tqdm import tqdm
  25. from mindspore import context, Tensor
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. import mindspore.dataset.engine as de
  28. from nnmnkwii import preprocessing as P
  29. from nnmnkwii.datasets import FileSourceDataset
  30. from wavenet_vocoder import WaveNet
  31. from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_scalar_input
  32. from src.dataset import RawAudioDataSource, MelSpecDataSource, DualDataset
  33. parser = argparse.ArgumentParser(description='TTS training')
  34. parser.add_argument('--data_path', type=str, required=True, default='',
  35. help='Directory contains preprocessed features.')
  36. parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
  37. parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
  38. parser.add_argument('--is_numpy', action="store_true", default=False, help='Using numpy for inference or not')
  39. parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
  40. parser.add_argument('--speaker_id', type=str, default='',
  41. help=' Use specific speaker of data in case for multi-speaker datasets.')
  42. parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
  43. help='run platform, support GPU and CPU. Default: GPU')
  44. args = parser.parse_args()
  45. def get_data_loader(hparam, data_dir):
  46. """
  47. test data loader
  48. """
  49. wav_paths = glob.glob(os.path.join(data_dir, "*-wave.npy"))
  50. if wav_paths:
  51. X = FileSourceDataset(RawAudioDataSource(data_dir,
  52. hop_size=audio.get_hop_size(),
  53. max_steps=None, cin_pad=hparam.cin_pad))
  54. else:
  55. X = None
  56. C = FileSourceDataset(MelSpecDataSource(data_dir,
  57. hop_size=audio.get_hop_size(),
  58. max_steps=None, cin_pad=hparam.cin_pad))
  59. length_x = np.array(C.file_data_source.lengths)
  60. if C[0].shape[-1] != hparam.cin_channels:
  61. raise RuntimeError("Invalid cin_channnels {}. Expected to be {}.".format(hparam.cin_channels, C[0].shape[-1]))
  62. dataset = DualDataset(X, C, length_x, batch_size=hparam.batch_size, hparams=hparam)
  63. data_loader = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"])
  64. return data_loader, dataset
  65. def batch_wavegen(hparam, net, c_input=None, g_input=None, tqdm_=None, is_numpy=True):
  66. """
  67. generate audio
  68. """
  69. assert c_input is not None
  70. B = c_input.shape[0]
  71. net.set_train(False)
  72. if hparam.upsample_conditional_features:
  73. length = (c_input.shape[-1] - hparam.cin_pad * 2) * audio.get_hop_size()
  74. else:
  75. # already dupulicated
  76. length = c_input.shape[-1]
  77. y_hat = net.incremental_forward(c=c_input, g=g_input, T=length, tqdm=tqdm_, softmax=True, quantize=True,
  78. log_scale_min=hparam.log_scale_min, is_numpy=is_numpy)
  79. if is_mulaw_quantize(hparam.input_type):
  80. # needs to be float since mulaw_inv returns in range of [-1, 1]
  81. y_hat = np.reshape(np.argmax(y_hat, 1), (B, -1))
  82. y_hat = y_hat.astype(np.float32)
  83. for k in range(B):
  84. y_hat[k] = P.inv_mulaw_quantize(y_hat[k], hparam.quantize_channels - 1)
  85. elif is_mulaw(hparam.input_type):
  86. y_hat = np.reshape(y_hat, (B, -1))
  87. for k in range(B):
  88. y_hat[k] = P.inv_mulaw(y_hat[k], hparam.quantize_channels - 1)
  89. else:
  90. y_hat = np.reshape(y_hat, (B, -1))
  91. if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
  92. for k in range(B):
  93. y_hat[k] = getattr(audio, hparam.postprocess)(y_hat[k])
  94. if hparam.global_gain_scale > 0:
  95. for k in range(B):
  96. y_hat[k] /= hparam.global_gain_scale
  97. return y_hat
  98. def to_int16(x_):
  99. """
  100. convert datatype to int16
  101. """
  102. if x_.dtype == np.int16:
  103. return x_
  104. assert x_.dtype == np.float32
  105. assert x_.min() >= -1 and x_.max() <= 1.0
  106. return (x_ * 32767).astype(np.int16)
  107. def get_reference_file(hparam, dataset_source, idx):
  108. """
  109. get reference files
  110. """
  111. reference_files = []
  112. reference_feats = []
  113. for _ in range(hparam.batch_size):
  114. if hasattr(dataset_source, "X"):
  115. reference_files.append(dataset_source.X.collected_files[idx][0])
  116. else:
  117. pass
  118. if hasattr(dataset_source, "Mel"):
  119. reference_feats.append(dataset_source.Mel.collected_files[idx][0])
  120. else:
  121. reference_feats.append(dataset_source.collected_files[idx][0])
  122. idx += 1
  123. return reference_files, reference_feats, idx
  124. def get_saved_audio_name(has_ref_file_, ref_file, ref_feat, g_fp):
  125. """get path to save reference audio"""
  126. if has_ref_file_:
  127. target_audio_path = ref_file
  128. name = os.path.splitext(os.path.basename(target_audio_path))[0].replace("-wave", "")
  129. else:
  130. target_feat_path = ref_feat
  131. name = os.path.splitext(os.path.basename(target_feat_path))[0].replace("-feats", "")
  132. # Paths
  133. if g_fp is None:
  134. dst_wav_path_ = join(args.output_path, "{}_gen.wav".format(name))
  135. target_wav_path_ = join(args.output_path, "{}_ref.wav".format(name))
  136. else:
  137. dst_wav_path_ = join(args.output_path, "speaker{}_{}_gen.wav".format(g, name))
  138. target_wav_path_ = join(args.output_path, "speaker{}_{}_ref.wav".format(g, name))
  139. return dst_wav_path_, target_wav_path_
  140. def save_ref_audio(hparam, ref, length, target_wav_path_):
  141. """
  142. save reference audio
  143. """
  144. if is_mulaw_quantize(hparam.input_type):
  145. ref = np.reshape(np.argmax(ref, 0), (-1))[:length]
  146. ref = ref.astype(np.float32)
  147. else:
  148. ref = np.reshape(ref, (-1))[:length]
  149. if is_mulaw_quantize(hparam.input_type):
  150. ref = P.inv_mulaw_quantize(ref, hparam.quantize_channels - 1)
  151. elif is_mulaw(hparam.input_type):
  152. ref = P.inv_mulaw(ref, hparam.quantize_channels - 1)
  153. if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
  154. ref = getattr(audio, hparam.postprocess)(ref)
  155. if hparam.global_gain_scale > 0:
  156. ref /= hparam.global_gain_scale
  157. ref = np.clip(ref, -1.0, 1.0)
  158. wavfile.write(target_wav_path_, hparam.sample_rate, to_int16(ref))
  159. if __name__ == '__main__':
  160. context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
  161. speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
  162. if args.preset is not None:
  163. with open(args.preset) as f:
  164. hparams.parse_json(f.read())
  165. assert hparams.name == "wavenet_vocoder"
  166. print(hparams_debug_string())
  167. fs = hparams.sample_rate
  168. hparams.batch_size = 10
  169. hparams.max_time_sec = None
  170. hparams.max_time_steps = None
  171. data_loaders, source_dataset = get_data_loader(hparam=hparams, data_dir=args.data_path)
  172. upsample_params = hparams.upsample_params
  173. upsample_params["cin_channels"] = hparams.cin_channels
  174. upsample_params["cin_pad"] = hparams.cin_pad
  175. model = WaveNet(
  176. out_channels=hparams.out_channels,
  177. layers=hparams.layers,
  178. stacks=hparams.stacks,
  179. residual_channels=hparams.residual_channels,
  180. gate_channels=hparams.gate_channels,
  181. skip_out_channels=hparams.skip_out_channels,
  182. cin_channels=hparams.cin_channels,
  183. gin_channels=hparams.gin_channels,
  184. n_speakers=hparams.n_speakers,
  185. dropout=hparams.dropout,
  186. kernel_size=hparams.kernel_size,
  187. cin_pad=hparams.cin_pad,
  188. upsample_conditional_features=hparams.upsample_conditional_features,
  189. upsample_params=upsample_params,
  190. scalar_input=is_scalar_input(hparams.input_type),
  191. output_distribution=hparams.output_distribution,
  192. )
  193. param_dict = load_checkpoint(args.pretrain_ckpt)
  194. load_param_into_net(model, param_dict)
  195. print('Successfully loading the pre-trained model')
  196. os.makedirs(args.output_path, exist_ok=True)
  197. cin_pad = hparams.cin_pad
  198. file_idx = 0
  199. for data in data_loaders.create_dict_iterator():
  200. x, y, c, g, input_lengths = data['x_batch'], data['y_batch'], data['c_batch'], data['g_batch'], data[
  201. 'input_lengths']
  202. if cin_pad > 0:
  203. c = c.asnumpy()
  204. c = np.pad(c, pad_width=(cin_pad, cin_pad), mode="edge")
  205. c = Tensor(c)
  206. ref_files, ref_feats, file_idx = get_reference_file(hparams, source_dataset, file_idx)
  207. # Generate
  208. y_hats = batch_wavegen(hparams, model, data['c_batch'], tqdm_=tqdm, is_numpy=args.is_numpy)
  209. x = x.asnumpy()
  210. input_lengths = input_lengths.asnumpy()
  211. # Save each utt.
  212. has_ref_file = bool(ref_files)
  213. for i, (ref_, gen_, length_) in enumerate(zip(x, y_hats, input_lengths)):
  214. dst_wav_path, target_wav_path = get_saved_audio_name(has_ref_file_=has_ref_file, ref_file=ref_files[i],
  215. ref_feat=ref_feats[i], g_fp=g)
  216. save_ref_audio(hparams, ref_, length_, target_wav_path)
  217. gen = gen_[:length_]
  218. gen = np.clip(gen, -1.0, 1.0)
  219. wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen))