mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 03:19:23 +08:00
use safetensors in whisper
This commit is contained in:
parent
9f34fdbda4
commit
15dcebc36a
@ -387,7 +387,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Save weights
|
# Save weights
|
||||||
print("[INFO] Saving")
|
print("[INFO] Saving")
|
||||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
||||||
|
|
||||||
# Save config.json with model_type
|
# Save config.json with model_type
|
||||||
with open(str(mlx_path / "config.json"), "w") as f:
|
with open(str(mlx_path / "config.json"), "w") as f:
|
||||||
|
@ -432,9 +432,6 @@ class DecodingTask:
|
|||||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||||
if options.beam_size is not None:
|
if options.beam_size is not None:
|
||||||
raise NotImplementedError("Beam search decoder is not yet implemented")
|
raise NotImplementedError("Beam search decoder is not yet implemented")
|
||||||
# self.decoder = BeamSearchDecoder(
|
|
||||||
# options.beam_size, tokenizer.eot, self.inference, options.patience
|
|
||||||
# )
|
|
||||||
else:
|
else:
|
||||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||||
|
|
||||||
|
@ -26,7 +26,10 @@ def load_model(
|
|||||||
|
|
||||||
model_args = whisper.ModelDimensions(**config)
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
wf = model_path / "weights.safetensors"
|
||||||
|
if not wf.exists():
|
||||||
|
wf = model_path / "weights.npz"
|
||||||
|
weights = mx.load(str(wf))
|
||||||
|
|
||||||
model = whisper.Whisper(model_args, dtype)
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user