mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Whisper improvements (#1080)
* use safetensors in whisper * speed up decoder * version
This commit is contained in:
@@ -26,7 +26,10 @@ def load_model(
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
|
Reference in New Issue
Block a user