From 15dcebc36a8fda15467946a9b6421c8bde0c0fbb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 4 Oct 2024 10:59:01 -0700 Subject: [PATCH] use safetensors in whisper --- whisper/convert.py | 2 +- whisper/mlx_whisper/decoding.py | 3 --- whisper/mlx_whisper/load_models.py | 5 ++++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/whisper/convert.py b/whisper/convert.py index cdd50bc5..5c0477ec 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -387,7 +387,7 @@ if __name__ == "__main__": # Save weights print("[INFO] Saving") - np.savez(str(mlx_path / "weights.npz"), **weights) + mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) # Save config.json with model_type with open(str(mlx_path / "config.json"), "w") as f: diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 41c2ec6d..7203edad 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -432,9 +432,6 @@ class DecodingTask: # decoder: implements how to select the next tokens, given the autoregressive distribution if options.beam_size is not None: raise NotImplementedError("Beam search decoder is not yet implemented") - # self.decoder = BeamSearchDecoder( - # options.beam_size, tokenizer.eot, self.inference, options.patience - # ) else: self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) diff --git a/whisper/mlx_whisper/load_models.py b/whisper/mlx_whisper/load_models.py index 6705385d..60766ab2 100644 --- a/whisper/mlx_whisper/load_models.py +++ b/whisper/mlx_whisper/load_models.py @@ -26,7 +26,10 @@ def load_model( model_args = whisper.ModelDimensions(**config) - weights = mx.load(str(model_path / "weights.npz")) + wf = model_path / "weights.safetensors" + if not wf.exists(): + wf = model_path / "weights.npz" + weights = mx.load(str(wf)) model = whisper.Whisper(model_args, dtype)