Add quantization

This commit is contained in:
bofenghuang
2023-12-26 19:57:16 +01:00
parent 43a68ee5e3
commit 39600eb383
3 changed files with 57 additions and 4 deletions

View File

@@ -1,18 +1,43 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import json
from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from whisper.load_models import load_torch_model, torch_to_mlx
from whisper.torch_whisper import ModelDimensions
from whisper.whisper import Whisper
MODEL_DTYPES = {"float16", "float32"}
def quantize(weights, config, dtype, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Whisper(ModelDimensions(**config), dtype)
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
@@ -33,19 +58,43 @@ if __name__ == "__main__":
default="float16",
help="The dtype to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
assert args.dtype in MODEL_DTYPES, f"dtype {args.dtype} not found in {MODEL_DTYPES}"
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, dtype, args)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
print("[INFO] Saving")
np.savez(str(mlx_path / "weights.npz"), **weights)
# Save config.json with model_type

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
from . import audio, decoding, load_models
from . import audio, decoding, load_models, torch_whisper, whisper
from .transcribe import transcribe

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import List
import mlx.core as mx
import mlx.nn as nn
import torch
from mlx.utils import tree_map, tree_unflatten
from tqdm import tqdm
@@ -227,12 +228,15 @@ def load_model(
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
model_args = torch_whisper.ModelDimensions(**config)
quantization = config.pop("quantization", None)
model_args = torch_whisper.ModelDimensions(**config)
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
mx.eval(model.parameters())