mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Quantize example (#162)
* testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -175,12 +176,11 @@ def generate(prompt: mx.array, model: Qwen, temp: 0.0):
|
||||
yield y
|
||||
|
||||
|
||||
def load_model(
|
||||
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
|
||||
):
|
||||
def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
|
||||
model_args = ModelArgs()
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
model_path = Path(model_path)
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
model_args.vocab_size = config["vocab_size"]
|
||||
model_args.hidden_size = config["hidden_size"]
|
||||
@@ -193,9 +193,11 @@ def load_model(
|
||||
model_args.no_bias = config["no_bias"]
|
||||
|
||||
model = Qwen(model_args)
|
||||
|
||||
weights = mx.load("weights.npz")
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
if quantization := config.get("quantization", False):
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
|
||||
)
|
||||
@@ -204,6 +206,12 @@ def load_model(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Qwen inference script")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the model weights and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
|
||||
@@ -216,7 +224,7 @@ if __name__ == "__main__":
|
||||
default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
@@ -233,7 +241,7 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.tokenizer)
|
||||
model, tokenizer = load_model(args.model_path, args.tokenizer)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
|
||||
Reference in New Issue
Block a user