Quantize example (#162)

* testing quantization

* conversion + quantization working

* one config processor

* quantization in mistral / nits in llama

* args for quantization

* llama / mistral conversion in good shape

* phi2 quantized

* mixtral

* qwen conversion
This commit is contained in:
Awni Hannun 2023-12-21 12:59:37 -08:00 committed by GitHub
parent 4c9db80ed2
commit 3cf436b529
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 553 additions and 126 deletions

View File

@ -30,24 +30,32 @@ Face](https://huggingface.co/TinyLlama).
Convert the weights with:
```
python convert.py --model-path <path_to_torch_model>
python convert.py --torch-path <path_to_torch_model>
```
To generate a 4-bit quantized model use the `-q` flag:
```
python convert.py --torch-path <path_to_torch_model> -q
```
For TinyLlama use
```
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
python convert.py --torch-path <path_to_torch_model> --model-name tiny_llama
```
The conversion script will save the converted weights in the same location.
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
### Run
Once you've converted the weights to MLX format, you can interact with the
LlaMA model:
LlamA model:
```
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
python llama.py --prompt "hello"
```
Run `python llama.py --help` for more details.

View File

@ -2,12 +2,18 @@
import argparse
import collections
import copy
import glob
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def llama(model_path):
@ -57,9 +63,7 @@ def tiny_llama(model_path):
except ImportError as e:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
import sys
sys.exit(0)
exit(0)
model = transformers.AutoModelForCausalLM.from_pretrained(
str(model_path)
@ -114,11 +118,40 @@ def tiny_llama(model_path):
return weights, params
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
config = sanitize_config(config, weights)
model = Llama(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--model-path",
help="Path to the model. The MLX weights will also be saved there.",
"--torch-path",
type=str,
help="Path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--model-name",
@ -130,12 +163,43 @@ if __name__ == "__main__":
choices=["tiny_llama", "llama"],
default="llama",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
model_path = Path(args.model_path)
weights, params = globals()[args.model_name](model_path)
torch_path = Path(args.torch_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading")
weights, params = globals()[args.model_name](torch_path)
params["model_type"] = "llama"
np.savez(str(model_path / "weights.npz"), **weights)
with open(model_path / "config.json", "w") as fid:
if args.quantize:
print("[INFO] Quantizing")
weights, params = quantize(weights, params, args)
print("[INFO] Saving")
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
np.savez(str(mlx_path / "weights.npz"), **weights)
with open(mlx_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4)

View File

@ -178,6 +178,12 @@ class Llama(nn.Module):
return self.output(x)
def generate(self, x, temp=1.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
cache = []
# Make an additive causal mask. We will need that to process the prompt.
@ -194,7 +200,7 @@ class Llama(nn.Module):
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
y = sample(y)
# y now has size [1]
# Since MLX is lazily evaluated nothing is computed yet.
@ -218,8 +224,7 @@ class Llama(nn.Module):
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
y = sample(self.output(x[:, -1]))
yield y
@ -326,38 +331,46 @@ def few_shot_generate(args):
print()
def sanitize_config(config, weights):
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
config.pop(k, None)
return config
def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
if k in config:
config.pop(k)
config = sanitize_config(json.loads(f.read()), weights)
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
return model
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument(
"model", help="Path to the model directory containing the MLX weights"
"--model-path",
help="Path to the model directory containing the MLX weights",
default="mlx_model",
)
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument(
"--prompt",
help="The message to be processed by the model. Ignored when --few-shot is provided.",
@ -374,7 +387,7 @@ if __name__ == "__main__":
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
"--temp", type=float, default=0.0, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
@ -382,9 +395,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("[INFO] Loading model from disk.")
model = load_model(args.model)
model, tokenizer = load_model(args.model_path)
if args.few_shot:
few_shot_generate(args)
else:

View File

@ -23,10 +23,17 @@ tar -xf mistral-7B-v0.1.tar
Then, convert the weights with:
```
python convert.py
python convert.py --torch-path <path_to_torch>
```
The conversion script will save the converted weights in the same location.
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
```
python convert.py --help
```
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
> [!TIP]
> Alternatively, you can also download a few converted checkpoints from the
@ -40,7 +47,7 @@ Once you've converted the weights to MLX format, you can generate text with
the Mistral model:
```
python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
python mistral.py --prompt "It is a truth universally acknowledged,"
```
Run `python mistral.py --help` for more details.

View File

@ -1,32 +1,98 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mistral import Mistral, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
config.pop("sliding_window", None)
model = Mistral(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--model-path",
"--torch-path",
type=str,
default="mistral-7B-v0.1/",
help="The path to the Mistral model. The MLX weights will also be saved there.",
default="mistral-7B-v0.1",
help="The path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
torch_path = Path(args.torch_path)
state = torch.load(str(torch_path / "consolidated.00.pth"))
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
weights = {k: v.to(torch.float16).numpy() for k, v in state.items()}
with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read())
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
# Save weights
np.savez(str(mlx_path / "weights.npz"), **weights)
# Copy tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
# Save config.json with model_type
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "mistral"
with open(model_path / "config.json", "w") as f:
json.dump(config, f, indent=4)

View File

@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor
@ -189,18 +189,20 @@ class Tokenizer:
return out
def load_model(folder: str, dtype=mx.float16):
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("sliding_window", None)
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mistral(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
return model, tokenizer
@ -227,7 +229,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model-path",
type=str,
default="mistral-7B-v0.1",
default="mlx_model",
help="The path to the model weights and tokenizer",
)
parser.add_argument(
@ -236,7 +238,7 @@ if __name__ == "__main__":
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max_tokens",
"--max-tokens",
"-m",
type=int,
default=100,
@ -246,7 +248,7 @@ if __name__ == "__main__":
"--temp",
help="The sampling temperature.",
type=float,
default=1.0,
default=0.0,
)
parser.add_argument(
"--tokens_per_eval",

View File

@ -43,10 +43,18 @@ Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them:
```
python convert.py --model-path $MIXTRAL_MODEL/
python convert.py --torch-path $MIXTRAL_MODEL/
```
The conversion script will save the converted weights in the same location.
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
```
python convert.py --help
```
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz`, `tokenizer.model`, and `config.json` there.
### Generate

View File

@ -1,59 +1,152 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mixtral import Mixtral, ModelArgs
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
def convert(weights, config):
def convert_single(k, v):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = config["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u)
for e, u in enumerate(v)
]
state = torch.load(tf)
weights = {}
for k, v in state.items():
weights.update(convert_single(k, v))
return weights
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model and update with the subset of weights:
config.pop("quantization", None)
model = Mixtral(ModelArgs(**config))
all_weights = dict(tree_flatten(model.parameters()))
weights = tree_map(mx.array, weights)
all_weights.update(weights)
all_weights = tree_unflatten(list(all_weights.items()))
model.update(all_weights)
# Quantize the model:
nn.QuantizedLinear.quantize_module(
model,
args.q_group_size,
args.q_bits,
# TODO: Quantize gate matrices when < 32 tiles supported
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
# Extract the subset of quantized weights:
all_weights = dict(tree_flatten(model.parameters()))
quantized_weights = {}
for k, v in all_weights.items():
if k not in weights:
continue
quantized_weights[k] = v
prefix = k.split(".")[:-1]
for qw in ["scales", "biases"]:
if (k := ".".join(prefix + [qw])) in all_weights:
quantized_weights[k] = all_weights[k]
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument(
"--model-path",
"--torch-path",
type=str,
default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
default="Mixtral-8x7B-v0.1",
help="The path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
model_path = Path(args.model_path)
torch_path = Path(args.torch_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
with open("params.json") as fid:
args = json.load(fid)
args["model_type"] = "mixtral"
with open(model_path / "config.json", "w") as f:
json.dump(args, f, indent=4)
config = json.load(fid)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
# Copy tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
# Convert and save model in shards
torch_files = glob.glob(str(torch_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)
weights = convert(tf, config)
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / f"weights.{e}.npz"), **weights)
# Save updated config
with open(mlx_path / "config.json", "w") as f:
config["model_type"] = "mixtral"
json.dump(config, f, indent=4)

View File

@ -244,20 +244,27 @@ class Tokenizer:
return out
def load_model(folder: str, dtype=mx.float16):
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args)
if quantization is not None:
# TODO: Quantize gate matrices when < 32 tiles supported
quantization["linear_class_predicate"] = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
)
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
return model, tokenizer
@ -284,7 +291,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model-path",
type=str,
default="Mixtral-8x7B-v0.1",
default="mlx_model",
help="The path to the model weights, tokenizer, and config",
)
parser.add_argument(

View File

@ -1 +0,0 @@
weights.npz

View File

@ -15,7 +15,14 @@ Download and convert the model:
python convert.py
```
This will make the `weights.npz` file which MLX can read.
To generate a 4-bit quantized model use the `-q` flag:
```
python convert.py -q
```
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz`, and `config.json` there.
> [!TIP] Alternatively, you can also download a few converted checkpoints from
> the [MLX Community](https://huggingface.co/mlx-community) organization on

View File

@ -1,7 +1,37 @@
import argparse
import copy
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from phi2 import ModelArgs, Phi2
from transformers import AutoModelForCausalLM
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Phi2(ModelArgs())
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def replace_key(key: str) -> str:
if "wte.weight" in key:
key = "wte.weight"
@ -12,12 +42,50 @@ def replace_key(key: str) -> str:
def convert():
parser = argparse.ArgumentParser(description="Convert Phi-2 weights to MLX")
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
params = {}
if args.quantize:
print("[INFO] Quantizing")
weights, params = quantize(weights, params, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
with open(mlx_path / "config.json", "w") as fid:
params["model_type"] = "phi2"
json.dump(params, fid, indent=4)
if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
import argparse
import json
import math
from dataclasses import dataclass
from pathlib import Path
@ -158,8 +159,16 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
def load_model(model_path: str):
model = Phi2(ModelArgs())
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
weights = mx.load(str(model_path / "weights.npz"))
model.update(tree_unflatten(list(weights.items())))
weights = tree_unflatten(list(weights.items()))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
return model, tokenizer
@ -169,7 +178,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model-path",
type=str,
default=".",
default="mlx_model",
help="The path to the model weights",
)
parser.add_argument(

View File

@ -1,2 +0,0 @@
weights.npz
config.json

View File

@ -11,11 +11,15 @@ First download and convert the model with:
```sh
python convert.py
```
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face page](https://huggingface.co/Qwen) to see a list of available models.
The conversion script will make the `weights.npz` and `config.json` files in
the working directory.
To generate a 4-bit quantized model, use ``-q``. For a full list of options:
The script downloads the model from Hugging Face. The default model is
`Qwen/Qwen-1_8B`. Check out the [Hugging Face
page](https://huggingface.co/Qwen) to see a list of available models.
By default, the conversion script will make the directory `mlx_model` and save
the converted `weights.npz` and `config.json` there.
## Generate

View File

@ -1,8 +1,14 @@
import argparse
import copy
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from qwen import ModelArgs, Qwen
from transformers import AutoModelForCausalLM
@ -14,19 +20,58 @@ def replace_key(key: str) -> str:
return key
def convert(model_path: str = "Qwen/Qwen-1_8B"):
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model_args = ModelArgs()
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.kv_channels = config["kv_channels"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(args):
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, torch_dtype=torch.float16
args.model, trust_remote_code=True, torch_dtype=torch.float16
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
config = model.config.to_dict()
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
# write config
config = model.config
config_dict = config.to_dict()
with open("config.json", "w") as f:
json.dump(config_dict, f, indent=4)
with open(mlx_path / "config.json", "w") as f:
json.dump(config, f, indent=4)
if __name__ == "__main__":
@ -37,7 +82,29 @@ if __name__ == "__main__":
help="The huggingface model to be converted",
default="Qwen/Qwen-1_8B",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="The path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q_group_size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
args = parser.parse_args()
convert(args.model)
convert(args)

View File

@ -1,6 +1,7 @@
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
@ -175,12 +176,11 @@ def generate(prompt: mx.array, model: Qwen, temp: 0.0):
yield y
def load_model(
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
):
def load_model(model_path: str, tokenizer_path: str = "Qwen/Qwen-1_8B"):
model_args = ModelArgs()
with open(config_path, "r") as f:
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
@ -193,9 +193,11 @@ def load_model(
model_args.no_bias = config["no_bias"]
model = Qwen(model_args)
weights = mx.load("weights.npz")
weights = mx.load(str(model_path / "weights.npz"))
if quantization := config.get("quantization", False):
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
)
@ -204,6 +206,12 @@ def load_model(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights and config",
)
parser.add_argument(
"--tokenizer",
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
@ -216,7 +224,7 @@ if __name__ == "__main__":
default="蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
)
parser.add_argument(
"--max_tokens",
"--max-tokens",
"-m",
type=int,
default=100,
@ -233,7 +241,7 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
model, tokenizer = load_model(args.tokenizer)
model, tokenizer = load_model(args.model_path, args.tokenizer)
prompt = tokenizer(
args.prompt,