From 3883874a317883b449b299341aa90071e4ea8c7d Mon Sep 17 00:00:00 2001 From: haixuantao Date: Wed, 29 Jan 2025 11:39:40 +0100 Subject: [PATCH] Use mlx whisper --- node-hub/dora-distil-whisper/README.md | 4 ++-- .../dora_distil_whisper/main.py | 24 +++++++++---------- node-hub/dora-distil-whisper/pyproject.toml | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/node-hub/dora-distil-whisper/README.md b/node-hub/dora-distil-whisper/README.md index 0c1854a4..6aeadb0c 100644 --- a/node-hub/dora-distil-whisper/README.md +++ b/node-hub/dora-distil-whisper/README.md @@ -18,10 +18,10 @@ This node is supposed to be used as follows: ## Examples -- Speech to Text +- speech to text - github: https://github.com/dora-rs/dora/blob/main/examples/speech-to-text - website: https://dora-rs.ai/docs/examples/stt -- Vision Language Model +- vision language model - github: https://github.com/dora-rs/dora/blob/main/examples/vlm - website: https://dora-rs.ai/docs/examples/vlm diff --git a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py index aed0a5b8..cc2207e7 100644 --- a/node-hub/dora-distil-whisper/dora_distil_whisper/main.py +++ b/node-hub/dora-distil-whisper/dora_distil_whisper/main.py @@ -46,15 +46,8 @@ def load_model(): return pipe -def load_model_mlx(): - # noqa: disable: import-error - from lightning_whisper_mlx import LightningWhisperMLX - - whisper = LightningWhisperMLX(model="distil-large-v3", batch_size=12, quant=None) - return whisper - - BAD_SENTENCES = [ + "", "字幕", "字幕志愿", "中文字幕", @@ -75,6 +68,8 @@ BAD_SENTENCES = [ def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): + if len(text) == 0: + return text # Check if the text is primarily Chinese (you may need to adjust this threshold) if sum(1 for char in text if "\u4e00" <= char <= "\u9fff") / len(text) > 0.5: # Chinese text processing @@ -109,9 +104,7 @@ def main(): node = Node() # For macos use mlx: - if sys.platform == "darwin": - whisper = load_model_mlx() - else: + if sys.platform != "darwin": pipe = load_model() for event in node: @@ -125,7 +118,14 @@ def main(): } ) if sys.platform == "darwin": - result = whisper.transcribe(audio) + import mlx_whisper + + result = mlx_whisper.transcribe( + audio, + path_or_hf_repo="mlx-community/whisper-large-v3-turbo", + append_punctuations=".", + ) + else: result = pipe( audio, diff --git a/node-hub/dora-distil-whisper/pyproject.toml b/node-hub/dora-distil-whisper/pyproject.toml index f2fd138a..75969bbd 100644 --- a/node-hub/dora-distil-whisper/pyproject.toml +++ b/node-hub/dora-distil-whisper/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "accelerate >= 0.29.2", "torch >= 2.2.0", "modelscope >= 1.18.1", - "lightning-whisper-mlx >= 0.0.10; sys_platform == 'darwin'", + "mlx-whisper >= 0.4.1; sys_platform == 'darwin'", ]