Quantize example (#162)

* testing quantization

* conversion + quantization working

* one config processor

* quantization in mistral / nits in llama

* args for quantization

* llama / mistral conversion in good shape

* phi2 quantized

* mixtral

* qwen conversion
This commit is contained in:
Awni Hannun
2023-12-21 12:59:37 -08:00
committed by GitHub
parent 4c9db80ed2
commit 3cf436b529
17 changed files with 553 additions and 126 deletions

View File

@@ -1,4 +1,5 @@
import argparse
import json
import math
from dataclasses import dataclass
from pathlib import Path
@@ -158,8 +159,16 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
def load_model(model_path: str):
model = Phi2(ModelArgs())
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
weights = mx.load(str(model_path / "weights.npz"))
model.update(tree_unflatten(list(weights.items())))
weights = tree_unflatten(list(weights.items()))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
return model, tokenizer
@@ -169,7 +178,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model-path",
type=str,
default=".",
default="mlx_model",
help="The path to the model weights",
)
parser.add_argument(