[Whisper] Add load from Hub. (#255)

* Add load from Hub.

* Up.
This commit is contained in:
Vaibhav Srivastav 2024-01-08 19:50:00 +05:30 committed by GitHub
parent d4c3a9cb54
commit bb35e878cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 6 deletions

View File

@ -7,14 +7,22 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from huggingface_hub import snapshot_download
from . import whisper from . import whisper
def load_model( def load_model(
folder: str, path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32, dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper: ) -> whisper.Whisper:
model_path = Path(folder) model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo
)
)
with open(str(model_path / "config.json"), "r") as f: with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())

View File

@ -62,7 +62,7 @@ class ModelHolder:
def transcribe( def transcribe(
audio: Union[str, np.ndarray, mx.array], audio: Union[str, np.ndarray, mx.array],
*, *,
model_path: str = "mlx_models/tiny", path_or_hf_repo: str = "mlx_models",
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4, compression_ratio_threshold: Optional[float] = 2.4,
@ -85,8 +85,8 @@ def transcribe(
audio: Union[str, np.ndarray, mx.array] audio: Union[str, np.ndarray, mx.array]
The path to the audio file to open, or the audio waveform The path to the audio file to open, or the audio waveform
model_path: str path_or_hf_repo: str
The path to the Whisper model that has been converted to MLX format. The localpath to the Whisper model or HF Hub repo with the MLX converted weights.
verbose: bool verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details, Whether to display the text being decoded to the console. If True, displays all the details,
@ -144,7 +144,7 @@ def transcribe(
""" """
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(model_path, dtype) model = ModelHolder.get_model(path_or_hf_repo, dtype)
# Pad 30-seconds of silence to the input audio, for slicing # Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)