diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index c92d9042..0d21ac1e 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -7,14 +7,22 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_unflatten +from huggingface_hub import snapshot_download + from . import whisper def load_model( - folder: str, + path_or_hf_repo: str, dtype: mx.Dtype = mx.float32, ) -> whisper.Whisper: - model_path = Path(folder) + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo + ) + ) with open(str(model_path / "config.json"), "r") as f: config = json.loads(f.read()) diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 704fd36c..05ff1fd1 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -62,7 +62,7 @@ class ModelHolder: def transcribe( audio: Union[str, np.ndarray, mx.array], *, - model_path: str = "mlx_models/tiny", + path_or_hf_repo: str = "mlx_models", verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, @@ -85,8 +85,8 @@ def transcribe( audio: Union[str, np.ndarray, mx.array] The path to the audio file to open, or the audio waveform - model_path: str - The path to the Whisper model that has been converted to MLX format. + path_or_hf_repo: str + The localpath to the Whisper model or HF Hub repo with the MLX converted weights. verbose: bool Whether to display the text being decoded to the console. If True, displays all the details, @@ -144,7 +144,7 @@ def transcribe( """ dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 - model = ModelHolder.get_model(model_path, dtype) + model = ModelHolder.get_model(path_or_hf_repo, dtype) # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)