You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 4.7 kB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """TODO: Add docstring."""
  2. import os
  3. import time
  4. from pathlib import Path
  5. from threading import Thread
  6. import numpy as np
  7. import pyaudio
  8. import torch
  9. from dora import Node
  10. from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
  11. from transformers import (
  12. AutoFeatureExtractor,
  13. AutoTokenizer,
  14. StoppingCriteria,
  15. StoppingCriteriaList,
  16. set_seed,
  17. )
  18. device = "cuda:0" # if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
  19. torch_dtype = torch.float16 if device != "cpu" else torch.float32
  20. DEFAULT_PATH = "ylacombe/parler-tts-mini-jenny-30H"
  21. MODEL_NAME_OR_PATH = os.getenv("MODEL_NAME_OR_PATH", DEFAULT_PATH)
  22. if bool(os.getenv("USE_MODELSCOPE_HUB") in ["True", "true"]):
  23. from modelscope import snapshot_download
  24. if not Path(MODEL_NAME_OR_PATH).exists():
  25. MODEL_NAME_OR_PATH = snapshot_download(MODEL_NAME_OR_PATH)
  26. model = ParlerTTSForConditionalGeneration.from_pretrained(
  27. MODEL_NAME_OR_PATH, torch_dtype=torch_dtype, low_cpu_mem_usage=True,
  28. ).to(device)
  29. model.generation_config.cache_implementation = "static"
  30. model.forward = torch.compile(model.forward, mode="default")
  31. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
  32. feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME_OR_PATH)
  33. SAMPLE_RATE = feature_extractor.sampling_rate
  34. SEED = 42
  35. default_text = "Hello, my name is Reachy the best robot in the world !"
  36. default_description = (
  37. "Jenny delivers her words quite expressively, in a very confined sounding environment with clear audio quality.",
  38. )
  39. p = pyaudio.PyAudio()
  40. sampling_rate = model.audio_encoder.config.sampling_rate
  41. frame_rate = model.audio_encoder.config.frame_rate
  42. stream = p.open(format=pyaudio.paInt16, channels=1, rate=sampling_rate, output=True)
  43. def play_audio(audio_array):
  44. """TODO: Add docstring."""
  45. if np.issubdtype(audio_array.dtype, np.floating):
  46. max_val = np.max(np.abs(audio_array))
  47. audio_array = (audio_array / max_val) * 32767
  48. audio_array = audio_array.astype(np.int16)
  49. stream.write(audio_array.tobytes())
  50. class InterruptStoppingCriteria(StoppingCriteria):
  51. """TODO: Add docstring."""
  52. def __init__(self):
  53. """TODO: Add docstring."""
  54. super().__init__()
  55. self.stop_signal = False
  56. def __call__(
  57. self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs,
  58. ) -> bool:
  59. """TODO: Add docstring."""
  60. return self.stop_signal
  61. def stop(self):
  62. """TODO: Add docstring."""
  63. self.stop_signal = True
  64. def generate_base(
  65. node,
  66. text=default_text,
  67. description=default_description,
  68. play_steps_in_s=0.5,
  69. ):
  70. """TODO: Add docstring."""
  71. prev_time = time.time()
  72. play_steps = int(frame_rate * play_steps_in_s)
  73. inputs = tokenizer(description, return_tensors="pt").to(device)
  74. prompt = tokenizer(text, return_tensors="pt").to(device)
  75. streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
  76. stopping_criteria = InterruptStoppingCriteria()
  77. generation_kwargs = dict(
  78. input_ids=inputs.input_ids,
  79. prompt_input_ids=prompt.input_ids,
  80. streamer=streamer,
  81. do_sample=True,
  82. temperature=1.0,
  83. min_new_tokens=10,
  84. stopping_criteria=StoppingCriteriaList([stopping_criteria]),
  85. )
  86. set_seed(SEED)
  87. thread = Thread(target=model.generate, kwargs=generation_kwargs)
  88. thread.start()
  89. for new_audio in streamer:
  90. current_time = time.time()
  91. print(f"Time between iterations: {round(current_time - prev_time, 2)} seconds")
  92. prev_time = current_time
  93. play_audio(new_audio)
  94. if node is None:
  95. continue
  96. event = node.next(timeout=0.01)
  97. if event["type"] == "ERROR":
  98. pass
  99. elif event["type"] == "INPUT":
  100. if event["id"] == "stop":
  101. stopping_criteria.stop()
  102. break
  103. if event["id"] == "text":
  104. stopping_criteria.stop()
  105. text = event["value"][0].as_py()
  106. generate_base(node, text, default_description, 0.5)
  107. def main():
  108. """TODO: Add docstring."""
  109. generate_base(None, "Ready !", default_description, 0.5)
  110. node = Node()
  111. while True:
  112. event = node.next()
  113. if event is None:
  114. break
  115. if event["type"] == "INPUT" and event["id"] == "text":
  116. text = event["value"][0].as_py()
  117. generate_base(node, text, default_description, 0.5)
  118. stream.stop_stream()
  119. stream.close()
  120. p.terminate()
  121. if __name__ == "__main__":
  122. main()