From 2f89cf811d758b443c6b88cc3635a78a13de4ee9 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Sat, 2 Aug 2025 00:00:46 +0800 Subject: [PATCH] Improvement of the overall language pipeline to be more resilient --- .../dora_distil_whisper/main.py | 85 +++--- .../dora-kokoro-tts/dora_kokoro_tts/main.py | 37 +-- node-hub/dora-openai-websocket/src/main.rs | 250 ++++++++++-------- node-hub/dora-vad/dora_vad/main.py | 16 +- 4 files changed, 218 insertions(+), 170 deletions(-) diff --git a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py index 26fcc9dd..3710bf0c 100644 --- a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py +++ b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py @@ -48,35 +48,6 @@ def remove_text_noise(text: str, text_noise="") -> str: text_words = normalized_text.split() noise_words = normalized_noise.split() - # Function to find and remove noise sequence flexibly - def remove_flexible(text_list, noise_list): - i = 0 - while i <= len(text_list) - len(noise_list): - match = True - extra_words = 0 - for j, noise_word in enumerate(noise_list): - if i + j + extra_words >= len(text_list): - match = False - break - # Allow skipping extra words in text_list - while ( - i + j + extra_words < len(text_list) - and text_list[i + j + extra_words] != noise_word - ): - extra_words += 1 - if i + j + extra_words >= len(text_list): - match = False - break - if not match: - break - if match: - # Remove matched part - del text_list[i : i + len(noise_list) + extra_words] - i = max(0, i - len(noise_list)) # Adjust index after removal - else: - i += 1 - return text_list - # Only remove parts of text_noise that are found in text cleaned_words = text_words[:] for noise_word in noise_words: @@ -126,7 +97,26 @@ BAD_SENTENCES = [ "", " so", " So.", + " So, let's go.", " so so", + " What?", + " We'll see you next time.", + " I'll see you next time.", + " We're going to come back.", + " let's move on.", + " Here we go.", + " my", + " All right. Thank you.", + " That's what we're doing.", + " That's what I wanted to do.", + " I'll be back.", + " Hold this. Hold this.", + " Hold this one. Hold this one.", + " And we'll see you next time.", + " strength.", + " Length.", + " Let's go.", + " Let's do it.", "You", "You ", " You", @@ -199,6 +189,12 @@ def main(): append_punctuations=".", language=TARGET_LANGUAGE, ) + result = mlx_whisper.transcribe( + np.array([]), + path_or_hf_repo="mlx-community/whisper-large-v3-turbo", + append_punctuations=".", + language=TARGET_LANGUAGE, + ) node = Node() noise_timestamp = time.time() @@ -244,6 +240,8 @@ def main(): generate_kwargs=confg, ) if result["text"] in BAD_SENTENCES: + print("Discarded text: ", result["text"]) + # cache_audio = None continue text = cut_repetition(result["text"]) @@ -258,20 +256,27 @@ def main(): continue if ( - ( - text.endswith(".") - or text.endswith("!") - or text.endswith("?") - or text.endswith('."') - or text.endswith('!"') - or text.endswith('?"') - ) - and not text.endswith("...") # Avoid ending with ellipsis - ): + text.endswith(".") + or text.endswith("!") + or text.endswith("?") + or text.endswith('."') + or text.endswith('!"') + or text.endswith('?"') + ) and not text.endswith("..."): node.send_output( "text", pa.array([text]), ) + node.send_output( + "stop", + pa.array([text]), + ) cache_audio = None - else: + audio = None + print("Text:", text) + elif text.endswith("..."): + print( + "Keeping audio in cache for next text output with punctuation" + ) + print("Discarded text", text) cache_audio = audio diff --git a/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py b/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py index 7762cfca..d1c28a4f 100644 --- a/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py +++ b/node-hub/dora-kokoro-tts/dora_kokoro_tts/main.py @@ -1,4 +1,5 @@ """TODO: Add docstring.""" + import os import re @@ -8,11 +9,12 @@ from kokoro import KPipeline LANGUAGE = os.getenv("LANGUAGE", "en") + def main(): """TODO: Add docstring.""" if LANGUAGE in ["en", "english"]: pipeline = KPipeline(lang_code="a") - elif LANGUAGE in ["zh","ch","chinese"]: + elif LANGUAGE in ["zh", "ch", "chinese"]: pipeline = KPipeline(lang_code="z") else: print("warning: Defaulting to english speaker as language not found") @@ -22,22 +24,23 @@ def main(): for event in node: if event["type"] == "INPUT": - if event["id"] == "text": - text = event["value"][0].as_py() - if re.findall(r'[\u4e00-\u9fff]+', text): - pipeline = KPipeline(lang_code="z") - elif pipeline.lang_code != "a": - pipeline = KPipeline(lang_code="a") # <= make sure lang_code matches voice - - generator = pipeline( - text, - voice="af_heart", # <= change voice here - speed=1.2, - split_pattern=r"\n+", - ) - for _, (_, _, audio) in enumerate(generator): - audio = audio.numpy() - node.send_output("audio", pa.array(audio), {"sample_rate": 24000}) + text = event["value"][0].as_py() + if re.findall(r"[\u4e00-\u9fff]+", text): + pipeline = KPipeline(lang_code="z") + elif pipeline.lang_code != "a": + pipeline = KPipeline( + lang_code="a" + ) # <= make sure lang_code matches voice + + generator = pipeline( + text, + voice="af_heart", # <= change voice here + speed=1.2, + split_pattern=r"\n+", + ) + for _, (_, _, audio) in enumerate(generator): + audio = audio.numpy() + node.send_output("audio", pa.array(audio), {"sample_rate": 24000}) if __name__ == "__main__": diff --git a/node-hub/dora-openai-websocket/src/main.rs b/node-hub/dora-openai-websocket/src/main.rs index 4eb473c6..3eda4d27 100644 --- a/node-hub/dora-openai-websocket/src/main.rs +++ b/node-hub/dora-openai-websocket/src/main.rs @@ -30,6 +30,10 @@ use fastwebsockets::Frame; use fastwebsockets::OpCode; use fastwebsockets::Payload; use fastwebsockets::WebSocketError; +use futures_concurrency::future::Race; +use futures_util::future; +use futures_util::future::Either; +use futures_util::FutureExt; use http_body_util::Empty; use hyper::body::Bytes; use hyper::body::Incoming; @@ -45,9 +49,7 @@ use std::fs; use std::io::{self, Write}; use std::net::IpAddr; use std::net::Ipv4Addr; -use std::time::Duration; use tokio::net::TcpListener; - #[derive(Serialize, Deserialize, Debug)] pub struct ErrorDetails { pub code: Option, @@ -310,125 +312,153 @@ async fn handle_client(fut: upgrade::UpgradeFut) -> Result<(), WebSocketError> { let frame = Frame::text(payload); ws.write_frame(frame).await?; loop { - let mut frame = ws.read_frame().await?; + let event_fut = events.recv_async().map(Either::Left); + let frame_fut = ws.read_frame().map(Either::Right); + let event_stream = (event_fut, frame_fut).race(); let mut finished = false; - match frame.opcode { - OpCode::Close => break, - OpCode::Text | OpCode::Binary => { - let data: OpenAIRealtimeMessage = serde_json::from_slice(&frame.payload).unwrap(); - - match data { - OpenAIRealtimeMessage::InputAudioBufferAppend { audio } => { - // println!("Received audio data: {}", audio); - let f32_data = audio; - // Decode base64 encoded audio data - let f32_data = f32_data.trim(); - if f32_data.is_empty() { - continue; - } + let frame = match event_stream.await { + future::Either::Left(Some(ev)) => { + let frame = match ev { + dora_node_api::Event::Input { + id, + metadata: _, + data, + } => { + if data.data_type() == &DataType::Utf8 { + let data = data.as_string::(); + let str = data.value(0); + let serialized_data = + OpenAIRealtimeResponse::ResponseAudioTranscriptDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: str.to_string(), + }; - if let Ok(f32_data) = general_purpose::STANDARD.decode(f32_data) { - let f32_data = convert_pcm16_to_f32(&f32_data); - // Downsample to 16 kHz from 24 kHz - let f32_data = f32_data - .into_iter() - .enumerate() - .filter(|(i, _)| i % 3 != 0) - .map(|(_, v)| v) - .collect::>(); - let mut parameter = MetadataParameters::default(); - parameter.insert( - "sample_rate".to_string(), - dora_node_api::Parameter::Integer(16000), - ); - node.send_output( - DataId::from("audio".to_string()), - parameter, - f32_data.into_arrow(), - ) - .unwrap(); - let ev = events.recv_async_timeout(Duration::from_millis(10)).await; - - // println!("Received event: {:?}", ev); - let frame = match ev { - Some(dora_node_api::Event::Input { - id, - metadata: _, - data, - }) => { - if data.data_type() == &DataType::Utf8 { - let data = data.as_string::(); - let str = data.value(0); - let serialized_data = - OpenAIRealtimeResponse::ResponseAudioTranscriptDelta { - response_id: "123".to_string(), - item_id: "123".to_string(), - output_index: 123, - content_index: 123, - delta: str.to_string(), - }; - - frame.payload = Payload::Bytes( - Bytes::from( - serde_json::to_string(&serialized_data).unwrap(), - ) - .into(), - ); - frame.opcode = OpCode::Text; - frame - } else if id.contains("audio") { - let data: Vec = into_vec(&data).unwrap(); - let data = convert_f32_to_pcm16(&data); - let serialized_data = - OpenAIRealtimeResponse::ResponseAudioDelta { - response_id: "123".to_string(), - item_id: "123".to_string(), - output_index: 123, - content_index: 123, - delta: general_purpose::STANDARD.encode(data), - }; - finished = true; - - frame.payload = Payload::Bytes( - Bytes::from( - serde_json::to_string(&serialized_data).unwrap(), - ) - .into(), - ); - frame.opcode = OpCode::Text; - frame - } else { - unimplemented!() - } - } - Some(dora_node_api::Event::Error(_)) => { - // println!("Error in input: {}", s); - continue; - } - _ => break, + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else if id.contains("audio") { + let data: Vec = into_vec(&data).unwrap(); + let data = convert_f32_to_pcm16(&data); + let serialized_data = OpenAIRealtimeResponse::ResponseAudioDelta { + response_id: "123".to_string(), + item_id: "123".to_string(), + output_index: 123, + content_index: 123, + delta: general_purpose::STANDARD.encode(data), }; - ws.write_frame(frame).await?; - if finished { - let serialized_data = OpenAIRealtimeResponse::ResponseDone { - response: serde_json::Value::Null, + finished = true; + + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else if id.contains("stop") { + let serialized_data = + OpenAIRealtimeResponse::InputAudioBufferSpeechStopped { + audio_end_ms: 123, + item_id: "123".to_string(), }; + finished = true; + + let frame = Frame::text(Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()) + .into(), + )); + frame + } else { + unimplemented!() + } + } + dora_node_api::Event::Error(_) => { + // println!("Error in input: {}", s); + continue; + } + _ => break, + }; + Some(frame) + } + future::Either::Left(None) => break, + future::Either::Right(Ok(frame)) => { + match frame.opcode { + OpCode::Close => break, + OpCode::Text | OpCode::Binary => { + let data: OpenAIRealtimeMessage = + serde_json::from_slice(&frame.payload).unwrap(); + + match data { + OpenAIRealtimeMessage::InputAudioBufferAppend { audio } => { + // println!("Received audio data: {}", audio); + let f32_data = audio; + // Decode base64 encoded audio data + let f32_data = f32_data.trim(); + if f32_data.is_empty() { + continue; + } - let payload = Payload::Bytes( - Bytes::from(serde_json::to_string(&serialized_data).unwrap()) - .into(), - ); - println!("Sending response done: {:?}", serialized_data); - let frame = Frame::text(payload); - ws.write_frame(frame).await?; + if let Ok(f32_data) = general_purpose::STANDARD.decode(f32_data) { + let f32_data = convert_pcm16_to_f32(&f32_data); + // Downsample to 16 kHz from 24 kHz + let f32_data = f32_data + .into_iter() + .enumerate() + .filter(|(i, _)| i % 3 != 0) + .map(|(_, v)| v) + .collect::>(); + + let mut parameter = MetadataParameters::default(); + parameter.insert( + "sample_rate".to_string(), + dora_node_api::Parameter::Integer(16000), + ); + node.send_output( + DataId::from("audio".to_string()), + parameter, + f32_data.into_arrow(), + ) + .unwrap(); + } + } + OpenAIRealtimeMessage::InputAudioBufferCommit => break, + OpenAIRealtimeMessage::ResponseCreate { response } => { + if let Some(text) = response.instructions { + node.send_output( + DataId::from("text".to_string()), + Default::default(), + text.into_arrow(), + ) + .unwrap(); + } } + _ => {} } } - OpenAIRealtimeMessage::InputAudioBufferCommit => break, - _ => {} + _ => break, } + None } - _ => break, + future::Either::Right(Err(_)) => break, + }; + if let Some(frame) = frame { + ws.write_frame(frame).await?; } + if finished { + let serialized_data = OpenAIRealtimeResponse::ResponseDone { + response: serde_json::Value::Null, + }; + + let payload = Payload::Bytes( + Bytes::from(serde_json::to_string(&serialized_data).unwrap()).into(), + ); + println!("Sending response done: {:?}", serialized_data); + let frame = Frame::text(payload); + ws.write_frame(frame).await?; + }; } Ok(()) diff --git a/node-hub/dora-vad/dora_vad/main.py b/node-hub/dora-vad/dora_vad/main.py index 0ee08b08..11f5f7b8 100644 --- a/node-hub/dora-vad/dora_vad/main.py +++ b/node-hub/dora-vad/dora_vad/main.py @@ -38,11 +38,21 @@ def main(): min_silence_duration_ms=MIN_SILENCE_DURATION_MS, sampling_rate=sr, ) - + if len(speech_timestamps) == 0: + # If there is no speech, return the audio + continue + arg_max = np.argmax([ts["end"] - ts["start"] for ts in speech_timestamps]) # Check ig there is timestamp if ( len(speech_timestamps) > 0 - and len(audio) > MIN_AUDIO_SAMPLING_DURATION_MS * sr / 1000 + and len( + audio[speech_timestamps[0]["start"] : speech_timestamps[-1]["end"]] + ) + > MIN_AUDIO_SAMPLING_DURATION_MS * sr / 1000 + and ( + (len(audio) - speech_timestamps[arg_max]["end"]) + > MIN_SILENCE_DURATION_MS / 1000 * sr * 5 + ) ): # Check if the audio is not cut at the end. And only return if there is a long time spent if speech_timestamps[-1]["end"] == len(audio): @@ -51,7 +61,7 @@ def main(): pa.array([speech_timestamps[-1]["start"]]), metadata={"sample_rate": sr}, ) - audio = audio[0 : speech_timestamps[-1]["end"]] + audio = audio[: speech_timestamps[-1]["end"]] node.send_output("audio", pa.array(audio), metadata={"sample_rate": sr}) last_audios = [audio[speech_timestamps[-1]["end"] :]]