Whisper updates to allow HF models (#923)

* simplify conversion and update convert for HF models

* use npz for compat

* fixes

* fixes

* fix gguf

* allow user supplied path
This commit is contained in:
Awni Hannun
2024-08-09 11:11:58 -07:00
committed by GitHub
parent df744c98e6
commit 33905447f9
5 changed files with 116 additions and 75 deletions

View File

@@ -13,7 +13,7 @@ import mlx_whisper.decoding as decoding
import mlx_whisper.load_models as load_models
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from convert import convert, load_torch_model, quantize
from mlx.utils import tree_flatten
MODEL_NAME = "tiny"
@@ -41,12 +41,12 @@ def _save_model(save_dir, weights, config):
def load_torch_and_mlx():
torch_model = load_torch_model(MODEL_NAME)
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
fp32_model = convert(MODEL_NAME, dtype=mx.float32)
config = asdict(fp32_model.dims)
weights = dict(tree_flatten(fp32_model.parameters()))
_save_model(MLX_FP32_MODEL_PATH, weights, config)
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
fp16_model = convert(MODEL_NAME, dtype=mx.float16)
config = asdict(fp16_model.dims)
weights = dict(tree_flatten(fp16_model.parameters()))
_save_model(MLX_FP16_MODEL_PATH, weights, config)