Quantize example (#162)

* testing quantization

* conversion + quantization working

* one config processor

* quantization in mistral / nits in llama

* args for quantization

* llama / mistral conversion in good shape

* phi2 quantized

* mixtral

* qwen conversion
This commit is contained in:
Awni Hannun
2023-12-21 12:59:37 -08:00
committed by GitHub
parent 4c9db80ed2
commit 3cf436b529
17 changed files with 553 additions and 126 deletions

View File

@@ -43,10 +43,18 @@ Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them:
```
python convert.py --model-path $MIXTRAL_MODEL/
python convert.py --torch-path $MIXTRAL_MODEL/
```
The conversion script will save the converted weights in the same location.
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
```
python convert.py --help
```
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
### Generate

View File

@@ -1,59 +1,152 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mixtral import Mixtral, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
def convert(weights, config):
def convert_single(k, v):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = config["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u)
for e, u in enumerate(v)
]
state = torch.load(tf)
weights = {}
for k, v in state.items():
weights.update(convert_single(k, v))
return weights
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model and update with the subset of weights:
config.pop("quantization", None)
model = Mixtral(ModelArgs(**config))
all_weights = dict(tree_flatten(model.parameters()))
weights = tree_map(mx.array, weights)
all_weights.update(weights)
all_weights = tree_unflatten(list(all_weights.items()))
model.update(all_weights)
# Quantize the model:
nn.QuantizedLinear.quantize_module(
model,
args.q_group_size,
args.q_bits,
# TODO: Quantize gate matrices when < 32 tiles supported
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
# Extract the subset of quantized weights:
all_weights = dict(tree_flatten(model.parameters()))
quantized_weights = {}
for k, v in all_weights.items():
if k not in weights:
continue
quantized_weights[k] = v
prefix = k.split(".")[:-1]
for qw in ["scales", "biases"]:
if (k := ".".join(prefix + [qw])) in all_weights:
quantized_weights[k] = all_weights[k]
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument(
"--model-path",
"--torch-path",
type=str,
default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
default="Mixtral-8x7B-v0.1",
help="The path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
model_path = Path(args.model_path)
torch_path = Path(args.torch_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
with open("params.json") as fid:
args = json.load(fid)
args["model_type"] = "mixtral"
with open(model_path / "config.json", "w") as f:
json.dump(args, f, indent=4)
config = json.load(fid)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
# Copy tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
# Convert and save model in shards
torch_files = glob.glob(str(torch_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)
weights = convert(tf, config)
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / f"weights.{e}.npz"), **weights)
# Save updated config
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "mixtral"
json.dump(config, f, indent=4)

View File

@@ -244,20 +244,27 @@ class Tokenizer:
return out
def load_model(folder: str, dtype=mx.float16):
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args)
if quantization is not None:
# TODO: Quantize gate matrices when < 32 tiles supported
quantization["linear_class_predicate"] = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
)
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
return model, tokenizer
@@ -284,7 +291,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model-path",
type=str,
default="Mixtral-8x7B-v0.1",
default="mlx_model",
help="The path to the model weights, tokenizer, and config",
)
parser.add_argument(