Whisper improvements (#1080)

* use safetensors in whisper

* speed up decoder

* version
This commit is contained in:
Awni Hannun
2024-11-01 10:52:28 -07:00
committed by GitHub
parent 85ffd2c96a
commit 8160e0c4e5
6 changed files with 85 additions and 64 deletions

View File

@@ -26,7 +26,10 @@ def load_model(
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)