From 689595604e5bd0950bf51cb0e9908c501fe38d6d Mon Sep 17 00:00:00 2001 From: haixuantao Date: Wed, 23 Jul 2025 15:46:22 +0200 Subject: [PATCH 1/8] make qwen model configurable --- node-hub/dora-vad/dora_vad/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/node-hub/dora-vad/dora_vad/main.py b/node-hub/dora-vad/dora_vad/main.py index 9a674ddb..0ee08b08 100644 --- a/node-hub/dora-vad/dora_vad/main.py +++ b/node-hub/dora-vad/dora_vad/main.py @@ -36,6 +36,7 @@ def main(): threshold=THRESHOLD, min_speech_duration_ms=MIN_SPEECH_DURATION_MS, min_silence_duration_ms=MIN_SILENCE_DURATION_MS, + sampling_rate=sr, ) # Check ig there is timestamp @@ -48,8 +49,8 @@ def main(): node.send_output( "timestamp_start", pa.array([speech_timestamps[-1]["start"]]), + metadata={"sample_rate": sr}, ) - continue audio = audio[0 : speech_timestamps[-1]["end"]] node.send_output("audio", pa.array(audio), metadata={"sample_rate": sr}) last_audios = [audio[speech_timestamps[-1]["end"] :]] From 774ea1b1816f92aa90e3bc898579adfa74c5c5c4 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Wed, 23 Jul 2025 16:37:32 +0200 Subject: [PATCH 2/8] Default to llama cpp --- node-hub/dora-qwen/dora_qwen/main.py | 43 +++++++++++----------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/node-hub/dora-qwen/dora_qwen/main.py b/node-hub/dora-qwen/dora_qwen/main.py index 957abf42..c956105d 100644 --- a/node-hub/dora-qwen/dora_qwen/main.py +++ b/node-hub/dora-qwen/dora_qwen/main.py @@ -1,7 +1,6 @@ """TODO: Add docstring.""" import os -import sys import pyarrow as pa from dora import Node @@ -12,14 +11,24 @@ SYSTEM_PROMPT = os.getenv( "You're a very succinct AI assistant with short answers.", ) +MODEL_NAME_OR_PATH = os.getenv("MODEL_NAME_OR_PATH", "Qwen/Qwen2.5-0.5B-Instruct-GGUF") +MODEL_FILE_PATTERN = os.getenv("MODEL_FILE_PATTERN", "*fp16.gguf") +MAX_TOKENS = int(os.getenv("MAX_TOKENS", "512")) +N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) +N_THREADS = int(os.getenv("N_THREADS", "4")) +CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "4096")) + def get_model_gguf(): """TODO: Add docstring.""" from llama_cpp import Llama return Llama.from_pretrained( - repo_id="Qwen/Qwen2.5-0.5B-Instruct-GGUF", - filename="*fp16.gguf", + repo_id=MODEL_NAME_OR_PATH, + filename=MODEL_FILE_PATTERN, + n_gpu_layers=N_GPU_LAYERS, + n_ctx=CONTEXT_SIZE, + n_threads=N_THREADS, verbose=False, ) @@ -71,12 +80,7 @@ def main(): """TODO: Add docstring.""" history = [] # If OS is not Darwin, use Huggingface model - if sys.platform == "darwin": - model = get_model_gguf() - elif sys.platform == "linux": - model, tokenizer = get_model_huggingface() - else: - model, tokenizer = get_model_darwin() + model = get_model_gguf() node = Node() @@ -90,23 +94,10 @@ def main(): word in ACTIVATION_WORDS for word in words ): # On linux, Windows - if sys.platform == "darwin": - response = model.create_chat_completion( - messages=[{"role": "user", "content": text}], # Prompt - max_tokens=24, - )["choices"][0]["message"]["content"] - elif sys.platform == "linux": - response, history = generate_hf(model, tokenizer, text, history) - else: - from mlx_lm import generate - - response = generate( - model, - tokenizer, - prompt=text, - verbose=False, - max_tokens=50, - ) + response = model.create_chat_completion( + messages=[{"role": "user", "content": text}], # Prompt + max_tokens=24, + )["choices"][0]["message"]["content"] node.send_output( output_id="text", From 709e8fec0a9a4d0b2bd7b1fccd2899196fcc3838 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Tue, 29 Jul 2025 15:32:39 +0800 Subject: [PATCH 3/8] Make whisper better by making it output punctuation --- .../dora_distil_whisper/main.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py index 007b3b43..06fea704 100644 --- a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py +++ b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py @@ -6,6 +6,7 @@ import sys import time from pathlib import Path +import numpy as np import pyarrow as pa import torch from dora import Node @@ -125,6 +126,8 @@ BAD_SENTENCES = [ "", " so", " so so", + "You", + "You ", "字幕", "字幕志愿", "中文字幕", @@ -181,13 +184,14 @@ def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): def main(): """TODO: Add docstring.""" - node = Node() text_noise = "" - noise_timestamp = time.time() # For macos use mlx: if sys.platform != "darwin": pipe = load_model() + node = Node() + noise_timestamp = time.time() + cache_audio = None for event in node: if event["type"] == "INPUT": if "text_noise" in event["id"]: @@ -200,7 +204,12 @@ def main(): ) noise_timestamp = time.time() else: - audio = event["value"].to_numpy() + audio_input = event["value"].to_numpy() + if cache_audio is not None: + audio = np.concatenate([cache_audio, audio_input]) + else: + audio = audio_input + confg = ( {"language": TARGET_LANGUAGE, "task": "translate"} if TRANSLATE @@ -215,6 +224,7 @@ def main(): audio, path_or_hf_repo="mlx-community/whisper-large-v3-turbo", append_punctuations=".", + language=TARGET_LANGUAGE, ) else: @@ -235,6 +245,22 @@ def main(): if text.strip() == "" or text.strip() == ".": continue - node.send_output( - "text", pa.array([text]), {"language": TARGET_LANGUAGE}, - ) + + if ( + ( + text.endswith(".") + or text.endswith("!") + or text.endswith("?") + or text.endswith('."') + or text.endswith('!"') + or text.endswith('?"') + ) + and not text.endswith("...") # Avoid ending with ellipsis + ): + node.send_output( + "text", + pa.array([text]), + ) + cache_audio = None + else: + cache_audio = audio From c73e0ad6d86618003ed4d80967639004a386d26a Mon Sep 17 00:00:00 2001 From: haixuantao Date: Wed, 30 Jul 2025 15:21:30 +0800 Subject: [PATCH 4/8] Making command public --- binaries/cli/src/command/mod.rs | 32 +++++++++++++-------------- binaries/cli/src/command/start/mod.rs | 16 +++++++------- binaries/cli/src/lib.rs | 2 +- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/binaries/cli/src/command/mod.rs b/binaries/cli/src/command/mod.rs index 617f9b05..e0d776a4 100644 --- a/binaries/cli/src/command/mod.rs +++ b/binaries/cli/src/command/mod.rs @@ -16,22 +16,22 @@ mod up; pub use run::run_func; -use build::Build; -use check::Check; -use coordinator::Coordinator; -use daemon::Daemon; -use destroy::Destroy; -use eyre::Context; -use graph::Graph; -use list::ListArgs; -use logs::LogsArgs; -use new::NewArgs; -use run::Run; -use runtime::Runtime; -use self_::SelfSubCommand; -use start::Start; -use stop::Stop; -use up::Up; +pub use build::Build; +pub use check::Check; +pub use coordinator::Coordinator; +pub use daemon::Daemon; +pub use destroy::Destroy; +pub use eyre::Context; +pub use graph::Graph; +pub use list::ListArgs; +pub use logs::LogsArgs; +pub use new::NewArgs; +pub use run::Run; +pub use runtime::Runtime; +pub use self_::SelfSubCommand; +pub use start::Start; +pub use stop::Stop; +pub use up::Up; /// dora-rs cli client #[derive(Debug, clap::Subcommand)] diff --git a/binaries/cli/src/command/start/mod.rs b/binaries/cli/src/command/start/mod.rs index 077a67b4..db2933e7 100644 --- a/binaries/cli/src/command/start/mod.rs +++ b/binaries/cli/src/command/start/mod.rs @@ -31,28 +31,28 @@ mod attach; pub struct Start { /// Path to the dataflow descriptor file #[clap(value_name = "PATH")] - dataflow: String, + pub dataflow: String, /// Assign a name to the dataflow #[clap(long)] - name: Option, + pub name: Option, /// Address of the dora coordinator #[clap(long, value_name = "IP", default_value_t = LOCALHOST)] - coordinator_addr: IpAddr, + pub coordinator_addr: IpAddr, /// Port number of the coordinator control server #[clap(long, value_name = "PORT", default_value_t = DORA_COORDINATOR_PORT_CONTROL_DEFAULT)] - coordinator_port: u16, + pub coordinator_port: u16, /// Attach to the dataflow and wait for its completion #[clap(long, action)] - attach: bool, + pub attach: bool, /// Run the dataflow in background #[clap(long, action)] - detach: bool, + pub detach: bool, /// Enable hot reloading (Python only) #[clap(long, action)] - hot_reload: bool, + pub hot_reload: bool, // Use UV to run nodes. #[clap(long, action)] - uv: bool, + pub uv: bool, } impl Executable for Start { diff --git a/binaries/cli/src/lib.rs b/binaries/cli/src/lib.rs index 868d7a5a..c89f74fa 100644 --- a/binaries/cli/src/lib.rs +++ b/binaries/cli/src/lib.rs @@ -5,7 +5,7 @@ use std::{ path::PathBuf, }; -mod command; +pub mod command; mod common; mod formatting; pub mod output; From 733c86ae39993a97a0933faf50aeec40753f7cfa Mon Sep 17 00:00:00 2001 From: haixuantao Date: Wed, 30 Jul 2025 15:26:12 +0800 Subject: [PATCH 5/8] Add dora-openai-websocket --- Cargo.lock | 466 ++++++++++++++++++-- Cargo.toml | 1 + node-hub/dora-openai-websocket/Cargo.toml | 29 ++ node-hub/dora-openai-websocket/src/main.rs | 474 +++++++++++++++++++++ 4 files changed, 936 insertions(+), 34 deletions(-) create mode 100644 node-hub/dora-openai-websocket/Cargo.toml create mode 100644 node-hub/dora-openai-websocket/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 3d46e6c7..533a3e3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -264,6 +264,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "ansi_colours" version = "1.2.3" @@ -666,7 +672,7 @@ dependencies = [ "enumflags2", "futures-channel", "futures-util", - "rand 0.9.1", + "rand 0.9.2", "raw-window-handle 0.6.2", "serde", "serde_repr", @@ -713,6 +719,29 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "assert2" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6c710e60d14b07d8f42d0e702b16120865eea39edb751e75cd6bf401d18f14" +dependencies = [ + "assert2-macros", + "diff", + "yansi", +] + +[[package]] +name = "assert2-macros" +version = "0.3.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9008cbbba9e1d655538870b91fd93814bd82e6968f27788fc734375120ac6f57" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "syn 2.0.101", +] + [[package]] name = "assert_matches" version = "1.5.0" @@ -1179,14 +1208,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", "bytes", "futures-util", "http 1.3.1", "http-body 1.0.1", "http-body-util", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -1199,6 +1228,40 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core 0.5.2", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.4.5" @@ -1219,6 +1282,26 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "az" version = "1.2.1" @@ -1240,6 +1323,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a956d500c2380c818e09d3d7c79ba4a1d7fc6354464f1fceaa5705483a29930" + [[package]] name = "base64" version = "0.13.1" @@ -1640,7 +1729,7 @@ dependencies = [ "metal 0.27.0", "num-traits", "num_cpus", - "rand 0.9.1", + "rand 0.9.2", "rand_distr", "rayon", "safetensors", @@ -1732,6 +1821,12 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.17" @@ -1867,6 +1962,33 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -2303,6 +2425,42 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "criterion" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb" +dependencies = [ + "anes", + "atty", + "cast", + "ciborium", + "clap 3.2.25", + "criterion-plot", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -2809,6 +2967,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c" +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -3266,6 +3430,33 @@ dependencies = [ "uuid 1.16.0", ] +[[package]] +name = "dora-openai-websocket" +version = "0.1.0" +dependencies = [ + "anyhow", + "assert2", + "axum 0.8.4", + "base", + "base64 0.22.1", + "bytes", + "criterion", + "dora-cli", + "dora-node-api", + "fastwebsockets", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "rand 0.9.2", + "rustls-pemfile 1.0.4", + "serde", + "serde_json", + "tokio", + "tokio-rustls 0.24.1", + "trybuild", + "webpki-roots 0.23.1", +] + [[package]] name = "dora-operator-api" version = "0.3.12" @@ -3362,7 +3553,7 @@ dependencies = [ "ndarray 0.15.6", "pinyin", "pyo3", - "rand 0.9.1", + "rand 0.9.2", "rerun", "tokio", ] @@ -4161,6 +4352,26 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fastwebsockets" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "305d3ba574508e27190906d11707dad683e0494e6b85eae9b044cb2734a5e422" +dependencies = [ + "base64 0.21.7", + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "pin-project", + "rand 0.8.5", + "sha1", + "simdutf8", + "thiserror 1.0.69", + "tokio", + "utf-8", +] + [[package]] name = "fdeflate" version = "0.3.7" @@ -4272,7 +4483,7 @@ dependencies = [ "cudarc", "half", "num-traits", - "rand 0.9.1", + "rand 0.9.2", "rand_distr", ] @@ -5052,7 +5263,7 @@ dependencies = [ "cfg-if 1.0.0", "crunchy", "num-traits", - "rand 0.9.1", + "rand 0.9.2", "rand_distr", ] @@ -5392,7 +5603,7 @@ dependencies = [ "rustls 0.23.25", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tower-service", "webpki-roots 0.26.8", ] @@ -6575,6 +6786,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -6858,7 +7075,7 @@ dependencies = [ "image", "indexmap 2.8.0", "mistralrs-core", - "rand 0.9.1", + "rand 0.9.2", "reqwest", "serde", "serde_json", @@ -6910,7 +7127,7 @@ dependencies = [ "objc", "once_cell", "radix_trie", - "rand 0.9.1", + "rand 0.9.2", "rand_isaac", "rayon", "regex", @@ -6931,7 +7148,7 @@ dependencies = [ "tokio", "tokio-rayon", "toktrie_hf_tokenizers", - "toml", + "toml 0.8.20", "tqdm", "tracing", "tracing-subscriber", @@ -7893,6 +8110,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openssl-probe" version = "0.1.6" @@ -8081,7 +8304,7 @@ dependencies = [ "glob", "opentelemetry 0.29.1", "percent-encoding", - "rand 0.9.1", + "rand 0.9.2", "serde_json", "thiserror 2.0.12", "tokio", @@ -8531,6 +8754,34 @@ dependencies = [ "time", ] +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "ply-rs" version = "0.1.3" @@ -9161,7 +9412,7 @@ checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc" dependencies = [ "bytes", "getrandom 0.3.2", - "rand 0.9.1", + "rand 0.9.2", "ring 0.17.14", "rustc-hash 2.1.1", "rustls 0.23.25", @@ -9261,9 +9512,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -9314,7 +9565,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] @@ -9370,7 +9621,7 @@ dependencies = [ "simd_helpers", "system-deps", "thiserror 1.0.69", - "toml", + "toml 0.8.20", "v_frame", "y4m", ] @@ -10508,7 +10759,7 @@ dependencies = [ "serde", "syn 2.0.101", "tempfile", - "toml", + "toml 0.8.20", "unindent", "xshell", ] @@ -11260,7 +11511,7 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tokio-util", "tower 0.5.2", "tower-service", @@ -11660,7 +11911,7 @@ dependencies = [ "num-derive", "num-traits", "paste", - "rand 0.9.1", + "rand 0.9.2", "serde", "serde_repr", "socket2 0.5.8", @@ -11731,6 +11982,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring 0.17.14", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.25" @@ -11824,6 +12087,26 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" +[[package]] +name = "rustls-webpki" +version = "0.100.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6a5fc258f1c1276dfe3016516945546e2d5383911efc0fc4f1cdc5df3a4ae3" +dependencies = [ + "ring 0.16.20", + "untrusted 0.7.1", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring 0.17.14", + "untrusted 0.9.0", +] + [[package]] name = "rustls-webpki" version = "0.102.8" @@ -12228,9 +12511,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "indexmap 2.8.0", "itoa", @@ -12239,6 +12522,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -12268,6 +12561,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40734c41988f7306bb04f0ecf60ec0f3f1caa34290e4e8ea471dcd3346483b83" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -13221,7 +13523,7 @@ dependencies = [ "cfg-expr", "heck 0.5.0", "pkg-config", - "toml", + "toml 0.8.20", "version-compare", ] @@ -13257,6 +13559,12 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "target-triple" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ac9aa371f599d22256307c24a9d748c041e548cbf599f35d890f9d365361790" + [[package]] name = "tempfile" version = "3.19.1" @@ -13515,6 +13823,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.9.0" @@ -13540,7 +13858,7 @@ dependencies = [ "pin-project-lite", "thiserror 2.0.12", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", ] [[package]] @@ -13638,6 +13956,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.2" @@ -13720,11 +14048,26 @@ checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" dependencies = [ "indexmap 2.8.0", "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.8", + "toml_datetime 0.6.8", "toml_edit", ] +[[package]] +name = "toml" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41ae868b5a0f67631c14589f7e250c1ea2c574ee5ba21c6c8dd4b1485705a5a1" +dependencies = [ + "indexmap 2.8.0", + "serde", + "serde_spanned 1.0.0", + "toml_datetime 0.7.0", + "toml_parser", + "toml_writer", + "winnow", +] + [[package]] name = "toml_datetime" version = "0.6.8" @@ -13734,6 +14077,15 @@ dependencies = [ "serde", ] +[[package]] +name = "toml_datetime" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bade1c3e902f58d73d3f294cd7f20391c1cb2fbcb643b73566bc773971df91e3" +dependencies = [ + "serde", +] + [[package]] name = "toml_edit" version = "0.22.24" @@ -13742,11 +14094,26 @@ checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ "indexmap 2.8.0", "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.8", + "toml_datetime 0.6.8", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97200572db069e74c512a14117b296ba0a80a30123fbbb5aa1f4a348f639ca30" +dependencies = [ "winnow", ] +[[package]] +name = "toml_writer" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64" + [[package]] name = "tonic" version = "0.12.3" @@ -13755,7 +14122,7 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.7.9", "base64 0.22.1", "bytes", "h2 0.4.8", @@ -13772,7 +14139,7 @@ dependencies = [ "rustls-pemfile 2.2.0", "socket2 0.5.8", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tokio-stream", "tower 0.4.13", "tower-layer", @@ -13858,6 +14225,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -14005,6 +14373,21 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "trybuild" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65af40ad689f2527aebbd37a0a816aea88ff5f774ceabe99de5be02f2f91dae2" +dependencies = [ + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml 0.9.4", +] + [[package]] name = "ttf-parser" version = "0.25.1" @@ -14352,7 +14735,7 @@ checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" dependencies = [ "getrandom 0.3.2", "js-sys", - "rand 0.9.1", + "rand 0.9.2", "serde", "uuid-macro-internal", "wasm-bindgen", @@ -14810,6 +15193,15 @@ dependencies = [ "webpki", ] +[[package]] +name = "webpki-roots" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" +dependencies = [ + "rustls-webpki 0.100.3", +] + [[package]] name = "webpki-roots" version = "0.26.8" @@ -15676,9 +16068,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.4" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] @@ -15888,6 +16280,12 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yoke" version = "0.7.5" @@ -16549,7 +16947,7 @@ dependencies = [ "time", "tls-listener", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tokio-util", "tracing", "webpki-roots 0.26.8", diff --git a/Cargo.toml b/Cargo.toml index 9d705ba2..aebb0158 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ members = [ "node-hub/dora-rerun", "node-hub/terminal-print", "node-hub/openai-proxy-server", + "node-hub/dora-openai-websocket", "node-hub/dora-kit-car", "node-hub/dora-object-to-pose", "node-hub/dora-mistral-rs", diff --git a/node-hub/dora-openai-websocket/Cargo.toml b/node-hub/dora-openai-websocket/Cargo.toml new file mode 100644 index 00000000..ae85552e --- /dev/null +++ b/node-hub/dora-openai-websocket/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "dora-openai-websocket" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +dora-node-api = { workspace = true } +dora-cli = { workspace = true } +tokio = { version = "1.25.0", features = ["full", "macros"] } +tokio-rustls = "0.24.0" +rustls-pemfile = "1.0" +hyper-util = { version = "0.1.0", features = ["tokio"] } +http-body-util = { version = "0.1.0" } +hyper = { version = "1", features = ["http1", "server", "client"] } +assert2 = "0.3.4" +trybuild = "1.0.106" +criterion = "0.4.0" +anyhow = "1.0.71" +webpki-roots = "0.23.0" +bytes = "1.4.0" +axum = "0.8.1" +fastwebsockets = { version = "0.10.0", features = ["upgrade"] } +serde_json = "1.0.141" +serde = "1.0.219" +base = "0.1.0" +base64 = "0.22.1" +rand = "0.9.2" diff --git a/node-hub/dora-openai-websocket/src/main.rs b/node-hub/dora-openai-websocket/src/main.rs new file mode 100644 index 00000000..4eb473c6 --- /dev/null +++ b/node-hub/dora-openai-websocket/src/main.rs @@ -0,0 +1,474 @@ +// Copyright 2023 Divy Srivastava +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use base64::engine::general_purpose; +use base64::Engine; +use dora_cli::command::Executable; +use dora_cli::command::Start; +use dora_node_api::arrow::array::AsArray; +use dora_node_api::arrow::datatypes::DataType; +use dora_node_api::dora_core::config::DataId; +use dora_node_api::dora_core::config::NodeId; +use dora_node_api::dora_core::topics::DORA_COORDINATOR_PORT_CONTROL_DEFAULT; +use dora_node_api::into_vec; +use dora_node_api::DoraNode; +use dora_node_api::IntoArrow; +use dora_node_api::MetadataParameters; +use fastwebsockets::upgrade; +use fastwebsockets::Frame; +use fastwebsockets::OpCode; +use fastwebsockets::Payload; +use fastwebsockets::WebSocketError; +use http_body_util::Empty; +use hyper::body::Bytes; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::Request; +use hyper::Response; +use rand::random; +use serde; +use serde::Deserialize; +use serde::Serialize; +use std::fs; +use std::io::{self, Write}; +use std::net::IpAddr; +use std::net::Ipv4Addr; +use std::time::Duration; +use tokio::net::TcpListener; + +#[derive(Serialize, Deserialize, Debug)] +pub struct ErrorDetails { + pub code: Option, + pub message: String, + pub param: Option, + #[serde(rename = "type")] + pub error_type: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum OpenAIRealtimeMessage { + #[serde(rename = "session.update")] + SessionUpdate { session: SessionConfig }, + #[serde(rename = "input_audio_buffer.append")] + InputAudioBufferAppend { + audio: String, // base64 encoded audio + }, + #[serde(rename = "input_audio_buffer.commit")] + InputAudioBufferCommit, + #[serde(rename = "response.create")] + ResponseCreate { response: ResponseConfig }, + #[serde(rename = "conversation.item.create")] + ConversationItemCreate { item: ConversationItem }, + #[serde(rename = "conversation.item.truncate")] + ConversationItemTruncate { + item_id: String, + content_index: u32, + audio_end_ms: u32, + #[serde(skip_serializing_if = "Option::is_none")] + event_id: Option, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SessionConfig { + pub modalities: Vec, + pub instructions: String, + pub voice: String, + pub input_audio_format: String, + pub output_audio_format: String, + pub input_audio_transcription: Option, + pub turn_detection: Option, + pub tools: Vec, + pub tool_choice: String, + pub temperature: f32, + pub max_response_output_tokens: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct TranscriptionConfig { + pub model: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct TurnDetectionConfig { + #[serde(rename = "type")] + pub detection_type: String, + pub threshold: f32, + pub prefix_padding_ms: u32, + pub silence_duration_ms: u32, + pub interrupt_response: bool, + pub create_response: bool, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ResponseConfig { + pub modalities: Vec, + pub instructions: Option, + pub voice: Option, + pub output_audio_format: Option, + pub tools: Option>, + pub tool_choice: Option, + pub temperature: Option, + pub max_output_tokens: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ConversationItem { + pub id: Option, + #[serde(rename = "type")] + pub item_type: String, + pub status: Option, + pub role: String, + pub content: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "input_text")] + InputText { text: String }, + #[serde(rename = "input_audio")] + InputAudio { + audio: String, + transcript: Option, + }, + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "audio")] + Audio { + audio: String, + transcript: Option, + }, +} + +// Incoming message types from OpenAI +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum OpenAIRealtimeResponse { + #[serde(rename = "error")] + Error { error: ErrorDetails }, + #[serde(rename = "session.created")] + SessionCreated { session: serde_json::Value }, + #[serde(rename = "session.updated")] + SessionUpdated { session: serde_json::Value }, + #[serde(rename = "conversation.item.created")] + ConversationItemCreated { item: serde_json::Value }, + #[serde(rename = "conversation.item.truncated")] + ConversationItemTruncated { item: serde_json::Value }, + #[serde(rename = "response.audio.delta")] + ResponseAudioDelta { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + delta: String, // base64 encoded audio + }, + #[serde(rename = "response.audio.done")] + ResponseAudioDone { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + }, + #[serde(rename = "response.text.delta")] + ResponseTextDelta { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + delta: String, + }, + #[serde(rename = "response.audio_transcript.delta")] + ResponseAudioTranscriptDelta { + response_id: String, + item_id: String, + output_index: u32, + content_index: u32, + delta: String, + }, + #[serde(rename = "response.done")] + ResponseDone { response: serde_json::Value }, + #[serde(rename = "input_audio_buffer.speech_started")] + InputAudioBufferSpeechStarted { + audio_start_ms: u32, + item_id: String, + }, + #[serde(rename = "input_audio_buffer.speech_stopped")] + InputAudioBufferSpeechStopped { audio_end_ms: u32, item_id: String }, + #[serde(other)] + Other, +} + +fn convert_pcm16_to_f32(bytes: &[u8]) -> Vec { + let mut samples = Vec::with_capacity(bytes.len() / 2); + + for chunk in bytes.chunks_exact(2) { + let pcm16_sample = i16::from_le_bytes([chunk[0], chunk[1]]); + let f32_sample = pcm16_sample as f32 / 32767.0; + samples.push(f32_sample); + } + + samples +} + +fn convert_f32_to_pcm16(samples: &[f32]) -> Vec { + let mut pcm16_bytes = Vec::with_capacity(samples.len() * 2); + + for &sample in samples { + // Clamp to [-1.0, 1.0] and convert to i16 + let clamped = sample.max(-1.0).min(1.0); + let pcm16_sample = (clamped * 32767.0) as i16; + pcm16_bytes.extend_from_slice(&pcm16_sample.to_le_bytes()); + } + + pcm16_bytes +} + +/// Replaces a placeholder in a file and writes the result to an output file. +/// +/// # Arguments +/// +/// * `input_path` - Path to the input file with placeholder text. +/// * `placeholder` - The placeholder text to search for (e.g., "{{PLACEHOLDER}}"). +/// * `replacement` - The text to replace the placeholder with. +/// * `output_path` - Path to write the modified content. +fn replace_placeholder_in_file( + input_path: &str, + placeholder: &str, + replacement: &str, + output_path: &str, +) -> io::Result<()> { + // Read the file content into a string + let content = fs::read_to_string(input_path)?; + + // Replace the placeholder + let modified_content = content.replace(placeholder, replacement); + + // Write the modified content to the output file + let mut file = fs::File::create(output_path)?; + file.write_all(modified_content.as_bytes())?; + + Ok(()) +} + +async fn handle_client(fut: upgrade::UpgradeFut) -> Result<(), WebSocketError> { + let mut ws = fastwebsockets::FragmentCollector::new(fut.await?); + + let frame = ws.read_frame().await?; + if frame.opcode != OpCode::Text { + return Err(WebSocketError::InvalidConnectionHeader); + } + let data: OpenAIRealtimeMessage = serde_json::from_slice(&frame.payload).unwrap(); + let OpenAIRealtimeMessage::SessionUpdate { session } = data else { + return Err(WebSocketError::InvalidConnectionHeader); + }; + + let input_audio_transcription = session + .input_audio_transcription + .map_or("moyoyo-whisper".to_string(), |t| t.model); + let id = random::(); + let node_id = format!("server-{id}"); + let dataflow = format!("{input_audio_transcription}-{}.yml", id); + let template = format!("{input_audio_transcription}-template-metal.yml"); + println!("Filling template: {}", template); + replace_placeholder_in_file(&template, "NODE_ID", &node_id, &dataflow).unwrap(); + // Copy configuration file but replace the node ID with "server-id" + // Read the configuration file and replace the node ID with "server-id" + dora_cli::command::Command::Start(Start { + dataflow, + name: Some(node_id.to_string()), + coordinator_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + coordinator_port: DORA_COORDINATOR_PORT_CONTROL_DEFAULT, + attach: false, + detach: true, + hot_reload: false, + uv: true, + }) + .execute() + .unwrap(); + let (mut node, mut events) = + DoraNode::init_from_node_id(NodeId::from(node_id.clone())).unwrap(); + let serialized_data = OpenAIRealtimeResponse::SessionCreated { + session: serde_json::Value::Null, + }; + + let payload = + Payload::Bytes(Bytes::from(serde_json::to_string(&serialized_data).unwrap()).into()); + let frame = Frame::text(payload); + ws.write_frame(frame).await?; + loop { + let mut frame = ws.read_frame().await?; + let mut finished = false; + match frame.opcode { + OpCode::Close => break, + OpCode::Text | OpCode::Binary => { + let data: OpenAIRealtimeMessage = serde_json::from_slice(&frame.payload).unwrap(); + + match data { + OpenAIRealtimeMessage::InputAudioBufferAppend { audio } => { + // println!("Received audio data: {}", audio); + let f32_data = audio; + // Decode base64 encoded audio data + let f32_data = f32_data.trim(); + if f32_data.is_empty() { + continue; + } + + if let Ok(f32_data) = general_purpose::STANDARD.decode(f32_data) { + let f32_data = convert_pcm16_to_f32(&f32_data); + // Downsample to 16 kHz from 24 kHz + let f32_data = f32_data + .into_iter() + .enumerate() + .filter(|(i, _)| i % 3 != 0) + .map(|(_, v)| v) + .collect::>(); + let mut parameter = MetadataParameters::default(); + parameter.insert( + "sample_rate".to_string(), + dora_node_api::Parameter::Integer(16000), + ); + node.send_output( + DataId::from("audio".to_string()), + parameter, + f32_data.into_arrow(), + ) + .unwrap(); + let ev = events.recv_async_timeout(Duration::from_millis(10)).await; + + // println!("Received event: {:?}", ev); + let frame = match ev { + Some(dora_node_api::Event::Input { + id, + metadata: _, + data, + }) => { + if data.data_type() == &DataType::Utf8 { + let data = data.as_string::(); + let str = data.value(0); + let serialized_data = + OpenAIRealtimeResponse::ResponseAudioTranscriptDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: str.to_string(), + }; + + frame.payload = Payload::Bytes( + Bytes::from( + serde_json::to_string(&serialized_data).unwrap(), + ) + .into(), + ); + frame.opcode = OpCode::Text; + frame + } else if id.contains("audio") { + let data: Vec = into_vec(&data).unwrap(); + let data = convert_f32_to_pcm16(&data); + let serialized_data = + OpenAIRealtimeResponse::ResponseAudioDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: general_purpose::STANDARD.encode(data), + }; + finished = true; + + frame.payload = Payload::Bytes( + Bytes::from( + serde_json::to_string(&serialized_data).unwrap(), + ) + .into(), + ); + frame.opcode = OpCode::Text; + frame + } else { + unimplemented!() + } + } + Some(dora_node_api::Event::Error(_)) => { + // println!("Error in input: {}", s); + continue; + } + _ => break, + }; + ws.write_frame(frame).await?; + if finished { + let serialized_data = OpenAIRealtimeResponse::ResponseDone { + response: serde_json::Value::Null, + }; + + let payload = Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + ); + println!("Sending response done: {:?}", serialized_data); + let frame = Frame::text(payload); + ws.write_frame(frame).await?; + } + } + } + OpenAIRealtimeMessage::InputAudioBufferCommit => break, + _ => {} + } + } + _ => break, + } + } + + Ok(()) +} +async fn server_upgrade( + mut req: Request, +) -> Result>, WebSocketError> { + let (response, fut) = upgrade::upgrade(&mut req)?; + + tokio::task::spawn(async move { + if let Err(e) = tokio::task::unconstrained(handle_client(fut)).await { + eprintln!("Error in websocket connection: {}", e); + } + }); + + Ok(response) +} + +fn main() -> Result<(), WebSocketError> { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + rt.block_on(async move { + let listener = TcpListener::bind("127.0.0.1:8123").await?; + println!("Server started, listening on {}", "127.0.0.1:8123"); + loop { + let (stream, _) = listener.accept().await?; + println!("Client connected"); + tokio::spawn(async move { + let io = hyper_util::rt::TokioIo::new(stream); + let conn_fut = http1::Builder::new() + .serve_connection(io, service_fn(server_upgrade)) + .with_upgrades(); + if let Err(e) = conn_fut.await { + println!("An error occurred: {:?}", e); + } + }); + } + }) +} From 27dca42dabc31e1ce3454f7b5f3bb4ad785eaca9 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Wed, 30 Jul 2025 15:26:29 +0800 Subject: [PATCH 6/8] Add an example using whisper --- examples/openai-realtime/README.md | 78 +++++++++++++++++++ .../whisper-template-metal.yml | 49 ++++++++++++ .../dora_distil_whisper/main.py | 11 +++ 3 files changed, 138 insertions(+) create mode 100644 examples/openai-realtime/README.md create mode 100644 examples/openai-realtime/whisper-template-metal.yml diff --git a/examples/openai-realtime/README.md b/examples/openai-realtime/README.md new file mode 100644 index 00000000..3df5a733 --- /dev/null +++ b/examples/openai-realtime/README.md @@ -0,0 +1,78 @@ +# Dora-OpenAI-Realtime (ROOT Repo) + +## Front End + +### Build Client + +```bash +git clone https://github.com/haixuantao/makepad-realtime +cd makepad-realtime +cargo build --release +``` + +### Run Client + +```bash +cd makepad-realtime +OPENAI_API_KEY=1 cargo run -r +``` + +## Server + +### Build server + +```bash +uv venv --seed -p 3.11 +dora build whisper-template-metal.yml --uv ## very long process +``` + +### Run server + +```bash +source .venv/bin/activate +dora up +cargo run --release -p dora-openai-websocket +``` + +## On finish + +```bash +dora destroy +``` + +## GUI + +### Connect Client and Server ( English Version ) + +- Click on Connect to Whisper +- When the session is ready +- Click on Start Conversation +- Start talking +- When finished click on Stop Conversation + +### For chinese realtime server version + +- Click on Connect to FunASR ( not working superwell for now ) +- When the session is ready +- Click on Start Conversation +- Start talking +- When finished click on Stop Conversation + +## Video tutorial + +https://github.com/user-attachments/assets/594ae453-4983-410d-8ba7-a11778322cfa + +### WIP: Moyoyo + +## {Recommended} Install git-lfs + +```bash +brew install git-lfs # MacOS +``` + +## Clone Moxin Voice Chat + +```bash +git lfs install +git clone https://github.com/moxin-org/moxin-voice-chat.git +``` diff --git a/examples/openai-realtime/whisper-template-metal.yml b/examples/openai-realtime/whisper-template-metal.yml new file mode 100644 index 00000000..c689bf90 --- /dev/null +++ b/examples/openai-realtime/whisper-template-metal.yml @@ -0,0 +1,49 @@ +nodes: + - id: NODE_ID + path: dynamic + inputs: + audio: tts/audio + text: stt/text + outputs: + - audio + + - id: dora-vad + build: pip install -e ../../node-hub/dora-vad + path: dora-vad + inputs: + audio: NODE_ID/audio + outputs: + - audio + env: + MIN_SPEECH_DURATION_MS: 2000 + MIN_SILENCE_DURATION_MS: 1200 + THRESHOLD: 1.0 + + - id: stt + build: pip install -e ../../node-hub/dora-distil-whisper + path: dora-distil-whisper + inputs: + audio: + source: dora-vad/audio + queue_size: 1000000 + outputs: + - text + - word + + - id: llm + build: pip install -e ../../node-hub/dora-qwen + path: dora-qwen + inputs: + text: stt/text + outputs: + - text + env: + MODEL_NAME_OR_PATH: Qwen/Qwen2.5-0.5B-Instruct-GGUF + + - id: tts + build: pip install -e ../../node-hub/dora-kokoro-tts + path: dora-kokoro-tts + inputs: + text: llm/text + outputs: + - audio diff --git a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py index 06fea704..26fcc9dd 100644 --- a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py +++ b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py @@ -125,9 +125,11 @@ def load_model(): BAD_SENTENCES = [ "", " so", + " So.", " so so", "You", "You ", + " You", "字幕", "字幕志愿", "中文字幕", @@ -188,6 +190,15 @@ def main(): # For macos use mlx: if sys.platform != "darwin": pipe = load_model() + else: + import mlx_whisper + + result = mlx_whisper.transcribe( + np.array([]), + path_or_hf_repo="mlx-community/whisper-large-v3-turbo", + append_punctuations=".", + language=TARGET_LANGUAGE, + ) node = Node() noise_timestamp = time.time() From fa888e8b82175f5e2670caeb781fa07bd481fa1c Mon Sep 17 00:00:00 2001 From: haixuantao Date: Wed, 30 Jul 2025 16:01:51 +0800 Subject: [PATCH 7/8] Minor fix and improvements --- examples/openai-realtime/whisper-template-metal.yml | 8 ++++---- node-hub/dora-openai-websocket/Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/openai-realtime/whisper-template-metal.yml b/examples/openai-realtime/whisper-template-metal.yml index c689bf90..44b1aacd 100644 --- a/examples/openai-realtime/whisper-template-metal.yml +++ b/examples/openai-realtime/whisper-template-metal.yml @@ -11,7 +11,9 @@ nodes: build: pip install -e ../../node-hub/dora-vad path: dora-vad inputs: - audio: NODE_ID/audio + audio: + source: NODE_ID/audio + queue_size: 1000000 outputs: - audio env: @@ -23,9 +25,7 @@ nodes: build: pip install -e ../../node-hub/dora-distil-whisper path: dora-distil-whisper inputs: - audio: - source: dora-vad/audio - queue_size: 1000000 + audio: dora-vad/audio outputs: - text - word diff --git a/node-hub/dora-openai-websocket/Cargo.toml b/node-hub/dora-openai-websocket/Cargo.toml index ae85552e..7a05fb12 100644 --- a/node-hub/dora-openai-websocket/Cargo.toml +++ b/node-hub/dora-openai-websocket/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] dora-node-api = { workspace = true } -dora-cli = { workspace = true } +dora-cli = { workspace = true, default-features = false } tokio = { version = "1.25.0", features = ["full", "macros"] } tokio-rustls = "0.24.0" rustls-pemfile = "1.0" From 2f89cf811d758b443c6b88cc3635a78a13de4ee9 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Sat, 2 Aug 2025 00:00:46 +0800 Subject: [PATCH 8/8] Improvement of the overall language pipeline to be more resilient --- .../dora_distil_whisper/main.py | 85 +++--- .../dora-kokoro-tts/dora_kokoro_tts/main.py | 37 +-- node-hub/dora-openai-websocket/src/main.rs | 250 ++++++++++-------- node-hub/dora-vad/dora_vad/main.py | 16 +- 4 files changed, 218 insertions(+), 170 deletions(-) diff --git a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py index 26fcc9dd..3710bf0c 100644 --- a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py +++ b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py @@ -48,35 +48,6 @@ def remove_text_noise(text: str, text_noise="") -> str: text_words = normalized_text.split() noise_words = normalized_noise.split() - # Function to find and remove noise sequence flexibly - def remove_flexible(text_list, noise_list): - i = 0 - while i <= len(text_list) - len(noise_list): - match = True - extra_words = 0 - for j, noise_word in enumerate(noise_list): - if i + j + extra_words >= len(text_list): - match = False - break - # Allow skipping extra words in text_list - while ( - i + j + extra_words < len(text_list) - and text_list[i + j + extra_words] != noise_word - ): - extra_words += 1 - if i + j + extra_words >= len(text_list): - match = False - break - if not match: - break - if match: - # Remove matched part - del text_list[i : i + len(noise_list) + extra_words] - i = max(0, i - len(noise_list)) # Adjust index after removal - else: - i += 1 - return text_list - # Only remove parts of text_noise that are found in text cleaned_words = text_words[:] for noise_word in noise_words: @@ -126,7 +97,26 @@ BAD_SENTENCES = [ "", " so", " So.", + " So, let's go.", " so so", + " What?", + " We'll see you next time.", + " I'll see you next time.", + " We're going to come back.", + " let's move on.", + " Here we go.", + " my", + " All right. Thank you.", + " That's what we're doing.", + " That's what I wanted to do.", + " I'll be back.", + " Hold this. Hold this.", + " Hold this one. Hold this one.", + " And we'll see you next time.", + " strength.", + " Length.", + " Let's go.", + " Let's do it.", "You", "You ", " You", @@ -199,6 +189,12 @@ def main(): append_punctuations=".", language=TARGET_LANGUAGE, ) + result = mlx_whisper.transcribe( + np.array([]), + path_or_hf_repo="mlx-community/whisper-large-v3-turbo", + append_punctuations=".", + language=TARGET_LANGUAGE, + ) node = Node() noise_timestamp = time.time() @@ -244,6 +240,8 @@ def main(): generate_kwargs=confg, ) if result["text"] in BAD_SENTENCES: + print("Discarded text: ", result["text"]) + # cache_audio = None continue text = cut_repetition(result["text"]) @@ -258,20 +256,27 @@ def main(): continue if ( - ( - text.endswith(".") - or text.endswith("!") - or text.endswith("?") - or text.endswith('."') - or text.endswith('!"') - or text.endswith('?"') - ) - and not text.endswith("...") # Avoid ending with ellipsis - ): + text.endswith(".") + or text.endswith("!") + or text.endswith("?") + or text.endswith('."') + or text.endswith('!"') + or text.endswith('?"') + ) and not text.endswith("..."): node.send_output( "text", pa.array([text]), ) + node.send_output( + "stop", + pa.array([text]), + ) cache_audio = None - else: + audio = None + print("Text:", text) + elif text.endswith("..."): + print( + "Keeping audio in cache for next text output with punctuation" + ) + print("Discarded text", text) cache_audio = audio diff --git a/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py b/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py index 7762cfca..d1c28a4f 100644 --- a/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py +++ b/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py @@ -1,4 +1,5 @@ """TODO: Add docstring.""" + import os import re @@ -8,11 +9,12 @@ from kokoro import KPipeline LANGUAGE = os.getenv("LANGUAGE", "en") + def main(): """TODO: Add docstring.""" if LANGUAGE in ["en", "english"]: pipeline = KPipeline(lang_code="a") - elif LANGUAGE in ["zh","ch","chinese"]: + elif LANGUAGE in ["zh", "ch", "chinese"]: pipeline = KPipeline(lang_code="z") else: print("warning: Defaulting to english speaker as language not found") @@ -22,22 +24,23 @@ def main(): for event in node: if event["type"] == "INPUT": - if event["id"] == "text": - text = event["value"][0].as_py() - if re.findall(r'[\u4e00-\u9fff]+', text): - pipeline = KPipeline(lang_code="z") - elif pipeline.lang_code != "a": - pipeline = KPipeline(lang_code="a") # <= make sure lang_code matches voice - - generator = pipeline( - text, - voice="af_heart", # <= change voice here - speed=1.2, - split_pattern=r"\n+", - ) - for _, (_, _, audio) in enumerate(generator): - audio = audio.numpy() - node.send_output("audio", pa.array(audio), {"sample_rate": 24000}) + text = event["value"][0].as_py() + if re.findall(r"[\u4e00-\u9fff]+", text): + pipeline = KPipeline(lang_code="z") + elif pipeline.lang_code != "a": + pipeline = KPipeline( + lang_code="a" + ) # <= make sure lang_code matches voice + + generator = pipeline( + text, + voice="af_heart", # <= change voice here + speed=1.2, + split_pattern=r"\n+", + ) + for _, (_, _, audio) in enumerate(generator): + audio = audio.numpy() + node.send_output("audio", pa.array(audio), {"sample_rate": 24000}) if __name__ == "__main__": diff --git a/node-hub/dora-openai-websocket/src/main.rs b/node-hub/dora-openai-websocket/src/main.rs index 4eb473c6..3eda4d27 100644 --- a/node-hub/dora-openai-websocket/src/main.rs +++ b/node-hub/dora-openai-websocket/src/main.rs @@ -30,6 +30,10 @@ use fastwebsockets::Frame; use fastwebsockets::OpCode; use fastwebsockets::Payload; use fastwebsockets::WebSocketError; +use futures_concurrency::future::Race; +use futures_util::future; +use futures_util::future::Either; +use futures_util::FutureExt; use http_body_util::Empty; use hyper::body::Bytes; use hyper::body::Incoming; @@ -45,9 +49,7 @@ use std::fs; use std::io::{self, Write}; use std::net::IpAddr; use std::net::Ipv4Addr; -use std::time::Duration; use tokio::net::TcpListener; - #[derive(Serialize, Deserialize, Debug)] pub struct ErrorDetails { pub code: Option, @@ -310,125 +312,153 @@ async fn handle_client(fut: upgrade::UpgradeFut) -> Result<(), WebSocketError> { let frame = Frame::text(payload); ws.write_frame(frame).await?; loop { - let mut frame = ws.read_frame().await?; + let event_fut = events.recv_async().map(Either::Left); + let frame_fut = ws.read_frame().map(Either::Right); + let event_stream = (event_fut, frame_fut).race(); let mut finished = false; - match frame.opcode { - OpCode::Close => break, - OpCode::Text | OpCode::Binary => { - let data: OpenAIRealtimeMessage = serde_json::from_slice(&frame.payload).unwrap(); - - match data { - OpenAIRealtimeMessage::InputAudioBufferAppend { audio } => { - // println!("Received audio data: {}", audio); - let f32_data = audio; - // Decode base64 encoded audio data - let f32_data = f32_data.trim(); - if f32_data.is_empty() { - continue; - } + let frame = match event_stream.await { + future::Either::Left(Some(ev)) => { + let frame = match ev { + dora_node_api::Event::Input { + id, + metadata: _, + data, + } => { + if data.data_type() == &DataType::Utf8 { + let data = data.as_string::(); + let str = data.value(0); + let serialized_data = + OpenAIRealtimeResponse::ResponseAudioTranscriptDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: str.to_string(), + }; - if let Ok(f32_data) = general_purpose::STANDARD.decode(f32_data) { - let f32_data = convert_pcm16_to_f32(&f32_data); - // Downsample to 16 kHz from 24 kHz - let f32_data = f32_data - .into_iter() - .enumerate() - .filter(|(i, _)| i % 3 != 0) - .map(|(_, v)| v) - .collect::>(); - let mut parameter = MetadataParameters::default(); - parameter.insert( - "sample_rate".to_string(), - dora_node_api::Parameter::Integer(16000), - ); - node.send_output( - DataId::from("audio".to_string()), - parameter, - f32_data.into_arrow(), - ) - .unwrap(); - let ev = events.recv_async_timeout(Duration::from_millis(10)).await; - - // println!("Received event: {:?}", ev); - let frame = match ev { - Some(dora_node_api::Event::Input { - id, - metadata: _, - data, - }) => { - if data.data_type() == &DataType::Utf8 { - let data = data.as_string::(); - let str = data.value(0); - let serialized_data = - OpenAIRealtimeResponse::ResponseAudioTranscriptDelta { - response_id: "123".to_string(), - item_id: "123".to_string(), - output_index: 123, - content_index: 123, - delta: str.to_string(), - }; - - frame.payload = Payload::Bytes( - Bytes::from( - serde_json::to_string(&serialized_data).unwrap(), - ) - .into(), - ); - frame.opcode = OpCode::Text; - frame - } else if id.contains("audio") { - let data: Vec = into_vec(&data).unwrap(); - let data = convert_f32_to_pcm16(&data); - let serialized_data = - OpenAIRealtimeResponse::ResponseAudioDelta { - response_id: "123".to_string(), - item_id: "123".to_string(), - output_index: 123, - content_index: 123, - delta: general_purpose::STANDARD.encode(data), - }; - finished = true; - - frame.payload = Payload::Bytes( - Bytes::from( - serde_json::to_string(&serialized_data).unwrap(), - ) - .into(), - ); - frame.opcode = OpCode::Text; - frame - } else { - unimplemented!() - } - } - Some(dora_node_api::Event::Error(_)) => { - // println!("Error in input: {}", s); - continue; - } - _ => break, + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else if id.contains("audio") { + let data: Vec = into_vec(&data).unwrap(); + let data = convert_f32_to_pcm16(&data); + let serialized_data = OpenAIRealtimeResponse::ResponseAudioDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: general_purpose::STANDARD.encode(data), }; - ws.write_frame(frame).await?; - if finished { - let serialized_data = OpenAIRealtimeResponse::ResponseDone { - response: serde_json::Value::Null, + finished = true; + + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else if id.contains("stop") { + let serialized_data = + OpenAIRealtimeResponse::InputAudioBufferSpeechStopped { + audio_end_ms: 123, + item_id: "123".to_string(), }; + finished = true; + + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else { + unimplemented!() + } + } + dora_node_api::Event::Error(_) => { + // println!("Error in input: {}", s); + continue; + } + _ => break, + }; + Some(frame) + } + future::Either::Left(None) => break, + future::Either::Right(Ok(frame)) => { + match frame.opcode { + OpCode::Close => break, + OpCode::Text | OpCode::Binary => { + let data: OpenAIRealtimeMessage = + serde_json::from_slice(&frame.payload).unwrap(); + + match data { + OpenAIRealtimeMessage::InputAudioBufferAppend { audio } => { + // println!("Received audio data: {}", audio); + let f32_data = audio; + // Decode base64 encoded audio data + let f32_data = f32_data.trim(); + if f32_data.is_empty() { + continue; + } - let payload = Payload::Bytes( - Bytes::from(serde_json::to_string(&serialized_data).unwrap()) - .into(), - ); - println!("Sending response done: {:?}", serialized_data); - let frame = Frame::text(payload); - ws.write_frame(frame).await?; + if let Ok(f32_data) = general_purpose::STANDARD.decode(f32_data) { + let f32_data = convert_pcm16_to_f32(&f32_data); + // Downsample to 16 kHz from 24 kHz + let f32_data = f32_data + .into_iter() + .enumerate() + .filter(|(i, _)| i % 3 != 0) + .map(|(_, v)| v) + .collect::>(); + + let mut parameter = MetadataParameters::default(); + parameter.insert( + "sample_rate".to_string(), + dora_node_api::Parameter::Integer(16000), + ); + node.send_output( + DataId::from("audio".to_string()), + parameter, + f32_data.into_arrow(), + ) + .unwrap(); + } + } + OpenAIRealtimeMessage::InputAudioBufferCommit => break, + OpenAIRealtimeMessage::ResponseCreate { response } => { + if let Some(text) = response.instructions { + node.send_output( + DataId::from("text".to_string()), + Default::default(), + text.into_arrow(), + ) + .unwrap(); + } } + _ => {} } } - OpenAIRealtimeMessage::InputAudioBufferCommit => break, - _ => {} + _ => break, } + None } - _ => break, + future::Either::Right(Err(_)) => break, + }; + if let Some(frame) = frame { + ws.write_frame(frame).await?; } + if finished { + let serialized_data = OpenAIRealtimeResponse::ResponseDone { + response: serde_json::Value::Null, + }; + + let payload = Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()).into(), + ); + println!("Sending response done: {:?}", serialized_data); + let frame = Frame::text(payload); + ws.write_frame(frame).await?; + }; } Ok(()) diff --git a/node-hub/dora-vad/dora_vad/main.py b/node-hub/dora-vad/dora_vad/main.py index 0ee08b08..11f5f7b8 100644 --- a/node-hub/dora-vad/dora_vad/main.py +++ b/node-hub/dora-vad/dora_vad/main.py @@ -38,11 +38,21 @@ def main(): min_silence_duration_ms=MIN_SILENCE_DURATION_MS, sampling_rate=sr, ) - + if len(speech_timestamps) == 0: + # If there is no speech, return the audio + continue + arg_max = np.argmax([ts["end"] - ts["start"] for ts in speech_timestamps]) # Check ig there is timestamp if ( len(speech_timestamps) > 0 - and len(audio) > MIN_AUDIO_SAMPLING_DURATION_MS * sr / 1000 + and len( + audio[speech_timestamps[0]["start"] : speech_timestamps[-1]["end"]] + ) + > MIN_AUDIO_SAMPLING_DURATION_MS * sr / 1000 + and ( + (len(audio) - speech_timestamps[arg_max]["end"]) + > MIN_SILENCE_DURATION_MS / 1000 * sr * 5 + ) ): # Check if the audio is not cut at the end. And only return if there is a long time spent if speech_timestamps[-1]["end"] == len(audio): @@ -51,7 +61,7 @@ def main(): pa.array([speech_timestamps[-1]["start"]]), metadata={"sample_rate": sr}, ) - audio = audio[0 : speech_timestamps[-1]["end"]] + audio = audio[: speech_timestamps[-1]["end"]] node.send_output("audio", pa.array(audio), metadata={"sample_rate": sr}) last_audios = [audio[speech_timestamps[-1]["end"] :]]