From a1aace4d99948acc7be878e8684dba316e159fc3 Mon Sep 17 00:00:00 2001 From: Haixuan Xavier Tao Date: Thu, 23 Jan 2025 11:55:12 +0100 Subject: [PATCH] Add the possibility to cache model instead of loading from disk each time. --- whisper/mlx_whisper/transcribe.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 7057679b..67850240 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -75,6 +75,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + cached_model: Optional[ModelHolder] = None, **decode_options, ): """ @@ -137,6 +138,9 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + cached_model: Optional[ModelHolder] + Stored in memory whisper model to avoid having to load from disk each time. + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -144,7 +148,10 @@ def transcribe( """ dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 - model = ModelHolder.get_model(path_or_hf_repo, dtype) + if cached_model is None: + model = ModelHolder.get_model(path_or_hf_repo, dtype) + else: + model = cached_model # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)