mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
parent
d4c3a9cb54
commit
bb35e878cb
@ -7,14 +7,22 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from . import whisper
|
from . import whisper
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
folder: str,
|
path_or_hf_repo: str,
|
||||||
dtype: mx.Dtype = mx.float32,
|
dtype: mx.Dtype = mx.float32,
|
||||||
) -> whisper.Whisper:
|
) -> whisper.Whisper:
|
||||||
model_path = Path(folder)
|
model_path = Path(path_or_hf_repo)
|
||||||
|
if not model_path.exists():
|
||||||
|
model_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=path_or_hf_repo
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with open(str(model_path / "config.json"), "r") as f:
|
with open(str(model_path / "config.json"), "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
|
@ -62,7 +62,7 @@ class ModelHolder:
|
|||||||
def transcribe(
|
def transcribe(
|
||||||
audio: Union[str, np.ndarray, mx.array],
|
audio: Union[str, np.ndarray, mx.array],
|
||||||
*,
|
*,
|
||||||
model_path: str = "mlx_models/tiny",
|
path_or_hf_repo: str = "mlx_models",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
@ -85,8 +85,8 @@ def transcribe(
|
|||||||
audio: Union[str, np.ndarray, mx.array]
|
audio: Union[str, np.ndarray, mx.array]
|
||||||
The path to the audio file to open, or the audio waveform
|
The path to the audio file to open, or the audio waveform
|
||||||
|
|
||||||
model_path: str
|
path_or_hf_repo: str
|
||||||
The path to the Whisper model that has been converted to MLX format.
|
The localpath to the Whisper model or HF Hub repo with the MLX converted weights.
|
||||||
|
|
||||||
verbose: bool
|
verbose: bool
|
||||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||||
@ -144,7 +144,7 @@ def transcribe(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||||
model = ModelHolder.get_model(model_path, dtype)
|
model = ModelHolder.get_model(path_or_hf_repo, dtype)
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||||
|
Loading…
Reference in New Issue
Block a user