mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantize example (#162)
* testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion
This commit is contained in:
@@ -30,24 +30,32 @@ Face](https://huggingface.co/TinyLlama).
|
||||
Convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model>
|
||||
python convert.py --torch-path <path_to_torch_model>
|
||||
```
|
||||
|
||||
To generate a 4-bit quantized model use the `-q` flag:
|
||||
|
||||
```
|
||||
python convert.py --torch-path <path_to_torch_model> -q
|
||||
```
|
||||
|
||||
For TinyLlama use
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
|
||||
python convert.py --torch-path <path_to_torch_model> --model-name tiny_llama
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
By default, the conversion script will make the directory `mlx_model` and save
|
||||
the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
|
||||
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can interact with the
|
||||
LlaMA model:
|
||||
LlamA model:
|
||||
|
||||
```
|
||||
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
|
||||
python llama.py --prompt "hello"
|
||||
```
|
||||
|
||||
Run `python llama.py --help` for more details.
|
||||
|
@@ -2,12 +2,18 @@
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def llama(model_path):
|
||||
@@ -57,9 +63,7 @@ def tiny_llama(model_path):
|
||||
except ImportError as e:
|
||||
print("The transformers package must be installed for this model conversion:")
|
||||
print("pip install transformers")
|
||||
import sys
|
||||
|
||||
sys.exit(0)
|
||||
exit(0)
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
str(model_path)
|
||||
@@ -114,11 +118,40 @@ def tiny_llama(model_path):
|
||||
return weights, params
|
||||
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model:
|
||||
config = sanitize_config(config, weights)
|
||||
model = Llama(ModelArgs(**config))
|
||||
weights = tree_map(mx.array, weights)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": args.q_group_size,
|
||||
"bits": args.q_bits,
|
||||
}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
help="Path to the model. The MLX weights will also be saved there.",
|
||||
"--torch-path",
|
||||
type=str,
|
||||
help="Path to the PyTorch model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
@@ -130,12 +163,43 @@ if __name__ == "__main__":
|
||||
choices=["tiny_llama", "llama"],
|
||||
default="llama",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_group_size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
weights, params = globals()[args.model_name](model_path)
|
||||
torch_path = Path(args.torch_path)
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("[INFO] Loading")
|
||||
weights, params = globals()[args.model_name](torch_path)
|
||||
params["model_type"] = "llama"
|
||||
np.savez(str(model_path / "weights.npz"), **weights)
|
||||
with open(model_path / "config.json", "w") as fid:
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, params = quantize(weights, params, args)
|
||||
|
||||
print("[INFO] Saving")
|
||||
shutil.copyfile(
|
||||
str(torch_path / "tokenizer.model"),
|
||||
str(mlx_path / "tokenizer.model"),
|
||||
)
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
|
@@ -178,6 +178,12 @@ class Llama(nn.Module):
|
||||
return self.output(x)
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
@@ -194,7 +200,7 @@ class Llama(nn.Module):
|
||||
x = self.norm(x)
|
||||
# We only care about the last logits that generate the next token
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
y = sample(y)
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
@@ -218,8 +224,7 @@ class Llama(nn.Module):
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
y = sample(self.output(x[:, -1]))
|
||||
|
||||
yield y
|
||||
|
||||
@@ -326,38 +331,46 @@ def few_shot_generate(args):
|
||||
print()
|
||||
|
||||
|
||||
def sanitize_config(config, weights):
|
||||
config.pop("model_type", None)
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
if "rope_theta" not in config:
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
config.pop(k, None)
|
||||
return config
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
if "rope_theta" not in config:
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
if k in config:
|
||||
config.pop(k)
|
||||
config = sanitize_config(json.loads(f.read()), weights)
|
||||
quantization = config.pop("quantization", None)
|
||||
model = Llama(ModelArgs(**config))
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
return model
|
||||
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||
parser.add_argument(
|
||||
"model", help="Path to the model directory containing the MLX weights"
|
||||
"--model-path",
|
||||
help="Path to the model directory containing the MLX weights",
|
||||
default="mlx_model",
|
||||
)
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model. Ignored when --few-shot is provided.",
|
||||
@@ -374,7 +387,7 @@ if __name__ == "__main__":
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
"--temp", type=float, default=0.0, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
@@ -382,9 +395,8 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model = load_model(args.model)
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user