Quantize example (#162)

* testing quantization

* conversion + quantization working

* one config processor

* quantization in mistral / nits in llama

* args for quantization

* llama / mistral conversion in good shape

* phi2 quantized

* mixtral

* qwen conversion
This commit is contained in:
Awni Hannun
2023-12-21 12:59:37 -08:00
committed by GitHub
parent 4c9db80ed2
commit 3cf436b529
17 changed files with 553 additions and 126 deletions

View File

@@ -30,24 +30,32 @@ Face](https://huggingface.co/TinyLlama).
Convert the weights with:
```
python convert.py --model-path <path_to_torch_model>
python convert.py --torch-path <path_to_torch_model>
```
To generate a 4-bit quantized model use the `-q` flag:
```
python convert.py --torch-path <path_to_torch_model> -q
```
For TinyLlama use
```
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
python convert.py --torch-path <path_to_torch_model> --model-name tiny_llama
```
The conversion script will save the converted weights in the same location.
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
### Run
Once you've converted the weights to MLX format, you can interact with the
LlaMA model:
LlamA model:
```
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
python llama.py --prompt "hello"
```
Run `python llama.py --help` for more details.

View File

@@ -2,12 +2,18 @@
import argparse
import collections
import copy
import glob
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def llama(model_path):
@@ -57,9 +63,7 @@ def tiny_llama(model_path):
except ImportError as e:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
import sys
sys.exit(0)
exit(0)
model = transformers.AutoModelForCausalLM.from_pretrained(
str(model_path)
@@ -114,11 +118,40 @@ def tiny_llama(model_path):
return weights, params
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
config = sanitize_config(config, weights)
model = Llama(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model-path",
help="Path to the model. The MLX weights will also be saved there.",
"--torch-path",
type=str,
help="Path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--model-name",
@@ -130,12 +163,43 @@ if __name__ == "__main__":
choices=["tiny_llama", "llama"],
default="llama",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
model_path = Path(args.model_path)
weights, params = globals()[args.model_name](model_path)
torch_path = Path(args.torch_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading")
weights, params = globals()[args.model_name](torch_path)
params["model_type"] = "llama"
np.savez(str(model_path / "weights.npz"), **weights)
with open(model_path / "config.json", "w") as fid:
if args.quantize:
print("[INFO] Quantizing")
weights, params = quantize(weights, params, args)
print("[INFO] Saving")
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
np.savez(str(mlx_path / "weights.npz"), **weights)
with open(mlx_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4)

View File

@@ -178,6 +178,12 @@ class Llama(nn.Module):
return self.output(x)
def generate(self, x, temp=1.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
cache = []
# Make an additive causal mask. We will need that to process the prompt.
@@ -194,7 +200,7 @@ class Llama(nn.Module):
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
y = sample(y)
# y now has size [1]
# Since MLX is lazily evaluated nothing is computed yet.
@@ -218,8 +224,7 @@ class Llama(nn.Module):
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
y = sample(self.output(x[:, -1]))
yield y
@@ -326,38 +331,46 @@ def few_shot_generate(args):
print()
def sanitize_config(config, weights):
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
config.pop(k, None)
return config
def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
if k in config:
config.pop(k)
config = sanitize_config(json.loads(f.read()), weights)
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
return model
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument(
"model", help="Path to the model directory containing the MLX weights"
"--model-path",
help="Path to the model directory containing the MLX weights",
default="mlx_model",
)
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument(
"--prompt",
help="The message to be processed by the model. Ignored when --few-shot is provided.",
@@ -374,7 +387,7 @@ if __name__ == "__main__":
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
"--temp", type=float, default=0.0, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
@@ -382,9 +395,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("[INFO] Loading model from disk.")
model = load_model(args.model)
model, tokenizer = load_model(args.model_path)
if args.few_shot:
few_shot_generate(args)
else: