mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
parent
d4c3a9cb54
commit
bb35e878cb
@ -7,14 +7,22 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from . import whisper
|
||||
|
||||
|
||||
def load_model(
|
||||
folder: str,
|
||||
path_or_hf_repo: str,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
model_path = Path(folder)
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo
|
||||
)
|
||||
)
|
||||
|
||||
with open(str(model_path / "config.json"), "r") as f:
|
||||
config = json.loads(f.read())
|
||||
|
@ -62,7 +62,7 @@ class ModelHolder:
|
||||
def transcribe(
|
||||
audio: Union[str, np.ndarray, mx.array],
|
||||
*,
|
||||
model_path: str = "mlx_models/tiny",
|
||||
path_or_hf_repo: str = "mlx_models",
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
@ -85,8 +85,8 @@ def transcribe(
|
||||
audio: Union[str, np.ndarray, mx.array]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
model_path: str
|
||||
The path to the Whisper model that has been converted to MLX format.
|
||||
path_or_hf_repo: str
|
||||
The localpath to the Whisper model or HF Hub repo with the MLX converted weights.
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
@ -144,7 +144,7 @@ def transcribe(
|
||||
"""
|
||||
|
||||
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||
model = ModelHolder.get_model(model_path, dtype)
|
||||
model = ModelHolder.get_model(path_or_hf_repo, dtype)
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||
|
Loading…
Reference in New Issue
Block a user