3 Commits

Author SHA1 Message Date
Alex Barron
ae53ed9090 add TesseraQ rounding 2024-12-19 19:35:26 -08:00
Alex Barron
d4ef909d4a Length masking for batch inputs (#1173)
* length masking

* add mask to mlx_lm model interface

* remove lengths

* fix test:

* comment + fix
2024-12-18 19:43:52 -08:00
Awni Hannun
db109184b7 Fix no template prompt + top_k sampling (#1166)
* fix no template prompt

* add top_k sampling

* fix chinese
2024-12-18 18:46:50 -08:00
39 changed files with 863 additions and 83 deletions

613
llms/mlx_lm/awq.py Normal file
View File

@@ -0,0 +1,613 @@
# Learned quantization using AWQ and TesseraQ:
# References:
# AWQ:
# https://arxiv.org/abs/2306.00978
# https://github.com/mit-han-lab/llm-awq
# TesseraQ:
# https://arxiv.org/abs/2410.19103
# https://github.com/Intelligent-Computing-Lab-Yale/TesseraQ
import argparse
import glob
import shutil
from pathlib import Path
from typing import Callable
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from datasets import Dataset, load_dataset
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_map_with_path
from mlx_lm.models.base import create_attention_mask
from mlx_lm.tokenizer_utils import TokenizerWrapper
from mlx_lm.utils import fetch_from_hub, get_model_path, save_config, save_weights
from tqdm import tqdm
ROUNDING_THRESHOLDS = [
0.8,
0.65,
0.5,
0.43,
0.38,
0.34,
0.3,
0.27,
0.24,
0.21,
0.18,
0.15,
0.12,
0.10,
0.08,
0.06,
0.04,
0.02,
0.01,
0.005,
]
def mse(x, y):
return ((x - y).astype(mx.float32)) ** 2
def sigmoid(x: mx.array, gamma: float = -0.1, zeta: float = 1.1):
return mx.clip(nn.sigmoid(x) * (zeta - gamma) + gamma, 0, 1)
def sigmoid_inverse(y: mx.array, gamma: float = -0.1, zeta: float = 1.1):
return -mx.log((zeta - gamma) / (y - gamma) - 1)
def run_layer(layer: nn.Module, x: mx.array, batch_size: int = 32, **kwargs):
y = []
for i in range(0, x.shape[0], batch_size):
y.append(layer(x[i : i + batch_size], **kwargs))
mx.eval(y)
y = mx.concatenate(y, axis=0)
return y
def dist_split(x: mx.array, group: mx.distributed.Group):
B = x.shape[0]
N = group.size()
assert B % N == 0
r = group.rank()
local_B = (B + N - 1) // N
return x[r * local_B : (r + 1) * local_B]
def search_best_scale(
layers: list[nn.Module],
x: mx.array,
quantize_func: Callable,
block: nn.Module | None = None,
layer_kwargs: dict | None = None,
n_grid: int = 1,
):
group = mx.distributed.init() if mx.distributed.is_available() else None
layer_kwargs = layer_kwargs or {}
block = block or layers[0]
out = block(x, **layer_kwargs)
x_max = x.abs().mean(axis=(0, 1))
best_error = float("inf")
best_scales = None
weights = tree_flatten(block.parameters())
for ratio in tqdm(range(n_grid)):
ratio = ratio * 1 / n_grid
scales = mx.maximum(x_max**ratio, 1e-4).reshape(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for layer in layers:
layer.weight = quantize_func(layer.weight * scales) / scales
out_q = run_layer(block, x, **layer_kwargs)
loss = mse(out, out_q).sum()
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= out.size
mx.eval(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_scales = scales
# reload the original weights
block.load_weights(weights)
best_scales = best_scales.reshape(-1)
mx.eval(best_scales)
return best_scales
def apply_scale(prev_op, layers, scales):
# Apply the scales to the layers
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
prev_op.weight = prev_op.weight / scales[:, mx.newaxis]
if hasattr(prev_op, "bias"):
prev_op.bias = prev_op.bias / scales
layers[0].weight = layers[0].weight * scales[mx.newaxis]
elif isinstance(prev_op, (nn.LayerNorm, nn.RMSNorm)):
prev_op.weight = prev_op.weight / scales
if hasattr(prev_op, "bias"):
prev_op.bias = prev_op.bias / scales
for layer in layers:
layer.weight = layer.weight * scales
else:
raise NotImplementedError(f"Could not apply scale to prev_op: {prev_op}")
def scale_block(
block, input_feat, quantize_func: Callable, layer_kwargs: dict | None = None
):
layers = [
block.self_attn.q_proj,
block.self_attn.k_proj,
block.self_attn.v_proj,
]
scales = search_best_scale(
layers=layers,
block=block.self_attn,
x=input_feat["q_proj"],
quantize_func=quantize_func,
layer_kwargs=layer_kwargs,
)
apply_scale(block.input_layernorm, layers, scales)
for name in ["q_proj", "k_proj", "v_proj"]:
input_feat[name] = input_feat[name] / scales
layers = [
block.mlp.gate_proj,
block.mlp.up_proj,
]
scales = search_best_scale(
block=block.mlp,
layers=layers,
x=input_feat["gate_proj"],
quantize_func=quantize_func,
)
mlp_norm = getattr(
block, "pre_feedforward_layernorm", block.post_attention_layernorm
)
apply_scale(mlp_norm, layers, scales)
for name in ["gate_proj", "up_proj"]:
input_feat[name] = input_feat[name] / scales
layers = [block.mlp.down_proj]
scales = search_best_scale(
layers=layers,
x=input_feat["down_proj"],
quantize_func=quantize_func,
)
apply_scale(block.mlp.up_proj, layers, scales)
input_feat["down_proj"] = input_feat["down_proj"] / scales
def search_best_clip(
w: mx.array,
x: mx.array,
quantize_func: Callable,
group_size: int,
n_grid: int = 2,
max_shrink: float = 0.5,
subsample: int = 4,
batch_size: int = 64,
):
group = mx.distributed.init() if mx.distributed.is_available() else None
x = x[:, ::subsample]
x = x.reshape(*x.shape[:-1], -1, group_size)
w_all = w
w_max_all = []
w_min_all = []
for b in range(0, w.shape[0], batch_size):
w = w_all[b : b + batch_size]
group_shape = (w.shape[0], w.shape[-1] // group_size)
best_error = mx.full(group_shape, float("inf"))
best_w_max = mx.zeros((*group_shape, 1), dtype=x.dtype)
best_w_min = mx.zeros((*group_shape, 1), dtype=x.dtype)
w_shape = w.shape
w = w.reshape(*w.shape[:-1], -1, group_size)
out = mx.einsum("btdg,odg->btod", x, w)
for i in range(int(max_shrink * n_grid)):
p = 1 - i / n_grid
w_max = p * w.max(axis=-1, keepdims=True)
w_min = p * w.min(axis=-1, keepdims=True)
w_m = mx.clip(w, w_min, w_max).reshape(w_shape)
w_q = quantize_func(w_m)
w_q = w_q.reshape(*w_q.shape[:-1], -1, group_size)
out_q = mx.einsum("btdg,odg->btod", x, w_q)
# Take the mean across the input batch
loss = mse(out, out_q).sum(axis=(0, 1))
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= out.shape[0] * out.shape[1]
best_indices = loss < best_error
best_error = mx.where(best_indices, loss, best_error)
best_w_max = mx.where(best_indices[..., mx.newaxis], w_max, best_w_max)
best_w_min = mx.where(best_indices[..., mx.newaxis], w_min, best_w_min)
mx.eval(best_w_max, best_w_min, best_error)
w_max_all.append(best_w_max)
w_min_all.append(best_w_min)
best_w_max = mx.concatenate(w_max_all, axis=0)
best_w_min = mx.concatenate(w_min_all, axis=0)
w_r = w_all.reshape(*w_all.shape[:-1], -1, group_size)
best_w = mx.clip(w_r, best_w_min, best_w_max)
best_w = best_w.reshape(w_all.shape)
mx.eval(best_w)
return best_w
def clip_block(
block: nn.Module,
input_feat: dict[str, mx.array],
quantize_func: Callable,
group_size: int,
):
def apply_clip(path, module):
if (
isinstance(module, nn.Linear)
and "q_proj" not in path
and "k_proj" not in path
):
name = path.split(".")[-1]
best_weight = search_best_clip(
module.weight,
input_feat[name],
quantize_func=quantize_func,
group_size=group_size,
)
module.weight = best_weight
tree_map_with_path(apply_clip, block.leaf_modules(), is_leaf=nn.Module.is_module)
class RoundQuant(nn.Module):
def __init__(self, module, group_size: int = 64, bits: int = 3):
super().__init__()
self.bits = bits
self.group_size = group_size
self._weight = module.weight
if hasattr(module, "bias"):
self._bias = module.bias
_, self._scales, self._biases = mx.quantize(
self._weight, group_size=group_size, bits=bits
)
self._scales = self._scales[..., mx.newaxis]
self._biases = self._biases[..., mx.newaxis]
self._weight = self._weight.reshape(self._weight.shape[0], -1, group_size)
rounding = self._weight / self._scales
rounding = rounding - mx.floor(rounding)
self.rounding = sigmoid_inverse(rounding)
self.v = mx.zeros_like(self._scales)
def __call__(self, x: mx.array):
q = (self._weight - self._biases) / self._scales
q = mx.floor(q) + sigmoid(self.rounding)
q = mx.clip(q, 0, 2**self.bits - 1)
w = (q * self._scales * 2 * sigmoid(self.v)) + self._biases
w = w.reshape(w.shape[0], -1)
if hasattr(self, "_bias"):
x = mx.addmm(self._bias, x, w.T)
else:
x = x @ w.T
return x
def to_quantized(self, group_size: int = 64, bits: int = 3):
assert (
group_size == self.group_size and bits == self.bits
), "Quantization parameters must match"
w = self._weight
output_dims, input_dims = w.shape[0], w.shape[1] * w.shape[2]
use_bias = hasattr(self, "_bias")
q = (w - self._biases) / self._scales
q = mx.floor(q) + sigmoid(self.rounding)
q = mx.clip(q, 0, 2**bits - 1)
q = q.astype(mx.uint32)
w = q * self._scales * 2 * sigmoid(self.v) + self._biases
w = w.reshape(w.shape[0], -1)
q = q.reshape(q.shape[0], -1)
bitarr = (q[..., mx.newaxis] >> mx.arange(bits, dtype=mx.uint32)) & 1
w_q = bitarr.reshape((q.shape[0], -1, 32))
w_q = (w_q << mx.arange(32, dtype=mx.uint32)).sum(axis=-1)
qlayer = nn.QuantizedLinear(input_dims, output_dims, use_bias, group_size, bits)
new_scales = self._scales * 2 * sigmoid(self.v)
qlayer.weight = w_q
qlayer.scales = new_scales[..., 0]
qlayer.biases = self._biases[..., 0]
if use_bias:
qlayer.bias = self._bias
return qlayer
def round_block(
block: nn.Module,
inputs: mx.array,
outputs: mx.array,
group_size: int = 64,
bits: int = 3,
layer_kwargs: dict | None = None,
batch_size: int = 4,
):
layer_kwargs = layer_kwargs or {}
block.freeze()
leaves = block.leaf_modules()
rounded = tree_map(
lambda m: RoundQuant(m, group_size, bits) if isinstance(m, nn.Linear) else m,
leaves,
is_leaf=nn.Module.is_module,
)
block.update_modules(rounded)
def hard_round(module, threshold: float = 0):
if not isinstance(module, RoundQuant):
return module
score = mx.abs(sigmoid(module.rounding) - 0.5)
value = mx.array(np.quantile(score.astype(mx.float32), q=threshold))
rounding = mx.where(
sigmoid(module.rounding) > value + 0.5, float("inf"), module.rounding
)
module.rounding = mx.where(
sigmoid(module.rounding) <= 0.5 - value, -float("inf"), rounding
)
return module
for threshold in ROUNDING_THRESHOLDS:
print("threshold", threshold)
optimizer = optim.Adam(learning_rate=1e-3)
tree_map(
lambda m: hard_round(m, threshold),
block.leaf_modules(),
is_leaf=nn.Module.is_module,
)
def loss(block, inputs, outputs):
outputs_q = run_layer(block, inputs, **layer_kwargs)
return mse(outputs, outputs_q).mean()
loss_value_and_grad = nn.value_and_grad(block, loss)
for i in range(0, inputs.shape[0], batch_size):
lvalue, grad = loss_value_and_grad(
block, inputs[i : i + batch_size], outputs[i : i + batch_size]
)
if mx.distributed.is_available():
grad = average_gradients(grad)
optimizer.update(block, grad)
mx.eval(block.parameters(), optimizer.state, lvalue)
print(lvalue)
tree_map(hard_round, block.leaf_modules(), is_leaf=nn.Module.is_module)
def awq_quantize(
model,
inputs: mx.array,
group_size: int = 64,
bits: int = 3,
embed_group_size: int = 32,
embed_bits: int = 4,
):
group = mx.distributed.init() if mx.distributed.is_available() else None
def quantize_func(w):
wq = mx.quantize(w, bits=bits, group_size=group_size)
return mx.dequantize(*wq, bits=bits, group_size=group_size)
mask = create_attention_mask(inputs)
model.model.embed_tokens = model.model.embed_tokens.to_quantized(
group_size=embed_group_size, bits=embed_bits
)
inputs = model.model.embed_tokens(inputs)
input_feat = {}
def capture(path, module):
if not isinstance(module, nn.Linear):
return module
class Catcher(nn.Module):
def __call__(self, x: mx.array):
name = path.split(".")[-1]
input_feat[name] = x
return module(x)
return Catcher()
for i, layer in enumerate(model.model.layers):
import time
s = time.time()
print(f"Starting block {i}")
# capture the inputs to each layer
orig_leaves = layer.leaf_modules()
capture_leaves = tree_map_with_path(
capture, orig_leaves, is_leaf=nn.Module.is_module
)
layer.update_modules(capture_leaves)
outputs = run_layer(layer, inputs, mask=mask)
layer.update_modules(orig_leaves)
del capture_leaves
nn.quantize(layer, group_size=group_size, bits=bits)
outputs_q = run_layer(layer, inputs, mask=mask)
loss = mse(outputs, outputs_q).sum()
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= outputs.size
print("Before Loss", loss, flush=True)
layer.update_modules(orig_leaves)
del orig_leaves
print("Scaling block", flush=True)
scale_block(
block=layer,
input_feat=input_feat,
quantize_func=quantize_func,
layer_kwargs={"mask": mask},
)
print("Clipping block", flush=True)
clip_block(
block=layer,
input_feat=input_feat,
quantize_func=quantize_func,
group_size=group_size,
)
print("Rounding block")
round_block(
block=layer,
inputs=inputs,
outputs=outputs,
group_size=group_size,
bits=bits,
layer_kwargs={"mask": mask},
)
nn.quantize(layer, group_size=group_size, bits=bits)
outputs_q = run_layer(layer, inputs, mask=mask)
loss = mse(outputs, outputs_q).sum()
if group is not None:
loss = mx.distributed.all_sum(loss, stream=mx.cpu) / group.size()
loss /= outputs.size
print("After Loss", loss, flush=True)
input_feat = {}
inputs = outputs
mx.eval(layer)
mx.metal.clear_cache()
e = time.time()
print("Loop time: ", e - s)
if hasattr(model, "lm_head"):
model.lm_head = model.lm_head.to_quantized(
group_size=embed_group_size, bits=embed_bits
)
def load_wikitext(
tokenizer, num_samples: int = 32, sequence_length: int = 2048, split: str = "train"
) -> mx.array:
dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split)
texts = "\n\n".join(dataset["text"])
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
# Select random chunks
starts = mx.random.randint(
0, len(tokens) - sequence_length - 1, shape=(num_samples, 1)
)
data = tokens[starts + mx.arange(sequence_length)]
if tokenizer.bos_token_id:
data = mx.concatenate(
[mx.full((*data.shape[:2], 1), tokenizer.bos_token_id), data], axis=-1
)
return data
def save_model(
model: nn.Module,
tokenizer: TokenizerWrapper,
config,
model_path: Path,
mlx_path: str,
):
weights = dict(tree_flatten(model.parameters()))
mlx_path = Path(mlx_path)
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
config["quantization"] = {"group_size": 64, "bits": 4}
def update_config(path, module):
if hasattr(module, "bits"):
config["quantization"][path] = {
"group_size": module.group_size,
"bits": module.bits,
}
else:
config["quantization"][path] = False
tree_map_with_path(update_config, model.leaf_modules(), is_leaf=nn.Module.is_module)
save_config(config, config_path=mlx_path / "config.json")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", "-m", default="mlx-community/Qwen2.5-7B-Instruct-bf16"
)
parser.add_argument("--mlx-path", default="mlx_model")
parser.add_argument("--bits", type=int, default=3)
parser.add_argument("--group-size", type=int, default=64)
parser.add_argument("--num-samples", type=int, default=32)
parser.add_argument("--sequence-length", type=int, default=2048)
parser.add_argument("--seed", type=int, default=123)
args = parser.parse_args()
group = mx.distributed.init() if mx.distributed.is_available() else None
num_samples = args.num_samples
if group is not None and num_samples % group.size() > 0:
num_samples += group.size() - num_samples % group.size()
mx.random.seed(args.seed)
model_path = get_model_path(args.model, revision=None)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
calibration_data = load_wikitext(tokenizer, args.num_samples, args.sequence_length)
if group is not None:
calibration_data = dist_split(calibration_data, group)
awq_quantize(model, calibration_data, bits=args.bits, group_size=args.group_size)
save_model(model, tokenizer, config, model_path, args.mlx_path)
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import codecs
import json
import sys
@@ -189,8 +188,8 @@ def main():
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
prompt = codecs.decode(args.prompt, "unicode_escape")
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
@@ -199,12 +198,7 @@ def main():
messages = [{"role": "system", "content": args.system_prompt}]
else:
messages = []
messages.append(
{
"role": "user",
"content": sys.stdin.read() if prompt == "-" else prompt,
}
)
messages.append({"role": "user", "content": prompt})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

View File

@@ -23,7 +23,12 @@ class BaseModelArgs:
)
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
def create_causal_mask(
N: int,
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
@@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9

View File

@@ -155,11 +155,13 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -180,9 +182,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@@ -6,7 +6,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@@ -151,16 +151,13 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -181,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@@ -197,11 +197,13 @@ class DBRX(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@@ -223,9 +225,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@@ -211,9 +211,11 @@ class DeepseekModel(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -236,8 +238,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -370,9 +370,12 @@ class DeepseekV2Model(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -395,8 +398,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -123,10 +123,12 @@ class ExaoneModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -149,9 +151,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@@ -138,12 +138,14 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -164,9 +166,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
return out

View File

@@ -160,12 +160,14 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -187,9 +189,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping

View File

@@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
@@ -138,7 +139,8 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -159,9 +161,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.wte.as_linear(out)
return out

View File

@@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
B, L = inputs.shape
@@ -149,7 +150,8 @@ class GPTBigCodeModel(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -172,9 +174,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@@ -146,13 +146,15 @@ class GPTNeoXModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
mask = create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return out
def sanitize(self, weights):

View File

@@ -239,11 +239,13 @@ class HunYuanModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -266,9 +268,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):

View File

@@ -193,11 +193,13 @@ class InternLM2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -220,9 +222,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.tok_embeddings.as_linear(out)
else:

View File

@@ -155,11 +155,13 @@ class LlamaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -182,9 +184,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -158,11 +158,13 @@ class MiniCPMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -186,9 +188,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))

View File

@@ -162,11 +162,13 @@ class MixtralModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -188,9 +190,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -176,11 +176,13 @@ class NemotronModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -203,9 +205,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -124,11 +124,13 @@ class Transformer(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@@ -152,9 +154,10 @@ class OlmoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.transformer(inputs, cache)
return self.transformer(inputs, mask, cache)
class Model(nn.Module):
@@ -167,9 +170,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.model(inputs, cache)
return self.model(inputs, mask, cache)
@property
def layers(self):

View File

@@ -163,10 +163,12 @@ class LlamaModel(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -190,8 +192,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -178,11 +178,13 @@ class OpenELMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.token_embeddings(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -205,9 +207,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:

View File

@@ -143,10 +143,11 @@ class PhiModel(nn.Module):
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache):
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -167,9 +168,10 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.model(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property

View File

@@ -168,11 +168,13 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -194,9 +196,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@@ -258,13 +258,15 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -290,9 +292,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier:
out = out / self.mup_width_multiplier

View File

@@ -155,11 +155,13 @@ class PhiMoEModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -181,9 +183,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -175,7 +175,9 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)

View File

@@ -174,10 +174,12 @@ class PlamoModel(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
@@ -202,8 +204,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
@property

View File

@@ -123,7 +123,8 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@@ -149,11 +149,13 @@ class Qwen2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -187,11 +187,13 @@ class Qwen2MoeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -213,9 +215,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -389,6 +389,7 @@ class Griffin(nn.Module):
def __call__(
self,
tokens,
mask: mx.array = None,
cache=None,
):
x = self.embed_tokens(tokens)
@@ -402,7 +403,8 @@ class Griffin(nn.Module):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
if mask is None:
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@@ -418,12 +420,12 @@ class Model(nn.Module):
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, cache=cache)
logits = self.model(tokens, mask=mask, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:

View File

@@ -199,7 +199,10 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)

View File

@@ -125,11 +125,13 @@ class Starcoder2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -152,9 +154,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -12,6 +12,7 @@ def make_sampler(
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
top_k: int = -1,
) -> Callable[mx.array, mx.array]:
"""
Make a sampler function for use with ``generate_step``.
@@ -25,6 +26,8 @@ def make_sampler(
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
top_k (int, optional): The top k tokens ranked by probability to constrain
the sampling to.
Returns:
Callable[mx.array, mx.array]:
@@ -36,6 +39,8 @@ def make_sampler(
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
elif top_k > 0:
return lambda x: top_k_sampling(x, top_k, temp)
else:
return lambda x: categorical_sampling(x, temp)
@@ -79,6 +84,33 @@ def make_logits_processors(
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_k_sampling(
logprobs: mx.array,
top_k: int,
temperature=1.0,
) -> mx.array:
"""
Sample from only the top K tokens ranked by probability.
Args:
logprobs: A vector of log probabilities.
top_k (int): Top k tokens to sample from.
"""
vocab_size = logprobs.shape[-1]
if not isinstance(top_k, int) or not (0 < top_k < vocab_size):
raise ValueError(
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
f" but is {top_k}."
)
logprobs = logprobs * (1 / temperature)
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
masked_logprobs = mx.put_along_axis(
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
)
return mx.random.categorical(masked_logprobs, axis=-1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logprobs: mx.array,
@@ -87,7 +119,7 @@ def min_p_sampling(
temperature=1.0,
) -> mx.array:
"""
Apply min-p sampling to the logits.
Apply min-p sampling to the logprobs.
Min-p keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter is more

View File

@@ -32,6 +32,7 @@ setup(
},
entry_points={
"console_scripts": [
"mlx_lm.awq = mlx_lm.awq:main",
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.chat = mlx_lm.chat:main",
"mlx_lm.convert = mlx_lm.convert:main",

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map
from mlx_lm.models import rope_utils
from mlx_lm.models.base import create_causal_mask
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
@@ -128,6 +129,22 @@ class TestModels(unittest.TestCase):
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def test_causal_mask_lengths(self):
mx.random.seed(8)
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
lengths = mx.array([1, 2, 3, 1])
q = mx.random.uniform(shape=(B, N_q, T_q, D))
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
v = k
mask = create_causal_mask(T_q, 0, lengths=lengths)
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
def test_rope(self):
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
self.assertTrue(isinstance(rope, nn.RoPE))
@@ -162,10 +179,16 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.dtype, t)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
outputs = model(inputs, cache=cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
if model_type != "mamba":
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
outputs = model(inputs, mask=mask)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)

View File

@@ -1,7 +1,7 @@
import unittest
import mlx.core as mx
from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling
class TestSampleUtils(unittest.TestCase):
@@ -42,6 +42,27 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
token = top_k_sampling(logits, 1).item()
self.assertEqual(token, 0)
probs = mx.array([0.5, 0.0, 0.0, 0.5])[None]
tokens = set()
for _ in range(100):
token = top_k_sampling(logits, 2)
tokens.add(token.item())
self.assertEqual(tokens, {0, 3})
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_k_sampling(logits, 1)
self.assertEqual(tokens.tolist(), [0, 1])
if __name__ == "__main__":
unittest.main()