mlx-examples/llms/llama/convert.py
Awni Hannun 3cf436b529
Quantize example (#162)
* testing quantization

* conversion + quantization working

* one config processor

* quantization in mistral / nits in llama

* args for quantization

* llama / mistral conversion in good shape

* phi2 quantized

* mixtral

* qwen conversion
2023-12-21 12:59:37 -08:00

206 lines
6.2 KiB
Python

# Copyright © 2023 Apple Inc.
import argparse
import collections
import copy
import glob
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def llama(model_path):
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def shard_key(k):
keys = k.split(".")
if len(keys) < 2:
return None
return keys[-2]
def unshard(k, v):
wn = shard_key(k)
if wn not in SHARD_WEIGHTS:
return v
elif wn in SHARD_FIRST:
axis = 0
elif wn in SHARD_SECOND:
axis = 1
else:
raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis)
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list)
for wf in torch_files:
state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items():
v = v.to(torch.float16).numpy()
if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v)
else:
weights[k] = v
for k, v in weights.items():
weights[k] = unshard(k, v)
with open(model_path / "params.json", "r") as f:
params = json.loads(f.read())
return weights, params
def tiny_llama(model_path):
try:
import transformers
except ImportError as e:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
exit(0)
model = transformers.AutoModelForCausalLM.from_pretrained(
str(model_path)
).state_dict()
config = transformers.AutoConfig.from_pretrained(model_path)
# things to change
# 1. there's no "model." in the weight names
model = {k.replace("model.", ""): v for k, v in model.items()}
# 2. mlp is called feed_forward
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
# 3. up_proj, down_proj, gate_proj
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
# 4. layernorms
model = {
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
}
model = {
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
}
# 5. lm head
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
# 6. token emb
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
# 7. attention
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
params = {}
params["dim"] = config.hidden_size
params["hidden_dim"] = config.intermediate_size
params["n_heads"] = config.num_attention_heads
if hasattr(config, "num_key_value_heads"):
params["n_kv_heads"] = config.num_key_value_heads
params["n_layers"] = config.num_hidden_layers
params["vocab_size"] = config.vocab_size
params["norm_eps"] = config.rms_norm_eps
params["rope_traditional"] = False
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
return weights, params
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
config = sanitize_config(config, weights)
model = Llama(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--torch-path",
type=str,
help="Path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--model-name",
help=(
"Name of the model to convert. Use 'llama' for models in the "
"Llama family distributed by Meta including Llama 1, Llama 2, "
"Code Llama, and Llama chat."
),
choices=["tiny_llama", "llama"],
default="llama",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
torch_path = Path(args.torch_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading")
weights, params = globals()[args.model_name](torch_path)
params["model_type"] = "llama"
if args.quantize:
print("[INFO] Quantizing")
weights, params = quantize(weights, params, args)
print("[INFO] Saving")
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
np.savez(str(mlx_path / "weights.npz"), **weights)
with open(mlx_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4)