|
|
|
@@ -46,15 +46,8 @@ def load_model(): |
|
|
|
return pipe |
|
|
|
|
|
|
|
|
|
|
|
def load_model_mlx(): |
|
|
|
# noqa: disable: import-error |
|
|
|
from lightning_whisper_mlx import LightningWhisperMLX |
|
|
|
|
|
|
|
whisper = LightningWhisperMLX(model="distil-large-v3", batch_size=12, quant=None) |
|
|
|
return whisper |
|
|
|
|
|
|
|
|
|
|
|
BAD_SENTENCES = [ |
|
|
|
"", |
|
|
|
"字幕", |
|
|
|
"字幕志愿", |
|
|
|
"中文字幕", |
|
|
|
@@ -75,6 +68,8 @@ BAD_SENTENCES = [ |
|
|
|
|
|
|
|
|
|
|
|
def cut_repetition(text, min_repeat_length=4, max_repeat_length=50): |
|
|
|
if len(text) == 0: |
|
|
|
return text |
|
|
|
# Check if the text is primarily Chinese (you may need to adjust this threshold) |
|
|
|
if sum(1 for char in text if "\u4e00" <= char <= "\u9fff") / len(text) > 0.5: |
|
|
|
# Chinese text processing |
|
|
|
@@ -109,9 +104,7 @@ def main(): |
|
|
|
node = Node() |
|
|
|
|
|
|
|
# For macos use mlx: |
|
|
|
if sys.platform == "darwin": |
|
|
|
whisper = load_model_mlx() |
|
|
|
else: |
|
|
|
if sys.platform != "darwin": |
|
|
|
pipe = load_model() |
|
|
|
|
|
|
|
for event in node: |
|
|
|
@@ -125,7 +118,14 @@ def main(): |
|
|
|
} |
|
|
|
) |
|
|
|
if sys.platform == "darwin": |
|
|
|
result = whisper.transcribe(audio) |
|
|
|
import mlx_whisper |
|
|
|
|
|
|
|
result = mlx_whisper.transcribe( |
|
|
|
audio, |
|
|
|
path_or_hf_repo="mlx-community/whisper-large-v3-turbo", |
|
|
|
append_punctuations=".", |
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
result = pipe( |
|
|
|
audio, |
|
|
|
|