mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
quantization in mistral / nits in llama
This commit is contained in:
@@ -33,6 +33,12 @@ Convert the weights with:
|
||||
python convert.py --model-path <path_to_torch_model>
|
||||
```
|
||||
|
||||
To generate a 4-bit quantized model use the `-q` flag:
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model> -q
|
||||
```
|
||||
|
||||
For TinyLlama use
|
||||
|
||||
```
|
||||
|
@@ -7,8 +7,12 @@ import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def llama(model_path):
|
||||
@@ -116,11 +120,6 @@ def tiny_llama(model_path):
|
||||
|
||||
|
||||
def quantize(weights, config):
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
@@ -133,7 +132,7 @@ def quantize(weights, config):
|
||||
nn.QuantizedLinear.quantize_module(model)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {"groups": 64, "width": 4}
|
||||
quantized_config["quantization"] = {"group_size": 64, "bits": 4}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
@@ -158,7 +157,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Quantize the model before saving",
|
||||
help="Generate a 4-bit quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
|
@@ -358,9 +358,7 @@ def load_model(model_path):
|
||||
quantization = config.pop("quantization", None)
|
||||
model = Llama(ModelArgs(**config))
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model, groups=quantization["groups"], width=quantization["width"]
|
||||
)
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
return model
|
||||
|
||||
|
@@ -26,6 +26,12 @@ Then, convert the weights with:
|
||||
python convert.py
|
||||
```
|
||||
|
||||
To generate a 4-bit quantized model, use:
|
||||
|
||||
```
|
||||
python convert.py -q
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
> [!TIP]
|
||||
|
@@ -4,8 +4,32 @@ import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from mistral import Mistral, ModelArgs
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def quantize(weights, config):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
config.pop("sliding_window", None)
|
||||
model = Mistral(ModelArgs(**config))
|
||||
weights = tree_map(mx.array, weights)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {"group_size": 64, "bits": 4}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||
@@ -15,14 +39,22 @@ if __name__ == "__main__":
|
||||
default="mistral-7B-v0.1/",
|
||||
help="The path to the Mistral model. The MLX weights will also be saved there.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a 4-bit quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
state = torch.load(str(model_path / "consolidated.00.pth"))
|
||||
np.savez(
|
||||
str(model_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
)
|
||||
weights = {k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, params = quantize(weights, params)
|
||||
|
||||
np.savez(str(model_path / "weights.npz"), **weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
|
@@ -196,11 +196,14 @@ def load_model(folder: str, dtype=mx.float16):
|
||||
config = json.loads(f.read())
|
||||
config.pop("sliding_window", None)
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mistral(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
|
Reference in New Issue
Block a user