From 942a6ef620ed560f67c2a31e3ee7464a306eaa06 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 21 Dec 2023 12:54:47 -0800 Subject: [PATCH] qwen conversion --- llms/mixtral/mixtral.py | 6 --- llms/qwen/README.md | 12 ++++-- llms/qwen/convert.py | 87 ++++++++++++++++++++++++++++++++++++----- llms/qwen/qwen.py | 24 ++++++++---- 4 files changed, 101 insertions(+), 28 deletions(-) diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 9201d9ea..30fa8d8b 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -258,14 +258,8 @@ def load_model(folder: str): weights.update(mx.load(wf).items()) weights = tree_unflatten(list(weights.items())) model = Mixtral(model_args) - # model.update(weights) - # quantization = { - # "group_size": 64, - # "bits": 4, - # } if quantization is not None: # TODO: Quantize gate matrices when < 32 tiles supported - print("QUANTIZING") quantization["linear_class_predicate"] = ( lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8 ) diff --git a/llms/qwen/README.md b/llms/qwen/README.md index f9276098..75154325 100644 --- a/llms/qwen/README.md +++ b/llms/qwen/README.md @@ -11,11 +11,15 @@ First download and convert the model with: ```sh python convert.py ``` -The script downloads the model from Hugging Face. The default model is -`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models. -The conversion script will make the `weights.npz` and `config.json` files in -the working directory. +To generate a 4-bit quantized model, use ``-q``. For a full list of options: + +The script downloads the model from Hugging Face. The default model is +`Qwen/Qwen-1_8B`. Check out the [Hugging Face +page](https://huggingface.co/Qwen) to see a list of available models. + +By default, the conversion script will make the directory `mlx_model` and save +the converted `weights.npz` and `config.json` there. ## Generate diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index c5fec060..88135208 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -1,8 +1,14 @@ import argparse +import copy import json +from pathlib import Path +import mlx.core as mx +import mlx.nn as nn import numpy as np import torch +from mlx.utils import tree_flatten, tree_map, tree_unflatten +from qwen import ModelArgs, Qwen from transformers import AutoModelForCausalLM @@ -14,19 +20,58 @@ def replace_key(key: str) -> str: return key -def convert(model_path: str = "Qwen/Qwen-1_8B"): +def quantize(weights, config, args): + quantized_config = copy.deepcopy(config) + + # Load the model: + model_args = ModelArgs() + model_args.vocab_size = config["vocab_size"] + model_args.hidden_size = config["hidden_size"] + model_args.num_attention_heads = config["num_attention_heads"] + model_args.num_hidden_layers = config["num_hidden_layers"] + model_args.kv_channels = config["kv_channels"] + model_args.max_position_embeddings = config["max_position_embeddings"] + model_args.layer_norm_epsilon = config["layer_norm_epsilon"] + model_args.intermediate_size = config["intermediate_size"] + model_args.no_bias = config["no_bias"] + model = Qwen(model_args) + + weights = tree_map(mx.array, weights) + model.update(tree_unflatten(list(weights.items()))) + + # Quantize the model: + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + + # Update the config: + quantized_config["quantization"] = { + "group_size": args.q_group_size, + "bits": args.q_bits, + } + quantized_weights = dict(tree_flatten(model.parameters())) + + return quantized_weights, quantized_config + + +def convert(args): + mlx_path = Path(args.mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + model = AutoModelForCausalLM.from_pretrained( - model_path, trust_remote_code=True, torch_dtype=torch.float16 + args.model, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - np.savez("weights.npz", **weights) + config = model.config.to_dict() + + if args.quantize: + print("[INFO] Quantizing") + weights, config = quantize(weights, config, args) + + np.savez(str(mlx_path / "weights.npz"), **weights) # write config - config = model.config - config_dict = config.to_dict() - with open("config.json", "w") as f: - json.dump(config_dict, f, indent=4) + with open(mlx_path / "config.json", "w") as f: + json.dump(config, f, indent=4) if __name__ == "__main__": @@ -37,7 +82,29 @@ if __name__ == "__main__": help="The huggingface model to be converted", default="Qwen/Qwen-1_8B", ) - + parser.add_argument( + "--mlx-path", + type=str, + default="mlx_model", + help="The path to save the MLX model.", + ) + parser.add_argument( + "-q", + "--quantize", + help="Generate a quantized model.", + action="store_true", + ) + parser.add_argument( + "--q_group_size", + help="Group size for quantization.", + type=int, + default=64, + ) + parser.add_argument( + "--q_bits", + help="Bits per weight for quantization.", + type=int, + default=4, + ) args = parser.parse_args() - - convert(args.model) + convert(args) diff --git a/llms/qwen/qwen.py b/llms/qwen/qwen.py index 3b153eb9..532a8031 100644 --- a/llms/qwen/qwen.py +++ b/llms/qwen/qwen.py @@ -1,6 +1,7 @@ import argparse import json from dataclasses import dataclass +from pathlib import Path import mlx.core as mx import mlx.nn as nn @@ -175,12 +176,11 @@ def generate(prompt: mx.array, model: Qwen, temp: 0.0): yield y -def load_model( - tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json" -): +def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"): model_args = ModelArgs() - with open(config_path, "r") as f: + model_path = Path(model_path) + with open(model_path / "config.json", "r") as f: config = json.load(f) model_args.vocab_size = config["vocab_size"] model_args.hidden_size = config["hidden_size"] @@ -193,9 +193,11 @@ def load_model( model_args.no_bias = config["no_bias"] model = Qwen(model_args) - - weights = mx.load("weights.npz") + weights = mx.load(str(model_path / "weights.npz")) + if quantization := config.get("quantization", False): + nn.QuantizedLinear.quantize_module(model, **quantization) model.update(tree_unflatten(list(weights.items()))) + tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>" ) @@ -204,6 +206,12 @@ def load_model( if __name__ == "__main__": parser = argparse.ArgumentParser(description="Qwen inference script") + parser.add_argument( + "--model-path", + type=str, + default="mlx_model", + help="The path to the model weights and config", + ) parser.add_argument( "--tokenizer", help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B", @@ -216,7 +224,7 @@ if __name__ == "__main__": default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是", ) parser.add_argument( - "--max_tokens", + "--max-tokens", "-m", type=int, default=100, @@ -233,7 +241,7 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model(args.tokenizer) + model, tokenizer = load_model(args.model_path, args.tokenizer) prompt = tokenizer( args.prompt,