[Whisper] Add load from Hub. (#255)

* Add load from Hub.

* Up.
This commit is contained in:
Vaibhav Srivastav 2024-01-08 19:50:00 +05:30 committed by GitHub
parent d4c3a9cb54
commit bb35e878cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 6 deletions

View File

@ -7,14 +7,22 @@ import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from huggingface_hub import snapshot_download
from . import whisper
def load_model(
folder: str,
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(folder)
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo
)
)
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())

View File

@ -62,7 +62,7 @@ class ModelHolder:
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
model_path: str = "mlx_models/tiny",
path_or_hf_repo: str = "mlx_models",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
@ -85,8 +85,8 @@ def transcribe(
audio: Union[str, np.ndarray, mx.array]
The path to the audio file to open, or the audio waveform
model_path: str
The path to the Whisper model that has been converted to MLX format.
path_or_hf_repo: str
The localpath to the Whisper model or HF Hub repo with the MLX converted weights.
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
@ -144,7 +144,7 @@ def transcribe(
"""
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(model_path, dtype)
model = ModelHolder.get_model(path_or_hf_repo, dtype)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)