Quantize example (#162)

* testing quantization

* conversion + quantization working

* one config processor

* quantization in mistral / nits in llama

* args for quantization

* llama / mistral conversion in good shape

* phi2 quantized

* mixtral

* qwen conversion
This commit is contained in:
Awni Hannun
2023-12-21 12:59:37 -08:00
committed by GitHub
parent 4c9db80ed2
commit 3cf436b529
17 changed files with 553 additions and 126 deletions

View File

@@ -1,6 +1,7 @@
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
@@ -175,12 +176,11 @@ def generate(prompt: mx.array, model: Qwen, temp: 0.0):
yield y
def load_model(
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
):
def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
model_args = ModelArgs()
with open(config_path, "r") as f:
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
@@ -193,9 +193,11 @@ def load_model(
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = mx.load("weights.npz")
weights = mx.load(str(model_path / "weights.npz"))
if quantization := config.get("quantization", False):
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
)
@@ -204,6 +206,12 @@ def load_model(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights and config",
)
parser.add_argument(
"--tokenizer",
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
@@ -216,7 +224,7 @@ if __name__ == "__main__":
default="蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
)
parser.add_argument(
"--max_tokens",
"--max-tokens",
"-m",
type=int,
default=100,
@@ -233,7 +241,7 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
model, tokenizer = load_model(args.tokenizer)
model, tokenizer = load_model(args.model_path, args.tokenizer)
prompt = tokenizer(
args.prompt,