mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
qwen conversion
This commit is contained in:
@@ -258,14 +258,8 @@ def load_model(folder: str):
|
||||
weights.update(mx.load(wf).items())
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
model = Mixtral(model_args)
|
||||
# model.update(weights)
|
||||
# quantization = {
|
||||
# "group_size": 64,
|
||||
# "bits": 4,
|
||||
# }
|
||||
if quantization is not None:
|
||||
# TODO: Quantize gate matrices when < 32 tiles supported
|
||||
print("QUANTIZING")
|
||||
quantization["linear_class_predicate"] = (
|
||||
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
|
||||
)
|
||||
|
@@ -11,11 +11,15 @@ First download and convert the model with:
|
||||
```sh
|
||||
python convert.py
|
||||
```
|
||||
The script downloads the model from Hugging Face. The default model is
|
||||
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
|
||||
|
||||
The conversion script will make the `weights.npz` and `config.json` files in
|
||||
the working directory.
|
||||
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
|
||||
|
||||
The script downloads the model from Hugging Face. The default model is
|
||||
`Qwen/Qwen-1_8B`. Check out the [Hugging Face
|
||||
page](https://huggingface.co/Qwen) to see a list of available models.
|
||||
|
||||
By default, the conversion script will make the directory `mlx_model` and save
|
||||
the converted `weights.npz` and `config.json` there.
|
||||
|
||||
## Generate
|
||||
|
||||
|
@@ -1,8 +1,14 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from qwen import ModelArgs, Qwen
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
@@ -14,19 +20,58 @@ def replace_key(key: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def convert(model_path: str = "Qwen/Qwen-1_8B"):
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
model_args = ModelArgs()
|
||||
model_args.vocab_size = config["vocab_size"]
|
||||
model_args.hidden_size = config["hidden_size"]
|
||||
model_args.num_attention_heads = config["num_attention_heads"]
|
||||
model_args.num_hidden_layers = config["num_hidden_layers"]
|
||||
model_args.kv_channels = config["kv_channels"]
|
||||
model_args.max_position_embeddings = config["max_position_embeddings"]
|
||||
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
|
||||
model_args.intermediate_size = config["intermediate_size"]
|
||||
model_args.no_bias = config["no_bias"]
|
||||
model = Qwen(model_args)
|
||||
|
||||
weights = tree_map(mx.array, weights)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": args.q_group_size,
|
||||
"bits": args.q_bits,
|
||||
}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def convert(args):
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, trust_remote_code=True, torch_dtype=torch.float16
|
||||
args.model, trust_remote_code=True, torch_dtype=torch.float16
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
config = model.config.to_dict()
|
||||
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, config = quantize(weights, config, args)
|
||||
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
|
||||
# write config
|
||||
config = model.config
|
||||
config_dict = config.to_dict()
|
||||
with open("config.json", "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
with open(mlx_path / "config.json", "w") as f:
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -37,7 +82,29 @@ if __name__ == "__main__":
|
||||
help="The huggingface model to be converted",
|
||||
default="Qwen/Qwen-1_8B",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_group_size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(args.model)
|
||||
convert(args)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -175,12 +176,11 @@ def generate(prompt: mx.array, model: Qwen, temp: 0.0):
|
||||
yield y
|
||||
|
||||
|
||||
def load_model(
|
||||
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
|
||||
):
|
||||
def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
|
||||
model_args = ModelArgs()
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
model_path = Path(model_path)
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
model_args.vocab_size = config["vocab_size"]
|
||||
model_args.hidden_size = config["hidden_size"]
|
||||
@@ -193,9 +193,11 @@ def load_model(
|
||||
model_args.no_bias = config["no_bias"]
|
||||
|
||||
model = Qwen(model_args)
|
||||
|
||||
weights = mx.load("weights.npz")
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
if quantization := config.get("quantization", False):
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
|
||||
)
|
||||
@@ -204,6 +206,12 @@ def load_model(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Qwen inference script")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the model weights and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
|
||||
@@ -216,7 +224,7 @@ if __name__ == "__main__":
|
||||
default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
@@ -233,7 +241,7 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.tokenizer)
|
||||
model, tokenizer = load_model(args.model_path, args.tokenizer)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
|
Reference in New Issue
Block a user