mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
feat: add quantization support
This commit is contained in:
@@ -14,6 +14,9 @@ Next, download and convert the model.
|
||||
```sh
|
||||
python convert.py --model-path <path_to_huggingface_model> --mlx-path <path_to_save_converted_model>
|
||||
```
|
||||
To generate a 4-bit quantized model, use -q. For a full list of options:
|
||||
|
||||
The script downloads the model from Hugging Face. The default model is deepseek-ai/deepseek-coder-6.7b-instruct. Check out the [Hugging Face page](https://huggingface.co/deepseek-ai) to see a list of available models.
|
||||
|
||||
By default, the conversion script will save
|
||||
the converted `weights.npz`, `tokenizer`, and `config.json` there in the mlx-path you speficied .
|
||||
|
@@ -1,10 +1,47 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import copy
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from deepseek_coder import ModelArgs, DeepseekCoder
|
||||
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
model_args = ModelArgs()
|
||||
model_args.vocab_size = config["vocab_size"]
|
||||
model_args.hidden_size = config["hidden_size"]
|
||||
model_args.num_attention_heads = config["num_attention_heads"]
|
||||
model_args.num_key_value_heads = config["num_key_value_heads"]
|
||||
model_args.num_hidden_layers = config["num_hidden_layers"]
|
||||
model_args.max_position_embeddings = config["max_position_embeddings"]
|
||||
model_args.rms_norm_eps = config["rms_norm_eps"]
|
||||
model_args.intermediate_size = config["intermediate_size"]
|
||||
model_args.rope_scaling_factor = config["rope_scaling"]["factor"]
|
||||
model = DeepseekCoder(model_args)
|
||||
|
||||
weights = tree_map(mx.array, weights)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": args.q_group_size,
|
||||
"bits": args.q_bits,
|
||||
}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def convert(args):
|
||||
@@ -59,6 +96,10 @@ def convert(args):
|
||||
|
||||
weights = {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, config = quantize(weights, config, args)
|
||||
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
with open(mlx_path / "config.json", "w") as f:
|
||||
@@ -80,5 +121,23 @@ if __name__ == "__main__":
|
||||
default="mlx_model",
|
||||
help="The path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_group_size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
@@ -18,7 +18,7 @@ class ModelArgs:
|
||||
num_hidden_layers: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
max_position_embeddings: int = 16384
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
rms_norm_eps: float = 1e-6
|
||||
intermediate_size: int = 11008
|
||||
rope_theta: float = 100000
|
||||
rope_scaling_factor: float = 4.0
|
||||
@@ -169,8 +169,8 @@ class TransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ffn_norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -194,7 +194,7 @@ class DeepseekCoder(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
@@ -243,7 +243,7 @@ def load_model(model_path: str):
|
||||
model_args.num_key_value_heads = config["num_key_value_heads"]
|
||||
model_args.num_hidden_layers = config["num_hidden_layers"]
|
||||
model_args.max_position_embeddings = config["max_position_embeddings"]
|
||||
model_args.layer_norm_epsilon = config["rms_norm_eps"]
|
||||
model_args.rms_norm_eps = config["rms_norm_eps"]
|
||||
model_args.intermediate_size = config["intermediate_size"]
|
||||
model_args.rope_scaling_factor = config["rope_scaling"]["factor"]
|
||||
|
||||
@@ -275,7 +275,7 @@ if __name__ == "__main__":
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=500,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
Reference in New Issue
Block a user