qwen conversion

This commit is contained in:
Awni Hannun
2023-12-21 12:54:47 -08:00
parent 9dbbd8755b
commit 942a6ef620
4 changed files with 101 additions and 28 deletions

View File

@@ -258,14 +258,8 @@ def load_model(folder: str):
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items()))
model = Mixtral(model_args)
# model.update(weights)
# quantization = {
# "group_size": 64,
# "bits": 4,
# }
if quantization is not None:
# TODO: Quantize gate matrices when < 32 tiles supported
print("QUANTIZING")
quantization["linear_class_predicate"] = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
)

View File

@@ -11,11 +11,15 @@ First download and convert the model with:
```sh
python convert.py
```
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
The conversion script will make the `weights.npz` and `config.json` files in
the working directory.
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face
page](https://huggingface.co/Qwen) to see a list of available models.
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz` and `config.json` there.
## Generate

View File

@@ -1,8 +1,14 @@
import argparse
import copy
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from qwen import ModelArgs, Qwen
from transformers import AutoModelForCausalLM
@@ -14,19 +20,58 @@ def replace_key(key: str) -> str:
return key
def convert(model_path: str = "Qwen/Qwen-1_8B"):
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model_args = ModelArgs()
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.kv_channels = config["kv_channels"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(args):
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, torch_dtype=torch.float16
args.model, trust_remote_code=True, torch_dtype=torch.float16
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
config = model.config.to_dict()
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
# write config
config = model.config
config_dict = config.to_dict()
with open("config.json", "w") as f:
json.dump(config_dict, f, indent=4)
with open(mlx_path / "config.json", "w") as f:
json.dump(config, f, indent=4)
if __name__ == "__main__":
@@ -37,7 +82,29 @@ if __name__ == "__main__":
help="The huggingface model to be converted",
default="Qwen/Qwen-1_8B",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
convert(args.model)
convert(args)

View File

@@ -1,6 +1,7 @@
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
@@ -175,12 +176,11 @@ def generate(prompt: mx.array, model: Qwen, temp: 0.0):
yield y
def load_model(
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
):
def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
model_args = ModelArgs()
with open(config_path, "r") as f:
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
@@ -193,9 +193,11 @@ def load_model(
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = mx.load("weights.npz")
weights = mx.load(str(model_path / "weights.npz"))
if quantization := config.get("quantization", False):
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
)
@@ -204,6 +206,12 @@ def load_model(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights and config",
)
parser.add_argument(
"--tokenizer",
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
@@ -216,7 +224,7 @@ if __name__ == "__main__":
default="蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
)
parser.add_argument(
"--max_tokens",
"--max-tokens",
"-m",
type=int,
default=100,
@@ -233,7 +241,7 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
model, tokenizer = load_model(args.tokenizer)
model, tokenizer = load_model(args.model_path, args.tokenizer)
prompt = tokenizer(
args.prompt,