Add the possibility to cache model instead of loading from disk each time.

This commit is contained in:
Haixuan Xavier Tao 2025-01-23 11:55:12 +01:00 committed by GitHub
parent 9a3ddc3e65
commit a1aace4d99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -75,6 +75,7 @@ def transcribe(
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
cached_model: Optional[ModelHolder] = None,
**decode_options, **decode_options,
): ):
""" """
@ -137,6 +138,9 @@ def transcribe(
When word_timestamps is True, skip silent periods longer than this threshold (in seconds) When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected when a possible hallucination is detected
cached_model: Optional[ModelHolder]
Stored in memory whisper model to avoid having to load from disk each time.
Returns Returns
------- -------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -144,7 +148,10 @@ def transcribe(
""" """
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
if cached_model is None:
model = ModelHolder.get_model(path_or_hf_repo, dtype) model = ModelHolder.get_model(path_or_hf_repo, dtype)
else:
model = cached_model
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)