[Whisper] Load customized MLX model & Quantization (#191)

* Add option to load customized mlx model

* Add quantization

* Apply reviews

* Separate model conversion and loading

* Update test

* Fix benchmark

* Add notes about conversion

* Improve doc
This commit is contained in:
bofeng huang
2023-12-29 19:22:15 +01:00
committed by GitHub
parent 1cdbf9e886
commit 581a5733a1
6 changed files with 421 additions and 211 deletions

View File

@@ -1,22 +1,68 @@
# Copyright © 2023 Apple Inc.
import json
import os
import subprocess
import unittest
from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import numpy as np
import torch
from mlx.utils import tree_flatten
import whisper
import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
import whisper.torch_whisper as torch_whisper
from convert import load_torch_model, quantize, torch_to_mlx
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
TEST_AUDIO = "whisper/assets/ls_test.flac"
def _save_model(save_dir, weights, config):
mlx_path = Path(save_dir)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)
config.pop("model_type", None)
def load_torch_and_mlx():
torch_model = load_torch_model(MODEL_NAME)
fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
config = asdict(fp32_model.dims)
weights = dict(tree_flatten(fp32_model.parameters()))
_save_model(MLX_FP32_MODEL_PATH, weights, config)
fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
config = asdict(fp16_model.dims)
weights = dict(tree_flatten(fp16_model.parameters()))
_save_model(MLX_FP16_MODEL_PATH, weights, config)
args = type("", (), {})()
args.q_group_size = 64
args.q_bits = 4
weights, config = quantize(weights, config, args)
_save_model(MLX_4BITS_MODEL_PATH, weights, config)
return torch_model, fp32_model, fp16_model
def forward_torch(model, mels, tokens):
mels = torch.Tensor(mels).to(torch.float32)
tokens = torch.Tensor(tokens).to(torch.int32)
@@ -35,7 +81,7 @@ def forward_mlx(model, mels, tokens):
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = load_models.load_model("tiny", dtype=mx.float32)
_, cls.model, _ = load_torch_and_mlx()
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data)
cls.mels = audio.log_mel_spectrogram(data)
@@ -43,7 +89,7 @@ class TestWhisper(unittest.TestCase):
def test_torch_mlx(self):
np.random.seed(10)
torch_model = load_models.load_torch_model("tiny")
torch_model = load_torch_model(MODEL_NAME)
dims = torch_model.dims
mels = np.random.randn(1, dims.n_mels, 3_000)
@@ -51,19 +97,27 @@ class TestWhisper(unittest.TestCase):
torch_logits = forward_torch(torch_model, mels, tokens)
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
mlx_logits = forward_mlx(mlx_model, mels, tokens)
mlx_logits = forward_mlx(self.model, mels, tokens)
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
def test_fp16(self):
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
mlx_model = load_models.load_model(MLX_FP16_MODEL_PATH, mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16)
def test_quantized_4bits(self):
mlx_model = load_models.load_model(MLX_4BITS_MODEL_PATH, mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
# Here, we just test if 4-bit models can forward, as the quantized tiny models struggle with accurate transcription
self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options)
@@ -135,7 +189,7 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO, fp16=False)
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
self.assertEqual(
result["text"],
(
@@ -154,7 +208,7 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file, fp16=False)
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)