use safetensors in whisper

This commit is contained in:
Awni Hannun 2024-10-04 10:59:01 -07:00
parent 9f34fdbda4
commit 15dcebc36a
3 changed files with 5 additions and 5 deletions

View File

@ -387,7 +387,7 @@ if __name__ == "__main__":
# Save weights
print("[INFO] Saving")
np.savez(str(mlx_path / "weights.npz"), **weights)
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:

View File

@ -432,9 +432,6 @@ class DecodingTask:
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
raise NotImplementedError("Beam search decoder is not yet implemented")
# self.decoder = BeamSearchDecoder(
# options.beam_size, tokenizer.eot, self.inference, options.patience
# )
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)

View File

@ -26,7 +26,10 @@ def load_model(
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)