You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 3.8 kB

10 months ago
1 year ago
10 months ago
10 months ago
10 months ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """TODO: Add docstring."""
  2. import argparse # Add argparse import
  3. import os
  4. import pathlib
  5. import outetts
  6. import pyarrow as pa
  7. import torch
  8. from dora import Node
  9. PATH_SPEAKER = os.getenv("PATH_SPEAKER", "speaker.json")
  10. device = "mps" if torch.backends.mps.is_available() else "cpu"
  11. device = "cuda:0" if torch.cuda.is_available() else device
  12. torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
  13. def load_interface():
  14. """TODO: Add docstring."""
  15. if os.getenv("INTERFACE", "HF") == "HF":
  16. model_config = outetts.HFModelConfig_v1(
  17. model_path="OuteAI/OuteTTS-0.2-500M",
  18. language="en",
  19. device=device,
  20. )
  21. interface = outetts.InterfaceHF(model_version="0.2", cfg=model_config)
  22. else:
  23. model_config = outetts.GGUFModelConfig_v1(
  24. model_path=os.getenv(
  25. "GGUF_MODEL_PATH",
  26. "~/.cache/huggingface/hub/models--OuteAI--OuteTTS-0.2-500M-GGUF/snapshots/e6d78720d2a8edce2bc8f5c5c2d0332e57091930/OuteTTS-0.2-500M-Q4_0.gguf",
  27. ),
  28. language="en", # Supported languages in v0.2: en, zh, ja, ko
  29. n_gpu_layers=0,
  30. )
  31. interface = outetts.InterfaceGGUF(model_version="0.2", cfg=model_config)
  32. return interface
  33. def create_speaker(interface, path):
  34. """TODO: Add docstring."""
  35. speaker = interface.create_speaker(
  36. audio_path=path,
  37. # If transcript is not provided, it will be automatically transcribed using Whisper
  38. transcript=None, # Set to None to use Whisper for transcription
  39. whisper_model="turbo", # Optional: specify Whisper model (default: "turbo")
  40. whisper_device=None, # Optional: specify device for Whisper (default: None)
  41. )
  42. interface.save_speaker(speaker, "speaker.json")
  43. print("saved speaker.json")
  44. def main(arg_list: list[str] | None = None):
  45. # Parse cli args
  46. """TODO: Add docstring."""
  47. parser = argparse.ArgumentParser(description="Dora Outetts Node")
  48. parser.add_argument("--create-speaker", type=str, help="Path to audio file")
  49. parser.add_argument("--test", action="store_true", help="Run tests")
  50. args = parser.parse_args(arg_list)
  51. if args.test:
  52. import pytest
  53. path = pathlib.Path(__file__).parent.resolve()
  54. pytest.main(["-x", path / "tests"])
  55. return
  56. interface = load_interface()
  57. if args.create_speaker:
  58. create_speaker(interface, args.create_speaker)
  59. return
  60. if os.path.exists(PATH_SPEAKER):
  61. print(f"Loading speaker from {PATH_SPEAKER}")
  62. # speaker = interface.load_speaker(PATH_SPEAKER)
  63. speaker = interface.load_default_speaker(name="male_1")
  64. else:
  65. # Load default speaker
  66. speaker = interface.load_default_speaker(name="male_1")
  67. node = Node()
  68. for event in node:
  69. if event["type"] == "INPUT":
  70. if event["id"] == "TICK":
  71. print(
  72. f"""Node received:
  73. id: {event["id"]},
  74. value: {event["value"]},
  75. metadata: {event["metadata"]}""",
  76. )
  77. elif event["id"] == "text":
  78. # Warning: Make sure to add my_output_id and my_input_id within the dataflow.
  79. text = event["value"][0].as_py()
  80. output = interface.generate(
  81. text=text,
  82. temperature=0.1,
  83. repetition_penalty=1.1,
  84. speaker=speaker, # Optional: speaker profile
  85. )
  86. node.send_output(
  87. "audio",
  88. pa.array(output.audio.cpu().numpy().ravel()),
  89. {"language": "en", "sample_rate": output.sr},
  90. )
  91. if __name__ == "__main__":
  92. main()