mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Quantize example (#162)
* testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -158,8 +159,16 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
|
||||
def load_model(model_path: str):
|
||||
model = Phi2(ModelArgs())
|
||||
model_path = Path(model_path)
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(weights)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
@@ -169,7 +178,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default=".",
|
||||
default="mlx_model",
|
||||
help="The path to the model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user