mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Whisper updates to allow HF models (#923)
* simplify conversion and update convert for HF models * use npz for compat * fixes * fixes * fix gguf * allow user supplied path
This commit is contained in:
@@ -13,7 +13,7 @@ import mlx_whisper.decoding as decoding
|
||||
import mlx_whisper.load_models as load_models
|
||||
import numpy as np
|
||||
import torch
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
from convert import convert, load_torch_model, quantize
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
MODEL_NAME = "tiny"
|
||||
@@ -41,12 +41,12 @@ def _save_model(save_dir, weights, config):
|
||||
def load_torch_and_mlx():
|
||||
torch_model = load_torch_model(MODEL_NAME)
|
||||
|
||||
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
|
||||
fp32_model = convert(MODEL_NAME, dtype=mx.float32)
|
||||
config = asdict(fp32_model.dims)
|
||||
weights = dict(tree_flatten(fp32_model.parameters()))
|
||||
_save_model(MLX_FP32_MODEL_PATH, weights, config)
|
||||
|
||||
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
|
||||
fp16_model = convert(MODEL_NAME, dtype=mx.float16)
|
||||
config = asdict(fp16_model.dims)
|
||||
weights = dict(tree_flatten(fp16_model.parameters()))
|
||||
_save_model(MLX_FP16_MODEL_PATH, weights, config)
|
||||
|
Reference in New Issue
Block a user