mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
[Whisper] Load customized MLX model & Quantization (#191)
* Add option to load customized mlx model * Add quantization * Apply reviews * Separate model conversion and loading * Update test * Fix benchmark * Add notes about conversion * Improve doc
This commit is contained in:
@@ -1,22 +1,68 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import unittest
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import whisper
|
||||
import whisper.audio as audio
|
||||
import whisper.decoding as decoding
|
||||
import whisper.load_models as load_models
|
||||
import whisper.torch_whisper as torch_whisper
|
||||
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
|
||||
MODEL_NAME = "tiny"
|
||||
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
|
||||
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
|
||||
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
|
||||
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
||||
|
||||
|
||||
def _save_model(save_dir, weights, config):
|
||||
mlx_path = Path(save_dir)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save weights
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(str(mlx_path / "config.json"), "w") as f:
|
||||
config["model_type"] = "whisper"
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
config.pop("model_type", None)
|
||||
|
||||
|
||||
def load_torch_and_mlx():
|
||||
torch_model = load_torch_model(MODEL_NAME)
|
||||
|
||||
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
|
||||
config = asdict(fp32_model.dims)
|
||||
weights = dict(tree_flatten(fp32_model.parameters()))
|
||||
_save_model(MLX_FP32_MODEL_PATH, weights, config)
|
||||
|
||||
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
|
||||
config = asdict(fp16_model.dims)
|
||||
weights = dict(tree_flatten(fp16_model.parameters()))
|
||||
_save_model(MLX_FP16_MODEL_PATH, weights, config)
|
||||
|
||||
args = type("", (), {})()
|
||||
args.q_group_size = 64
|
||||
args.q_bits = 4
|
||||
weights, config = quantize(weights, config, args)
|
||||
_save_model(MLX_4BITS_MODEL_PATH, weights, config)
|
||||
|
||||
return torch_model, fp32_model, fp16_model
|
||||
|
||||
|
||||
def forward_torch(model, mels, tokens):
|
||||
mels = torch.Tensor(mels).to(torch.float32)
|
||||
tokens = torch.Tensor(tokens).to(torch.int32)
|
||||
@@ -35,7 +81,7 @@ def forward_mlx(model, mels, tokens):
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = load_models.load_model("tiny", dtype=mx.float32)
|
||||
_, cls.model, _ = load_torch_and_mlx()
|
||||
data = audio.load_audio(TEST_AUDIO)
|
||||
data = audio.pad_or_trim(data)
|
||||
cls.mels = audio.log_mel_spectrogram(data)
|
||||
@@ -43,7 +89,7 @@ class TestWhisper(unittest.TestCase):
|
||||
def test_torch_mlx(self):
|
||||
np.random.seed(10)
|
||||
|
||||
torch_model = load_models.load_torch_model("tiny")
|
||||
torch_model = load_torch_model(MODEL_NAME)
|
||||
dims = torch_model.dims
|
||||
|
||||
mels = np.random.randn(1, dims.n_mels, 3_000)
|
||||
@@ -51,19 +97,27 @@ class TestWhisper(unittest.TestCase):
|
||||
|
||||
torch_logits = forward_torch(torch_model, mels, tokens)
|
||||
|
||||
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
|
||||
mlx_logits = forward_mlx(mlx_model, mels, tokens)
|
||||
mlx_logits = forward_mlx(self.model, mels, tokens)
|
||||
|
||||
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
|
||||
|
||||
def test_fp16(self):
|
||||
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
|
||||
mlx_model = load_models.load_model(MLX_FP16_MODEL_PATH, mx.float16)
|
||||
dims = mlx_model.dims
|
||||
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
|
||||
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
|
||||
logits = mlx_model(mels, tokens)
|
||||
self.assertEqual(logits.dtype, mx.float16)
|
||||
|
||||
def test_quantized_4bits(self):
|
||||
mlx_model = load_models.load_model(MLX_4BITS_MODEL_PATH, mx.float16)
|
||||
dims = mlx_model.dims
|
||||
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
|
||||
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
|
||||
logits = mlx_model(mels, tokens)
|
||||
# Here, we just test if 4-bit models can forward, as the quantized tiny models struggle with accurate transcription
|
||||
self.assertEqual(logits.dtype, mx.float16)
|
||||
|
||||
def test_decode_lang(self):
|
||||
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
||||
result = decoding.decode(self.model, self.mels, options)
|
||||
@@ -135,7 +189,7 @@ class TestWhisper(unittest.TestCase):
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
def test_transcribe(self):
|
||||
result = whisper.transcribe(TEST_AUDIO, fp16=False)
|
||||
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||
self.assertEqual(
|
||||
result["text"],
|
||||
(
|
||||
@@ -154,7 +208,7 @@ class TestWhisper(unittest.TestCase):
|
||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||
return
|
||||
|
||||
result = whisper.transcribe(audio_file, fp16=False)
|
||||
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
|
||||
self.assertEqual(len(result["text"]), 10920)
|
||||
self.assertEqual(result["language"], "en")
|
||||
self.assertEqual(len(result["segments"]), 77)
|
||||
|
Reference in New Issue
Block a user